import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
from torchvision import transforms
from PIL import Image

# Define the GCN model class
class GCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_feats, hid_feats)
        self.conv2 = GCNConv(hid_feats, out_feats)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

# Function to read image features
def read_image_features(image_path):
    transform = transforms.Compose([
        transforms.Resize((40, 40)),
        transforms.ToTensor()
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    return image.flatten()

# Function to read labels
def read_labels(label_path):
    with open(label_path, 'r') as file:
        labels = file.readline().split()
        labels = [int(label) for label in labels]
    return labels

# Function to read edges
def read_edges(edge_path):
    edges = []
    with open(edge_path, 'r') as file:
        for line in file:
            src, tgt = line.split(',')
            edges.append((int(src), int(tgt)))
    return edges

# Function to create PyG dataset from image features, labels, and edges
def create_dataset(image_dir, label_dir, edge_file):
    dataset = []
    edges = read_edges(edge_file)
    for i in range(1, 43):
        for j in range(37):
            image_path = f"{image_dir}/i{i}_{j}.png"
            label_path = f"{label_dir}/i{i}_{j}.txt"
            features = read_image_features(image_path)
            labels = read_labels(label_path)
            data = Data(x=torch.tensor(features), y=torch.tensor(labels))
            dataset.append(data)
    return dataset, edges

# Function to split dataset into train and validation sets
def split_dataset(dataset):
    train_dataset = dataset[:30*37]
    val_dataset = dataset[30*37:]
    return train_dataset, val_dataset

# Create the GCN model
in_feats = 40*40*3
hid_feats = 64
out_feats = 8
model = GCN(in_feats, hid_feats, out_feats)

# Load and split the dataset
image_dir = "C:/Users/jh/Desktop/data/input/images"
label_dir = "C:/Users/jh/Desktop/data/input/labels"
edge_file = "C:/Users/jh/Desktop/data/input/edges_L.csv"
dataset, edges = create_dataset(image_dir, label_dir, edge_file)
train_dataset, val_dataset = split_dataset(dataset)

# Define the dataloaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Define the optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()

# Training loop
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x.float(), data.edge_index)
        loss = criterion(out, data.y.float())
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            out = model(data.x.float(), data.edge_index)
            val_loss += criterion(out, data.y.float()).item()
    
    print(f"Epoch: {epoch+1}, Val Loss: {val_loss:.4f}")

该代码使用PyTorch Geometric库构建了一个GCN模型,并使用图像特征作为节点特征,边信息来学习节点之间的关系。代码中包含了以下步骤:

  1. 定义GCN模型: 使用GCN类定义一个两层GCN模型。
  2. 读取图像特征: 使用read_image_features函数读取图像特征,并将其展平成一维向量。
  3. 读取标签: 使用read_labels函数读取每个节点的标签。
  4. 读取边信息: 使用read_edges函数读取边的信息。
  5. 创建PyG数据集: 使用create_dataset函数将图像特征、标签和边信息组合成PyG数据集。
  6. 划分训练集和验证集: 使用split_dataset函数将数据集划分为训练集和验证集。
  7. 定义数据加载器: 使用DataLoader类创建训练数据加载器和验证数据加载器。
  8. 定义优化器和损失函数: 使用torch.optim.Adam类定义优化器,使用nn.BCEWithLogitsLoss类定义损失函数。
  9. 训练模型: 使用训练数据加载器训练模型,并使用验证数据加载器评估模型性能。

该代码示例展示了如何使用GCN模型进行动态图数据节点标签预测。需要注意的是,该代码只是一个简单的示例,实际应用中可能需要根据具体情况进行调整。


原文地址: https://www.cveoy.top/t/topic/pbNN 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录