使用PyG库构建GCN网络实现多标签分类任务
以下代码使用PyG库构建GCN网络,结合CNN进行特征降维,实现对42个图的节点进行多标签分类任务。数据包含42个图,每个图有37个节点,节点特征为40x40像素的RGB图像,每个节点有8个标签,边关系存储在CSV文件中。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from PIL import Image
# 定义节点特征数据集类
class NodeFeatureDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
def __len__(self):
return 42
def __getitem__(self, idx):
img_path = self.root_dir + f'images\{idx+1}.png_{idx+1}.png'
label_path = self.root_dir + f'labels\{idx+1}_j.txt'
img = Image.open(img_path)
img = transforms.ToTensor()(img)
with open(label_path, 'r') as f:
label = [int(x) for x in f.readline().split()]
return img, label
# 定义图数据集类
class GraphDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.edge_file = self.root_dir + '\edges_L.csv'
self.edge_index = self.read_edge_index()
def read_edge_index(self):
edge_index = []
with open(self.edge_file, 'r') as f:
for line in f:
source, target = line.strip().split(',')
edge_index.append([int(source)-1, int(target)-1]) # 索引从0开始,需要减去1
return torch.tensor(edge_index, dtype=torch.long).t().contiguous()
def __len__(self):
return 42
def __getitem__(self, idx):
return self.edge_index
# 定义CNN网络对节点像素特征进行降维
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.pool(x)
return x
# 定义GCN网络实现多标签分类任务
class GCN(nn.Module):
def __init__(self, in_channels, out_channels):
super(GCN, self).__init__()
self.conv1 = GCNConv(in_channels, 64)
self.conv2 = GCNConv(64, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return torch.sigmoid(x)
# 定义损失函数
loss_fn = nn.MultiLabelSoftMarginLoss()
# 创建节点特征数据集对象和图数据集对象
node_feature_dataset = NodeFeatureDataset("C:\Users\jh\Desktop\data\input")
graph_dataset = GraphDataset("C:\Users\jh\Desktop\data\input")
# 创建CNN网络和GCN网络对象
cnn_net = CNN()
gcn_net = GCN(in_channels=16, out_channels=8)
# 定义优化器
optimizer = optim.Adam(gcn_net.parameters(), lr=0.01)
# 训练循环
for epoch in range(10):
for i in range(len(node_feature_dataset)):
# 获取节点特征和边关系
img, label = node_feature_dataset[i]
edge_index = graph_dataset[i]
# 通过CNN网络进行降维
x = cnn_net(img.unsqueeze(0))
# 通过GCN网络进行多标签分类
output = gcn_net(x.squeeze(), edge_index)
# 计算损失函数
loss = loss_fn(output.unsqueeze(0), torch.tensor(label).unsqueeze(0).float())
# 反向传播和参数更新
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印损失函数值
print(f"Epoch {epoch+1}, Graph {i+1}, Loss: {loss.item()}")
# 验证
for i in range(42, 45):
img, label = node_feature_dataset[i]
edge_index = graph_dataset[i]
x = cnn_net(img.unsqueeze(0))
output = gcn_net(x.squeeze(), edge_index)
print(f"Graph {i+1}, Predicted Label: {output.detach().numpy()}, True Label: {label}")
请注意,上述代码假设您已经安装了必要的库:torch、torchvision、torch_geometric。如果没有安装,可以通过以下命令进行安装:
pip install torch torchvision torch-geometric
另外,请将数据文件夹的路径更改为您实际存储数据的路径。
原文地址: https://www.cveoy.top/t/topic/pfWM 著作权归作者所有。请勿转载和采集!