GCN 多标签分类任务:使用 PYG 库和 CNN 预处理图像
GCN 多标签分类任务:使用 PYG 库和 CNN 预处理图像
本教程演示了如何使用 PyTorch Geometric (PYG) 库构建一个图卷积网络 (GCN) 模型,用于解决多标签分类任务。模型通过 CNN 网络对图像进行预处理,然后使用 GCN 模型进行特征学习。代码示例包含数据加载、模型构建、训练和测试步骤。
数据集
数据集包含 42 张图像,每张图像包含 37 个节点。
- 节点特征: 每个节点的特征由其对应图像的 RGB 像素值表示,存储在 'C:\Users\jh\Desktop\data\input\images\i.png_j.png' 路径下,其中 i 表示图像编号 (1 到 42),j 表示节点编号 (0 到 36)。每个图像的大小为 40 x 40 像素。
- 节点标签: 每个节点有 8 个标签,存储在 'C:\Users\jh\Desktop\data\input\labels\i_j.txt' 路径下,标签之间用空格隔开。
- 边关系: 图像之间的连接关系存储在 'C:\Users\jh\Desktop\data\input\edges_L.csv' 路径下,是一个 CSV 文件,没有表头,第一列为源节点,第二列为目标节点,共有 61 条无向边。
任务
任务是使用 GCN 模型对每个节点进行多标签分类,预测节点的 8 个标签。
预处理
在将图像特征输入 GCN 模型之前,需要使用 CNN 模型进行预处理,以提取更具辨别性的特征。
模型
模型由两个部分组成:
- CNN 模型: 用于对图像进行预处理,提取特征。
- GCN 模型: 用于学习节点之间的关系,并进行多标签分类。
代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx
# 定义 CNN 模型,用于图像预处理
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(32 * 10 * 10, 128)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(-1, 32 * 10 * 10)
x = F.relu(self.fc1(x))
return x
# 定义 GCN 模型
class GCNModel(nn.Module):
def __init__(self, in_features, hidden_dim, num_classes):
super(GCNModel, self).__init__()
self.conv1 = GCNConv(in_features, hidden_dim)
self.conv2 = GCNConv(hidden_dim, num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return x
# 加载节点特征和标签
def load_data():
node_features = []
labels = []
for i in range(1, 43):
for j in range(37):
image_path = f'C:\Users\jh\Desktop\data\input\images\{i}.png_{j}.png'
label_path = f'C:\Users\jh\Desktop\data\input\labels\{i}_{j}.txt'
# 使用 CNN 模型预处理图像
image = preprocess_image(image_path)
node_features.append(image)
# 加载标签
with open(label_path, 'r') as file:
label = file.read().split()
label = [int(l) for l in label]
labels.append(label)
node_features = torch.stack(node_features)
labels = torch.tensor(labels)
return node_features, labels
# 使用 CNN 模型预处理图像
def preprocess_image(image_path):
image = # 加载并使用您自己的方法预处理图像
image = torch.tensor(image)
image = image.permute(2, 0, 1)
image = image.unsqueeze(0)
image = cnn_model(image)
return image
# 加载边关系
def load_edges():
edge_index = []
with open('C:\Users\jh\Desktop\data\input\edges_L.csv', 'r') as file:
for line in file:
edge = line.strip().split(',')
edge = [int(e) for e in edge]
edge_index.append(edge)
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
return edge_index
# 设置 GCN 模型
def setup_model():
in_features = 128 # CNN 模型的输出大小
hidden_dim = 64
num_classes = 8
gcn_model = GCNModel(in_features, hidden_dim, num_classes)
return gcn_model
# 训练 GCN 模型
def train(model, data):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
# 测试 GCN 模型
def test(model, data):
model.eval()
logits, accs = model(data.x, data.edge_index), []
for _, mask in data('train_mask', 'val_mask', 'test_mask'):
pred = logits[mask].max(1)[1]
acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
accs.append(acc)
return accs
# 主代码
if __name__ == '__main__':
# 加载 CNN 模型
cnn_model = CNNModel()
cnn_model.eval()
# 加载节点特征和标签
node_features, labels = load_data()
# 加载边关系
edge_index = load_edges()
# 准备 GCN 模型的数据
data = Data(x=node_features, edge_index=edge_index, y=labels)
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.train_mask[:30] = 1
data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.val_mask[30:37] = 1
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.test_mask[37:] = 1
# 设置 GCN 模型
gcn_model = setup_model()
# 训练和测试 GCN 模型
for epoch in range(200):
loss = train(gcn_model, data)
accs = test(gcn_model, data)
print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Train Accuracy: {accs[0]:.4f}, Val Accuracy: {accs[1]:.4f}, Test Accuracy: {accs[2]:.4f}')
注意:
- 代码中的
preprocess_image函数需要根据您自己的图像预处理方法进行实现。 - 确保已安装
torch、torch_geometric和torchvision库。
运行代码
- 将代码保存为 Python 文件 (例如:
gcn_multilabel.py)。 - 确保数据已准备就绪,并将其放置在代码中指定的路径下。
- 使用以下命令运行代码:
python gcn_multilabel.py
代码将训练 GCN 模型,并输出每个 epoch 的损失和准确率。
总结
本教程演示了如何使用 PYG 库和 CNN 预处理图像构建 GCN 模型,用于解决多标签分类任务。您可以根据自己的需求修改代码,例如使用不同的 CNN 模型、优化超参数等。
原文地址: https://www.cveoy.top/t/topic/pfHT 著作权归作者所有。请勿转载和采集!