使用 PYG 库构建 GCN 网络进行多标签分类任务

本代码使用 PYG 库构建一个 GCN 网络,用于对具有 42 个图、37 个节点和 8 个标签的图数据集进行多标签分类。每个节点具有 40x40 像素的 RGB 图像作为特征,并使用 MultiLabelSoftMarginLoss 作为损失函数。

数据描述:

  • num_graphs = 42:图的数量。

  • num_nodes = 37:每个图的节点数量。

  • image_size = 40:每个节点图像的大小。

  • num_labels = 8:每个节点的标签数量。

  • num_edges = 61:图中的边数量。

  • 节点特征文件是 'C:\Users\jh\Desktop\data\input\images\i.png_j.png' 的所有图片的 RGB 像素值,其中 i 表示图序号,i 从 1 到 42,j 表示节点序号,j 从 0 到 36。

  • 每个节点有 8 个标签,储存在 'C:\Users\jh\Desktop\data\input\labels\i_j.txt' 文本文件中,标签用空格隔开。

  • 边的关系储存在 'C:\Users\jh\Desktop\data\input\edges_L.csv' csv 文件中,表格中没有 header,第一列为源节点,第二列为目标节点。

任务:

建立一个 CNN 网络对节点像素特征 x 进行降维,将每个图的前 30 个节点颜色特征加入训练掩码,后 7 个节点颜色特征加入验证掩码。用 PYG 库建立 GCN 网络实现多标签分类任务,损失函数由 torch.nn 模块中的 MultiLabelSoftMarginLoss 来实现。

完整代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
import os
import numpy as np
import csv
from PIL import Image

class GCN(nn.Module):
    def __init__(self, num_features, hidden_dim, num_labels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, num_labels)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return torch.sigmoid(x)

def load_node_features(image_path, num_nodes):
    node_features = []
    for i in range(1, num_graphs + 1):
        graph_features = []
        for j in range(num_nodes):
            node_image_path = image_path.replace('i', str(i)).replace('j', str(j))
            node_image = Image.open(node_image_path).convert('RGB')
            node_feature = np.array(node_image).flatten()
            graph_features.append(node_feature)
        node_features.append(graph_features)
    return torch.tensor(node_features, dtype=torch.float)

def load_labels(label_path, num_nodes, num_labels):
    labels = []
    for i in range(1, num_graphs + 1):
        graph_labels = []
        label_file_path = label_path.replace('i', str(i))
        with open(label_file_path, 'r') as file:
            lines = file.readlines()
            for line in lines:
                label = [int(l) for l in line.strip().split(' ')]
                graph_labels.append(label)
        labels.append(graph_labels)
    return torch.tensor(labels, dtype=torch.float)

def load_edges(edge_file_path):
    edges = []
    with open(edge_file_path, 'r') as file:
        lines = csv.reader(file)
        for line in lines:
            edges.append([int(line[0]), int(line[1])])
    return torch.tensor(edges, dtype=torch.long).t().contiguous()

# Define paths
node_feature_path = 'C:\Users\jh\Desktop\data\input\images\i.png_j.png'
label_path = 'C:\Users\jh\Desktop\data\input\labels\i_j.txt'
edge_file_path = 'C:\Users\jh\Desktop\data\input\edges_L.csv'

# Define hyperparameters
num_graphs = 42
num_nodes = 37
image_size = 40
num_labels = 8
num_edges = 61
hidden_dim = 64
num_epochs = 10
learning_rate = 0.01

# Load data
node_features = load_node_features(node_feature_path, num_nodes)
labels = load_labels(label_path, num_nodes, num_labels)
edges = load_edges(edge_file_path)

# Create a graph dataset
data_list = []
for i in range(num_graphs):
    x = node_features[i]
    y = labels[i]
    edge_index = edges
    data = Data(x=x, y=y, edge_index=edge_index)
    data_list.append(data)

# Split data into train and validation sets
train_data = data_list[:30]
val_data = data_list[30:]

# Create data loaders
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1)

# Create the model
model = GCN(num_features=image_size*image_size*3, hidden_dim=hidden_dim, num_labels=num_labels)

# Define the loss function
criterion = nn.MultiLabelSoftMarginLoss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out, data.y)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    avg_loss = total_loss / len(train_loader)
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, avg_loss))

# Validation loop
model.eval()
with torch.no_grad():
    num_correct = 0
    num_samples = 0
    for data in val_loader:
        out = model(data.x, data.edge_index)
        predicted_labels = (out > 0.5).float()
        num_correct += (predicted_labels == data.y).sum().item()
        num_samples += data.y.size(0)
    accuracy = num_correct / num_samples
    print('Validation Accuracy: {:.2f}%'.format(accuracy * 100))

请确保将 image_pathlabel_pathedge_file_path 替换为正确的路径。此代码会加载节点特征、标签和边,创建 GCN 模型,并在训练集上训练模型,并在验证集上评估模型性能。

注意:

  • 代码中的 load_node_features 函数将节点的图像特征加载到 PyTorch 张量中。
  • load_labels 函数将每个节点的标签加载到 PyTorch 张量中。
  • load_edges 函数将图的边信息加载到 PyTorch 张量中。
  • GCN 类定义了 GCN 网络的结构。
  • 代码使用 MultiLabelSoftMarginLoss 作为损失函数,因为它适用于多标签分类任务。
  • 代码使用 torch.optim.Adam 作为优化器。
  • 代码在训练集上训练模型,并在验证集上评估模型性能。
使用 PYG 库构建 GCN 网络进行多标签分类任务

原文地址: https://www.cveoy.top/t/topic/pguh 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录