PyG GCN模型错误修复:解决'Dimension out of range'问题
The error occurs in the 'forward' method of the 'GCN' class. The error message suggests that the dimension of 'x' is out of range. This could be due to the incorrect dimensionality of the input data.
In the provided code, the 'x' tensor is obtained from 'data.x', which should be the node features. However, in the 'MyDataset' class, 'data.x' is not set. To fix this, you need to set the 'x' attribute in the 'Data' object in the 'getitem' method of the 'MyDataset' class.
Here's the updated code:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
# 加载数据并创建PyG数据集类:
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root, transform=None, pre_transform=None):
self.edges = pd.read_csv(os.path.join(root, 'edges_L.csv'))
self.transform = transform
self.pre_transform = pre_transform
self.num_classes = 8 # 修改成8类标签
self.features = []
self.labels = []
for i in range(1, 43):
for j in range(37):
image_path = os.path.join(root, 'images', f'{i}_{j}.png') # 修改图片文件名及路径
image = self.extract_features(image_path)
self.features.append(image)
label_path = os.path.join(root, 'labels', f'{i}_{j}.txt') # 修改标签文件名
labels = pd.read_csv(label_path, header=None, sep=' ', encoding='ansi')
self.labels.append(torch.tensor(labels.values.squeeze(), dtype=torch.long)) # 修改标签数据类型为long型
def __len__(self):
return len(self.features)
def __getitem__(self, idx):
edge_index = torch.tensor(self.edges.values, dtype=torch.long).t().contiguous()
x = self.features[idx]
y = self.labels[idx]
# 定义图数据的train_mask和val_mask
train_mask = torch.zeros(y.size(0), dtype=torch.bool)
val_mask = torch.zeros(y.size(0), dtype=torch.bool)
train_mask[:30] = 1
val_mask[30:] = 1
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask)
if self.transform is not None:
data = self.transform(data)
return data
def extract_features(self, image_path):
# 根据图片路径提取颜色特征,这里假设使用某种方法提取特征,返回一个特征向量
return torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
# 定义GCN模型:
class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_node_features, 8)
self.conv2 = GCNConv(8, 16)
self.conv3 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
return x
# 创建训练和验证模型:
def train_model(dataset, model, optimizer, device):
model.train()
total_loss = 0.0
for data in dataset:
data = data.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataset)
def validate_model(dataset, model, device):
model.eval()
correct = 0
total = 0
for data in dataset:
data = data.to(device)
output = model(data)
_, predicted = torch.max(output[data.val_mask], 1)
total += data.val_mask.sum().item()
correct += (predicted == data.y[data.val_mask]).sum().item()
return correct / total
# 加载数据集、创建模型、定义优化器和训练循环,以及验证模型:
if __name__ == '__main__':
dataset = MyDataset(root="C:\Users\jh\Desktop\data\input")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(num_node_features=8, num_classes=8).to(device) # 修改num_node_features为8
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
train_dataset, val_dataset = train_test_split(dataset, test_size=0.1)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
epochs = 2
for epoch in range(epochs):
train_loss = train_model(train_loader, model, optimizer, device)
print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}')
val_accuracy = validate_model(val_loader, model, device)
print(f'Val_Acc: {val_accuracy:.4f}')
This should fix the dimension out of range error and allow the code to run without any issues.
原文地址: https://www.cveoy.top/t/topic/pbmv 著作权归作者所有。请勿转载和采集!