import os
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from PIL import Image

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        self.img_dir = os.path.join(root, 'images')
        self.label_dir = os.path.join(root, 'labels')
        self.edge_file = os.path.join(root, 'edges_L.csv')
        self.transform = transform
        self.pre_transform = pre_transform
        self.dataset = []
        self.create_dataset()
        
    def create_dataset(self):
        edges = None
        edge_index, num_nodes = self.read_edges(self.edge_file)
        for i in range(1, 43):
            train_mask = torch.zeros(num_nodes, dtype=torch.bool)
            train_mask[:30] = True
            val_mask = ~train_mask
            
            # Convert mask dimensions to (num_nodes, num_nodes)
            train_mask = train_mask.unsqueeze(1).repeat(1, num_nodes)
            val_mask = val_mask.unsqueeze(1).repeat(1, num_nodes)
            
            for j in range(37):
                image_path = os.path.join(self.img_dir, f'{i}.png_{j}.png')
                label_path = os.path.join(self.label_dir, f'{i}_{j}.txt')
                features = self.read_image_features(image_path)
                labels = self.read_labels(label_path)
                labels = torch.tensor(labels, dtype=torch.long)
                features = torch.tensor(features).unsqueeze(0)
                features = features.float()
                
                train_data = Data(x=features, edge_index=edge_index, y=labels, train_mask=train_mask.clone(), val_mask=val_mask.clone())
                self.dataset.append(train_data)
                
                val_data = Data(x=features, edge_index=edge_index, y=labels, train_mask=val_mask.clone(), val_mask=train_mask.clone())
                self.dataset.append(val_data)
        
        return self.dataset, edges

    def read_edges(self, edge_path):
        edges = []
        with open(edge_path, 'r') as file:
            for line in file:
                src, tgt = line.strip().split(',')
                edges.append((int(src), int(tgt)))
        max_node_idx = max(max(edges, key=lambda x: max(x)))
        num_nodes = max_node_idx + 1
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        
        return edge_index, num_nodes

    def read_image_features(self, image_path):
        img = Image.open(image_path)
        img = img.resize((40, 40))
        rgb_img = img.convert('RGB')
        
        features = []
        for i in range(40):
            for j in range(40):
                r, g, b = rgb_img.getpixel((i, j))
                features.append([r, g, b])
        return features

    def read_labels(self, label_path):
        with open(label_path, 'r') as file:
            labels = [int(label) for label in file.read().strip().split()]
        return labels

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        if self.transform is not None:
            data = self.transform(data)
        return data

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 8)
        self.conv2 = GCNConv(8, 16)
        self.conv3 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        return x

    def train_model(data_loader, model, optimizer, device):
        model.train()
        total_loss = 0
        
        for data in data_loader:
            data = data.to(device)
            print(data.x.shape)
            train_mask = data.train_mask
            labels = data.y
            
            optimizer.zero_grad()
            output = model(data)
            
            loss = F.cross_entropy(output[train_mask], labels[train_mask])
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        return total_loss / len(data_loader)

    def validate_model(data_loader, model, device):
        model.eval()
        correct = 0
        total = 0
        
        for data in data_loader:
            data = data.to(device)
            val_mask = data.val_mask
            labels = data.y
            
            output = model(data)
            _, predicted = torch.max(output[val_mask], 1)
            
            total += val_mask.sum().item()
            correct += (predicted == labels[val_mask]).sum().item()
        return correct / total

if __name__ == '__main__':
    dataset = MyDataset(root="C:\Users\jh\Desktop\data\input")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GCN(num_node_features=3, num_classes=8).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    train_dataset, val_dataset = train_test_split(dataset, test_size=0.1)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
    epochs = 2
    for epoch in range(epochs):
        train_loss = train_model(train_loader, model, optimizer, device)
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}')
    
    val_accuracy = validate_model(val_loader, model, device)
    print(f'Val_Acc: {val_accuracy:.4f}')
Graph Neural Network for Image Feature Classification with Node Labels

原文地址: https://www.cveoy.top/t/topic/pcrm 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录