使用 PYG 库构建 GCN 网络进行多标签分类任务
使用 PYG 库构建 GCN 网络进行多标签分类任务
本代码使用 PYG 库构建一个 GCN 网络,用于对具有 42 个图、37 个节点和 8 个标签的图数据集进行多标签分类。每个节点具有 40x40 像素的 RGB 图像作为特征,并使用 MultiLabelSoftMarginLoss 作为损失函数。
数据描述:
-
num_graphs = 42:图的数量。 -
num_nodes = 37:每个图的节点数量。 -
image_size = 40:每个节点图像的大小。 -
num_labels = 8:每个节点的标签数量。 -
num_edges = 61:图中的边数量。 -
节点特征文件是 'C:\Users\jh\Desktop\data\input\images\i.png_j.png' 的所有图片的 RGB 像素值,其中
i表示图序号,i从 1 到 42,j表示节点序号,j从 0 到 36。 -
每个节点有 8 个标签,储存在 'C:\Users\jh\Desktop\data\input\labels\i_j.txt' 文本文件中,标签用空格隔开。
-
边的关系储存在 'C:\Users\jh\Desktop\data\input\edges_L.csv' csv 文件中,表格中没有 header,第一列为源节点,第二列为目标节点。
任务:
建立一个 CNN 网络对节点像素特征 x 进行降维,将每个图的前 30 个节点颜色特征加入训练掩码,后 7 个节点颜色特征加入验证掩码。用 PYG 库建立 GCN 网络实现多标签分类任务,损失函数由 torch.nn 模块中的 MultiLabelSoftMarginLoss 来实现。
完整代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
import os
import numpy as np
import csv
from PIL import Image
class GCN(nn.Module):
def __init__(self, num_features, hidden_dim, num_labels):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_features, hidden_dim)
self.conv2 = GCNConv(hidden_dim, num_labels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return torch.sigmoid(x)
def load_node_features(image_path, num_nodes):
node_features = []
for i in range(1, num_graphs + 1):
graph_features = []
for j in range(num_nodes):
node_image_path = image_path.replace('i', str(i)).replace('j', str(j))
node_image = Image.open(node_image_path).convert('RGB')
node_feature = np.array(node_image).flatten()
graph_features.append(node_feature)
node_features.append(graph_features)
return torch.tensor(node_features, dtype=torch.float)
def load_labels(label_path, num_nodes, num_labels):
labels = []
for i in range(1, num_graphs + 1):
graph_labels = []
label_file_path = label_path.replace('i', str(i))
with open(label_file_path, 'r') as file:
lines = file.readlines()
for line in lines:
label = [int(l) for l in line.strip().split(' ')]
graph_labels.append(label)
labels.append(graph_labels)
return torch.tensor(labels, dtype=torch.float)
def load_edges(edge_file_path):
edges = []
with open(edge_file_path, 'r') as file:
lines = csv.reader(file)
for line in lines:
edges.append([int(line[0]), int(line[1])])
return torch.tensor(edges, dtype=torch.long).t().contiguous()
# Define paths
node_feature_path = 'C:\Users\jh\Desktop\data\input\images\i.png_j.png'
label_path = 'C:\Users\jh\Desktop\data\input\labels\i_j.txt'
edge_file_path = 'C:\Users\jh\Desktop\data\input\edges_L.csv'
# Define hyperparameters
num_graphs = 42
num_nodes = 37
image_size = 40
num_labels = 8
num_edges = 61
hidden_dim = 64
num_epochs = 10
learning_rate = 0.01
# Load data
node_features = load_node_features(node_feature_path, num_nodes)
labels = load_labels(label_path, num_nodes, num_labels)
edges = load_edges(edge_file_path)
# Create a graph dataset
data_list = []
for i in range(num_graphs):
x = node_features[i]
y = labels[i]
edge_index = edges
data = Data(x=x, y=y, edge_index=edge_index)
data_list.append(data)
# Split data into train and validation sets
train_data = data_list[:30]
val_data = data_list[30:]
# Create data loaders
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1)
# Create the model
model = GCN(num_features=image_size*image_size*3, hidden_dim=hidden_dim, num_labels=num_labels)
# Define the loss function
criterion = nn.MultiLabelSoftMarginLoss()
# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Training loop
for epoch in range(num_epochs):
model.train()
total_loss = 0
for data in train_loader:
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out, data.y)
total_loss += loss.item()
loss.backward()
optimizer.step()
avg_loss = total_loss / len(train_loader)
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, avg_loss))
# Validation loop
model.eval()
with torch.no_grad():
num_correct = 0
num_samples = 0
for data in val_loader:
out = model(data.x, data.edge_index)
predicted_labels = (out > 0.5).float()
num_correct += (predicted_labels == data.y).sum().item()
num_samples += data.y.size(0)
accuracy = num_correct / num_samples
print('Validation Accuracy: {:.2f}%'.format(accuracy * 100))
请确保将 image_path、label_path 和 edge_file_path 替换为正确的路径。此代码会加载节点特征、标签和边,创建 GCN 模型,并在训练集上训练模型,并在验证集上评估模型性能。
注意:
- 代码中的
load_node_features函数将节点的图像特征加载到 PyTorch 张量中。 load_labels函数将每个节点的标签加载到 PyTorch 张量中。load_edges函数将图的边信息加载到 PyTorch 张量中。GCN类定义了 GCN 网络的结构。- 代码使用
MultiLabelSoftMarginLoss作为损失函数,因为它适用于多标签分类任务。 - 代码使用
torch.optim.Adam作为优化器。 - 代码在训练集上训练模型,并在验证集上评估模型性能。
原文地址: https://www.cveoy.top/t/topic/pguh 著作权归作者所有。请勿转载和采集!