使用 PyTorch 生成图像分类模型的混淆矩阵
import torch import torchvision.datasets as datasets import torchvision.transforms as transforms from sklearn.metrics import confusion_matrix import pandas as pd
设置设备
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
定义数据增强
transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
加载数据集
data_path = '/mnt/disk1/lh/landmark/imagenet/val' dataset = datasets.ImageFolder(root=data_path, transform=transform)
加载模型
model_path = '/mnt/disk1/lh/code/InternImage/classification/outpre/0628/internimage_b_1k_224/ckpt_epoch_best.pth' model = torch.load(model_path) model = model['model'] # 提取模型对象 model.to(device) model.eval()
获取预测结果和标签
labels = [] preds = [] for inputs, targets in dataset: inputs = inputs.unsqueeze(0).to(device) targets = targets.to(device) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) labels.append(targets.item()) preds.append(predicted.item())
生成混淆矩阵
cm = confusion_matrix(labels, preds) classes = dataset.classes cm_df = pd.DataFrame(cm, index=classes, columns=classes)
保存为CSV文件
cm_df.to_csv('confusion_matrix.csv') print('Confusion matrix saved as confusion_matrix.csv')
原文地址: https://www.cveoy.top/t/topic/o3iA 著作权归作者所有。请勿转载和采集!