import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from sklearn.metrics import confusion_matrix
from mydataset import MyDataset
from torch.utils.data import DataLoader

# 加载模型
model_path = "./modelpth/68.pth"  # 模型保存的路径
model = torch.load(model_path)

# 加载数据集
root_path = "das_data"
test_dataset = MyDataset(root_path + "/test", root_path + "/test/label.txt")
test_loader = DataLoader(dataset=test_dataset, batch_size=100, shuffle=True, num_workers=0)

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 测试模型
model.eval()
predictions = []
labels = []
for i in test_loader:
    batch_x = i['data']
    batch_y = i['label']
    batch_x = torch.unsqueeze(batch_x, dim=1)
    batch_x = batch_x.float()

    if torch.cuda.is_available():
        batch_x = batch_x.cuda()
        batch_y = batch_y.cuda()

    _, probs = model(batch_x)
    _, preds = torch.max(probs, dim=1)

    predictions.extend(preds.tolist())
    labels.extend(batch_y.tolist())

# 计算混淆矩阵
confusion = confusion_matrix(labels, predictions)

# 绘制混淆矩阵图像
fig = plt.figure()
ax = fig.add_subplot(111)
df = pd.DataFrame(confusion)
sns.heatmap(df, fmt='g', annot=True, annot_kws={'size': 10}, cmap='Blues')

ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')

plt.savefig('./confusion_matrix.jpg')
plt.show()

运行上述代码即可生成epoch为68时的混淆矩阵图像。

CNN模型在DAS数据上的混淆矩阵可视化

原文地址: http://www.cveoy.top/t/topic/bDTp 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录