GCN 多标签分类任务:使用 PYG 库和 CNN 预处理图像

本教程演示了如何使用 PyTorch Geometric (PYG) 库构建一个图卷积网络 (GCN) 模型,用于解决多标签分类任务。模型通过 CNN 网络对图像进行预处理,然后使用 GCN 模型进行特征学习。代码示例包含数据加载、模型构建、训练和测试步骤。

数据集

数据集包含 42 张图像,每张图像包含 37 个节点。

  • 节点特征: 每个节点的特征由其对应图像的 RGB 像素值表示,存储在 'C:\Users\jh\Desktop\data\input\images\i.png_j.png' 路径下,其中 i 表示图像编号 (1 到 42),j 表示节点编号 (0 到 36)。每个图像的大小为 40 x 40 像素。
  • 节点标签: 每个节点有 8 个标签,存储在 'C:\Users\jh\Desktop\data\input\labels\i_j.txt' 路径下,标签之间用空格隔开。
  • 边关系: 图像之间的连接关系存储在 'C:\Users\jh\Desktop\data\input\edges_L.csv' 路径下,是一个 CSV 文件,没有表头,第一列为源节点,第二列为目标节点,共有 61 条无向边。

任务

任务是使用 GCN 模型对每个节点进行多标签分类,预测节点的 8 个标签。

预处理

在将图像特征输入 GCN 模型之前,需要使用 CNN 模型进行预处理,以提取更具辨别性的特征。

模型

模型由两个部分组成:

  1. CNN 模型: 用于对图像进行预处理,提取特征。
  2. GCN 模型: 用于学习节点之间的关系,并进行多标签分类。

代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx

# 定义 CNN 模型,用于图像预处理
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 10 * 10, 128)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 32 * 10 * 10)
        x = F.relu(self.fc1(x))
        return x

# 定义 GCN 模型
class GCNModel(nn.Module):
    def __init__(self, in_features, hidden_dim, num_classes):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(in_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, num_classes)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

# 加载节点特征和标签
def load_data():
    node_features = []
    labels = []
    for i in range(1, 43):
        for j in range(37):
            image_path = f'C:\Users\jh\Desktop\data\input\images\{i}.png_{j}.png'
            label_path = f'C:\Users\jh\Desktop\data\input\labels\{i}_{j}.txt'
            
            # 使用 CNN 模型预处理图像
            image = preprocess_image(image_path)
            node_features.append(image)
            
            # 加载标签
            with open(label_path, 'r') as file:
                label = file.read().split()
                label = [int(l) for l in label]
                labels.append(label)
    
    node_features = torch.stack(node_features)
    labels = torch.tensor(labels)
    
    return node_features, labels

# 使用 CNN 模型预处理图像
def preprocess_image(image_path):
    image = # 加载并使用您自己的方法预处理图像
    image = torch.tensor(image)
    image = image.permute(2, 0, 1)
    image = image.unsqueeze(0)
    image = cnn_model(image)
    return image

# 加载边关系
def load_edges():
    edge_index = []
    with open('C:\Users\jh\Desktop\data\input\edges_L.csv', 'r') as file:
        for line in file:
            edge = line.strip().split(',')
            edge = [int(e) for e in edge]
            edge_index.append(edge)
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
    return edge_index

# 设置 GCN 模型
def setup_model():
    in_features = 128  # CNN 模型的输出大小
    hidden_dim = 64
    num_classes = 8
    
    gcn_model = GCNModel(in_features, hidden_dim, num_classes)
    return gcn_model

# 训练 GCN 模型
def train(model, data):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    return loss.item()

# 测试 GCN 模型
def test(model, data):
    model.eval()
    logits, accs = model(data.x, data.edge_index), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    
    return accs

# 主代码
if __name__ == '__main__':
    # 加载 CNN 模型
    cnn_model = CNNModel()
    cnn_model.eval()
    
    # 加载节点特征和标签
    node_features, labels = load_data()
    
    # 加载边关系
    edge_index = load_edges()
    
    # 准备 GCN 模型的数据
    data = Data(x=node_features, edge_index=edge_index, y=labels)
    data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.train_mask[:30] = 1
    data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.val_mask[30:37] = 1
    data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.test_mask[37:] = 1
    
    # 设置 GCN 模型
    gcn_model = setup_model()
    
    # 训练和测试 GCN 模型
    for epoch in range(200):
        loss = train(gcn_model, data)
        accs = test(gcn_model, data)
        print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Train Accuracy: {accs[0]:.4f}, Val Accuracy: {accs[1]:.4f}, Test Accuracy: {accs[2]:.4f}')

注意:

  • 代码中的preprocess_image函数需要根据您自己的图像预处理方法进行实现。
  • 确保已安装 torchtorch_geometrictorchvision 库。

运行代码

  1. 将代码保存为 Python 文件 (例如:gcn_multilabel.py)。
  2. 确保数据已准备就绪,并将其放置在代码中指定的路径下。
  3. 使用以下命令运行代码:
python gcn_multilabel.py

代码将训练 GCN 模型,并输出每个 epoch 的损失和准确率。

总结

本教程演示了如何使用 PYG 库和 CNN 预处理图像构建 GCN 模型,用于解决多标签分类任务。您可以根据自己的需求修改代码,例如使用不同的 CNN 模型、优化超参数等。

GCN 多标签分类任务:使用 PYG 库和 CNN 预处理图像

原文地址: https://www.cveoy.top/t/topic/pfHT 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录