import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv from torch_geometric.data import Data from torch_geometric.loader import DataLoader from torchvision import transforms from PIL import Image import os import numpy as np import pandas as pd

定义 GCN 模型

class GCN(torch.nn.Module): def init(self, num_node_features, num_classes): super(GCN, self).init() self.conv1 = GCNConv(num_node_features, 8) self.conv2 = GCNConv(8, 16) self.conv3 = GCNConv(16, num_classes)

def forward(self, data):
    x, edge_index = data.x, data.edge_index
    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = self.conv2(x, edge_index)
    x = F.relu(x)
    x = F.dropout(x, training=self.training)
    x = self.conv3(x, edge_index)
    return x

加载图像特征

image_folder = 'C:/Users/jh/Desktop/data/input/images/' image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')] image_files.sort()

加载标签文件

label_folder = 'C:/Users/jh/Desktop/data/input/labels/' label_files = [f for f in os.listdir(label_folder) if f.endswith('.txt')] label_files.sort()

加载边文件

edge_file = 'C:/Users/jh/Desktop/data/input/edges_L.csv' edges = pd.read_csv(edge_file)

定义图像处理变换

transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])

创建一个列表来存储图

data_list = []

遍历每个时间步

for i in range(len(image_files)): image_path = os.path.join(image_folder, image_files[i]) label_path = os.path.join(label_folder, label_files[i])

# 加载图像特征
image = Image.open(image_path)
image = transform(image)

# 加载标签
labels = np.loadtxt(label_path)
labels = torch.tensor(labels).unsqueeze(1)

# 通过连接图像特征和标签创建节点特征
node_features = torch.cat((image, labels), dim=1)

# 创建边索引
edge_index = torch.tensor(edges.values, dtype=torch.long).t().contiguous()

# 创建图数据
graph_data = Data(x=node_features, edge_index=edge_index)
data_list.append(graph_data)

将数据拆分为训练和验证掩码

train_mask = torch.zeros(len(data_list), 37, dtype=torch.bool) train_mask[:, :30] = True val_mask = ~train_mask

创建 GCN 模型

num_node_features = 9 # 8 for image features + 1 for label num_classes = 8 model = GCN(num_node_features, num_classes)

定义数据加载器

loader = DataLoader(data_list, batch_size=1)

训练模型

for epoch in range(10): for data in loader: out = model(data) loss = F.cross_entropy(out[train_mask], data.x[train_mask][:, -1].squeeze().long()) optimizer.zero_grad() loss.backward() optimizer.step()

验证模型

model.eval() with torch.no_grad(): for data in loader: out = model(data) pred = out[val_mask].max(1)[1] print(pred)

图卷积神经网络 (GCN) 多标签分类任务: 使用 PYG 库和 CNN 特征提取

原文地址: https://www.cveoy.top/t/topic/pek2 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录