基于CNN-GCN的多标签图分类模型

本项目利用CNN提取节点像素特征,并使用GCN进行多标签图分类。包含42个图,每个图包含37个节点,每个节点有8个标签。

数据说明:

  • num_graphs = 42: 图的总数。
  • num_nodes = 37: 每个图的节点数。
  • image_size = 40: 节点图像的大小。
  • num_labels = 8: 每个节点的标签数。
  • num_edges = 61: 图中边的总数。
  • 节点特征文件: 'C:\Users\jh\Desktop\data\input\images\i.png_j.png',表示第i个图的第j个节点的RGB像素值。其中i从1到42,j从0到36。
  • 标签文件: 'C:\Users\jh\Desktop\data\input\labels\i_j.txt',表示第i个图的第j个节点的标签。标签用空格隔开。
  • 边关系文件: 'C:\Users\jh\Desktop\data\input\edges_L.csv',表示图中边的关系。第一列为源节点,第二列为目标节点。

模型架构:

  1. CNN: 用于提取节点像素特征。
  2. GCN: 用于学习节点之间的关系,并进行多标签分类。

训练过程:

  • 将每个图的前30个节点颜色特征加入训练掩码,后7个节点颜色特征加入验证掩码。
  • 使用PYG库建立GCN网络。
  • 使用torch.nn模块中的MultiLabelSoftMarginLoss作为损失函数。

代码示例:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv

# Define the CNN network for node pixel feature extraction
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 10 * 10, 128)
        self.fc2 = nn.Linear(128, 64)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the GCN network for multi-label classification
class GCN(nn.Module):
    def __init__(self, num_features, hidden_size, num_labels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_size)
        self.conv2 = GCNConv(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size, num_labels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.fc(x)
        return torch.sigmoid(x)

# Define the dataset class
class GraphDataset(Dataset):
    def __init__(self, num_graphs, num_nodes, image_size, num_labels, num_edges, image_folder, label_folder, edge_file):
        self.num_graphs = num_graphs
        self.num_nodes = num_nodes
        self.image_size = image_size
        self.num_labels = num_labels
        self.num_edges = num_edges
        self.image_folder = image_folder
        self.label_folder = label_folder
        self.edge_file = edge_file
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return self.num_graphs

    def __getitem__(self, idx):
        # Load image features
        image_path = os.path.join(self.image_folder, f'{idx+1}.png')
        image = Image.open(image_path).convert('RGB')
        image = image.resize((self.image_size, self.image_size))
        image = self.transform(image)

        # Load labels
        label_path = os.path.join(self.label_folder, f'{idx+1}.txt')
        labels = np.loadtxt(label_path, delimiter=' ')
        labels = torch.from_numpy(labels).float()

        return image, labels

# Load the edge data
edge_data = pd.read_csv('C:/Users/jh/Desktop/data/input/edges_L.csv', header=None)
edge_index = torch.tensor(edge_data.values, dtype=torch.long).t().contiguous()

# Create the dataset
dataset = GraphDataset(num_graphs, num_nodes, image_size, num_labels, num_edges, 'C:/Users/jh/Desktop/data/input/images',
                       'C:/Users/jh/Desktop/data/input/labels', edge_index)

# Split the dataset into training and validation sets
train_dataset, val_dataset = train_test_split(dataset, test_size=7, random_state=42)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=30, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=7, shuffle=False)

# Initialize the CNN and GCN models
cnn_model = CNN()
gcn_model = GCN(num_features=64, hidden_size=32, num_labels=num_labels)

# Define the optimizer and loss function
optimizer = optim.Adam(gcn_model.parameters(), lr=0.001)
loss_fn = nn.MultiLabelSoftMarginLoss()

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = 0.0
    val_loss = 0.0

    # Training
    cnn_model.train()
    gcn_model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()

        # Extract node features using the CNN model
        node_features = cnn_model(images)

        # Create batched graph data
        batched_data = Batch.from_data_list([Data(x=node_features[i], edge_index=edge_index) for i in range(len(labels))])

        # Forward pass
        output = gcn_model(batched_data.x, batched_data.edge_index)
        loss = loss_fn(output, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    # Validation
    cnn_model.eval()
    gcn_model.eval()
    for images, labels in val_loader:
        # Extract node features using the CNN model
        node_features = cnn_model(images)

        # Create batched graph data
        batched_data = Batch.from_data_list([Data(x=node_features[i], edge_index=edge_index) for i in range(len(labels))])

        # Forward pass
        output = gcn_model(batched_data.x, batched_data.edge_index)
        loss = loss_fn(output, labels)

        val_loss += loss.item() * images.size(0)

    # Calculate average losses
    train_loss /= len(train_dataset)
    val_loss /= len(val_dataset)

    # Print epoch results
    print(f'Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

运行结果:

代码运行后,将输出每个epoch的训练损失和验证损失。

总结:

本项目利用CNN-GCN模型实现了多标签图分类任务。通过提取节点像素特征和学习节点之间的关系,该模型可以有效地识别图的类别。

基于CNN-GCN的多标签图分类模型

原文地址: https://www.cveoy.top/t/topic/pgt7 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录