Graph Convolutional Network with CNN Feature Extraction for Image-Based Node Classification
import os
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torchvision import transforms
from PIL import Image
# 加载数据并创建PyG数据集类:
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root, transform=None, pre_transform=None):
self.edges = pd.read_csv(os.path.join(root, 'input', 'edges_L.csv'), header=None)
self.transform = transform
self.pre_transform = pre_transform
# 读取特征和标签数据
self.features = []
self.labels = []
for i in range(1, 43):
for j in range(37):
# 读取特征
img_name = os.path.join(root, 'input', 'images', '{}.png_{}.png'.format(i, j))
img = Image.open(img_name).convert('RGB').resize((224, 224), resample=Image.BILINEAR)
img_tensor = transforms.ToTensor()(img)
self.features.append(img_tensor)
# 读取标签
label_name = os.path.join(root, 'input', 'labels', '{}_{}.txt'.format(i, j))
with open(label_name, 'r') as f:
labels = [int(x) for x in f.readline().strip().split()]
self.labels.append(labels)
# 将特征调整维度为[batch_size, num_node_features, width, height]
self.features = torch.stack(self.features, dim=0)
self.labels = torch.tensor(self.labels)
# Calculate the total number of nodes
self.num_nodes = len(self.labels)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
network_id = idx // 37 # Calculate the network index
node_id = idx % 37 # Calculate the node index within the network
# Determine the edges for the current network
network_edges = self.edges[self.edges[0] == network_id]
# Construct the edge_index tensor
edge_index = torch.tensor([network_edges[0].values + network_id * 37, network_edges[1].values + network_id * 37], dtype=torch.long)
x = self.features[idx] # 获取节点特征
y = self.labels[idx] # 获取标签
# Define graph-wide train_mask and val_mask
train_mask = torch.zeros(self.num_nodes, dtype=torch.bool)
val_mask = torch.zeros(self.num_nodes, dtype=torch.bool)
# Set train_mask for the first 30 nodes in each network, and val_mask for the last 7 nodes
if node_id < 30:
train_mask[node_id + network_id * 37] = 1
else:
val_mask[node_id + network_id * 37] = 1
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask)
if self.pre_transform is not None:
data = self.pre_transform(data)
if self.transform is not None:
data.x = self.transform(data.x)
return data
# 定义CNN模型
class CNN(torch.nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc = nn.Linear(32 * 56 * 56, 40 * 40 * 3)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 32 * 56 * 56)
x = self.fc(x)
return x
# 创建GCN模型
class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_node_features, 8)
self.conv2 = GCNConv(8, 16)
self.conv3 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
x = F.softmax(x, dim=1)
x = x.view(-1, 37, -1) # 调整输出维度
return x
# 创建训练和验证模型
def train_model(dataset, model, optimizer, device):
model.train()
total_loss = 0.0
for data in dataset:
data = data.to(device)
optimizer.zero_grad()
# 使用CNN提取图像特征
features = model.cnn(data.x)
data.x = features.view(features.size(0), -1)
output = model(data)
loss = F.cross_entropy(output[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataset)
def validate_model(dataset, model, device):
model.eval()
correct = 0
total = 0
for data in dataset:
data = data.to(device)
# 使用CNN提取图像特征
features = model.cnn(data.x)
data.x = features.view(features.size(0), -1)
output = model(data)
_, predicted = torch.max(output[data.val_mask], 1)
total += data.val_mask.sum().item()
correct += (predicted == data.y[data.val_mask]).sum().item()
return correct / total
if __name__ == '__main__':
dataset = MyDataset(root="C:\Users\jh\Desktop\data")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn_model = CNN().to(device)
model = GCN(num_node_features=40 * 40 * 3, num_classes=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
train_dataset, val_dataset = train_test_split(dataset, test_size=0.1)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
epochs = 2
for epoch in range(epochs):
train_loss = train_model(train_loader, model, optimizer, device)
print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}')
val_accuracy = validate_model(val_loader, model, device)
print(f'Val_Acc: {val_accuracy:.4f}')
In this code, we introduce a CNN model to extract features from the image data. The CNN output then serves as input to the GCN model. By using CNN for feature extraction, the input dimensionality is reduced, leading to a more efficient model.
Feel free to adjust the CNN model's structure and hyperparameters to suit your specific needs for more flexible feature extraction.
原文地址: http://www.cveoy.top/t/topic/pdmz 著作权归作者所有。请勿转载和采集!