import os
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from PIL import Image

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        self.img_dir = os.path.join(root, 'images')
        self.label_dir = os.path.join(root, 'labels')
        self.edge_file = os.path.join(root, 'edges_L.csv')
        self.transform = transform
        self.pre_transform = pre_transform
        self.dataset = []
        self.create_dataset()

    def create_dataset(self):
        edges = None
        edge_index, num_nodes = self.read_edges(self.edge_file)
        
        train_mask = torch.zeros(num_nodes, dtype=torch.bool)
        train_mask[:30] = True
        train_mask = train_mask.unsqueeze(1).repeat(1, num_nodes)

        for i in range(1, 43):
            val_mask = ~train_mask.clone()

            for j in range(37):
                image_path = os.path.join(self.img_dir, f'{i}.png_{j}.png')
                label_path = os.path.join(self.label_dir, f'{i}_{j}.txt')
                features = self.read_image_features(image_path)
                labels = self.read_labels(label_path)
                labels = torch.tensor(labels, dtype=torch.long)
                features = torch.tensor(features).unsqueeze(0)
                features = features.float()

                train_data = Data(x=features, edge_index=edge_index, y=labels, train_mask=train_mask.clone(), val_mask=val_mask.clone())
                self.dataset.append(train_data)

                val_data = Data(x=features, edge_index=edge_index, y=labels, train_mask=val_mask.clone(), val_mask=train_mask.clone())
                self.dataset.append(val_data)

        return self.dataset, edges

    def read_edges(self, edge_path):
        edges = []
        with open(edge_path, 'r') as file:
            for line in file:
                src, tgt = line.strip().split(',')
                edges.append((int(src), int(tgt)))
        max_node_idx = max(max(edges, key=lambda x: max(x)))
        num_nodes = max_node_idx + 1
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

        return edge_index, num_nodes

    def read_image_features(self, image_path):
        img = Image.open(image_path)
        img = img.resize((40, 40))
        rgb_img = img.convert('RGB')

        features = []
        for i in range(40):
            for j in range(40):
                r, g, b = rgb_img.getpixel((i, j))
                features.append([r, g, b])
        return features

    def read_labels(self, label_path):
        with open(label_path, 'r') as file:
            labels = [int(label) for label in file.read().strip().split()]
        return labels

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        if self.transform is not None:
            data = self.transform(data)
        return data

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 8)
        self.conv2 = GCNConv(8, 16)
        self.conv3 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        return x

def train_model(data_loader, model, optimizer, device):
    model.train()
    total_loss = 0

    for data in data_loader:
        data = data.to(device)
        print(data.x.shape)
        train_mask = data.train_mask
        labels = data.y

        optimizer.zero_grad()
        output = model(data)

        loss = F.cross_entropy(output[train_mask], labels[train_mask])
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(data_loader)

def validate_model(data_loader, model, device):
    model.eval()
    correct = 0
    total = 0

    for data in data_loader:
        data = data.to(device)
        val_mask = data.val_mask
        labels = data.y

        output = model(data)
        _, predicted = torch.max(output[val_mask], 1)

        total += val_mask.sum().item()
        correct += (predicted == labels[val_mask]).sum().item()

    return correct / total

if __name__ == '__main__':
    dataset = MyDataset(root='C:\Users\jh\Desktop\data\input')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GCN(num_node_features=3, num_classes=8).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    train_dataset, val_dataset = train_test_split(dataset, test_size=0.1)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    epochs = 2
    for epoch in range(epochs):
        train_loss = train_model(train_loader, model, optimizer, device)
        print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}')

    val_accuracy = validate_model(val_loader, model, device)
    print(f'Val_Acc: {val_accuracy:.4f}')

This code implements a Graph Convolutional Network (GCN) in PyTorch Geometric for multi-label image classification. Here's a breakdown:

Dataset Class (MyDataset)

  • __init__: Initializes the dataset, defining paths to image, label, and edge files. Also initializes the dataset list, which will store Data objects.
  • create_dataset: Reads edge connections from edges_L.csv, generates train and validation masks (ensuring they match the shape of the output tensor), loops through image and label files, converts features and labels to tensors, creates Data objects, and appends them to the dataset list.
  • read_edges: Reads edge connections from edges_L.csv, converting them to a torch.Tensor.
  • read_image_features: Reads image data, resizes it, converts to RGB, and extracts color features as a list of lists.
  • read_labels: Reads labels from the label files.
  • __len__: Returns the length of the dataset.
  • __getitem__: Retrieves a single Data object from the dataset, optionally applying a transform.

GCN Model Class (GCN)

  • __init__: Initializes the GCN model with three convolutional layers, using GCNConv from PyTorch Geometric.
  • forward: Implements the forward pass of the GCN, applying convolutions and activation functions.

Training and Validation Functions

  • train_model: Trains the model on the training data, calculating the cross-entropy loss, backpropagating gradients, and updating model parameters.
  • validate_model: Evaluates the model on the validation data, calculating accuracy.

Main Script

  • Dataset Initialization: Creates an instance of MyDataset, specifying the data root.
  • Model and Optimizer: Initializes the GCN model and Adam optimizer.
  • Data Splitting: Splits the dataset into train and validation sets using train_test_split.
  • Data Loaders: Creates DataLoader objects for training and validation.
  • Training Loop: Iterates over epochs, training the model and evaluating its performance on the validation set.

Important Notes

  • Data Structure: The dataset is structured with folders for images, labels, and a separate file for edge connections. This ensures clear organization and easy access to data during training and validation.
  • Edge Connections: The edge connections are assumed to be static and consistent across different time steps. This is crucial for the graph structure and message passing within the GCN.
  • Image Feature Extraction: The code extracts simple color features from the images. You can adjust the feature extraction process based on the nature of your data and desired features.
  • Label Format: Labels are assumed to be stored as a space-separated list of integers in text files.
  • Multi-Label Classification: The F.cross_entropy function handles multi-label classification, allowing multiple labels per node.

This code provides a starting point for building a GCN model for multi-label image classification using PyTorch Geometric. You can extend this code by adding more sophisticated feature extraction techniques, experimenting with different GCN architectures, and optimizing hyperparameters for improved performance.

PyTorch Geometric Graph Convolutional Network (GCN) for Multi-Label Image Classification

原文地址: https://www.cveoy.top/t/topic/pcrJ 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录