使用 PYG 库构建 GCN 网络实现多标签分类任务

本代码演示了如何使用 PYG 库构建 GCN 网络,实现多标签分类任务,根据节点特征和边关系预测每个节点的标签向量。

已知条件:

  • 图的数量:num_graphs = 42
  • 每个图的节点数量:num_nodes = 37
  • 图片大小:image_size = 40
  • 标签数量:num_labels = 8
  • 边数量:num_edges = 61
  • 节点特征文件:'C:\Users\jh\Desktop\data\input\images\{i}.png_{j}.png' (所有图片的像素值,每个节点有 8 个特征)
  • 节点标签文件:'C:\Users\jh\Desktop\data\input\labels{i}{j}.txt' (每个节点有 8 个标签,用空格隔开)
  • 标签类别:0, 1, 2, 3, 4
  • 边关系文件:'C:\Users\jh\Desktop\data\input\edges_L.csv' (无向边,表格中没有 header,第一列为源节点,第二列为目标节点)

目标:

  • 输出每个节点的预测特征向量
  • 根据预测特征向量得到预测标签向量,使预测标签向量与真实标签向量一致
  • 每个节点的预测标签是一个 8 维向量,而不是输出概率向量
  • 将每个图的前 30 个节点颜色特征加入训练掩码,后 7 个节点颜色特征加入验证掩码

代码实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv

from PIL import Image
import numpy as np
import pandas as pd

# 定义 GCN 模型
class GCN(nn.Module):
    def __init__(self, num_features, hidden_size, num_labels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_size)
        self.conv2 = GCNConv(hidden_size, num_labels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

# 加载节点特征
node_features = []
for i in range(1, num_graphs + 1):
    for j in range(1, num_nodes + 1):
        image_path = f"C:\Users\jh\Desktop\data\input\images\{i}.png_{j}.png"
        image = Image.open(image_path)
        image_data = np.array(image)
        node_features.append(image_data.flatten())

node_features = torch.tensor(node_features, dtype=torch.float)

# 加载标签
labels = []
for i in range(1, num_graphs + 1):
    for j in range(1, num_nodes + 1):
        label_path = f"C:\Users\jh\Desktop\data\input\labels{i}{j}.txt"
        with open(label_path, 'r') as f:
            label = f.read().split()
            label = [int(x) for x in label]
            labels.append(label)

labels = torch.tensor(labels, dtype=torch.float)

# 加载边关系
edges = pd.read_csv("C:\Users\jh\Desktop\data\input\edges_L.csv", header=None)
edge_index = torch.tensor(edges.values.T, dtype=torch.long)

# 准备训练和验证掩码
train_mask = torch.zeros(num_nodes * num_graphs, dtype=torch.bool)
train_mask[:30 * num_graphs] = 1
val_mask = ~train_mask

# 创建数据集和数据加载器
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, node_features, labels, edge_index):
        self.node_features = node_features
        self.labels = labels
        self.edge_index = edge_index

    def __getitem__(self, idx):
        x = self.node_features[idx]
        y = self.labels[idx]
        edge_index = self.edge_index
        return x, y, edge_index

    def __len__(self):
        return len(self.node_features)

dataset = CustomDataset(node_features, labels, edge_index)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# 定义模型、优化器和损失函数
model = GCN(num_features=node_features.shape[1], hidden_size=16, num_labels=8)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()

# 训练循环
for epoch in range(100):
    model.train()

    for data in dataloader:
        x, y, edge_index = data
        optimizer.zero_grad()
        out = model(x, edge_index)
        loss = criterion(out[train_mask], y[train_mask])
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        out = model(node_features, edge_index)
        pred = out[val_mask].round()
        acc = (pred == labels[val_mask]).sum().item() / val_mask.sum().item()
        print(f"Epoch: {epoch + 1}, Validation Accuracy: {acc:.4f}")

注意:

  • 运行代码之前,请确保已安装以下库:torch, torch_geometric, PIL, numpy, pandas
  • 请根据实际情况修改文件路径和其他超参数,以适应您的数据集和任务要求

代码功能:

  1. 加载节点特征和标签数据
  2. 使用 GCN 模型构建图神经网络
  3. 使用 DataLoader 构建数据加载器
  4. 使用 Adam 优化器和 BCEWithLogitsLoss 损失函数进行训练
  5. 计算验证集的准确率

结果:

代码将输出每个 epoch 的验证集准确率,并根据训练结果得到每个节点的预测标签向量。

扩展:

  • 可以尝试使用不同的 GCN 模型,例如 GATSAGE
  • 可以尝试使用不同的优化器和损失函数
  • 可以尝试使用不同的训练策略,例如 early stopping 或 dropout
  • 可以尝试对预测结果进行可视化,以更好地理解模型的预测能力
使用 PYG 库构建 GCN 网络实现多标签分类任务

原文地址: https://www.cveoy.top/t/topic/plzD 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录