import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mydataset import MyDataset
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix
import argparse
import torch.nn as nn


def test(model, dataset, criterion, epoch):
    model.eval()
    total_batch_num = 0.
    val_loss = 0
    prediction = []
    labels = []
    feature_list = torch.tensor([])
    if torch.cuda.is_available():
        feature_list = feature_list.cuda()
    accuracy_dict = {}
    for (step, i) in enumerate(dataset):
        total_batch_num = total_batch_num + 1
        batch_x = i['data']
        batch_y = i['label']
        batch_x = torch.unsqueeze(batch_x, dim=1)
        batch_x = batch_x.float()
        if torch.cuda.is_available():
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()
        feature, probs = model(batch_x)
        batch_label = batch_y.unsqueeze(1).float()
        feature_label = torch.cat((feature, batch_label), dim=1)
        feature_list = torch.cat((feature_list, feature_label), dim=0)
        loss = criterion(probs, batch_y)
        _, pred = torch.max(probs, dim=1)
        predi = pred.tolist()
        label = batch_y.tolist()
        val_loss += loss.item()
        prediction.extend(predi)
        labels.extend(label)

    accuracy = accuracy_score(labels, prediction)
    C = confusion_matrix(labels, prediction)

    if epoch == 68:
        unique_labels = set(labels)
        for label in unique_labels:
            accuracy_dict[label] = accuracy_score([label] * labels.count(label), [label] * labels.count(label))

    return accuracy, val_loss / total_batch_num, feature_list, C, accuracy_dict


# 定义args
parser = argparse.ArgumentParser()
rootpath = 'das_data'
parser.add_argument('--root2', type=str, default=rootpath + '/test', help='rootpath of valdata')
parser.add_argument('--txtpath2', type=str, default=rootpath + '/test/label.txt', help='path pf val_list')
parser.add_argument('--batch_size', type=int, default=32, help='batch size for testing')
args = parser.parse_args()

# 将args传递给相关函数
test_dataset = MyDataset(args.root2, args.txtpath2, transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
model = torch.load('./modelpth/68.pth')
test_dataset = MyDataset(args.root2, args.txtpath2, transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
criterion = nn.CrossEntropyLoss()
model.eval()
accuracy, _, _, C, accuracy_dict = test(model, test_loader, criterion, epoch=68)

# 绘制混淆矩阵
fig = plt.figure()
ax = fig.add_subplot(111)
df = pd.DataFrame(C)
sns.heatmap(df, fmt='g', annot=True, annot_kws={'size': 10}, xticklabels=['1', '2', '3', '4', '5', '6'],
            yticklabels=['1', '2', '3', '4', '5', '6'], cmap='Blues')
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
plt.savefig('confusion_matrix.jpg')
plt.show()

# 计算每个标签的预测概率
total_predictions = C.sum(axis=1)
predicted_labels = C.argmax(axis=1)
predicted_probabilities = C.max(axis=1) / total_predictions

for label, probability in zip(predicted_labels, predicted_probabilities):
    print(f'Label {label}: Predicted Probability {probability}')

原文地址: https://www.cveoy.top/t/topic/fttI 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录