解决 PyTorch 中的 IndexError: tensors used as indices must be long, int, byte or bool tensors 错误
The error message indicates that the tensors used as indices in the F.cross_entropy function must be of type long, int, byte, or bool.
To fix this issue, you need to convert train_mask.squeeze() and labels[train_mask.squeeze()] to the appropriate type. You can use the to() method to convert the tensors to torch.long.
Here's the modified code:
def train_model(data_loader, model, optimizer, device):
model.train()
total_loss = 0
for data in data_loader:
data = data.to(device)
train_mask = data.train_mask.unsqueeze(1).to(torch.long)
labels = data.y.unsqueeze(1)
output = model(data).view(-1)
optimizer.zero_grad()
loss = F.cross_entropy(output[train_mask.squeeze()], labels[train_mask.squeeze()])
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(data_loader)
By converting train_mask.squeeze() to torch.long, the IndexError should be resolved.
原文地址: https://www.cveoy.top/t/topic/pcyH 著作权归作者所有。请勿转载和采集!