Process detections for i det in enumeratepred # detections per image if webcam # batch_size = 1 p s im0 frame = pathi fi im0sicopy datasetcount else
这段代码是 YOLOv5 检测器的主要代码,它对输入的图像进行预测并返回检测结果。以下是对代码的详细解释:
-
设置相关参数
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size device = select_device(device) half = device.type != 'cpu' # half precision only supported on CUDA
在此处,我们检查输入的图像大小是否符合模型要求,选择使用哪个设备(CPU 或 GPU),并设置是否使用半精度浮点数运算。
-
构建模型
model = attempt_load(weights, map_location=device) # load FP32 model
在此处,我们加载模型权重,构建模型并将其移到所选设备上。
-
加载图像
dataset = LoadImages(source, img_size=imgsz, stride=stride)
在此处,我们使用 LoadImages 类加载输入图像,并将其调整为模型需要的大小。
-
进行预测
for path, img, im0s, vid_cap in dataset: img = torch.from_numpy(img).to(device) img = img.half() if half else img.float() # uint8 to fp16/32 img /= 255.0 # 0 - 255 to 0.0 - 1.0 if img.ndimension() == 3: img = img.unsqueeze(0)
# Inference t1 = time_synchronized() pred = model(img, augment=opt.augment)[0]
在此处,我们将图像转换为 PyTorch 张量,并将其输入模型进行预测。如果设置了数据增强(augment),则会在预测时进行增强。
-
处理预测结果
# Apply NMS pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) t2 = time_synchronized() # Process detections for i, det in enumerate(pred): # detections per image if webcam: # batch_size >= 1 p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count else: p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)
在此处,我们使用非极大值抑制(NMS)筛选预测结果,并对筛选后的结果进行处理。如果处理的是视频,则需要记录当前处理的帧数。
-
输出结果
p = Path(p) # to Path save_path = str(save_dir / p.name) # img.jpg txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt s += '%gx%g ' % img.shape[2:] # print string gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh imc = im0.copy() if save_crop else im0 # for save_crop if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() # Print results for c in det[:, -1].unique(): n = (det[:, -1] == c).sum() # detections per class s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
在此处,我们将输出结果的保存路径(图片和文本文件)设置为指定的目录,并对检测到的物体进行计数和分类。如果检测到物体,则将其在图像上标注出来。
-
显示和保存结果
# Stream results if view_img: cv2.imshow(str(p), im0) cv2.waitKey(1) # 1 millisecond # Save results (image with detections) if save_img: if dataset.mode == 'image': cv2.imwrite(save_path, im0) else: # 'video' or 'stream' if vid_path[i] != save_path: # new video vid_path[i] = save_path if isinstance(vid_writer[i], cv2.VideoWriter): vid_writer[i].release() # release previous video writer if vid_cap: # video fps = vid_cap.get(cv2.CAP_PROP_FPS) w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) else: # stream fps, w, h = 30, im0.shape[1], im0.shape[0] save_path += '.mp4' vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) vid_writer[i].write(im0)
在此处,如果设置了 view_img,代码将显示检测结果。如果设置了 save_img,代码将保存检测结果。对于视频,代码将保存为视频文件。
-
输出结果
if save_txt or save_img: s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' print(f"Results saved to {save_dir}{s}")
在此处,我们输出保存的文件路径。
-
更新模型
if update: strip_optimizer(weights) # update model (to fix SourceChangeWarning)
在此处,如果设置了 update,代码将更新模型以修复可能出现的 SourceChangeWarning
原文地址: https://www.cveoy.top/t/topic/he0w 著作权归作者所有。请勿转载和采集!