基于PYG库的GCN多标签分类:节点特征降维与验证集构建

本文介绍使用PyTorch Geometric (PYG) 库构建图卷积神经网络 (GCN) 进行多标签分类任务。我们将使用CNN对节点像素特征进行降维,并利用PYG库建立GCN网络。文章将提供完整的代码示例,包括数据加载、模型构建、训练和测试等环节,并解释如何构建验证集以及如何使用MultiLabelSoftMarginLoss作为损失函数。

数据说明:

  • 一共有42个时刻的图,而且边的连接关系相同。
  • 每个图都有37个节点,节点特征文件是'C:\Users\jh\Desktop\data\input\images\i.png_j.png'的所有图片的RGB像素值,其中i表示图,i从1到42,j表示节点,j从0到36。
  • 特征图片的尺寸为40 x 40。
  • 每个节点有8个标签,储存在'C:\Users\jh\Desktop\data\input\labels\i_j.txt'文本文件中,标签用空格隔开。
  • 边的关系储存在'C:\Users\jh\Desktop\data\input\edges_L.csv'csv文件中,表格中没有header,第一列为源节点,第二列为目标节点,共有61条无向边。

任务目标:

  • 建立一个CNN网络对节点像素特征x进行降维。
  • 将后面3个图的某些节点作为验证集。
  • 使用PYG库建立GCN网络实现多标签分类任务。
  • 损失函数由torch.nn模块中的MultiLabelSoftMarginLoss来实现。

代码示例:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import train_test_split_edges
from torch_geometric.loader import DataLoader
import pandas as pd
from PIL import Image
from torchvision.transforms import ToTensor

# 定义CNN网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32*10*10, 128)
        self.fc2 = nn.Linear(128, 8)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义GCN网络
class GCN(nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# 加载节点特征
def load_node_features():
    node_features = []
    for i in range(1, 43):
        image_path = f'C:\Users\jh\Desktop\data\input\images\{i}.png'
        img = Image.open(image_path)
        img_tensor = ToTensor()(img)
        node_features.append(img_tensor)

    return torch.stack(node_features)

# 加载节点标签
def load_node_labels():
    node_labels = []
    for i in range(1, 43):
        for j in range(37):
            label_path = f'C:\Users\jh\Desktop\data\input\labels\{i}_{j}.txt'
            with open(label_path, 'r') as f:
                labels = f.read().split()
                labels = [int(label) for label in labels]
                node_labels.append(labels)

    return torch.tensor(node_labels)

# 加载边的关系
def load_edges():
    edges = pd.read_csv('C:\Users\jh\Desktop\data\input\edges_L.csv', header=None)
    edge_index = torch.tensor(edges.values, dtype=torch.long).t().contiguous()

    return edge_index

# 创建数据集
def create_dataset():
    node_features = load_node_features()
    node_labels = load_node_labels()
    edge_index = load_edges()
    data = Data(x=node_features, y=node_labels, edge_index=edge_index)

    return data

# 创建模型
def create_model():
    model = GCN(num_features=8, hidden_channels=16, num_classes=8)
    return model

# 训练模型
def train_model(model, data, train_edges):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.MultiLabelSoftMarginLoss()

    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[train_edges], data.y[train_edges])
    loss.backward()
    optimizer.step()

    return loss.item()

# 测试模型
def test_model(model, data, test_edges):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out[test_edges].round()
    correct = pred.eq(data.y[test_edges]).sum().item()
    return correct / test_edges.numel()

# 主函数
def main():
    data = create_dataset()
    train_edges, test_edges = train_test_split_edges(data.edge_index)

    model = create_model()
    for epoch in range(100):
        loss = train_model(model, data, train_edges)
        accuracy = test_model(model, data, test_edges)
        print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')

if __name__ == '__main__':
    main()

代码说明:

  • 代码首先定义了CNN和GCN两个网络结构,分别用于节点特征降维和图卷积操作。
  • 接着定义了三个函数用于加载节点特征、节点标签和边的关系。
  • create_dataset() 函数将所有数据整合为一个Data对象。
  • create_model() 函数创建GCN模型。
  • train_model() 函数进行模型训练,使用MultiLabelSoftMarginLoss 作为损失函数。
  • test_model() 函数对模型进行评估。
  • 最后,main() 函数进行模型训练和测试。

注意:

  • 代码中的路径需要根据实际情况进行修改。
  • 代码中使用的CNN网络结构仅为示例,可以根据具体需求进行调整。
  • 代码中使用的超参数也需要根据实际情况进行调整。

总结:

本文提供了一个使用PYG库进行GCN多标签分类任务的完整示例。通过CNN对节点特征进行降维,并利用PYG库建立GCN网络,我们可以有效地对图数据进行分析和预测。该示例代码可以作为基础框架,并根据具体需求进行修改和扩展。

基于PYG库的GCN多标签分类:节点特征降维与验证集构建

原文地址: https://www.cveoy.top/t/topic/pfZb 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录