import argparse
import datetime
import random
import time
from pathlib import Path

import torch
import torchvision.transforms as standard_transforms
import numpy as np

from PIL import Image
import cv2
# 假设您有crowd_datasets.py和engine.py
from crowd_datasets import build_dataset
from engine import *
from models import build_model
import os
import warnings
warnings.filterwarnings('ignore')

def get_args_parser():
    parser = argparse.ArgumentParser('设置 P2PNet 评估参数', add_help=False)

    # * Backbone
    parser.add_argument('--backbone', default='vgg16_bn', type=str,
                        help='使用的卷积骨干网络名称')

    parser.add_argument('--row', default=2, type=int,
                        help='锚点的行数')
    parser.add_argument('--line', default=2, type=int,
                        help='锚点的列数')

    parser.add_argument('--output_dir', default='vis',
                        help='保存结果的路径')
    parser.add_argument('--weight_path', default='weights/SHTechA.pth',
                        help='保存训练模型权重的路径')

    parser.add_argument('--gpu_id', default=0, type=int, help='用于评估的GPU ID')

    return parser

def compute_mae_mse(pred_density_map, true_density_map):
    pred_density_map = pred_density_map.astype(np.float32)
    true_density_map = true_density_map.astype(np.float32)

    abs_diff = np.abs(pred_density_map - true_density_map)
    mae = np.sum(abs_diff) / np.sum(true_density_map)

    squared_diff = np.square(pred_density_map - true_density_map)
    mse = np.mean(squared_diff)

    return mae, mse

def generate_density_map(points, image_size):
    density_map = np.zeros(image_size, dtype=np.float32)

    for p in points:
        x, y = int(p[0]), int(p[1])
        if x < image_size[1] and y < image_size[0]:
            density_map[y, x] += 1

    return density_map

def main(args, debug=False):

    # os.environ["CUDA_VISIBLE_DEVICES"] = '{}'.format(args.gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = ''

    print(args)
    # device = torch.device('cuda')
    device = torch.device('cpu')
    # 获取 P2PNet 模型
    model = build_model(args)
    # 将模型移动到 GPU
    model.to(device)
    # 加载训练好的模型
    if args.weight_path is not None:
        checkpoint = torch.load(args.weight_path, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
    # 将模型设置为评估模式
    model.eval()
    # 创建预处理变换
    transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # 设置您的图像路径
    img_path = './vis/demo1.jpg'
    # 加载图像
    img_raw = Image.open(img_path).convert('RGB')
    # 调整图像大小
    width, height = img_raw.size
    new_width = width // 128 * 128
    new_height = height // 128 * 128
    img_raw = img_raw.resize((new_width, new_height), Image.ANTIALIAS)
    # 预处理图像
    img = transform(img_raw)

    samples = torch.Tensor(img).unsqueeze(0)
    samples = samples.to(device)
    # 运行推理
    outputs = model(samples)
    outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]

    outputs_points = outputs['pred_points'][0]

    threshold = 0.5
    # 过滤预测结果
    points = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist()
    predict_cnt = int((outputs_scores > threshold).sum())

    # 生成预测的人群密度图
    image_size = (new_height, new_width)
    pred_density_map = generate_density_map(points, image_size)

    # # TODO: 加载真实的人群密度图
    # true_density_map = ... # 用实际数据替换
    # 
    # # 计算 MAE 和 MSE
    # mae, mse = compute_mae_mse(pred_density_map, true_density_map)
    # print('MAE:', mae)
    # print('MSE:', mse)

    # 绘制预测结果
    size = 2
    img_to_draw = cv2.cvtColor(np.array(img_raw), cv2.COLOR_RGB2BGR)
    for p in points:
        img_to_draw = cv2.circle(img_to_draw, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
    # 保存可视化图像
    cv2.imwrite(os.path.join(args.output_dir, 'pred{}.jpg'.format(predict_cnt)), img_to_draw)

if __name__ == '__main__':
    parser = argparse.ArgumentParser('P2PNet 评估脚本', parents=[get_args_parser()])
    args = parser.parse_args()
    main(args)
P2PNet人群计数模型评估与可视化

原文地址: https://www.cveoy.top/t/topic/RfG 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录