PyTorch模型测试及混淆矩阵可视化
import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mydataset import MyDataset
import argparse
from torch.utils.data import DataLoader
from torch import nn
# 定义命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--root2', type=str, default='path_to_root2', help='测试数据集根目录')
parser.add_argument('--txtpath2', type=str, default='path_to_txtpath2', help='测试数据集txt文件路径')
parser.add_argument('--batch_size', type=int, default=32, help='测试批次大小')
args = parser.parse_args()
# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载测试数据集
test_dataset = MyDataset(args.root2, args.txtpath2, transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
# 加载训练好的模型
model = torch.load('./modelpth/68.pth').to(device)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 测试模型
def test(model, test_loader, criterion):
model.eval()
correct = 0
total = 0
C = [[0 for _ in range(6)] for _ in range(6)] # 初始化混淆矩阵
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
for t, p in zip(labels.view(-1), predicted.view(-1)):
C[t.long()][p.long()] += 1 # 更新混淆矩阵
accuracy = 100 * correct / total
return accuracy, correct, total, C
# 测试并获取混淆矩阵
accuracy, _, _, C = test(model, test_loader, criterion)
# 绘制混淆矩阵
fig = plt.figure()
ax = fig.add_subplot(111)
df = pd.DataFrame(C)
sns.heatmap(df, fmt='g', annot=True, annot_kws={'size': 10}, xticklabels=['1', '2', '3', '4', '5', '6'], yticklabels=['1', '2', '3', '4', '5', '6'], cmap='Blues')
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
plt.savefig('confusion_matrix.jpg')
plt.show()
# 打印测试结果
print(f'Test Accuracy of the model on the test images: {accuracy:.2f} %')
使用方法:
- 确保已安装必要的库:PyTorch、Matplotlib、Pandas、Seaborn。
- 将代码保存为Python文件,例如
test_model.py。 - 在命令行中运行以下命令,将
path_to_root2和path_to_txtpath2替换为实际路径:
python test_model.py --root2 /path/to/your/test/dataset --txtpath2 /path/to/your/test/dataset.txt
代码说明:
- 代码首先定义了命令行参数解析器,方便用户指定测试数据集路径和批次大小。
- 然后加载测试数据集和预训练的模型。
test()函数用于测试模型并在测试集上计算准确率和混淆矩阵。- 最后使用Seaborn库绘制混淆矩阵,并将结果保存为图片。
- 同时打印了模型在测试集上的准确率。
注意:
- 需要将代码中的占位符路径替换为实际路径。
MyDataset类需要根据实际数据集进行修改。- 模型路径
./modelpth/68.pth需要根据实际情况修改。 - 混淆矩阵的标签数量和类别名称需要根据实际情况进行调整。
- 代码中使用了GPU加速,如果需要使用CPU,请将
device设置为'cpu'。
原文地址: https://www.cveoy.top/t/topic/fssO 著作权归作者所有。请勿转载和采集!