CNN模型在DAS数据上的混淆矩阵可视化
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix
from mydataset import MyDataset
from torch.utils.data import DataLoader
# 加载模型
model_path = "./modelpth/68.pth" # 模型保存的路径
model = torch.load(model_path)
# 加载数据集
root_path = "das_data"
test_dataset = MyDataset(root_path + "/test", root_path + "/test/label.txt")
test_loader = DataLoader(dataset=test_dataset, batch_size=100, shuffle=True, num_workers=0)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 测试模型
model.eval()
predictions = []
labels = []
for i in test_loader:
batch_x = i['data']
batch_y = i['label']
batch_x = torch.unsqueeze(batch_x, dim=1)
batch_x = batch_x.float()
if torch.cuda.is_available():
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()
_, probs = model(batch_x)
_, preds = torch.max(probs, dim=1)
predictions.extend(preds.tolist())
labels.extend(batch_y.tolist())
# 计算混淆矩阵
confusion = confusion_matrix(labels, predictions)
# 绘制混淆矩阵图像
fig = plt.figure()
ax = fig.add_subplot(111)
df = pd.DataFrame(confusion)
sns.heatmap(df, fmt='g', annot=True, annot_kws={'size': 10}, cmap='Blues')
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
plt.savefig('./confusion_matrix.jpg')
plt.show()
运行上述代码即可生成epoch为68时的混淆矩阵图像。
原文地址: http://www.cveoy.top/t/topic/bDTp 著作权归作者所有。请勿转载和采集!