如何将预测单张图片的程序改为预测数据集的程序
要将这个程序改为预测一个数据集的程序,需要进行以下步骤:
-
修改数据集的文件夹路径。将
img_path变量修改为数据集文件夹的路径,例如dataset_path = "/path/to/dataset"。 -
修改输出结果的方式。当前程序中使用
print_res和plt.title输出预测结果,可以将其修改为将预测结果保存到一个文件中,或者将预测结果打印出来。 -
遍历数据集中的每张图片。使用
os.listdir函数获取数据集文件夹中的所有图片文件名,然后遍历这些文件名。 -
在循环中进行预测。将当前图片的路径传给
Image.open函数打开图片,然后进行后续的预处理和预测操作。 -
将预测结果保存或打印。可以将预测结果保存到一个列表中,然后在循环结束后进行保存或打印。
下面是修改后的程序示例:
import os
import json
import torch
from PIL import Image
from torchvision import transforms
from model_v2 import MobileNetV2
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{"json_path}' does not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = MobileNetV2(num_classes=6).to(device)
# load model weights
model_weight_path = "./MobileNetV2.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
dataset_path = "/path/to/dataset"
assert os.path.exists(dataset_path), "dataset path: '{"dataset_path}' does not exist.".format(dataset_path)
image_files = os.listdir(dataset_path)
results = []
for image_file in image_files:
image_path = os.path.join(dataset_path, image_file)
img = Image.open(image_path)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
with torch.no_grad():
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
result = {
"image_path": image_path,
"class": class_indict[str(predict_cla)],
"prob": predict[predict_cla].numpy()
}
results.append(result)
for result in results:
print("image: {} class: {:10} prob: {:.3}".format(result["image_path"],
result["class"],
result["prob"]))
if __name__ == '__main__':
main()
在这个修改后的程序中,将预测结果保存到了一个列表results中,并在循环结束后打印了每张图片的预测结果。你可以根据需要将这个程序进一步修改以满足你的需求。
原文地址: https://www.cveoy.top/t/topic/hL0C 著作权归作者所有。请勿转载和采集!