使用PYG库建立GCN网络实现多标签分类任务

本文介绍使用PyTorch Geometric库建立GCN网络,实现多标签分类任务,并利用CNN对节点像素特征进行降维。数据包含42个时刻的图,每个图有37个节点,节点特征为40x40像素的RGB值,每个节点有8个标签。

数据说明:

  • 共有42个时刻的图,边的连接关系相同。
  • 每个图有37个节点。
  • 节点特征文件为'C:/Users/jh/Desktop/data/input/images/i.png_j.png',其中i表示图 (1到42),j表示节点 (0到36),特征为RGB像素值,尺寸为40 x 40。
  • 每个节点有8个标签,储存在'C:/Users/jh/Desktop/data/input/labels/i_j.txt'文本文件中,标签用空格隔开。
  • 边的关系储存在'C:/Users/jh/Desktop/data/input/edges_L.csv' csv文件中,表格中没有header,第一列为源节点,第二列为目标节点,共有61条无向边。

任务:

  • 建立一个CNN网络对节点像素特征x进行降维。
  • 使用前38个图作为训练集,剩余4个图的一部分节点作为测试集。
  • 使用PYG库建立GCN网络实现多标签分类任务,预测测试集节点的标签。
  • 损失函数使用torch.nn模块中的MultiLabelSoftMarginLoss。

代码实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader

# 定义GCN模型
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

# 定义数据集
class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        return self.data_list[idx]

# 加载数据
data_list = []
for i in range(1, 39):
    for j in range(37):
        # 加载节点特征
        feature_path = f"C:/Users/jh/Desktop/data/input/images/{i}.png_{j}.png"
        features = torch.Tensor(load_features(feature_path))  # 实现加载特征的函数
        
        # 加载节点标签
        label_path = f"C:/Users/jh/Desktop/data/input/labels/{i}_{j}.txt"
        with open(label_path, 'r') as f:
            labels = [int(label) for label in f.read().split()]
        labels = torch.Tensor(labels)
        
        # 创建图数据对象
        edge_index = load_edges("C:/Users/jh/Desktop/data/input/edges_L.csv")  # 实现加载边的函数
        data = Data(x=features, y=labels, edge_index=edge_index)
        data_list.append(data)

# 将数据集划分为训练集和测试集
train_dataset = GraphDataset(data_list[:38*37])
test_dataset = GraphDataset(data_list[38*37:])

# 定义数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 初始化模型
model = GCN(input_dim=40*40, hidden_dim=64, output_dim=8)

# 定义损失函数
criterion = nn.MultiLabelSoftMarginLoss()

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练模型
model.train()
for epoch in range(10):
    running_loss = 0.0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x.view(-1, 40*40), batch.edge_index)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {running_loss}")

# 测试模型
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for batch in test_loader:
        out = model(batch.x.view(-1, 40*40), batch.edge_index)
        predicted_labels = torch.round(torch.sigmoid(out))
        total += batch.y.size(0)
        correct += (predicted_labels == batch.y).sum().item()
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy}")

注意:

  • 代码中的load_features函数和load_edges函数需要根据你的实际情况进行实现,以加载节点特征和边的关系。
  • 你可能需要调整模型的超参数和优化器的学习率等,以获得更好的结果。

代码解析:

  1. 定义GCN模型: 使用GCNConv层构建GCN模型,模型包含两层卷积层,第一层将节点特征从40*40降维到64维,第二层将64维特征映射到8个标签的输出。
  2. 定义数据集: 定义GraphDataset类,用于加载每个图的节点特征、标签和边关系,并将其封装成Data对象。
  3. 加载数据: 循环遍历每个图,加载节点特征和标签,并使用load_edges函数加载边关系,创建Data对象并添加到data_list中。
  4. 划分数据集:data_list划分为训练集和测试集。
  5. 定义数据加载器: 使用DataLoader将训练集和测试集封装成数据加载器,以便在训练和测试时方便地获取数据。
  6. 初始化模型、损失函数和优化器: 初始化GCN模型,使用MultiLabelSoftMarginLoss作为损失函数,使用Adam优化器进行优化。
  7. 训练模型: 使用训练集数据进行训练,并在每个epoch后打印损失值。
  8. 测试模型: 使用测试集数据进行测试,计算模型的准确率。

总结:

本文介绍了如何使用PYG库建立GCN网络,实现多标签分类任务,并利用CNN对节点像素特征进行降维。代码实现了数据集的加载、模型的构建、训练和测试,并给出了完整的代码内容。你可以根据自己的需求和数据进行调整和修改。

基于PYG库的GCN网络多标签分类任务实现:节点特征降维及预测

原文地址: https://www.cveoy.top/t/topic/pfZt 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录