PyTorch模型测试代码:性能评估与混淆矩阵分析
import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mydataset import MyDataset
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix
import argparse
import torch.nn as nn
def test(model, dataset, criterion, num_classes):
model.eval()
total_batch_num = 0.
val_loss = 0
prediction = []
labels = []
feature_list = torch.tensor([])
if torch.cuda.is_available():
feature_list = feature_list.cuda()
class_correct = [0] * num_classes # 初始化每个类别的正确预测数为0
class_total = [0] * num_classes # 初始化每个类别的样本总数为0
for (step, i) in enumerate(dataset):
total_batch_num += 1
batch_x = i['data']
batch_y = i['label'] - 1
batch_x = torch.unsqueeze(batch_x, dim=1)
batch_x = batch_x.float()
if torch.cuda.is_available():
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()
feature, probs = model(batch_x)
batch_label = batch_y.unsqueeze(1).float()
feature_label = torch.cat((feature, batch_label), dim=1)
feature_list = torch.cat((feature_list, feature_label), dim=0)
loss = criterion(probs, batch_y)
_, pred = torch.max(probs, dim=1)
predi = pred.tolist()
label = batch_y.tolist()
val_loss += loss.item()
prediction.extend(predi)
labels.extend(label)
# 计算每个类别的正确预测数和样本总数
for j in range(len(label)):
if predi[j] == label[j]:
class_correct[label[j]] += 1
class_total[label[j]] += 1
accuracy = accuracy_score(labels, prediction)
C = confusion_matrix(labels, prediction)
class_accuracy = [class_correct[i] / class_total[i] for i in range(num_classes)] # 计算每个类别的精确率
return accuracy, val_loss / total_batch_num, feature_list, C, class_accuracy
# 调用测试函数
accuracy, _, _, C, class_accuracy = test(model, test_loader, criterion, num_classes=6)
print("Overall Accuracy:", accuracy)
for i in range(len(class_accuracy)):
print("Class {} Accuracy: {:.2f}%".format(i, class_accuracy[i] * 100))
代码说明:
- 导入库: 导入必要的库,包括 PyTorch、Matplotlib、Pandas、Seaborn 等用于数据处理、模型构建、可视化等的库。
- 测试函数: 定义
test()函数,用于测试模型性能。该函数接收模型、测试集、损失函数和类别数作为参数,并返回总体准确率、平均损失、特征列表、混淆矩阵以及每个类别的准确率。 - 数据预处理: 在测试函数中,代码首先将测试集的标签进行调整,使其从 2-7 变为 1-6,方便后续计算。
- 模型评估: 使用
model.eval()将模型设置为评估模式,并使用循环迭代测试集,计算每个批次的损失、预测结果以及每个类别的正确预测数和样本总数。 - 性能指标计算: 使用
accuracy_score()和confusion_matrix()计算总体准确率和混淆矩阵。 - 结果输出: 打印总体准确率、每个类别的准确率以及混淆矩阵,方便分析模型性能。
注意事项:
- 请确保已加载训练好的模型 (
model) 和测试集数据 (test_loader),以及定义了相应的损失函数 (criterion)。 - 类别数 (
num_classes) 应与数据集中实际类别数一致。 - 本代码示例针对 6 类进行测试,如果您的数据集类别数不同,请根据实际情况修改代码。
使用示例:
假设您已完成模型训练,并加载了训练好的模型 model 和测试集 test_loader,以及定义了损失函数 criterion,则可以使用以下代码进行测试:
# 加载训练好的模型
model = torch.load('model.pth')
# 创建测试集数据加载器
test_loader = DataLoader(MyDataset(test_data), batch_size=32, shuffle=False)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 进行模型测试
accuracy, _, _, C, class_accuracy = test(model, test_loader, criterion, num_classes=6)
# 打印结果
print("Overall Accuracy:", accuracy)
for i in range(len(class_accuracy)):
print("Class {} Accuracy: {:.2f}%".format(i, class_accuracy[i] * 100))
代码分析:
本代码示例展示了如何使用 PyTorch 测试训练好的模型,并进行性能评估。代码涵盖了数据预处理、模型加载、损失函数等关键步骤,并提供了详细的注释,方便理解代码逻辑。通过分析测试结果,可以评估模型的性能,并根据结果对模型进行进一步优化。
原文地址: https://www.cveoy.top/t/topic/fttK 著作权归作者所有。请勿转载和采集!