基于PYG的GCN模型多标签分类任务代码示例
基于PYG的GCN模型多标签分类任务代码示例
本示例展示了如何利用PyTorch Geometric(PYG)库构建GCN模型,并使用图像特征、标签和边信息进行多标签分类任务。代码示例包括数据加载、模型构建、训练和验证过程。
数据描述
- 数据集包含42个时刻的图数据,每个时刻包含37张图片,每张图片代表一个节点。
- 图片大小为40x40像素,位于’C:\Users\jh\Desktop\data\input\images\i.png_j.png‘目录下,其中i表示时刻,i从1到42,j表示节点序号,j从0到36。
- 每个节点有8个标签,储存在’C:\Users\jh\Desktop\data\input\labels\i_j.txt‘文本文件中,标签用空格隔开,特征值和标签都是整数。
- 节点之间的连接关系相同,保存在’C:\Users\jh\Desktop\data\input\edges_L.csv‘文件中,第一列为源节点,第二列为目标节点,边为无向边。
模型构建
- 使用PYG库中的GCNConv层构建GCN模型。
- 模型接受节点特征和边信息作为输入,输出每个节点的标签预测值。
训练与验证
- 将每个时刻的前30个节点作为训练集,后7个节点作为验证集。
- 使用Adam优化器和BCEWithLogitsLoss损失函数进行训练。
- 评估模型的验证集损失,并打印训练过程中的信息。
代码示例
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
# 定义GCN模型
class GCN(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats):
super(GCN, self).__init__()
self.conv1 = GCNConv(in_feats, hid_feats)
self.conv2 = GCNConv(hid_feats, out_feats)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return x
# 函数读取图像特征
def read_image_features(image_path):
# TODO: 实现图像特征提取逻辑
pass
# 函数读取标签
def read_labels(label_path):
with open(label_path, 'r') as file:
labels = file.readline().split()
labels = [int(label) for label in labels]
return labels
# 函数读取边信息
def read_edges(edge_path):
edges = []
with open(edge_path, 'r') as file:
for line in file:
src, tgt = line.split(',')
edges.append((int(src), int(tgt)))
return edges
# 函数从图像特征、标签和边信息创建PYG数据集
def create_dataset(image_dir, label_dir, edge_file):
dataset = []
edges = read_edges(edge_file)
for i in range(1, 43):
for j in range(37):
image_path = f"{image_dir}/i{i}_{j}.png"
label_path = f"{label_dir}/i{i}_{j}.txt"
features = read_image_features(image_path)
labels = read_labels(label_path)
data = Data(x=torch.tensor(features), y=torch.tensor(labels))
dataset.append(data)
return dataset, edges
# 函数将数据集拆分为训练集和验证集
def split_dataset(dataset):
train_dataset = dataset[:30*37]
val_dataset = dataset[30*37:]
return train_dataset, val_dataset
# 创建GCN模型
in_feats = 40
hid_feats = 64
out_feats = 8
model = GCN(in_feats, hid_feats, out_feats)
# 加载并拆分数据集
image_dir = "C:/Users/jh/Desktop/data/input/images"
label_dir = "C:/Users/jh/Desktop/data/input/labels"
edge_file = "C:/Users/jh/Desktop/data/input/edges_L.csv"
dataset, edges = create_dataset(image_dir, label_dir, edge_file)
train_dataset, val_dataset = split_dataset(dataset)
# 定义数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()
# 训练循环
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for epoch in range(num_epochs):
model.train()
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x.float(), data.edge_index)
loss = criterion(out, data.y.float())
loss.backward()
optimizer.step()
model.eval()
val_loss = 0
with torch.no_grad():
for data in val_loader:
data = data.to(device)
out = model(data.x.float(), data.edge_index)
val_loss += criterion(out, data.y.float()).item()
print(f"Epoch: {epoch+1}, Val Loss: {val_loss:.4f}")
# 对测试数据进行预测
# TODO: 实现测试数据加载和预测逻辑
请注意,上述代码中的read_image_features函数和read_labels函数需要根据实际情况实现图像特征提取和标签读取的逻辑。另外,还需要根据实际情况实现测试数据加载和预测逻辑。
原文地址: https://www.cveoy.top/t/topic/pbNp 著作权归作者所有。请勿转载和采集!