import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch_geometric.data import Data from torch_geometric.nn import GCNConv from torch_geometric.data import DataLoader from torch_geometric.utils import to_networkx import numpy as np import pandas as pd from PIL import Image from sklearn.model_selection import train_test_split

Define the CNN model for node feature extraction

class CNNEncoder(nn.Module): def init(self): super(CNNEncoder, self).init() self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

def forward(self, x):
    x = F.relu(self.conv1(x))
    x = self.pool1(x)
    x = F.relu(self.conv2(x))
    x = self.pool2(x)
    return x

Define the GCN model for multi-label classification

class GCNClassifier(nn.Module): def init(self, input_dim, hidden_dim, output_dim): super(GCNClassifier, self).init() self.conv1 = GCNConv(input_dim, hidden_dim) self.conv2 = GCNConv(hidden_dim, output_dim)

def forward(self, x, edge_index):
    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = self.conv2(x, edge_index)
    return x

Set the paths for node features and labels

node_features_path = 'C:/Users/jh/Desktop/data/input/images/{i}.png_{j}.png' labels_path = 'C:/Users/jh/Desktop/data/input/labels/{i}_{j}.txt' edges_path = 'C:/Users/jh/Desktop/data/input/edges_L.csv'

Load node features and labels

node_features = [] labels = [] for i in range(1, 43): for j in range(37): image_path = node_features_path.format(i=i, j=j) image = Image.open(image_path).convert('L') # Convert to grayscale image = np.array(image) node_features.append(image)

    label_path = labels_path.format(i=i, j=j)
    with open(label_path, 'r') as f:
        label = f.read().split()
        label = [int(l) for l in label]
        labels.append(label)

node_features = np.array(node_features) labels = np.array(labels)

Load edge relations

edges = pd.read_csv(edges_path, header=None).values

Split the data into train and validation sets

train_features, val_features, train_labels, val_labels = train_test_split( node_features, labels, test_size=7, random_state=42)

Convert the data to PyTorch geometric format

train_data = [] for i in range(train_features.shape[0]): x = torch.tensor(train_features[i]).unsqueeze(0).unsqueeze(0).float() y = torch.tensor(train_labels[i]).float() edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() data = Data(x=x, y=y, edge_index=edge_index) train_data.append(data)

val_data = [] for i in range(val_features.shape[0]): x = torch.tensor(val_features[i]).unsqueeze(0).unsqueeze(0).float() y = torch.tensor(val_labels[i]).float() edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() data = Data(x=x, y=y, edge_index=edge_index) val_data.append(data)

Create data loaders

train_loader = DataLoader(train_data, batch_size=1, shuffle=True) val_loader = DataLoader(val_data, batch_size=1, shuffle=False)

Initialize the models

cnn_encoder = CNNEncoder() gcn_classifier = GCNClassifier(32, 16, 8)

Set device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') cnn_encoder.to(device) gcn_classifier.to(device)

Set optimizer and loss function

optimizer = optim.Adam(list(cnn_encoder.parameters()) + list(gcn_classifier.parameters()), lr=0.001) loss_fn = nn.MultiLabelSoftMarginLoss()

Training loop

for epoch in range(10): cnn_encoder.train() gcn_classifier.train() total_loss = 0

for data in train_loader:
    data = data.to(device)
    optimizer.zero_grad()
    
    # Forward pass
    x = cnn_encoder(data.x)
    x = x.view(x.size(0), -1)
    x = gcn_classifier(x, data.edge_index)
    pred = torch.sigmoid(x)
    
    # Compute loss
    loss = loss_fn(pred, data.y)
    total_loss += loss.item()
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
# Validation
cnn_encoder.eval()
gcn_classifier.eval()
val_loss = 0

for data in val_loader:
    data = data.to(device)
    
    # Forward pass
    x = cnn_encoder(data.x)
    x = x.view(x.size(0), -1)
    x = gcn_classifier(x, data.edge_index)
    pred = torch.sigmoid(x)
    
    # Compute loss
    loss = loss_fn(pred, data.y)
    val_loss += loss.item()

print(f'Epoch {epoch+1}, Train Loss: {total_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}')

Predict labels for all nodes

cnn_encoder.eval() gcn_classifier.eval() pred_labels = [] for i in range(1, 43): for j in range(37): image_path = node_features_path.format(i=i, j=j) image = Image.open(image_path).convert('L') # Convert to grayscale image = np.array(image) x = torch.tensor(image).unsqueeze(0).unsqueeze(0).float().to(device)

    # Forward pass
    x = cnn_encoder(x)
    x = x.view(x.size(0), -1)
    x = gcn_classifier(x, torch.tensor(edges, dtype=torch.long).t().contiguous().to(device))
    pred = torch.sigmoid(x)
    
    # Get predicted label vector
    pred_label = (pred > 0.5).float()
    pred_labels.append(pred_label.cpu().numpy()[0])

Print the predicted label vectors

for i, pred_label in enumerate(pred_labels): print(f'Node {i+1}: {pred_label}')


原文地址: https://www.cveoy.top/t/topic/pk3C 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录