PyG 自定义数据集:将 10 张图片的节点特征和标签数据构建为张量并使用掩码进行训练集和验证集划分
使用 PyG 库自定义数据集并进行训练集和验证集划分
本文将使用 PyG 库自定义数据集,将 10 张图片的节点特征和标签数据构建为张量形状 (10, 20, 2),并使用掩码将每张图片的前 16 个节点作为训练集,剩下的 4 个节点作为验证集。
数据描述:
- 边数据: 10 张图片,每张图片包含 20 个节点,一共 66 条边,每行包括一个源节点和一个目标节点,源节点和目标节点用空格隔开,储存在 'C:/Users/jh/Desktop/data/raw/edges.txt' 中。
- 节点特征: 每个节点包含两个一维特征,所有图片节点的第一个特征储存在 'features1.txt' 中,第二个特征储存在 'features2.txt' 中。文件中每一行代表一张图片的节点特征,不同节点间的特征用空格隔开。
- 节点标签: 每个节点还包含一个标签,储存在 'label.txt' 中,文件中每一行代表一张图片的节点标签,不同节点间的标签用空格隔开。
目标:
利用 PyG 库自定义数据集 My_dataset 类,用掩码 mask 将每张图片的前 16 个节点的特征和标签作为训练集,剩下 4 个节点作为验证集。
代码示例:
import torch
from torch_geometric.data import Dataset, Data
class My_dataset(Dataset):
def __init__(self, root):
super(My_dataset, self).__init__(root)
# 读取边数据
edge_file = open('C:/Users/jh/Desktop/data/raw/edges.txt', 'r')
edges = edge_file.readlines()
edge_file.close()
# 读取特征数据
feature_file1 = open('C:/Users/jh/Desktop/data/raw/features1.txt', 'r')
feature_file2 = open('C:/Users/jh/Desktop/data/raw/features2.txt', 'r')
features1 = feature_file1.readlines()
features2 = feature_file2.readlines()
feature_file1.close()
feature_file2.close()
# 读取标签数据
label_file = open('C:/Users/jh/Desktop/data/raw/label.txt', 'r')
labels = label_file.readlines()
label_file.close()
self.data_list = []
for i in range(len(edges)):
edge = edges[i].strip().split()
src = int(edge[0])
tgt = int(edge[1])
feature1 = [float(x) for x in features1[i].strip().split()]
feature2 = [float(x) for x in features2[i].strip().split()]
label = [int(x) for x in labels[i].strip().split()]
x = torch.tensor([feature1, feature2], dtype=torch.float)
y = torch.tensor(label, dtype=torch.float)
edge_index = torch.tensor([[src, tgt], [tgt, src]], dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)
self.data_list.append(data)
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
return self.data_list[index]
def get_mask(self):
masks = []
for i in range(len(self.data_list)):
mask = torch.zeros(len(self.data_list[i].x))
mask[:16] = 1
masks.append(mask.bool())
return masks
dataset = My_dataset('C:/Users/jh/Desktop/data/raw')
train_mask = dataset.get_mask()
代码解释:
- 首先定义了一个
My_dataset类,继承自Dataset类。 - 在类的初始化方法中,读取边数据、特征数据和标签数据,并将其存储到
data_list列表中。 - 然后,根据每个样本的边、特征和标签构造
Data对象,并将其添加到data_list列表中。 - 在
get_mask方法中,为每个样本生成一个掩码 (mask),其中前 16 个节点的掩码为 1,剩下的节点的掩码为 0。 - 最后,通过调用
get_mask方法获取训练集的掩码。
注意事项:
- 代码中的文件路径 'C:/Users/jh/Desktop/data/raw/' 请根据实际情况修改。
- 可以根据自己的需求对上述代码进行修改和调整,例如,改变训练集和验证集的节点数量,或添加其他数据预处理步骤。
- 在训练模型时,可以使用
train_mask掩码来选择训练集数据。
总结:
本文介绍了如何使用 PyG 库自定义数据集,并利用掩码将数据划分为训练集和验证集。这种方法可以方便地处理图数据,并进行模型训练和评估。
原文地址: https://www.cveoy.top/t/topic/mSIV 著作权归作者所有。请勿转载和采集!