PyTorch 模型测试代码解析:逐行解释和注释
这段代码是一个 PyTorch 模型训练和测试的循环。下面是代码的详细解释和注释:
with torch.no_grad():
# 禁用梯度计算,因为在测试阶段我们不需要计算梯度
# 遍历测试集数据
for imgs, labels in test_loader:
# 将标签转换为torch.long类型的张量
labels = torch.tensor(labels, dtype=torch.long)
# 将图像和标签移到设备(如GPU)上进行计算
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
# 通过模型计算图像的输出
outputs = model(imgs)
# 计算损失函数
loss = loss_fn(outputs, labels)
# 累积测试集上的损失
test_loss += loss.item()
# 计算模型的预测概率
ps = torch.exp(outputs)
# 获取最高概率和对应的类别
top_p, top_class = ps.topk(1, dim=1)
# 检查预测类别是否与真实标签相等
equals = top_class == labels.view(*top_class.shape)
# 计算准确率并累积
accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
# 计算平均训练损失并将其添加到train_losses列表中
train_losses.append(running_loss / len(train_loader))
# 计算平均测试损失并将其添加到test_losses列表中
test_losses.append(test_loss / len(test_loader))
这段代码首先使用torch.no_grad()上下文管理器禁用梯度计算,因为在测试阶段我们不需要计算梯度。然后通过遍历test_loader中的图像和标签进行测试。在每次迭代中,将标签转换为torch.long类型的张量,并将图像和标签移到设备上进行计算。然后,通过模型计算图像的输出,并计算损失函数。接下来,将测试集上的损失累积到test_loss变量中。然后,通过计算模型输出的指数函数来获取预测的概率,并使用topk方法获取最高概率和对应的类别。最后,将预测类别与真实标签进行比较,计算准确率并将其累积到accuracy变量中。
在循环结束后,计算平均训练损失并将其添加到train_losses列表中。然后,计算平均测试损失并将其添加到test_losses列表中。
原文地址: https://www.cveoy.top/t/topic/fx3I 著作权归作者所有。请勿转载和采集!