图卷积神经网络 (GCN) 多标签分类任务: 使用 PYG 库和 CNN 特征提取
import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv from torch_geometric.data import Data from torch_geometric.loader import DataLoader from torchvision import transforms from PIL import Image import os import numpy as np import pandas as pd
定义 GCN 模型
class GCN(torch.nn.Module): def init(self, num_node_features, num_classes): super(GCN, self).init() self.conv1 = GCNConv(num_node_features, 8) self.conv2 = GCNConv(8, 16) self.conv3 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
return x
加载图像特征
image_folder = 'C:/Users/jh/Desktop/data/input/images/' image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')] image_files.sort()
加载标签文件
label_folder = 'C:/Users/jh/Desktop/data/input/labels/' label_files = [f for f in os.listdir(label_folder) if f.endswith('.txt')] label_files.sort()
加载边文件
edge_file = 'C:/Users/jh/Desktop/data/input/edges_L.csv' edges = pd.read_csv(edge_file)
定义图像处理变换
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])
创建一个列表来存储图
data_list = []
遍历每个时间步
for i in range(len(image_files)): image_path = os.path.join(image_folder, image_files[i]) label_path = os.path.join(label_folder, label_files[i])
# 加载图像特征
image = Image.open(image_path)
image = transform(image)
# 加载标签
labels = np.loadtxt(label_path)
labels = torch.tensor(labels).unsqueeze(1)
# 通过连接图像特征和标签创建节点特征
node_features = torch.cat((image, labels), dim=1)
# 创建边索引
edge_index = torch.tensor(edges.values, dtype=torch.long).t().contiguous()
# 创建图数据
graph_data = Data(x=node_features, edge_index=edge_index)
data_list.append(graph_data)
将数据拆分为训练和验证掩码
train_mask = torch.zeros(len(data_list), 37, dtype=torch.bool) train_mask[:, :30] = True val_mask = ~train_mask
创建 GCN 模型
num_node_features = 9 # 8 for image features + 1 for label num_classes = 8 model = GCN(num_node_features, num_classes)
定义数据加载器
loader = DataLoader(data_list, batch_size=1)
训练模型
for epoch in range(10): for data in loader: out = model(data) loss = F.cross_entropy(out[train_mask], data.x[train_mask][:, -1].squeeze().long()) optimizer.zero_grad() loss.backward() optimizer.step()
验证模型
model.eval() with torch.no_grad(): for data in loader: out = model(data) pred = out[val_mask].max(1)[1] print(pred)
原文地址: https://www.cveoy.top/t/topic/pek2 著作权归作者所有。请勿转载和采集!