import os
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torchvision import transforms
from PIL import Image

# 加载数据并创建PyG数据集类:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        self.edges = pd.read_csv(os.path.join(root, 'input', 'edges_L.csv'), header=None)
        self.transform = transform
        self.pre_transform = pre_transform

        # 读取特征和标签数据
        self.features = []
        self.labels = []
        for i in range(1, 43):
            for j in range(37):
                # 读取特征
                img_name = os.path.join(root, 'input', 'images', '{}.png_{}.png'.format(i, j))
                img = Image.open(img_name).convert('RGB').resize((224, 224), resample=Image.BILINEAR)
                img_tensor = transforms.ToTensor()(img)
                self.features.append(img_tensor)

                # 读取标签
                label_name = os.path.join(root, 'input', 'labels', '{}_{}.txt'.format(i, j))
                with open(label_name, 'r') as f:
                    labels = [int(x) for x in f.readline().strip().split()]
                self.labels.append(labels)

        # 将特征调整维度为[batch_size, num_node_features, width, height]
        self.features = torch.stack(self.features, dim=0)
        self.labels = torch.tensor(self.labels)

        # Calculate the total number of nodes
        self.num_nodes = len(self.labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        network_id = idx // 37  # Calculate the network index
        node_id = idx % 37  # Calculate the node index within the network

        # Determine the edges for the current network
        network_edges = self.edges[self.edges[0] == network_id]

        # Construct the edge_index tensor
        edge_index = torch.tensor([network_edges[0].values + network_id * 37, network_edges[1].values + network_id * 37], dtype=torch.long)

        x = self.features[idx]  # 获取节点特征
        y = self.labels[idx]  # 获取标签

        # Define graph-wide train_mask and val_mask
        train_mask = torch.zeros(self.num_nodes, dtype=torch.bool)
        val_mask = torch.zeros(self.num_nodes, dtype=torch.bool)

        # Set train_mask for the first 30 nodes in each network, and val_mask for the last 7 nodes
        if node_id < 30:
            train_mask[node_id + network_id * 37] = 1
        else:
            val_mask[node_id + network_id * 37] = 1

        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask)

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        if self.transform is not None:
            data.x = self.transform(data.x)

        return data

# 定义CNN模型
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(32 * 56 * 56, 40 * 40 * 3)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 56 * 56)
        x = self.fc(x)
        return x

# 创建GCN模型
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 8)
        self.conv2 = GCNConv(8, 16)
        self.conv3 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        x = F.softmax(x, dim=1)
        x = x.view(-1, 37, -1)  # 调整输出维度
        return x

# 创建训练和验证模型
def train_model(dataset, model, optimizer, device):
    model.train()
    total_loss = 0.0

    for data in dataset:
        data = data.to(device)
        optimizer.zero_grad()

        # 使用CNN提取图像特征
        features = model.cnn(data.x)
        data.x = features.view(features.size(0), -1)

        output = model(data)
        loss = F.cross_entropy(output[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataset)

def validate_model(dataset, model, device):
    model.eval()
    correct = 0
    total = 0

    for data in dataset:
        data = data.to(device)

        # 使用CNN提取图像特征
        features = model.cnn(data.x)
        data.x = features.view(features.size(0), -1)

        output = model(data)
        _, predicted = torch.max(output[data.val_mask], 1)
        total += data.val_mask.sum().item()
        correct += (predicted == data.y[data.val_mask]).sum().item()

    return correct / total

if __name__ == '__main__':
    dataset = MyDataset(root="C:\Users\jh\Desktop\data")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    cnn_model = CNN().to(device)
    model = GCN(num_node_features=40 * 40 * 3, num_classes=8).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    train_dataset, val_dataset = train_test_split(dataset, test_size=0.1)

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    epochs = 2

    for epoch in range(epochs):
        train_loss = train_model(train_loader, model, optimizer, device)
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}')

        val_accuracy = validate_model(val_loader, model, device)
        print(f'Val_Acc: {val_accuracy:.4f}')

In this code, we introduce a CNN model to extract features from the image data. The CNN output then serves as input to the GCN model. By using CNN for feature extraction, the input dimensionality is reduced, leading to a more efficient model.

Feel free to adjust the CNN model's structure and hyperparameters to suit your specific needs for more flexible feature extraction.

Graph Convolutional Network with CNN Feature Extraction for Image-Based Node Classification

原文地址: http://www.cveoy.top/t/topic/pdmz 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录