基于CNN和GCN的多标签节点分类模型
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch_geometric.data import Data from torch_geometric.nn import GCNConv from torch_geometric.data import DataLoader from torch_geometric.utils import to_networkx import numpy as np import pandas as pd from PIL import Image from sklearn.model_selection import train_test_split
Define the CNN model for node feature extraction
class CNNEncoder(nn.Module): def init(self): super(CNNEncoder, self).init() self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
return x
Define the GCN model for multi-label classification
class GCNClassifier(nn.Module): def init(self, input_dim, hidden_dim, output_dim): super(GCNClassifier, self).init() self.conv1 = GCNConv(input_dim, hidden_dim) self.conv2 = GCNConv(hidden_dim, output_dim)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return x
Set the paths for node features and labels
node_features_path = 'C:/Users/jh/Desktop/data/input/images/{i}.png_{j}.png' labels_path = 'C:/Users/jh/Desktop/data/input/labels/{i}_{j}.txt' edges_path = 'C:/Users/jh/Desktop/data/input/edges_L.csv'
Load node features and labels
node_features = [] labels = [] for i in range(1, 43): for j in range(37): image_path = node_features_path.format(i=i, j=j) image = Image.open(image_path).convert('L') # Convert to grayscale image = np.array(image) node_features.append(image)
label_path = labels_path.format(i=i, j=j)
with open(label_path, 'r') as f:
label = f.read().split()
label = [int(l) for l in label]
labels.append(label)
node_features = np.array(node_features) labels = np.array(labels)
Load edge relations
edges = pd.read_csv(edges_path, header=None).values
Split the data into train and validation sets
train_features, val_features, train_labels, val_labels = train_test_split( node_features, labels, test_size=7, random_state=42)
Convert the data to PyTorch geometric format
train_data = [] for i in range(train_features.shape[0]): x = torch.tensor(train_features[i]).unsqueeze(0).unsqueeze(0).float() y = torch.tensor(train_labels[i]).float() edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() data = Data(x=x, y=y, edge_index=edge_index) train_data.append(data)
val_data = [] for i in range(val_features.shape[0]): x = torch.tensor(val_features[i]).unsqueeze(0).unsqueeze(0).float() y = torch.tensor(val_labels[i]).float() edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() data = Data(x=x, y=y, edge_index=edge_index) val_data.append(data)
Create data loaders
train_loader = DataLoader(train_data, batch_size=1, shuffle=True) val_loader = DataLoader(val_data, batch_size=1, shuffle=False)
Initialize the models
cnn_encoder = CNNEncoder() gcn_classifier = GCNClassifier(32, 16, 8)
Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') cnn_encoder.to(device) gcn_classifier.to(device)
Set optimizer and loss function
optimizer = optim.Adam(list(cnn_encoder.parameters()) + list(gcn_classifier.parameters()), lr=0.001) loss_fn = nn.MultiLabelSoftMarginLoss()
Training loop
for epoch in range(10): cnn_encoder.train() gcn_classifier.train() total_loss = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
# Forward pass
x = cnn_encoder(data.x)
x = x.view(x.size(0), -1)
x = gcn_classifier(x, data.edge_index)
pred = torch.sigmoid(x)
# Compute loss
loss = loss_fn(pred, data.y)
total_loss += loss.item()
# Backward pass
loss.backward()
optimizer.step()
# Validation
cnn_encoder.eval()
gcn_classifier.eval()
val_loss = 0
for data in val_loader:
data = data.to(device)
# Forward pass
x = cnn_encoder(data.x)
x = x.view(x.size(0), -1)
x = gcn_classifier(x, data.edge_index)
pred = torch.sigmoid(x)
# Compute loss
loss = loss_fn(pred, data.y)
val_loss += loss.item()
print(f'Epoch {epoch+1}, Train Loss: {total_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}')
Predict labels for all nodes
cnn_encoder.eval() gcn_classifier.eval() pred_labels = [] for i in range(1, 43): for j in range(37): image_path = node_features_path.format(i=i, j=j) image = Image.open(image_path).convert('L') # Convert to grayscale image = np.array(image) x = torch.tensor(image).unsqueeze(0).unsqueeze(0).float().to(device)
# Forward pass
x = cnn_encoder(x)
x = x.view(x.size(0), -1)
x = gcn_classifier(x, torch.tensor(edges, dtype=torch.long).t().contiguous().to(device))
pred = torch.sigmoid(x)
# Get predicted label vector
pred_label = (pred > 0.5).float()
pred_labels.append(pred_label.cpu().numpy()[0])
Print the predicted label vectors
for i, pred_label in enumerate(pred_labels): print(f'Node {i+1}: {pred_label}')
原文地址: https://www.cveoy.top/t/topic/pk3C 著作权归作者所有。请勿转载和采集!