PyG 数据集创建与 GCN 模型训练:解决边索引超出边界错误

本文将介绍如何使用 PyG 创建数据集并训练 GCN 模型,并重点讲解如何解决 GCNConv 中出现的边索引超出边界错误。

数据集信息

  • 特征文件为 'C:\Users\jh\Desktop\data\input\images\i.png_j.png' 的所有图片,图片尺寸为 40 x 40,共有 42 个时刻的图数据。
  • 每个时刻都有 37 张图片,即 37 个节点,其中 i 表示时刻 (1 到 42),j 表示节点 (0 到 36)。
  • 每个节点有 8 个标签,储存在 'C:\Users\jh\Desktop\data\input\labels\i_j.txt' 文本文件中,标签用空格隔开。
  • 边的关系储存在 'C:\Users\jh\Desktop\data\input\edges_L.csv' csv 文件中,第一列为源节点,第二列为目标节点,共有 61 条无向边。

错误分析

根据报错信息 RuntimeError: index 4 is out of bounds for dimension 0 with size 3,可以看出是在 GCNConv 的 forward 函数中出现了错误。具体错误是在 gcn_norm 函数中,scatter 函数的 index 参数中存在索引超出边界的情况。

代码分析发现,在创建数据集时,边索引是从 csv 文件中读取的,而这里的边索引是从 torch.tensor 转换而来的,所以可能出现索引错误。

解决方案

  1. 检查 csv 文件中的边索引是否正确,确保没有超出节点范围的索引。
  2. 在创建数据集时,使用 Pandas 库读取 csv 文件的时候,将 header 参数设置为 None,以避免将第一行作为列名。

修改后的代码

# 加载数据并创建PyG数据集类:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        self.edges = pd.read_csv(os.path.join(root, 'input', 'edges_L.csv'), header=None)
        self.transform = transform
        self.pre_transform = pre_transform
        
        # 读取特征和标签数据
        self.features = []
        self.labels = []
        self.cnn = CNN()
        for i in range(1, 43):
            for j in range(37):
                # 读取特征
                img_name = os.path.join(root, 'input', 'images', '{}.png_{}.png'.format(i, j))
                img = Image.open(img_name).convert('RGBA').resize((40, 40), resample=Image.BILINEAR)
                img_tensor = transforms.ToTensor()(img)
                feature = self.cnn(img_tensor.unsqueeze(0))
                self.features.append(feature)
                
                # 读取标签
                label_name = os.path.join(root, 'input', 'labels', '{}_{}.txt'.format(i, j))
                with open(label_name, 'r') as f:
                    labels = [int(x) for x in f.readline().strip().split()]
                self.labels.append(labels)
        
        # 将特征调整维度为[batch_size, num_node_features, width, height]
        self.features = torch.stack(self.features, dim=0)
        self.labels = torch.tensor(self.labels)
        
        # Calculate the total number of nodes
        self.num_nodes = len(self.labels)
        
        # 添加边的关系
        self.edges = torch.tensor(self.edges.values, dtype=torch.long).t().contiguous()
        
        self.cnn = CNN()  # 添加CNN模型
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        x = self.features[idx]
        y = self.labels[idx]
        
        # Define graph-wide train_mask and val_mask
        train_mask = torch.zeros(self.num_nodes, dtype=torch.bool)
        val_mask = torch.zeros(self.num_nodes, dtype=torch.bool)
        
        # Set train_mask for the first 30 nodes in each network, and val_mask for the last 7 nodes
        node_id = idx % 37
        network_id = idx // 37
        if node_id < 30 and (node_id + network_id * 37) <self.num_nodes:
            train_mask[node_id + network_id * 37] = 1
        elif (node_id + network_id * 37) <self.num_nodes:
            val_mask[node_id + network_id * 37] = 1
        
        data = Data(x=x, edge_index=self.edges, y=y, train_mask=train_mask, val_mask=val_mask)
        
        if self.pre_transform is not None:
            data = self.pre_transform(data)
        
        if self.transform is not None:
            data.x = self.transform(data.x)
        
        return data

# 创建GCN模型
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 8)
        self.conv2 = GCNConv(8, 16)
        self.conv3 = GCNConv(16, num_classes)
        self.cnn = CNN()
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.cnn(x)  # 使用CNN模型提取图像特征
        print(x.shape)
        print(edge_index.shape)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        
        return x

修改后,重新运行代码,应该就能解决索引超出边界的错误了。

PyG 数据集创建与 GCN 模型训练:解决边索引超出边界错误

原文地址: https://www.cveoy.top/t/topic/pe7y 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录