chendl's picture
Add application file
0b7b08a
raw
history blame
8.12 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import argparse
import os
import time
import cv2
import megengine as mge
import megengine.functional as F
from loguru import logger
from yolox.data.datasets import COCO_CLASSES
from yolox.utils import vis
from yolox.data.data_augment import preproc as preprocess
from build import build_and_load
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
def make_parser():
parser = argparse.ArgumentParser("YOLOX Demo!")
parser.add_argument(
"demo", default="image", help="demo type, eg. image, video and webcam"
)
parser.add_argument("-n", "--name", type=str, default="yolox-s", help="model name")
parser.add_argument("--path", default="./test.png", help="path to images or video")
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
parser.add_argument(
"--save_result",
action="store_true",
help="whether to save the inference result of image/video",
)
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
parser.add_argument("--conf", default=None, type=float, help="test conf")
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
parser.add_argument("--tsize", default=None, type=int, help="test img size")
return parser
def get_image_list(path):
image_names = []
for maindir, subdir, file_name_list in os.walk(path):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext in IMAGE_EXT:
image_names.append(apath)
return image_names
def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
box_corner = F.zeros_like(prediction)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
for i, image_pred in enumerate(prediction):
# If none are remaining => process next image
if not image_pred.shape[0]:
continue
# Get score and class with highest confidence
class_conf = F.max(image_pred[:, 5: 5 + num_classes], 1, keepdims=True)
class_pred = F.argmax(image_pred[:, 5: 5 + num_classes], 1, keepdims=True)
class_conf_squeeze = F.squeeze(class_conf)
conf_mask = image_pred[:, 4] * class_conf_squeeze >= conf_thre
detections = F.concat((image_pred[:, :5], class_conf, class_pred), 1)
detections = detections[conf_mask]
if not detections.shape[0]:
continue
nms_out_index = F.vision.nms(
detections[:, :4], detections[:, 4] * detections[:, 5], nms_thre,
)
detections = detections[nms_out_index]
if output[i] is None:
output[i] = detections
else:
output[i] = F.concat((output[i], detections))
return output
class Predictor(object):
def __init__(
self,
model,
confthre=0.01,
nmsthre=0.65,
test_size=(640, 640),
cls_names=COCO_CLASSES,
trt_file=None,
decoder=None,
):
self.model = model
self.cls_names = cls_names
self.decoder = decoder
self.num_classes = 80
self.confthre = confthre
self.nmsthre = nmsthre
self.test_size = test_size
def inference(self, img):
img_info = {"id": 0}
if isinstance(img, str):
img_info["file_name"] = os.path.basename(img)
img = cv2.imread(img)
if img is None:
raise ValueError("test image path is invalid!")
else:
img_info["file_name"] = None
height, width = img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img
img, ratio = preprocess(img, self.test_size)
img_info["ratio"] = ratio
img = F.expand_dims(mge.tensor(img), 0)
t0 = time.time()
outputs = self.model(img)
outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre)
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info
def visual(self, output, img_info, cls_conf=0.35):
ratio = img_info["ratio"]
img = img_info["raw_img"]
if output is None:
return img
output = output.numpy()
# preprocessing: resize
bboxes = output[:, 0:4] / ratio
cls = output[:, 6]
scores = output[:, 4] * output[:, 5]
vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
return vis_res
def image_demo(predictor, vis_folder, path, current_time, save_result):
if os.path.isdir(path):
files = get_image_list(path)
else:
files = [path]
files.sort()
for image_name in files:
outputs, img_info = predictor.inference(image_name)
result_image = predictor.visual(outputs[0], img_info)
if save_result:
save_folder = os.path.join(
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
)
os.makedirs(save_folder, exist_ok=True)
save_file_name = os.path.join(save_folder, os.path.basename(image_name))
logger.info("Saving detection result in {}".format(save_file_name))
cv2.imwrite(save_file_name, result_image)
ch = cv2.waitKey(0)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
def imageflow_demo(predictor, vis_folder, current_time, args):
cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
fps = cap.get(cv2.CAP_PROP_FPS)
save_folder = os.path.join(
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
)
os.makedirs(save_folder, exist_ok=True)
if args.demo == "video":
save_path = os.path.join(save_folder, os.path.basename(args.path))
else:
save_path = os.path.join(save_folder, "camera.mp4")
logger.info(f"video save_path is {save_path}")
vid_writer = cv2.VideoWriter(
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
)
while True:
ret_val, frame = cap.read()
if ret_val:
outputs, img_info = predictor.inference(frame)
result_frame = predictor.visual(outputs[0], img_info)
if args.save_result:
vid_writer.write(result_frame)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
break
def main(args):
file_name = os.path.join("./yolox_outputs", args.name)
os.makedirs(file_name, exist_ok=True)
if args.save_result:
vis_folder = os.path.join(file_name, "vis_res")
os.makedirs(vis_folder, exist_ok=True)
confthre = 0.01
nmsthre = 0.65
test_size = (640, 640)
if args.conf is not None:
confthre = args.conf
if args.nms is not None:
nmsthre = args.nms
if args.tsize is not None:
test_size = (args.tsize, args.tsize)
model = build_and_load(args.ckpt, name=args.name)
model.eval()
predictor = Predictor(model, confthre, nmsthre, test_size, COCO_CLASSES, None, None)
current_time = time.localtime()
if args.demo == "image":
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
elif args.demo == "video" or args.demo == "webcam":
imageflow_demo(predictor, vis_folder, current_time, args)
if __name__ == "__main__":
args = make_parser().parse_args()
main(args)