以下代码使用PyG库构建GCN网络,结合CNN进行特征降维,实现对42个图的节点进行多标签分类任务。数据包含42个图,每个图有37个节点,节点特征为40x40像素的RGB图像,每个节点有8个标签,边关系存储在CSV文件中。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from PIL import Image

# 定义节点特征数据集类
class NodeFeatureDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        
    def __len__(self):
        return 42
    
    def __getitem__(self, idx):
        img_path = self.root_dir + f'images\{idx+1}.png_{idx+1}.png'
        label_path = self.root_dir + f'labels\{idx+1}_j.txt'
        
        img = Image.open(img_path)
        img = transforms.ToTensor()(img)
        
        with open(label_path, 'r') as f:
            label = [int(x) for x in f.readline().split()]
            
        return img, label

# 定义图数据集类
class GraphDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.edge_file = self.root_dir + '\edges_L.csv'
        self.edge_index = self.read_edge_index()
        
    def read_edge_index(self):
        edge_index = []
        with open(self.edge_file, 'r') as f:
            for line in f:
                source, target = line.strip().split(',')
                edge_index.append([int(source)-1, int(target)-1])  # 索引从0开始,需要减去1
        return torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
    def __len__(self):
        return 42
    
    def __getitem__(self, idx):
        return self.edge_index

# 定义CNN网络对节点像素特征进行降维
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        
        return x

# 定义GCN网络实现多标签分类任务
class GCN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, 64)
        self.conv2 = GCNConv(64, out_channels)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        
        return torch.sigmoid(x)

# 定义损失函数
loss_fn = nn.MultiLabelSoftMarginLoss()

# 创建节点特征数据集对象和图数据集对象
node_feature_dataset = NodeFeatureDataset("C:\Users\jh\Desktop\data\input")
graph_dataset = GraphDataset("C:\Users\jh\Desktop\data\input")

# 创建CNN网络和GCN网络对象
cnn_net = CNN()
gcn_net = GCN(in_channels=16, out_channels=8)

# 定义优化器
optimizer = optim.Adam(gcn_net.parameters(), lr=0.01)

# 训练循环
for epoch in range(10):
    for i in range(len(node_feature_dataset)):
        # 获取节点特征和边关系
        img, label = node_feature_dataset[i]
        edge_index = graph_dataset[i]
        
        # 通过CNN网络进行降维
        x = cnn_net(img.unsqueeze(0))
        
        # 通过GCN网络进行多标签分类
        output = gcn_net(x.squeeze(), edge_index)
        
        # 计算损失函数
        loss = loss_fn(output.unsqueeze(0), torch.tensor(label).unsqueeze(0).float())
        
        # 反向传播和参数更新
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 打印损失函数值
        print(f"Epoch {epoch+1}, Graph {i+1}, Loss: {loss.item()}")

# 验证
for i in range(42, 45):
    img, label = node_feature_dataset[i]
    edge_index = graph_dataset[i]
    
    x = cnn_net(img.unsqueeze(0))
    output = gcn_net(x.squeeze(), edge_index)
    
    print(f"Graph {i+1}, Predicted Label: {output.detach().numpy()}, True Label: {label}")

请注意,上述代码假设您已经安装了必要的库:torch、torchvision、torch_geometric。如果没有安装,可以通过以下命令进行安装:

pip install torch torchvision torch-geometric

另外,请将数据文件夹的路径更改为您实际存储数据的路径。

使用PyG库构建GCN网络实现多标签分类任务

原文地址: https://www.cveoy.top/t/topic/pfWM 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录