基于CNN-GCN的多标签图分类模型
基于CNN-GCN的多标签图分类模型
本项目利用CNN提取节点像素特征,并使用GCN进行多标签图分类。包含42个图,每个图包含37个节点,每个节点有8个标签。
数据说明:
- num_graphs = 42: 图的总数。
- num_nodes = 37: 每个图的节点数。
- image_size = 40: 节点图像的大小。
- num_labels = 8: 每个节点的标签数。
- num_edges = 61: 图中边的总数。
- 节点特征文件: 'C:\Users\jh\Desktop\data\input\images\i.png_j.png',表示第i个图的第j个节点的RGB像素值。其中i从1到42,j从0到36。
- 标签文件: 'C:\Users\jh\Desktop\data\input\labels\i_j.txt',表示第i个图的第j个节点的标签。标签用空格隔开。
- 边关系文件: 'C:\Users\jh\Desktop\data\input\edges_L.csv',表示图中边的关系。第一列为源节点,第二列为目标节点。
模型架构:
- CNN: 用于提取节点像素特征。
- GCN: 用于学习节点之间的关系,并进行多标签分类。
训练过程:
- 将每个图的前30个节点颜色特征加入训练掩码,后7个节点颜色特征加入验证掩码。
- 使用PYG库建立GCN网络。
- 使用torch.nn模块中的MultiLabelSoftMarginLoss作为损失函数。
代码示例:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv
# Define the CNN network for node pixel feature extraction
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(32 * 10 * 10, 128)
self.fc2 = nn.Linear(128, 64)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# Define the GCN network for multi-label classification
class GCN(nn.Module):
def __init__(self, num_features, hidden_size, num_labels):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_features, hidden_size)
self.conv2 = GCNConv(hidden_size, hidden_size)
self.fc = nn.Linear(hidden_size, num_labels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = self.fc(x)
return torch.sigmoid(x)
# Define the dataset class
class GraphDataset(Dataset):
def __init__(self, num_graphs, num_nodes, image_size, num_labels, num_edges, image_folder, label_folder, edge_file):
self.num_graphs = num_graphs
self.num_nodes = num_nodes
self.image_size = image_size
self.num_labels = num_labels
self.num_edges = num_edges
self.image_folder = image_folder
self.label_folder = label_folder
self.edge_file = edge_file
self.transform = transforms.Compose([transforms.ToTensor()])
def __len__(self):
return self.num_graphs
def __getitem__(self, idx):
# Load image features
image_path = os.path.join(self.image_folder, f'{idx+1}.png')
image = Image.open(image_path).convert('RGB')
image = image.resize((self.image_size, self.image_size))
image = self.transform(image)
# Load labels
label_path = os.path.join(self.label_folder, f'{idx+1}.txt')
labels = np.loadtxt(label_path, delimiter=' ')
labels = torch.from_numpy(labels).float()
return image, labels
# Load the edge data
edge_data = pd.read_csv('C:/Users/jh/Desktop/data/input/edges_L.csv', header=None)
edge_index = torch.tensor(edge_data.values, dtype=torch.long).t().contiguous()
# Create the dataset
dataset = GraphDataset(num_graphs, num_nodes, image_size, num_labels, num_edges, 'C:/Users/jh/Desktop/data/input/images',
'C:/Users/jh/Desktop/data/input/labels', edge_index)
# Split the dataset into training and validation sets
train_dataset, val_dataset = train_test_split(dataset, test_size=7, random_state=42)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=30, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=7, shuffle=False)
# Initialize the CNN and GCN models
cnn_model = CNN()
gcn_model = GCN(num_features=64, hidden_size=32, num_labels=num_labels)
# Define the optimizer and loss function
optimizer = optim.Adam(gcn_model.parameters(), lr=0.001)
loss_fn = nn.MultiLabelSoftMarginLoss()
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
train_loss = 0.0
val_loss = 0.0
# Training
cnn_model.train()
gcn_model.train()
for images, labels in train_loader:
optimizer.zero_grad()
# Extract node features using the CNN model
node_features = cnn_model(images)
# Create batched graph data
batched_data = Batch.from_data_list([Data(x=node_features[i], edge_index=edge_index) for i in range(len(labels))])
# Forward pass
output = gcn_model(batched_data.x, batched_data.edge_index)
loss = loss_fn(output, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
# Validation
cnn_model.eval()
gcn_model.eval()
for images, labels in val_loader:
# Extract node features using the CNN model
node_features = cnn_model(images)
# Create batched graph data
batched_data = Batch.from_data_list([Data(x=node_features[i], edge_index=edge_index) for i in range(len(labels))])
# Forward pass
output = gcn_model(batched_data.x, batched_data.edge_index)
loss = loss_fn(output, labels)
val_loss += loss.item() * images.size(0)
# Calculate average losses
train_loss /= len(train_dataset)
val_loss /= len(val_dataset)
# Print epoch results
print(f'Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
运行结果:
代码运行后,将输出每个epoch的训练损失和验证损失。
总结:
本项目利用CNN-GCN模型实现了多标签图分类任务。通过提取节点像素特征和学习节点之间的关系,该模型可以有效地识别图的类别。
原文地址: https://www.cveoy.top/t/topic/pgt7 著作权归作者所有。请勿转载和采集!