使用 Torch Geometric 库自定义数据集进行图神经网络训练和验证
本文将展示如何利用 torch_geometric 库自定义数据集,并进行 GCN 模型的训练和验证。
数据集描述
数据集中包含 10 张图片,每张图片包含 20 个节点。节点编号从 0 到 19。每张图片包含 66 条边,存储在 'C:\Users\jh\Desktop\data\raw\edges.csv' 文件中。该文件每一行代表一条边,共 66 行。第一列为源节点,第二列为目标节点。
每个节点包含 2 个特征,所有图片节点的第一个特征存储在 'C:\Users\jh\Desktop\data\raw\features1.csv' 文件中,第二个特征存储在 'C:\Users\jh\Desktop\data\raw\features2.csv' 文件中。每一行代表一张图片的 20 个节点特征,共有 10 行。我们将两个特征拼接成形状为 (10, 20, 2) 的张量。
每个节点还包含一个标签,共有 0 和 1 两类标签,存储在 'C:\Users\jh\Desktop\data\raw\label.csv' 文件中。每一行代表一张图片的标签,一个值代表一个节点的标签。
自定义数据集类
我们将使用 torch_geometric 库的 InMemoryDataset 类来定义我们的自定义数据集。以下代码展示了自定义数据集类的示例代码:
import torch
from torch_geometric.data import InMemoryDataset, Data
class My_dataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(My_dataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['edges.csv', 'features1.csv', 'features2.csv', 'label.csv']
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download the raw data from the given URLs and save to self.raw_dir
...
def process(self):
# Read data from raw files and process into PyG Data objects
edges_path = self.raw_paths[0]
features1_path = self.raw_paths[1]
features2_path = self.raw_paths[2]
label_path = self.raw_paths[3]
# Process edges
edges = []
with open(edges_path, 'r') as f:
for line in f.readlines():
src, tgt = line.strip().split(',')
edges.append((int(src), int(tgt)))
# Process features
features1 = torch.tensor(pd.read_csv(features1_path, header=None).values)
features2 = torch.tensor(pd.read_csv(features2_path, header=None).values)
# Process labels
labels = torch.tensor(pd.read_csv(label_path, header=None).values.squeeze())
# Create PyG Data object
data = Data(x=torch.cat([features1, features2], dim=2), edge_index=torch.tensor(edges).t().contiguous(), y=labels)
# Save processed data
torch.save(self.collate([data]), self.processed_paths[0])
数据加载和模型训练
以下代码展示了如何使用自定义数据集类加载数据,并进行 GCN 模型训练和验证。
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
dataset = My_dataset(root='C:\Users\jh\Desktop\data')
loader = DataLoader(dataset, batch_size=1, shuffle=True)
model = GCNConv(2, 16) # Define your GCN model
# Training loop
for data in loader:
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
model.eval()
_, pred = model(data.x, data.edge_index).max(dim=1)
correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))
数据划分
在上述代码中,我们使用 data.train_mask 和 data.test_mask 来划分训练集和验证集。每个图数据的第一个 16 个节点作为训练集,剩余的 4 个节点作为验证集。需要注意的是,我们并没有对边连接关系进行掩码。
模型结构
你可以根据自己的需求修改 GCN 模型的结构。在上面的代码中,我们定义了一个输入特征维度为 2,输出特征维度为 16 的 GCNConv 层。
训练和验证
我们使用 DataLoader 加载数据,并使用 for 循环遍历每个图数据,进行模型训练和验证。训练过程中,我们使用 F.cross_entropy 计算交叉熵损失,并使用优化器进行梯度下降。验证过程中,我们计算模型在验证集上的准确率。
总结
本文展示了如何使用 torch_geometric 库自定义数据集,并进行 GCN 模型的训练和验证。你可以根据自己的需求修改数据格式、数据集类、模型结构以及训练和验证过程。
注意: 上述代码仅供参考,你需要根据自己的模型结构和需求进行相应的修改和调整。
原文地址: https://www.cveoy.top/t/topic/mulK 著作权归作者所有。请勿转载和采集!