使用PYG库建立GCN网络实现多标签分类任务

本文档提供使用PYG库建立GCN网络实现多标签分类任务的完整代码示例。

数据说明

已知:

  • num_graphs = 42
  • num_nodes = 37
  • image_size = 40
  • num_labels = 8
  • num_edges = 61

节点特征文件是'C:\Users\jh\Desktop\data\input\images{i}.png_{j}.png'的所有图片的像素值,每个节点有8个标签,储存在'C:\Users\jh\Desktop\data\input\labels{i}{j}.txt'文本文件中,标签用空格隔开,例如某个节点的标签向量为: 2 2 1 1 3 1 2 1,5_21.txt的标签向量为1 3 4 1 3 1 1 3。

真实标签值只有0、1、2、3、4五个类别,但是每个节点的标签是一个8维的标签向量。

边的关系储存在'C:\Users\jh\Desktop\data\input\edges_L.csv'csv文件中,表格中没有header,第一列为源节点,第二列为目标节点,共有61条无向边。

任务目标

要求输出每个节点的预测特征向量,并根据这些预测特征得到预测标签向量,使预测标签向量与真实标签向量一致。每个节点的预测标签都是一个8维向量,而不是输出概率向量。

代码示例

import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx

# Define GCN model
class GCN(torch.nn.Module):
    def __init__(self, num_features, num_labels, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, num_labels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        return x

# Load dataset
dataset = Planetoid(root='C:\Users\jh\Desktop\data', name='dataset_name')

# Split dataset into training and validation sets
train_data = dataset[:30]
val_data = dataset[30:]

# Define data loader
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1, shuffle=False)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(num_features=dataset.num_features, num_labels=dataset.num_classes, hidden_channels=16).to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

# Training loop
def train():
    model.train()
    for data in train_loader:
        data = data.to(device)
        
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

# Evaluation function
def evaluate(loader):
    model.eval()
    correct = 0
    total = 0
    for data in loader:
        data = data.to(device)
        
        with torch.no_grad():
            out = model(data.x, data.edge_index)
            pred = (out > 0).float()  # Convert logits to binary predictions
            
        correct += (pred[data.val_mask] == data.y[data.val_mask]).sum().item()
        total += data.val_mask.sum().item()
    
    return correct / total

# Train and evaluate model
best_val_acc = 0
for epoch in range(1, 101):
    train()
    train_acc = evaluate(train_loader)
    val_acc = evaluate(val_loader)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pt')
    
    print(f'Epoch: {epoch}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')

# Load best model and evaluate on test set
model.load_state_dict(torch.load('best_model.pt'))
test_data = dataset[30:]
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
test_acc = evaluate(test_loader)
print(f'Test Acc: {test_acc:.4f}')

注意事项

  • 以上代码是一个基本的GCN多标签分类任务的示例,您可能需要根据您的数据和任务进行适当的修改和调整。
  • 您需要将代码中的'dataset_name'替换成您的数据集名称。
  • 代码中使用了'Planetoid'数据集,您需要根据实际情况修改数据加载部分。
  • 代码中的'train_mask'和'val_mask'是用来区分训练集和验证集的掩码,您可以根据实际情况修改这些掩码的定义。
  • 代码中使用了'BCEWithLogitsLoss'作为损失函数,您可能需要根据实际情况修改损失函数。
  • 代码中使用了'Adam'优化器,您可能需要根据实际情况修改优化器和学习率。

总结

本示例演示了如何使用PYG库建立GCN网络实现多标签分类任务,包含数据加载、模型构建、训练、评估和测试的完整代码。希望这份示例能帮助您更好地理解GCN网络,并为您的多标签分类任务提供一些参考。

使用PYG库建立GCN网络实现多标签分类任务

原文地址: https://www.cveoy.top/t/topic/plBH 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录