PyTorch 图像分类模型测试:准确率计算和可视化
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch
import random
import os
from PIL import Image
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
# 设置随机种子
random.seed(123)
# 定义图像预处理操作
IMG_HEIGHT = 224
IMG_WIDTH = 224
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), # 调整图像大小
transforms.ToTensor(), # 转换为Tensor
normalize # 归一化
])
# 加载测试集数据
val_dir = 'path/to/your/validation/dataset' # 替换为实际路径
test_dataset = ImageFolder(val_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# 获取测试集样本数
total_test = len(test_loader)
print('Total testing data batches: ', total_test)
# 模型测试
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ... # 替换为您的模型
model.to(device)
model.eval()
y_true_class = []
y_pred_class = []
# 随机选择10张图片进行预测
selected_img_paths = random.choices(test_dataset.imgs, k=10)
for img_path, _ in selected_img_paths:
img = Image.open(img_path).convert('RGB')
img = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img)
pred = torch.argmax(output, dim=1)
y_pred_class.append(pred.item())
y_true_class.append(test_dataset.class_to_idx[os.path.basename(os.path.dirname(img_path))])
# 计算准确率
y_true = torch.Tensor(y_true_class)
y_pred = torch.Tensor(y_pred_class)
accuracy = (y_true.eq(y_pred).sum() / len(y_true)).item()
print('Accuracy: {:.2f}%'.format(accuracy * 100))
# 随机选择一些测试样本进行展示
for img, label in random.sample(test_dataset, 5):
img = img.unsqueeze(0).to(device)
with torch.no_grad():
output = model(img)
pred = torch.argmax(output, dim=1).item()
label_name = test_dataset.classes[label]
pred_name = test_dataset.classes[pred]
# 反归一化
img = img.squeeze(0).permute(1, 2, 0)
img = img * torch.Tensor([0.229, 0.224, 0.225]) + torch.Tensor([0.485, 0.456, 0.406])
plt.title('true label: '{}', predicted label: '{}''.format(label_name, pred_name))
plt.imshow(img)
plt.show()
代码说明:
- 导入必要的库:导入
torchvision,torch,random,os,PIL,sklearn.metrics和matplotlib.pyplot库。 - 设置随机种子:使用
random.seed(123)设置随机种子,确保代码运行结果一致。 - 定义图像预处理操作:使用
transforms.Compose定义图像预处理操作,包括调整图像大小、转换为 Tensor 和归一化。 - 加载测试集数据:使用
ImageFolder加载测试集数据,并使用DataLoader创建数据加载器。 - 获取测试集样本数:使用
len(test_loader)获取测试集样本数。 - 模型测试:将模型加载到设备上并设置为评估模式。
- 随机选择测试样本:使用
random.choices随机选择 10 个测试样本进行预测。 - 计算准确率:使用
accuracy_score计算模型的准确率。 - 可视化预测结果:随机选择 5 个测试样本,将真实标签和预测标签显示在图像标题中。
优化点:
- 使用
transforms.Normalize时,将归一化参数定义为变量,以便在多个地方重复使用。 - 使用
random.choices随机选择多个样本,避免样本重复选择。 - 使用
torch.no_grad()上下文管理器关闭梯度计算,提高代码运行效率。 - 使用
torch.Tensor的eq和mean方法计算准确率。
注意:
- 将
val_dir替换为您的实际测试集路径。 - 将
model替换为您的图像分类模型。 - 确保您的模型已经训练完毕。
原文地址: https://www.cveoy.top/t/topic/pivR 著作权归作者所有。请勿转载和采集!