Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
"""Perform MMYOLO inference on a video as: | |
```shell | |
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth # noqa: E501, E261. | |
python demo/video_demo.py \ | |
demo/video_demo.mp4 \ | |
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \ | |
checkpoint/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth \ | |
--out demo_result.mp4 | |
``` | |
""" | |
import argparse | |
import cv2 | |
import mmcv | |
from mmcv.transforms import Compose | |
from mmdet.apis import inference_detector, init_detector | |
from mmengine.utils import track_iter_progress | |
from mmyolo.registry import VISUALIZERS | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='MMYOLO video demo') | |
parser.add_argument('video', help='Video file') | |
parser.add_argument('config', help='Config file') | |
parser.add_argument('checkpoint', help='Checkpoint file') | |
parser.add_argument( | |
'--device', default='cuda:0', help='Device used for inference') | |
parser.add_argument( | |
'--score-thr', type=float, default=0.3, help='Bbox score threshold') | |
parser.add_argument('--out', type=str, help='Output video file') | |
parser.add_argument('--show', action='store_true', help='Show video') | |
parser.add_argument( | |
'--wait-time', | |
type=float, | |
default=1, | |
help='The interval of show (s), 0 is block') | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = parse_args() | |
assert args.out or args.show, \ | |
('Please specify at least one operation (save/show the ' | |
'video) with the argument "--out" or "--show"') | |
# build the model from a config file and a checkpoint file | |
model = init_detector(args.config, args.checkpoint, device=args.device) | |
# build test pipeline | |
model.cfg.test_dataloader.dataset.pipeline[ | |
0].type = 'mmdet.LoadImageFromNDArray' | |
test_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline) | |
# init visualizer | |
visualizer = VISUALIZERS.build(model.cfg.visualizer) | |
# the dataset_meta is loaded from the checkpoint and | |
# then pass to the model in init_detector | |
visualizer.dataset_meta = model.dataset_meta | |
video_reader = mmcv.VideoReader(args.video) | |
video_writer = None | |
if args.out: | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video_writer = cv2.VideoWriter( | |
args.out, fourcc, video_reader.fps, | |
(video_reader.width, video_reader.height)) | |
for frame in track_iter_progress(video_reader): | |
result = inference_detector(model, frame, test_pipeline=test_pipeline) | |
visualizer.add_datasample( | |
name='video', | |
image=frame, | |
data_sample=result, | |
draw_gt=False, | |
show=False, | |
pred_score_thr=args.score_thr) | |
frame = visualizer.get_image() | |
if args.show: | |
cv2.namedWindow('video', 0) | |
mmcv.imshow(frame, 'video', args.wait_time) | |
if args.out: | |
video_writer.write(frame) | |
if video_writer: | |
video_writer.release() | |
cv2.destroyAllWindows() | |
if __name__ == '__main__': | |
main() | |