PyTorch深度学习模型测试与混淆矩阵可视化
import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mydataset import MyDataset
import argparse
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
# 定义命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--root2', type=str, default='path_to_root2', help='测试集根目录')
parser.add_argument('--txtpath2', type=str, default='path_to_txtpath2', help='测试集txt文件路径')
parser.add_argument('--batch_size', type=int, default=32, help='测试批次大小')
args = parser.parse_args()
# 加载训练好的模型
model = torch.load('./modelpth/68.pth')
# 创建测试集和数据加载器
test_dataset = MyDataset(args.root2, args.txtpath2, transform=None) # 使用args.txtpath2
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
# 定义损失函数
criterion = CrossEntropyLoss()
# 设置模型为评估模式
model.eval()
# 测试模型并获取指标
accuracy, _, _, C = test(model, test_loader, criterion) # 假设你有一个名为'test'的函数来测试模型
# 使用seaborn和matplotlib可视化混淆矩阵
fig = plt.figure()
ax = fig.add_subplot(111)
df = pd.DataFrame(C)
sns.heatmap(df, fmt='g', annot=True, annot_kws={'size': 10}, xticklabels=['1', '2', '3', '4', '5', '6'], yticklabels=['1', '2', '3', '4', '5', '6'], cmap='Blues')
ax.set_xlabel('预测标签')
ax.set_ylabel('真实标签')
plt.savefig('confusion_matrix.jpg')
plt.show()
代码说明:
- 使用
argparse模块定义命令行参数,方便用户指定测试集路径和批次大小等信息。 - 使用
torch.load函数加载训练好的模型。 - 根据用户提供的路径创建
MyDataset实例,并使用DataLoader加载测试数据。 - 定义损失函数
CrossEntropyLoss。 - 调用
model.eval()将模型设置为评估模式。 - 调用自定义的
test函数对模型进行测试,并获取准确率和混淆矩阵等指标。 - 使用
pandas和seaborn库将混淆矩阵可视化,并使用matplotlib保存图像。
修改说明:
将 test_dataset = MyDataset(args.root2, das_data + '/test', transform=None) 修改为 test_dataset = MyDataset(args.root2, args.txtpath2, transform=None),确保使用命令行参数指定的测试集路径。
原文地址: https://www.cveoy.top/t/topic/fsrQ 著作权归作者所有。请勿转载和采集!