PyTorch深度学习模型测试与混淆矩阵可视化
import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mydataset import MyDataset
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix
import argparse
import torch.nn as nn
def test(model, dataset, criterion, epoch):
model.eval()
total_batch_num = 0.
val_loss = 0
prediction = []
labels = []
feature_list = torch.tensor([])
if torch.cuda.is_available():
feature_list = feature_list.cuda()
accuracy_dict = {}
for (step, i) in enumerate(dataset):
total_batch_num = total_batch_num + 1
batch_x = i['data']
batch_y = i['label']
batch_x = torch.unsqueeze(batch_x, dim=1)
batch_x = batch_x.float()
if torch.cuda.is_available():
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()
feature, probs = model(batch_x)
batch_label = batch_y.unsqueeze(1).float()
feature_label = torch.cat((feature, batch_label), dim=1)
feature_list = torch.cat((feature_list, feature_label), dim=0)
loss = criterion(probs, batch_y)
_, pred = torch.max(probs, dim=1)
predi = pred.tolist()
label = batch_y.tolist()
val_loss += loss.item()
prediction.extend(predi)
labels.extend(label)
accuracy = accuracy_score(labels, prediction)
C = confusion_matrix(labels, prediction)
if epoch == 68:
unique_labels = set(labels)
for label in unique_labels:
accuracy_dict[label] = accuracy_score([label] * labels.count(label), [label] * labels.count(label))
return accuracy, val_loss / total_batch_num, feature_list, C, accuracy_dict
# 定义args
parser = argparse.ArgumentParser()
rootpath = 'das_data'
parser.add_argument('--root2', type=str, default=rootpath + '/test', help='rootpath of valdata')
parser.add_argument('--txtpath2', type=str, default=rootpath + '/test/label.txt', help='path pf val_list')
parser.add_argument('--batch_size', type=int, default=32, help='batch size for testing')
args = parser.parse_args()
# 将args传递给相关函数
test_dataset = MyDataset(args.root2, args.txtpath2, transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
model = torch.load('./modelpth/68.pth')
test_dataset = MyDataset(args.root2, args.txtpath2, transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
criterion = nn.CrossEntropyLoss()
model.eval()
accuracy, _, _, C, accuracy_dict = test(model, test_loader, criterion, epoch=68)
# 绘制混淆矩阵
fig = plt.figure()
ax = fig.add_subplot(111)
df = pd.DataFrame(C)
sns.heatmap(df, fmt='g', annot=True, annot_kws={'size': 10}, xticklabels=['1', '2', '3', '4', '5', '6'],
yticklabels=['1', '2', '3', '4', '5', '6'], cmap='Blues')
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
plt.savefig('confusion_matrix.jpg')
plt.show()
# 计算每个标签的预测概率
total_predictions = C.sum(axis=1)
predicted_labels = C.argmax(axis=1)
predicted_probabilities = C.max(axis=1) / total_predictions
for label, probability in zip(predicted_labels, predicted_probabilities):
print(f'Label {label}: Predicted Probability {probability}')
原文地址: https://www.cveoy.top/t/topic/fttI 著作权归作者所有。请勿转载和采集!