PyG 数据集创建与 GCN 模型训练:解决边索引超出边界错误
PyG 数据集创建与 GCN 模型训练:解决边索引超出边界错误
本文将介绍如何使用 PyG 创建数据集并训练 GCN 模型,并重点讲解如何解决 GCNConv 中出现的边索引超出边界错误。
数据集信息
- 特征文件为 'C:\Users\jh\Desktop\data\input\images\i.png_j.png' 的所有图片,图片尺寸为 40 x 40,共有 42 个时刻的图数据。
- 每个时刻都有 37 张图片,即 37 个节点,其中 i 表示时刻 (1 到 42),j 表示节点 (0 到 36)。
- 每个节点有 8 个标签,储存在 'C:\Users\jh\Desktop\data\input\labels\i_j.txt' 文本文件中,标签用空格隔开。
- 边的关系储存在 'C:\Users\jh\Desktop\data\input\edges_L.csv' csv 文件中,第一列为源节点,第二列为目标节点,共有 61 条无向边。
错误分析
根据报错信息 RuntimeError: index 4 is out of bounds for dimension 0 with size 3,可以看出是在 GCNConv 的 forward 函数中出现了错误。具体错误是在 gcn_norm 函数中,scatter 函数的 index 参数中存在索引超出边界的情况。
代码分析发现,在创建数据集时,边索引是从 csv 文件中读取的,而这里的边索引是从 torch.tensor 转换而来的,所以可能出现索引错误。
解决方案
- 检查 csv 文件中的边索引是否正确,确保没有超出节点范围的索引。
- 在创建数据集时,使用 Pandas 库读取 csv 文件的时候,将
header参数设置为None,以避免将第一行作为列名。
修改后的代码
# 加载数据并创建PyG数据集类:
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root, transform=None, pre_transform=None):
self.edges = pd.read_csv(os.path.join(root, 'input', 'edges_L.csv'), header=None)
self.transform = transform
self.pre_transform = pre_transform
# 读取特征和标签数据
self.features = []
self.labels = []
self.cnn = CNN()
for i in range(1, 43):
for j in range(37):
# 读取特征
img_name = os.path.join(root, 'input', 'images', '{}.png_{}.png'.format(i, j))
img = Image.open(img_name).convert('RGBA').resize((40, 40), resample=Image.BILINEAR)
img_tensor = transforms.ToTensor()(img)
feature = self.cnn(img_tensor.unsqueeze(0))
self.features.append(feature)
# 读取标签
label_name = os.path.join(root, 'input', 'labels', '{}_{}.txt'.format(i, j))
with open(label_name, 'r') as f:
labels = [int(x) for x in f.readline().strip().split()]
self.labels.append(labels)
# 将特征调整维度为[batch_size, num_node_features, width, height]
self.features = torch.stack(self.features, dim=0)
self.labels = torch.tensor(self.labels)
# Calculate the total number of nodes
self.num_nodes = len(self.labels)
# 添加边的关系
self.edges = torch.tensor(self.edges.values, dtype=torch.long).t().contiguous()
self.cnn = CNN() # 添加CNN模型
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
x = self.features[idx]
y = self.labels[idx]
# Define graph-wide train_mask and val_mask
train_mask = torch.zeros(self.num_nodes, dtype=torch.bool)
val_mask = torch.zeros(self.num_nodes, dtype=torch.bool)
# Set train_mask for the first 30 nodes in each network, and val_mask for the last 7 nodes
node_id = idx % 37
network_id = idx // 37
if node_id < 30 and (node_id + network_id * 37) <self.num_nodes:
train_mask[node_id + network_id * 37] = 1
elif (node_id + network_id * 37) <self.num_nodes:
val_mask[node_id + network_id * 37] = 1
data = Data(x=x, edge_index=self.edges, y=y, train_mask=train_mask, val_mask=val_mask)
if self.pre_transform is not None:
data = self.pre_transform(data)
if self.transform is not None:
data.x = self.transform(data.x)
return data
# 创建GCN模型
class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_node_features, 8)
self.conv2 = GCNConv(8, 16)
self.conv3 = GCNConv(16, num_classes)
self.cnn = CNN()
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.cnn(x) # 使用CNN模型提取图像特征
print(x.shape)
print(edge_index.shape)
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
return x
修改后,重新运行代码,应该就能解决索引超出边界的错误了。
原文地址: https://www.cveoy.top/t/topic/pe7y 著作权归作者所有。请勿转载和采集!