基于GCN的多标签图分类任务代码实现 - PyG库示例
基于GCN的多标签图分类任务代码实现 - PyG库示例
该代码示例演示了如何使用PyG库构建一个GCN网络,实现一个多标签图分类任务。
问题描述:
已知:
- num_graphs = 42
- num_nodes = 37
- image_size = 40
- num_labels = 8
- num_edges = 61
节点特征文件是'C:\Users\jh\Desktop\data\input\images{i}.png_{j}.png'的所有图片的像素值,每个节点有8个标签,储存在'C:\Users\jh\Desktop\data\input\labels{i}{j}.txt'文本文件中,标签用空格隔开,例如某个节点的标签向量为: 2 2 1 1 3 1 2 1 ,5_21.txt的标签向量为1 3 4 1 3 1 1 3,真实标签值只有0、1、2、3、4五个类别,但是每个节点的标签是一个8维的标签向量,边的关系储存在'C:\Users\jh\Desktop\data\input\edges_L.csv'csv文件中,表格中没有header,第一列为源节点,第二列为目标节点,共有61条无向边。
目标:
- 输出每个节点的预测特征向量。
- 根据这些预测特征得到预测标签向量,使预测标签向量与真实标签向量一致,每个节点的预测标签都是一个8维向量,而不是输出概率向量。
- 将每个图的前30个节点颜色特征加入训练掩码,后7个节点颜色特征加入验证掩码。
代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
# 定义GCN模型
class GCN(nn.Module):
def __init__(self, num_features, num_labels):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_features, 16)
self.conv2 = GCNConv(16, num_labels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return x
# 加载节点特征和标签
num_graphs = 42
num_nodes = 37
image_size = 40
num_labels = 8
num_edges = 61
node_features = []
for i in range(num_graphs):
graph_features = []
for j in range(num_nodes):
image_path = f'C:\Users\jh\Desktop\data\input\images\{i}.png_{j}.png'
image = ... # 加载并预处理图片
graph_features.append(image)
node_features.append(graph_features)
node_features = torch.tensor(node_features, dtype=torch.float32)
labels = []
for i in range(num_graphs):
graph_labels = []
for j in range(num_nodes):
label_path = f'C:\Users\jh\Desktop\data\input\labels{i}{j}.txt'
with open(label_path, 'r') as file:
label_vector = [int(label) for label in file.read().split()]
graph_labels.append(label_vector)
labels.append(graph_labels)
labels = torch.tensor(labels, dtype=torch.float32)
# 加载边索引
edge_index_path = 'C:\Users\jh\Desktop\data\input\edges_L.csv'
edge_index = []
with open(edge_index_path, 'r') as file:
for line in file.readlines():
src, dst = line.strip().split(',')
edge_index.append([int(src), int(dst)])
edge_index.append([int(dst), int(src)])
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
# 创建PyG数据
data_list = []
for i in range(num_graphs):
x = node_features[i]
y = labels[i]
edge_index_i = edge_index[:, edge_index[0] == i]
data = Data(x=x, y=y, edge_index=edge_index_i)
data_list.append(data)
# 创建数据加载器
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[:30] = True
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask[30:] = True
train_loader = DataLoader([data for data in data_list if train_mask[data.y.size(0)]], batch_size=1, shuffle=True)
val_loader = DataLoader([data for data in data_list if val_mask[data.y.size(0)]], batch_size=1, shuffle=False)
# 初始化模型和优化器
model = GCN(num_features=image_size, num_labels=num_labels)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 训练循环
model.train()
for epoch in range(100):
total_loss = 0
for data in train_loader:
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.binary_cross_entropy_with_logits(out[train_mask], data.y[train_mask])
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_graphs
val_loss = 0
model.eval()
for data in val_loader:
with torch.no_grad():
out = model(data.x, data.edge_index)
loss = F.binary_cross_entropy_with_logits(out[val_mask], data.y[val_mask])
val_loss += loss.item() * data.num_graphs
print(f'Epoch: {epoch}, Train Loss: {total_loss / len(train_loader.dataset)}, '
f'Val Loss: {val_loss / len(val_loader.dataset)}')
注意:
- 该代码示例中省略了图片加载和预处理部分,你需要根据具体情况实现图片加载和预处理逻辑。
- 训练过程中的超参数可以根据具体任务进行调整。
运行结果:
该代码会在控制台输出训练和验证过程中的损失值。
总结:
该代码示例展示了如何使用PyG库构建一个GCN网络,实现一个多标签图分类任务。你可以参考该代码示例,根据自己的具体任务进行修改和调整。
原文地址: https://www.cveoy.top/t/topic/plBA 著作权归作者所有。请勿转载和采集!