PyTorch图像分类模型测试:准确率评估与示例展示
from torchvision import transforms from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader import random import os import torch from PIL import Image import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score
设置随机种子
random.seed(123)
定义图像预处理操作
transform = transforms.Compose([ transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), # 调整图像大小 transforms.ToTensor(), # 转换为Tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化 ])
加载测试集数据
test_dataset = ImageFolder(val_dir, transform=transform) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
获取测试集样本数
total_test = len(test_loader) print('Total testing data batches: ', total_test)
模型测试
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) model.eval()
y_true_class = [] y_pred_class = []
随机选择10张图片进行预测
selected_img_paths = random.sample(test_dataset.imgs, 10) for img_path, _ in selected_img_paths: img = Image.open(img_path).convert('RGB') img = transform(img).unsqueeze(0).to(device) with torch.no_grad(): output = model(img) pred = torch.argmax(output, dim=1) y_pred_class.append(pred.item()) y_true_class.append(test_dataset.class_to_idx[os.path.basename(os.path.dirname(img_path))])
计算准确率
accuracy = accuracy_score(y_true_class, y_pred_class) print('Accuracy: {:.2f}%'.format(accuracy * 100))
随机选择一些测试样本进行展示
for img, label in random.sample(test_dataset, 5): img = img.unsqueeze(0).to(device) with torch.no_grad(): output = model(img) pred = torch.argmax(output, dim=1).item()
label_name = test_dataset.classes[label]
pred_name = test_dataset.classes[pred]
# 反归一化
img = img.squeeze(0).permute(1, 2, 0)
img = img * torch.Tensor([0.229, 0.224, 0.225]) + torch.Tensor([0.485, 0.456, 0.406])
plt.title('true label: {}, predicted label: {}'.format(label_name, pred_name))
plt.imshow(img)
plt.show()
原文地址: https://www.cveoy.top/t/topic/pivV 著作权归作者所有。请勿转载和采集!