import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mydataset import MyDataset
import argparse
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss

# 定义命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--root2', type=str, default='path_to_root2', help='测试集根目录')
parser.add_argument('--txtpath2', type=str, default='path_to_txtpath2', help='测试集txt文件路径')
parser.add_argument('--batch_size', type=int, default=32, help='测试批次大小')
args = parser.parse_args()

# 加载训练好的模型
model = torch.load('./modelpth/68.pth')

# 创建测试集和数据加载器
test_dataset = MyDataset(args.root2, args.txtpath2, transform=None)  # 使用args.txtpath2
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)

# 定义损失函数
criterion = CrossEntropyLoss()

# 设置模型为评估模式
model.eval()

# 测试模型并获取指标
accuracy, _, _, C = test(model, test_loader, criterion)  # 假设你有一个名为'test'的函数来测试模型

# 使用seaborn和matplotlib可视化混淆矩阵
fig = plt.figure()
ax = fig.add_subplot(111)
df = pd.DataFrame(C)
sns.heatmap(df, fmt='g', annot=True, annot_kws={'size': 10}, xticklabels=['1', '2', '3', '4', '5', '6'], yticklabels=['1', '2', '3', '4', '5', '6'], cmap='Blues')
ax.set_xlabel('预测标签')
ax.set_ylabel('真实标签')
plt.savefig('confusion_matrix.jpg')
plt.show()

代码说明:

  1. 使用 argparse 模块定义命令行参数,方便用户指定测试集路径和批次大小等信息。
  2. 使用 torch.load 函数加载训练好的模型。
  3. 根据用户提供的路径创建 MyDataset 实例,并使用 DataLoader 加载测试数据。
  4. 定义损失函数 CrossEntropyLoss
  5. 调用 model.eval() 将模型设置为评估模式。
  6. 调用自定义的 test 函数对模型进行测试,并获取准确率和混淆矩阵等指标。
  7. 使用 pandasseaborn 库将混淆矩阵可视化,并使用 matplotlib 保存图像。

修改说明:

test_dataset = MyDataset(args.root2, das_data + '/test', transform=None) 修改为 test_dataset = MyDataset(args.root2, args.txtpath2, transform=None),确保使用命令行参数指定的测试集路径。


原文地址: https://www.cveoy.top/t/topic/fsrQ 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录