使用PYG库建立GCN网络实现多标签分类任务
使用PYG库建立GCN网络实现多标签分类任务
本文档提供使用PYG库建立GCN网络实现多标签分类任务的完整代码示例。
数据说明
已知:
- num_graphs = 42
- num_nodes = 37
- image_size = 40
- num_labels = 8
- num_edges = 61
节点特征文件是'C:\Users\jh\Desktop\data\input\images{i}.png_{j}.png'的所有图片的像素值,每个节点有8个标签,储存在'C:\Users\jh\Desktop\data\input\labels{i}{j}.txt'文本文件中,标签用空格隔开,例如某个节点的标签向量为: 2 2 1 1 3 1 2 1,5_21.txt的标签向量为1 3 4 1 3 1 1 3。
真实标签值只有0、1、2、3、4五个类别,但是每个节点的标签是一个8维的标签向量。
边的关系储存在'C:\Users\jh\Desktop\data\input\edges_L.csv'csv文件中,表格中没有header,第一列为源节点,第二列为目标节点,共有61条无向边。
任务目标
要求输出每个节点的预测特征向量,并根据这些预测特征得到预测标签向量,使预测标签向量与真实标签向量一致。每个节点的预测标签都是一个8维向量,而不是输出概率向量。
代码示例
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx
# Define GCN model
class GCN(torch.nn.Module):
def __init__(self, num_features, num_labels, hidden_channels):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, num_labels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = self.conv3(x, edge_index)
return x
# Load dataset
dataset = Planetoid(root='C:\Users\jh\Desktop\data', name='dataset_name')
# Split dataset into training and validation sets
train_data = dataset[:30]
val_data = dataset[30:]
# Define data loader
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1, shuffle=False)
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(num_features=dataset.num_features, num_labels=dataset.num_classes, hidden_channels=16).to(device)
# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
# Training loop
def train():
model.train()
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# Evaluation function
def evaluate(loader):
model.eval()
correct = 0
total = 0
for data in loader:
data = data.to(device)
with torch.no_grad():
out = model(data.x, data.edge_index)
pred = (out > 0).float() # Convert logits to binary predictions
correct += (pred[data.val_mask] == data.y[data.val_mask]).sum().item()
total += data.val_mask.sum().item()
return correct / total
# Train and evaluate model
best_val_acc = 0
for epoch in range(1, 101):
train()
train_acc = evaluate(train_loader)
val_acc = evaluate(val_loader)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pt')
print(f'Epoch: {epoch}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')
# Load best model and evaluate on test set
model.load_state_dict(torch.load('best_model.pt'))
test_data = dataset[30:]
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
test_acc = evaluate(test_loader)
print(f'Test Acc: {test_acc:.4f}')
注意事项
- 以上代码是一个基本的GCN多标签分类任务的示例,您可能需要根据您的数据和任务进行适当的修改和调整。
- 您需要将代码中的'dataset_name'替换成您的数据集名称。
- 代码中使用了'Planetoid'数据集,您需要根据实际情况修改数据加载部分。
- 代码中的'train_mask'和'val_mask'是用来区分训练集和验证集的掩码,您可以根据实际情况修改这些掩码的定义。
- 代码中使用了'BCEWithLogitsLoss'作为损失函数,您可能需要根据实际情况修改损失函数。
- 代码中使用了'Adam'优化器,您可能需要根据实际情况修改优化器和学习率。
总结
本示例演示了如何使用PYG库建立GCN网络实现多标签分类任务,包含数据加载、模型构建、训练、评估和测试的完整代码。希望这份示例能帮助您更好地理解GCN网络,并为您的多标签分类任务提供一些参考。
原文地址: https://www.cveoy.top/t/topic/plBH 著作权归作者所有。请勿转载和采集!