CNN模型混淆矩阵可视化:Epoch 68
import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from models import CNN # 确保你已经导入了CNN模型的定义
from mydataset import MyDataset # 确保你已经导入了MyDataset的定义
from torch.utils.data import DataLoader
import torch.nn as nn
# 加载模型
model = torch.load('./modelpth/68.pth')
# 定义测试数据集和损失函数
test_dataset = MyDataset(args.root2, args.txtpath2, transform=None) # 使用你实际的路径和参数
# 确保你的MyDataset和DataLoader的定义与代码中一致
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
criterion = nn.CrossEntropyLoss()
# 进行模型测试
model.eval()
# 确保你的test函数与代码中一致
accuracy, _, _, C = test(model, test_loader, criterion)
# 绘制混淆矩阵图
fig = plt.figure()
ax = fig.add_subplot(111)
df = pd.DataFrame(C)
sns.heatmap(df, fmt='g', annot=True, annot_kws={'size': 10}, xticklabels=['1', '2', '3', '4', '5', '6'], yticklabels=['1', '2', '3', '4', '5', '6'], cmap='Blues')
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
plt.savefig('confusion_matrix_epoch_68.jpg')
plt.show()
解释:
- 导入库: 导入必要的库,包括
torch用于模型操作,matplotlib.pyplot用于绘制图形,pandas用于处理DataFrame,seaborn用于创建热力图,以及你定义的CNN模型和MyDataset类。 - 加载模型: 使用
torch.load加载训练好的模型文件,这里是./modelpth/68.pth。 - 定义测试数据集和损失函数: 定义测试数据集
test_dataset,它应该包含测试数据,以及相应的DataLoader用于批次处理数据。此外,定义损失函数criterion。 - 模型测试: 使用
test函数进行模型测试,并获取混淆矩阵C。 - 绘制混淆矩阵图: 使用
seaborn的heatmap函数绘制混淆矩阵图,并使用plt.savefig保存图片。
注意:
- 确保你已经定义了
CNN模型、MyDataset类和test函数,并根据你的具体情况调整代码中的参数,例如文件路径、数据集路径、批次大小等。 - 如果你没有定义
CNN模型或MyDataset类,你需要根据你的项目需求进行定义。 - 你还可以根据需要修改绘图参数,例如颜色、标签等。
原文地址: https://www.cveoy.top/t/topic/bDTN 著作权归作者所有。请勿转载和采集!