PyTorch 模型测试和混淆矩阵可视化 - 评估第68次模型的类别准确率
import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mydataset import MyDataset
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix
import argparse
import torch.nn as nn
def test(model, dataset, criterion, num_classes): # 该函数用于对模型进行测试,输入参数包括模型、数据集、损失函数和类别数量。函数的功能包括计算模型在测试集上的准确率、平均损失,并返回特征列表、混淆矩阵和每个类别的准确率。
model.eval() # 将模型设置为评估模式
total_batch_num = 0. # 初始化总批次数为0,验证损失为0,预测结果列表和标签列表为空,特征列表为空的Tensor。
val_loss = 0
prediction = []
labels = []
feature_list = torch.tensor([])
if torch.cuda.is_available():
feature_list = feature_list.cuda()
class_correct = [0] * num_classes # 初始化每个类别的正确预测数为0
class_total = [0] * num_classes # 初始化每个类别的样本总数为0
for (step, i) in enumerate(dataset):
total_batch_num += 1
batch_x = i['data']
batch_y = i['label']
batch_x = torch.unsqueeze(batch_x, dim=1)
batch_x = batch_x.float()
if torch.cuda.is_available():
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()
feature, probs = model(batch_x)
batch_label = batch_y.unsqueeze(1).float()
feature_label = torch.cat((feature, batch_label), dim=1)
feature_list = torch.cat((feature_list, feature_label), dim=0)
loss = criterion(probs, batch_y)
_, pred = torch.max(probs, dim=1)
predi = pred.tolist()
label = batch_y.tolist()
val_loss += loss.item()
prediction.extend(predi)
labels.extend(label)
# 计算每个类别的正确预测数和样本总数
for j in range(len(label)):
if predi[j] == label[j]:
class_correct[label[j]] += 1
class_total[label[j]] += 1
accuracy = accuracy_score(labels, prediction)
C = confusion_matrix(labels, prediction)
class_accuracy = [class_correct[i] / class_total[i] for i in range(num_classes)] # 计算每个类别的精确率
return accuracy, val_loss / total_batch_num, feature_list, C, class_accuracy
# 定义args
parser = argparse.ArgumentParser()
rootpath = 'das_data'
parser.add_argument('--root2', type=str, default=rootpath + '/test', help='rootpath of valdata')
parser.add_argument('--txtpath2', type=str, default=rootpath + '/test/label.txt', help='path pf val_list')
parser.add_argument('--batch_size', type=int, default=32, help='batch size for testing')
args = parser.parse_args()
# 将args传递给相关函数
test_dataset = MyDataset(args.root2, args.txtpath2, transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
model = torch.load('./modelpth/68.pth')
test_dataset = MyDataset(args.root2, args.txtpath2, transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
criterion = nn.CrossEntropyLoss()
model.eval()
accuracy, _, _, C, class_accuracy = test(model, test_loader, criterion, num_classes=7)
print('Overall Accuracy:', accuracy)
for i in range(len(class_accuracy)):
print('Class {} Accuracy: {:.2f}%'.format(i, class_accuracy[i] * 100))
fig = plt.figure()
ax = fig.add_subplot(111)
df = pd.DataFrame(C)
sns.heatmap(df, fmt='g', annot=True, annot_kws={'size': 10}, xticklabels=['1', '2', '3', '4', '5', '6'], yticklabels=['1', '2', '3', '4', '5', '6'], cmap='Blues')
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
plt.savefig('confusion_matrix.jpg')
plt.show()
上述代码修改了测试函数,增加了对每个类别准确率的计算。通过修改后的代码,可以输出第68次模型的整体准确率和每个类别的准确率。
代码说明:
-
test函数的修改:- 添加参数
num_classes用于指定类别数量。 - 初始化两个列表
class_correct和class_total,分别用于存储每个类别的正确预测数和样本总数。 - 在循环中计算每个类别的正确预测数和样本总数。
- 计算每个类别的精确率并将其存储在列表
class_accuracy中。
- 添加参数
-
调用
test函数并输出结果:- 将
num_classes设置为 7,表示有 7 个类别。 - 输出整体准确率和每个类别的准确率。
- 将
-
混淆矩阵可视化:
- 使用
seaborn库绘制混淆矩阵的热力图。 - 保存热力图到
confusion_matrix.jpg文件中。 - 显示热力图。
- 使用
通过以上修改,可以更好地评估第68次模型的性能,并针对每个类别进行分析。
原文地址: https://www.cveoy.top/t/topic/fs15 著作权归作者所有。请勿转载和采集!