import\x20torch\nfrom\x20torch_geometric.data\x20import\x20Dataset,\x20Data\n\nclass\x20My_dataset(Dataset):\n\x20\x20def\x20__init__(self,\x20root):\n\x20\x20\x20\x20self.root\x20=\x20root\n\x20\x20\x20\x20self.edges\x20=\x20self.read_edges()\n\x20\x20\x20\x20self.features\x20=\x20self.read_features()\n\x20\x20\x20\x20self.labels\x20=\x20self.read_labels()\n\n\x20\x20def\x20read_edges(self):\n\x20\x20\x20\x20edges_path\x20=\x20self.root\x20+\x20'/edges.txt'\n\x20\x20\x20\x20with\x20open(edges_path,\x20'r')\x20as\x20f:\n\x20\x20\x20\x20\x20\x20edges\x20=\x20[list(map(int,\x20line.split()))\x20for\x20line\x20in\x20f.readlines()]\n\x20\x20\x20\x20return\x20torch.tensor(edges,\x20dtype=torch.long).t()\n\n\x20\x20def\x20read_features(self):\n\x20\x20\x20\x20features1_path\x20=\x20self.root\x20+\x20'/features1.txt'\n\x20\x20\x20\x20features2_path\x20=\x20self.root\x20+\x20'/features2.txt'\n\x20\x20\x20\x20with\x20open(features1_path,\x20'r')\x20as\x20f1,\x20open(features2_path,\x20'r')\x20as\x20f2:\n\x20\x20\x20\x20\x20\x20features1\x20=\x20[list(map(int,\x20line.split()))\x20for\x20line\x20in\x20f1.readlines()]\n\x20\x20\x20\x20\x20\x20features2\x20=\x20[list(map(int,\x20line.split()))\x20for\x20line\x20in\x20f2.readlines()]\n\x20\x20\x20\x20features\x20=\x20torch.tensor(features1\x20+\x20features2,\x20dtype=torch.float)\n\x20\x20\x20\x20return\x20features.view(-1,\x2020,\x202)\n\n\x20\x20def\x20read_labels(self):\n\x20\x20\x20\x20labels_path\x20=\x20self.root\x20+\x20'/label.txt'\n\x20\x20\x20\x20with\x20open(labels_path,\x20'r')\x20as\x20f:\n\x20\x20\x20\x20\x20\x20labels\x20=\x20[list(map(int,\x20line.split()))\x20for\x20line\x20in\x20f.readlines()]\n\x20\x20\x20\x20return\x20torch.tensor(labels,\x20dtype=torch.long).view(-1,\x2020)\n\n\x20\x20def\x20__len__(self):\n\x20\x20\x20\x20return\x20len(self.labels)\n\n\x20\x20def\x20__getitem__(self,\x20idx):\n\x20\x20\x20\x20edge_index\x20=\x20self.edges.clone()\n\x20\x20\x20\x20edge_index[0]\x20+=\x20idx\x20*\x2020\n\x20\x20\x20\x20edge_index[1]\x20+=\x20idx\x20*\x2020\n\n\x20\x20\x20\x20x\x20=\x20self.features[idx]\n\x20\x20\x20\x20y\x20=\x20self.labels[idx]\n\n\x20\x20\x20\x20train_mask\x20=\x20torch.zeros(20,\x20dtype=torch.uint8)\n\x20\x20\x20\x20train_mask[:16]\x20=\x201\n\x20\x20\x20\x20val_mask\x20=\x20torch.zeros(20,\x20dtype=torch.uint8)\n\x20\x20\x20\x20val_mask[16:]\x20=\x201\n\n\x20\x20\x20\x20data\x20=\x20Data(x=x,\x20edge_index=edge_index,\x20y=y,\x20train_mask=train_mask,\x20val_mask=val_mask)\n\n\x20\x20\x20\x20return\x20data\n\n#\x20使用方法示例:\npython\ndataset\x20=\x20My_dataset('C:/Users/jh/Desktop/data/raw')\ndata\x20=\x20dataset[0]\n

PyG 自定义数据集:从图片特征、边信息到图数据

原文地址: https://www.cveoy.top/t/topic/mTWO 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录