基于CNN和GCN的多标签节点分类模型

本项目使用CNN网络对节点像素特征进行降维,并利用PYG库构建GCN网络,实现对42个图的37个节点的多标签分类任务。

数据描述:

  • num_graphs = 42:共42个图
  • num_nodes = 37:每个图有37个节点
  • image_size = 40:每个节点的图像尺寸为40x40像素
  • num_labels = 8:每个节点有8个标签
  • num_edges = 61:共61条无向边

数据存储:

  • 节点特征文件是'C:\Users\jh\Desktop\data\input\images\i.png_j.png'的所有图片的RGB像素值,其中i表示图,i从1到42,j表示节点,j从0到36。
  • 每个节点有8个标签,储存在'C:\Users\jh\Desktop\data\input\labels\i_j.txt'文本文件中,标签用空格隔开。
  • 边的关系储存在'C:\Users\jh\Desktop\data\input\edges_L.csv'csv文件中,表格中没有header,第一列为源节点,第二列为目标节点。

模型架构:

  • 使用CNN网络对节点像素特征进行降维。
  • 使用PYG库建立GCN网络实现多标签分类任务。

训练策略:

  • 将每个图的前30个节点颜色特征加入训练掩码,后7个节点颜色特征加入验证掩码。
  • 使用torch.nn模块中的MultiLabelSoftMarginLoss作为损失函数。

代码实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader

import os
import numpy as np
from PIL import Image
import pandas as pd

# Define the CNN network for node pixel feature reduction
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(32*14*14, 128)
        self.fc2 = nn.Linear(128, 64)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 32*14*14)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the GCN network for multi-label classification
class GCN(nn.Module):
    def __init__(self, num_features, hidden_size, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_size)
        self.conv2 = GCNConv(hidden_size, num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return torch.sigmoid(x)

# Load node pixel features
def load_node_features():
    node_features = []
    for i in range(1, num_graphs+1):
        image_path = f"C:\Users\jh\Desktop\data\input\images\{i}.png"
        image = Image.open(image_path)
        image = image.resize((image_size, image_size))
        pixels = np.array(image)

        node_features.append(torch.tensor(pixels).float())
    return node_features

# Load node labels
def load_node_labels():
    node_labels = []
    for i in range(1, num_graphs+1):
        graph_labels = []
        for j in range(num_nodes):
            label_path = f"C:\Users\jh\Desktop\data\input\labels\{i}_{j}.txt"
            with open(label_path, 'r') as f:
                labels = f.read().split()
                labels = [int(label) for label in labels]
                graph_labels.append(torch.tensor(labels).float())
        node_labels.append(torch.stack(graph_labels))
    return node_labels

# Load edge indices
def load_edge_indices():
    edge_indices = []
    edge_file = "C:\Users\jh\Desktop\data\input\edges_L.csv"
    edges = pd.read_csv(edge_file, header=None)
    edges = edges.values
    edge_indices.append(torch.tensor(edges.T).long())
    return edge_indices

# Create PyG data objects
def create_data_objects():
    node_features = load_node_features()
    node_labels = load_node_labels()
    edge_indices = load_edge_indices()

    data_list = []
    for i in range(num_graphs):
        x = node_features[i]
        y = node_labels[i]
        edge_index = edge_indices[i]
        data = Data(x=x, y=y, edge_index=edge_index)
        data_list.append(data)

    train_mask = torch.zeros(num_graphs, num_nodes).bool()
    train_mask[:, :30] = True
    val_mask = torch.zeros(num_graphs, num_nodes).bool()
    val_mask[:, 30:] = True

    return data_list, train_mask, val_mask

# Train and evaluate the model
def train(model, data_list, train_mask, val_mask, epochs, lr):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MultiLabelSoftMarginLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for data in data_list:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = loss_fn(out[train_mask], data.y[train_mask])
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

        avg_loss = total_loss / len(data_list)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data in data_list:
                out = model(data.x, data.edge_index)
                loss = loss_fn(out[val_mask], data.y[val_mask])
                val_loss += loss.item()

        avg_val_loss = val_loss / len(data_list)
        print(f"Validation Loss: {avg_val_loss:.4f}")

# Main function
if __name__ == '__main__':
    num_graphs = 42
    num_nodes = 37
    image_size = 40
    num_labels = 8
    num_edges = 61

    cnn = CNN()
    gcn = GCN(64, 64, num_labels)

    data_list, train_mask, val_mask = create_data_objects()

    train(cnn, data_list, train_mask, val_mask, epochs=10, lr=0.001)
    train(gcn, data_list, train_mask, val_mask, epochs=10, lr=0.001)
基于CNN和GCN的多标签节点分类模型

原文地址: https://www.cveoy.top/t/topic/pgtZ 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录