PyTorch 验证集损失和准确率计算
PyTorch 验证集损失和准确率计算
这段代码演示如何在 PyTorch 中计算模型在验证集上的损失和准确率,并打印结果。代码使用 torch.no_grad() 来禁用梯度计算,并使用 custom_loss 函数计算自定义损失函数。
with torch.no_grad():
val_loss = 0.0
val_correct = 0
val_total = 0
for j, val_input_tensor in enumerate(val_tensors):
val_output = network(val_input_tensor)
# 计算相似度
val_target_similarity = F.cosine_similarity(val_output, tensor_list[j].unsqueeze(0), dim=1)
val_other_similarities = []
for k, tensor in enumerate(tensor_list):
if k != j:
similarity = F.cosine_similarity(val_output, tensor.unsqueeze(0), dim=1)
val_other_similarities.append(similarity)
val_other_similarities = torch.cat(val_other_similarities)
val_labels = [torch.tensor([1, 0, 0, 0]), torch.tensor([0, 1, 0, 0]), torch.tensor([0, 0, 1, 0]), torch.tensor([1, 1, 1, 1])]
val_label_index = torch.argmax(tensor_list[j])
val_label = val_labels[val_label_index]
if val_target_similarity > torch.max(val_other_similarities):
val_predicted_index = torch.argmax(val_output)
if torch.all(torch.eq(val_label, val_labels[val_predicted_index])):
val_correct += 1
val_total += 1
val_loss += custom_loss(val_output, tensor_list[j]).item()
# 计算验证集上的损失和准确率
val_loss /= val_total
val_accuracy = 100 * val_correct / val_total
# 打印验证信息
print('Validation Loss: %.3f, Accuracy: %.2f%%' % (val_loss, val_accuracy))
代码解释:
torch.no_grad():禁用梯度计算,因为在验证阶段我们不需要计算梯度。- 初始化变量
val_loss、val_correct和val_total分别用于累计验证集上的损失、正确预测的数量和总样本数量。 - 循环遍历验证集样本,计算相似度、判断预测是否正确,并累加相关变量。
- 使用
custom_loss函数计算自定义损失函数,并累加到val_loss中。 - 计算验证集上的平均损失和准确率。
- 打印验证信息,包括损失和准确率。
关键点:
- 代码使用
F.cosine_similarity计算样本之间的相似度。 val_labels列表定义了不同样本的标签。- 代码使用
torch.argmax函数获取预测结果的索引,并与标签进行比较判断预测是否正确。 - 代码使用
custom_loss函数计算自定义损失函数。 - 代码打印验证集上的平均损失和准确率,以评估模型的性能。
注意:
val_tensors和tensor_list是验证集样本和对应的标签列表。network是模型对象。custom_loss是自定义的损失函数。- 代码中使用
%.3f和%.2f格式化输出损失和准确率,可以根据需要修改格式。
希望这个解释对您有所帮助!如果您还有其他问题,请随时提问。
原文地址: https://www.cveoy.top/t/topic/NgD 著作权归作者所有。请勿转载和采集!