# Gradio YOLOv5 Det v0.1 # 创建人:曾逸夫 # 创建时间:2022-04-03 import argparse import csv import sys import gradio as gr import torch import yaml from PIL import Image from zmq import device ROOT_PATH = sys.path[0] # 根目录 # 本地模型路径 local_model_path = f"{ROOT_PATH}/yolov5" # 模型名称临时变量 model_name_tmp = "" # 设备临时变量 device_tmp = "" def parse_args(known=False): parser = argparse.ArgumentParser(description="Gradio YOLOv5 Det v0.1") parser.add_argument( "--model_name", "-mn", default="yolov5s", type=str, help="model name" ) parser.add_argument( "--model_cfg", "-mc", default="./model_config/model_name_p5_all.yaml", type=str, help="model config", ) parser.add_argument( "--cls_name", "-cls", default="./cls_name/cls_name.yaml", type=str, help="cls name", ) parser.add_argument( "--nms_conf", "-conf", default=0.5, type=float, help="model NMS confidence threshold", ) parser.add_argument( "--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold" ) parser.add_argument( "--label_dnt_show", "-lds", action="store_false", default=True, help="label show", ) parser.add_argument( "--device", "-dev", default="0", type=str, help="cuda or cpu", ) args = parser.parse_known_args()[0] if known else parser.parse_args() return args # 模型加载 def model_loading(model_name, device): # 加载本地模型 # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', device=device) model = torch.hub.load("ultralytics/yolov5", model_name, device=device) # model = torch.hub.load( # local_model_path, # "custom", # path=f"{local_model_path}/{model_name}", # source="local", # device=device, # ) return model # 检测信息 def export_json(results, model, img_size): return [ [ { "id": int(i), "class": int(result[i][5]), "class_name": model.model.names[int(result[i][5])], "normalized_box": { "x0": round(result[i][:4].tolist()[0], 6), "y0": round(result[i][:4].tolist()[1], 6), "x1": round(result[i][:4].tolist()[2], 6), "y1": round(result[i][:4].tolist()[3], 6), }, "confidence": round(float(result[i][4]), 2), "fps": round(1000 / float(results.t[1]), 2), "width": img_size[0], "height": img_size[1], } for i in range(len(result)) ] for result in results.xyxyn ] # YOLOv5图片检测函数 def yolo_det(img, device, model_name, conf, iou, label_opt, model_cls): global model, model_name_tmp, device_tmp if model_name_tmp != model_name: # 模型判断,避免反复加载 model_name_tmp = model_name model = model_loading(model_name_tmp, device) elif device_tmp != device: device_tmp = device model = model_loading(model_name_tmp, device) # -----------模型调参----------- model.conf = conf # NMS 置信度阈值 model.iou = iou # NMS IOU阈值 model.max_det = 1000 # 最大检测框数 model.classes = model_cls # 模型类别 results = model(img) # 检测 results.render(labels=label_opt) # 渲染 det_img = Image.fromarray(results.imgs[0]) # 检测图片 det_json = export_json(results, model, img.size)[0] # 检测信息 return det_img, det_json # yaml文件解析 def yaml_parse(file_path): return yaml.load( open(file_path, "r", encoding="utf-8").read(), Loader=yaml.FullLoader ) def main(args): global model slider_step = 0.05 # 滑动步长 nms_conf = args.nms_conf nms_iou = args.nms_iou label_opt = args.label_dnt_show model_name = args.model_name model_cfg = args.model_cfg cls_name = args.cls_name device = args.device # 模型加载 model = model_loading(model_name, device) # 模型名称 # model_names = [i[0] for i in list(csv.reader(open(model_cfg)))] # csv版 model_names = yaml_parse(model_cfg).get("model_names") # yaml版 # 类别名称 # model_cls_name = [i[0] for i in list(csv.reader(open(cls_name)))] # csv版 model_cls_name = yaml_parse(cls_name).get("model_cls_name") # yaml版 # -------------------输入组件------------------- inputs_img = gr.inputs.Image(type="pil", label="原始图片") device = gr.inputs.Dropdown( choices=["0", "cpu"], default=device, type="value", label="设备" ) inputs_model = gr.inputs.Dropdown( choices=model_names, default=model_name, type="value", label="模型" ) input_conf = gr.inputs.Slider( 0, 1, step=slider_step, default=nms_conf, label="置信度阈值" ) inputs_iou = gr.inputs.Slider( 0, 1, step=slider_step, default=nms_iou, label="IoU 阈值" ) inputs_label = gr.inputs.Checkbox(default=label_opt, label="标签显示") inputs_clsName = gr.inputs.CheckboxGroup( choices=model_cls_name, default=model_cls_name, type="index", label="类别" ) # 输入参数 inputs = [ inputs_img, # 输入图片 device, # 设备 inputs_model, # 模型 input_conf, # 置信度阈值 inputs_iou, # IoU阈值 inputs_label, # 标签显示 inputs_clsName, # 类别 ] # 输出参数 outputs = gr.outputs.Image(type="pil", label="检测图片") outputs02 = gr.outputs.JSON(label="检测信息") # 标题 title = "基于Gradio的YOLOv5通用目标检测系统" # 描述 description = "
可自定义目标检测模型、安装简单、使用方便
" gr.close_all() # 接口 gr.Interface( fn=yolo_det, inputs=inputs, outputs=[outputs, outputs02], title=title, description=description, theme="seafoam", # live=True, # 实时变更输出 flagging_dir="run" # 输出目录 # ).launch(inbrowser=True, auth=['admin', 'admin']) ).launch( inbrowser=True, # 自动打开默认浏览器 show_tips=True, # 自动显示gradio最新功能 favicon_path="./icon/logo.ico", ) if __name__ == "__main__": args = parse_args() main(args)