PyTorch Geometric Graph Convolutional Network (GCN) for Multi-Label Image Classification
import os
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from PIL import Image
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root, transform=None, pre_transform=None):
self.img_dir = os.path.join(root, 'images')
self.label_dir = os.path.join(root, 'labels')
self.edge_file = os.path.join(root, 'edges_L.csv')
self.transform = transform
self.pre_transform = pre_transform
self.dataset = []
self.create_dataset()
def create_dataset(self):
edges = None
edge_index, num_nodes = self.read_edges(self.edge_file)
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[:30] = True
train_mask = train_mask.unsqueeze(1).repeat(1, num_nodes)
for i in range(1, 43):
val_mask = ~train_mask.clone()
for j in range(37):
image_path = os.path.join(self.img_dir, f'{i}.png_{j}.png')
label_path = os.path.join(self.label_dir, f'{i}_{j}.txt')
features = self.read_image_features(image_path)
labels = self.read_labels(label_path)
labels = torch.tensor(labels, dtype=torch.long)
features = torch.tensor(features).unsqueeze(0)
features = features.float()
train_data = Data(x=features, edge_index=edge_index, y=labels, train_mask=train_mask.clone(), val_mask=val_mask.clone())
self.dataset.append(train_data)
val_data = Data(x=features, edge_index=edge_index, y=labels, train_mask=val_mask.clone(), val_mask=train_mask.clone())
self.dataset.append(val_data)
return self.dataset, edges
def read_edges(self, edge_path):
edges = []
with open(edge_path, 'r') as file:
for line in file:
src, tgt = line.strip().split(',')
edges.append((int(src), int(tgt)))
max_node_idx = max(max(edges, key=lambda x: max(x)))
num_nodes = max_node_idx + 1
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
return edge_index, num_nodes
def read_image_features(self, image_path):
img = Image.open(image_path)
img = img.resize((40, 40))
rgb_img = img.convert('RGB')
features = []
for i in range(40):
for j in range(40):
r, g, b = rgb_img.getpixel((i, j))
features.append([r, g, b])
return features
def read_labels(self, label_path):
with open(label_path, 'r') as file:
labels = [int(label) for label in file.read().strip().split()]
return labels
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
if self.transform is not None:
data = self.transform(data)
return data
class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_node_features, 8)
self.conv2 = GCNConv(8, 16)
self.conv3 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
return x
def train_model(data_loader, model, optimizer, device):
model.train()
total_loss = 0
for data in data_loader:
data = data.to(device)
print(data.x.shape)
train_mask = data.train_mask
labels = data.y
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output[train_mask], labels[train_mask])
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(data_loader)
def validate_model(data_loader, model, device):
model.eval()
correct = 0
total = 0
for data in data_loader:
data = data.to(device)
val_mask = data.val_mask
labels = data.y
output = model(data)
_, predicted = torch.max(output[val_mask], 1)
total += val_mask.sum().item()
correct += (predicted == labels[val_mask]).sum().item()
return correct / total
if __name__ == '__main__':
dataset = MyDataset(root='C:\Users\jh\Desktop\data\input')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(num_node_features=3, num_classes=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
train_dataset, val_dataset = train_test_split(dataset, test_size=0.1)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
epochs = 2
for epoch in range(epochs):
train_loss = train_model(train_loader, model, optimizer, device)
print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}')
val_accuracy = validate_model(val_loader, model, device)
print(f'Val_Acc: {val_accuracy:.4f}')
This code implements a Graph Convolutional Network (GCN) in PyTorch Geometric for multi-label image classification. Here's a breakdown:
Dataset Class (MyDataset)
__init__: Initializes the dataset, defining paths to image, label, and edge files. Also initializes thedatasetlist, which will storeDataobjects.create_dataset: Reads edge connections fromedges_L.csv, generates train and validation masks (ensuring they match the shape of the output tensor), loops through image and label files, converts features and labels to tensors, createsDataobjects, and appends them to thedatasetlist.read_edges: Reads edge connections fromedges_L.csv, converting them to atorch.Tensor.read_image_features: Reads image data, resizes it, converts to RGB, and extracts color features as a list of lists.read_labels: Reads labels from the label files.__len__: Returns the length of the dataset.__getitem__: Retrieves a singleDataobject from the dataset, optionally applying a transform.
GCN Model Class (GCN)
__init__: Initializes the GCN model with three convolutional layers, usingGCNConvfrom PyTorch Geometric.forward: Implements the forward pass of the GCN, applying convolutions and activation functions.
Training and Validation Functions
train_model: Trains the model on the training data, calculating the cross-entropy loss, backpropagating gradients, and updating model parameters.validate_model: Evaluates the model on the validation data, calculating accuracy.
Main Script
- Dataset Initialization: Creates an instance of
MyDataset, specifying the data root. - Model and Optimizer: Initializes the
GCNmodel andAdamoptimizer. - Data Splitting: Splits the dataset into train and validation sets using
train_test_split. - Data Loaders: Creates
DataLoaderobjects for training and validation. - Training Loop: Iterates over epochs, training the model and evaluating its performance on the validation set.
Important Notes
- Data Structure: The dataset is structured with folders for images, labels, and a separate file for edge connections. This ensures clear organization and easy access to data during training and validation.
- Edge Connections: The edge connections are assumed to be static and consistent across different time steps. This is crucial for the graph structure and message passing within the GCN.
- Image Feature Extraction: The code extracts simple color features from the images. You can adjust the feature extraction process based on the nature of your data and desired features.
- Label Format: Labels are assumed to be stored as a space-separated list of integers in text files.
- Multi-Label Classification: The
F.cross_entropyfunction handles multi-label classification, allowing multiple labels per node.
This code provides a starting point for building a GCN model for multi-label image classification using PyTorch Geometric. You can extend this code by adding more sophisticated feature extraction techniques, experimenting with different GCN architectures, and optimizing hyperparameters for improved performance.
原文地址: https://www.cveoy.top/t/topic/pcrJ 著作权归作者所有。请勿转载和采集!