import os
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torchvision import transforms
from PIL import Image

# 加载数据并创建PyG数据集类:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        self.edges = pd.read_csv(os.path.join(root, 'input', 'edges_L.csv'), header=None)
        self.transform = transform
        self.pre_transform = pre_transform
        
        # 读取特征和标签数据
        self.features = []
        self.labels = []
        for i in range(1, 43):
            for j in range(37):
                # 读取特征
                img_name = os.path.join(root, 'input', 'images', '{}.png_{}.png'.format(i, j))
                img = Image.open(img_name).convert('RGB').resize((224, 224), resample=Image.BILINEAR)
                img_tensor = transforms.ToTensor()(img)
                self.features.append(img_tensor)
                
                # 读取标签
                label_name = os.path.join(root, 'input', 'labels', '{}_{}.txt'.format(i, j))
                with open(label_name, 'r') as f:
                    labels = [int(x) for x in f.readline().strip().split()]
                self.labels.append(labels)
        
        # 将特征调整维度为[batch_size, num_node_features, width, height]
        self.features = torch.stack(self.features, dim=0)
        self.labels = torch.tensor(self.labels)
        
        # Calculate the total number of nodes
        self.num_nodes = len(self.labels)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        network_id = idx // 37  # Calculate the network index
        node_id = idx % 37  # Calculate the node index within the network
        
        # Get the edge connections for the current node
        edge_index = self.edges[(self.edges[0] == node_id) & (self.edges[1] == network_id)].index[0]
        
        x = self.features[idx]  # 获取节点特征
        y = self.labels[idx]  # 获取标签
        
        # Define graph-wide train_mask and val_mask
        train_mask = torch.zeros(self.num_nodes, dtype=torch.bool)
        val_mask = torch.zeros(self.num_nodes, dtype=torch.bool)
        
        # Set train_mask for the first 30 nodes in each network, and val_mask for the last 7 nodes
        if node_id < 30:
            train_mask[node_id + network_id * 37] = 1
        else:
            val_mask[node_id + network_id * 37] = 1
        
        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask)
        
        if self.pre_transform is not None:
            data = self.pre_transform(data)
        if self.transform is not None:
            data.x = self.transform(data.x)
        return data

# 定义CNN模型
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(32 * 56 * 56, 40 * 40 * 3)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 56 * 56)
        x = self.fc(x)
        return x

# 创建GCN模型
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        
        self.cnn = CNN()  # 添加CNN模型
        
        self.conv1 = GCNConv(num_node_features, 8)
        self.conv2 = GCNConv(8, 16)
        self.conv3 = GCNConv(16, num_classes)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        print (x.shape)
        print (edge_index.shape)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        x = F.softmax(x, dim=1)
        x = x.view(-1, 37, -1)  # 调整输出维度
        return x

# 创建训练和验证模型
def train_model(dataset, model, optimizer, device):
    model.train()
    total_loss = 0.0
    
    for data in dataset:
        data = data.to(device)
        optimizer.zero_grad()
        
        features = model.cnn(data.x)  # 提取图像特征
        data.x = features.view(features.size(0), -1)
        
        output = model(data)
        loss = F.cross_entropy(output[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(dataset)

def validate_model(dataset, model, device):
    model.eval()
    correct = 0
    total = 0
    
    for data in dataset:
        data = data.to(device)
        
        features = model.cnn(data.x)  # 提取图像特征
        data.x = features.view(features.size(0), -1)
        
        output = model(data)
        _, predicted = torch.max(output[data.val_mask], 1)
        total += data.val_mask.sum().item()
        correct += (predicted == data.y[data.val_mask]).sum().item()
    
    return correct / total

if __name__ == '__main__':
    dataset = MyDataset(root="C:\Users\jh\Desktop\data")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    cnn_model = CNN().to(device)
    model = GCN(num_node_features=40 * 40 * 3, num_classes=8).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    train_dataset, val_dataset = train_test_split(dataset, test_size=0.1)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
    epochs = 2
    for epoch in range(epochs):
        train_loss = train_model(train_loader, model, optimizer, device)
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}')
        
        val_accuracy = validate_model(val_loader, model, device)
        print(f'Val_Acc: {val_accuracy:.4f}')
        
# 边的连接关系如下:
# 0	0	0	1	1	2	2	3	3	4	4	5	5	6	6	7	8	8	9	9	10	10	11	12	12	13	13	14	15	16	16	17	17	18	19	19	20	20	21	22	22	23	23	24	25	25	26	26	27	27	28	29	29	30	31	31	32	32	33	34	35
# 1	4	8	2	5	3	6	7	14	5	9	6	10	7	11	13	9	15	10	17	11	18	12	13	19	14	20	21	16	17	22	18	23	24	20	28	21	29	30	23	25	24	26	27	26	34	27	31	28	31	32	30	33	36	32	34	33	35	36	35	36
# 储存在"C:\Users\jh\Desktop\data\input\edges_L.csv"中
# 第一列为源节点,第二列为目标节点,边为无向边,
Graph Convolutional Network (GCN) with Image Features for Node Classification

原文地址: https://www.cveoy.top/t/topic/pdS7 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录