基于图卷积网络的时序图数据多标签分类

本项目使用图卷积网络 (GCN) 对时序图数据进行多标签分类。每个节点表示一张图片,包含从图片提取的颜色特征,并拥有多个标签。通过图网络学习节点之间的关系,并预测每个节点的标签。

数据描述

  • 数据集包含 42 个时刻的图数据,每个时刻有 37 张图片,每张图片代表一个节点。
  • 图片大小为 40x40 像素,并从图片中提取颜色特征作为节点的特征向量。
  • 每个节点拥有 8 个标签,标签用空格隔开,存储在 'C:\Users\jh\Desktop\data\input\labels\i_j.txt' 文件中,其中 i 表示图片序号,j 表示节点序号。
  • 节点之间的连接关系储存在 'C:\Users\jh\Desktop\data\input\edges_L.csv' 文件中,第一列为源节点,第二列为目标节点,边为无向边。

代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import os
from PIL import Image
import numpy as np
import pandas as pd

# 定义 GCN 模型类
class GCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_feats, hid_feats)
        self.conv2 = GCNConv(hid_feats, out_feats)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

# 函数用于读取图片特征
def read_image_features(image_path):
    img = Image.open(image_path).convert('RGB')
    img = img.resize((40, 40))  # 将图片尺寸调整为 40x40
    img_array = np.array(img)
    features = img_array.flatten()
    return features.tolist()

# 函数用于读取标签
def read_labels(label_path):
    with open(label_path, 'r') as file:
        labels = file.readline().split()
        labels = [int(label) for label in labels]
    return labels

# 函数用于读取边
def read_edges(edge_path):
    edges_df = pd.read_csv(edge_path)
    edges = [(int(src), int(tgt)) for src, tgt in edges_df.values]
    return edges

# 函数用于从图片特征、标签和边创建 PyG 数据集
def create_dataset(image_dir, label_dir, edge_file):
    dataset = []
    edges = read_edges(edge_file)

    for i in range(1, 43):
        for j in range(37):
            image_path = os.path.join(image_dir, f'i{i}_{j}.png')
            label_path = os.path.join(label_dir, f'i{i}_{j}.txt')
            features = read_image_features(image_path)
            labels = read_labels(label_path)
            data = Data(x=torch.tensor(features).float(), y=torch.tensor(labels))
            dataset.append(data)

    return dataset, edges

# 函数用于将数据集拆分为训练集和验证集
def split_dataset(dataset):
    train_dataset = dataset[:30 * 37]
    val_dataset = dataset[30 * 37:]
    return train_dataset, val_dataset

# 创建 GCN 模型
in_feats = 40
hid_feats = 64
out_feats = 8
model = GCN(in_feats, hid_feats, out_feats)

# 加载并拆分数据集
image_dir = 'C:\Users\jh\Desktop\data\input\images'
label_dir = 'C:\Users\jh\Desktop\data\input\labels'
edge_file = 'C:\Users\jh\Desktop\data\input\edges_L.csv'
dataset, edges = create_dataset(image_dir, label_dir, edge_file)
train_dataset, val_dataset = split_dataset(dataset)

# 定义数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()

# 训练循环
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out, data.y.float())
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_loss = 0
    
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            out = model(data.x, data.edge_index)
            val_loss += criterion(out, data.y.float()).item()
    
    print(f'Epoch: {epoch+1}, Val Loss: {val_loss:.4f}')

代码说明

  • 代码中首先定义了 GCN 模型类,包含两个图卷积层,并使用 ReLU 激活函数。
  • 接着定义了三个函数,分别用于读取图片特征、标签和边数据。
  • create_dataset 函数用于从图片特征、标签和边数据创建 PyG 数据集。
  • split_dataset 函数用于将数据集拆分为训练集和验证集。
  • 最后,代码创建 GCN 模型,加载并拆分数据集,定义数据加载器、优化器和损失函数,并进行训练循环。

结果

训练完成后,可以使用训练好的模型对新的图数据进行多标签分类。

注意

  • 代码中的路径需要根据实际情况进行修改。
  • 可以根据需要调整模型结构、超参数和训练参数。

未来工作

  • 可以尝试使用其他图卷积网络模型,例如 GAT、GraphSAGE 等。
  • 可以尝试使用其他特征提取方法,例如提取图片的纹理特征、形状特征等。
  • 可以尝试使用其他损失函数,例如 focal loss、label smoothing 等。
  • 可以尝试使用其他训练方法,例如对抗训练、半监督学习等。
基于图卷积网络的时序图数据多标签分类

原文地址: https://www.cveoy.top/t/topic/pbNA 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录