#!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii, Inc. and its affiliates. import argparse import os import time import cv2 import megengine as mge import megengine.functional as F from loguru import logger from yolox.data.datasets import COCO_CLASSES from yolox.utils import vis from yolox.data.data_augment import preproc as preprocess from build import build_and_load IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"] def make_parser(): parser = argparse.ArgumentParser("YOLOX Demo!") parser.add_argument( "demo", default="image", help="demo type, eg. image, video and webcam" ) parser.add_argument("-n", "--name", type=str, default="yolox-s", help="model name") parser.add_argument("--path", default="./test.png", help="path to images or video") parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id") parser.add_argument( "--save_result", action="store_true", help="whether to save the inference result of image/video", ) parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval") parser.add_argument("--conf", default=None, type=float, help="test conf") parser.add_argument("--nms", default=None, type=float, help="test nms threshold") parser.add_argument("--tsize", default=None, type=int, help="test img size") return parser def get_image_list(path): image_names = [] for maindir, subdir, file_name_list in os.walk(path): for filename in file_name_list: apath = os.path.join(maindir, filename) ext = os.path.splitext(apath)[1] if ext in IMAGE_EXT: image_names.append(apath) return image_names def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45): box_corner = F.zeros_like(prediction) box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 prediction[:, :, :4] = box_corner[:, :, :4] output = [None for _ in range(len(prediction))] for i, image_pred in enumerate(prediction): # If none are remaining => process next image if not image_pred.shape[0]: continue # Get score and class with highest confidence class_conf = F.max(image_pred[:, 5: 5 + num_classes], 1, keepdims=True) class_pred = F.argmax(image_pred[:, 5: 5 + num_classes], 1, keepdims=True) class_conf_squeeze = F.squeeze(class_conf) conf_mask = image_pred[:, 4] * class_conf_squeeze >= conf_thre detections = F.concat((image_pred[:, :5], class_conf, class_pred), 1) detections = detections[conf_mask] if not detections.shape[0]: continue nms_out_index = F.vision.nms( detections[:, :4], detections[:, 4] * detections[:, 5], nms_thre, ) detections = detections[nms_out_index] if output[i] is None: output[i] = detections else: output[i] = F.concat((output[i], detections)) return output class Predictor(object): def __init__( self, model, confthre=0.01, nmsthre=0.65, test_size=(640, 640), cls_names=COCO_CLASSES, trt_file=None, decoder=None, ): self.model = model self.cls_names = cls_names self.decoder = decoder self.num_classes = 80 self.confthre = confthre self.nmsthre = nmsthre self.test_size = test_size def inference(self, img): img_info = {"id": 0} if isinstance(img, str): img_info["file_name"] = os.path.basename(img) img = cv2.imread(img) if img is None: raise ValueError("test image path is invalid!") else: img_info["file_name"] = None height, width = img.shape[:2] img_info["height"] = height img_info["width"] = width img_info["raw_img"] = img img, ratio = preprocess(img, self.test_size) img_info["ratio"] = ratio img = F.expand_dims(mge.tensor(img), 0) t0 = time.time() outputs = self.model(img) outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre) logger.info("Infer time: {:.4f}s".format(time.time() - t0)) return outputs, img_info def visual(self, output, img_info, cls_conf=0.35): ratio = img_info["ratio"] img = img_info["raw_img"] if output is None: return img output = output.numpy() # preprocessing: resize bboxes = output[:, 0:4] / ratio cls = output[:, 6] scores = output[:, 4] * output[:, 5] vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names) return vis_res def image_demo(predictor, vis_folder, path, current_time, save_result): if os.path.isdir(path): files = get_image_list(path) else: files = [path] files.sort() for image_name in files: outputs, img_info = predictor.inference(image_name) result_image = predictor.visual(outputs[0], img_info) if save_result: save_folder = os.path.join( vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time) ) os.makedirs(save_folder, exist_ok=True) save_file_name = os.path.join(save_folder, os.path.basename(image_name)) logger.info("Saving detection result in {}".format(save_file_name)) cv2.imwrite(save_file_name, result_image) ch = cv2.waitKey(0) if ch == 27 or ch == ord("q") or ch == ord("Q"): break def imageflow_demo(predictor, vis_folder, current_time, args): cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid) width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float fps = cap.get(cv2.CAP_PROP_FPS) save_folder = os.path.join( vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time) ) os.makedirs(save_folder, exist_ok=True) if args.demo == "video": save_path = os.path.join(save_folder, os.path.basename(args.path)) else: save_path = os.path.join(save_folder, "camera.mp4") logger.info(f"video save_path is {save_path}") vid_writer = cv2.VideoWriter( save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height)) ) while True: ret_val, frame = cap.read() if ret_val: outputs, img_info = predictor.inference(frame) result_frame = predictor.visual(outputs[0], img_info) if args.save_result: vid_writer.write(result_frame) ch = cv2.waitKey(1) if ch == 27 or ch == ord("q") or ch == ord("Q"): break else: break def main(args): file_name = os.path.join("./yolox_outputs", args.name) os.makedirs(file_name, exist_ok=True) if args.save_result: vis_folder = os.path.join(file_name, "vis_res") os.makedirs(vis_folder, exist_ok=True) confthre = 0.01 nmsthre = 0.65 test_size = (640, 640) if args.conf is not None: confthre = args.conf if args.nms is not None: nmsthre = args.nms if args.tsize is not None: test_size = (args.tsize, args.tsize) model = build_and_load(args.ckpt, name=args.name) model.eval() predictor = Predictor(model, confthre, nmsthre, test_size, COCO_CLASSES, None, None) current_time = time.localtime() if args.demo == "image": image_demo(predictor, vis_folder, args.path, current_time, args.save_result) elif args.demo == "video" or args.demo == "webcam": imageflow_demo(predictor, vis_folder, current_time, args) if __name__ == "__main__": args = make_parser().parse_args() main(args)