Zengyf-CVer commited on
Commit
6daa32c
1 Parent(s): 23f6cfc
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio YOLOv5 Det v0.1
2
+ # 创建人:曾逸夫
3
+ # 创建时间:2022-04-03
4
+
5
+ import argparse
6
+ import csv
7
+ import sys
8
+
9
+ import gradio as gr
10
+ import torch
11
+ import yaml
12
+ from PIL import Image
13
+ from zmq import device
14
+
15
+ ROOT_PATH = sys.path[0] # 根目录
16
+
17
+ # 本地模型路径
18
+ local_model_path = f"{ROOT_PATH}/yolov5"
19
+
20
+
21
+ # 模型名称临时变量
22
+ model_name_tmp = ""
23
+
24
+ # 设备临时变量
25
+ device_tmp = ""
26
+
27
+
28
+ def parse_args(known=False):
29
+ parser = argparse.ArgumentParser(description="Gradio YOLOv5 Det v0.1")
30
+ parser.add_argument(
31
+ "--model_name", "-mn", default="yolov5s", type=str, help="model name"
32
+ )
33
+ parser.add_argument(
34
+ "--model_cfg",
35
+ "-mc",
36
+ default="./model_config/model_name_p5_all.yaml",
37
+ type=str,
38
+ help="model config",
39
+ )
40
+ parser.add_argument(
41
+ "--cls_name",
42
+ "-cls",
43
+ default="./cls_name/cls_name.yaml",
44
+ type=str,
45
+ help="cls name",
46
+ )
47
+ parser.add_argument(
48
+ "--nms_conf",
49
+ "-conf",
50
+ default=0.5,
51
+ type=float,
52
+ help="model NMS confidence threshold",
53
+ )
54
+ parser.add_argument(
55
+ "--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold"
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--label_dnt_show",
60
+ "-lds",
61
+ action="store_false",
62
+ default=True,
63
+ help="label show",
64
+ )
65
+ parser.add_argument(
66
+ "--device",
67
+ "-dev",
68
+ default="0",
69
+ type=str,
70
+ help="cuda or cpu",
71
+ )
72
+
73
+ args = parser.parse_known_args()[0] if known else parser.parse_args()
74
+ return args
75
+
76
+
77
+ # 模型加载
78
+ def model_loading(model_name, device):
79
+
80
+ # 加载本地模型
81
+ model = torch.hub.load(
82
+ local_model_path,
83
+ "custom",
84
+ path=f"{local_model_path}/{model_name}",
85
+ source="local",
86
+ device=device,
87
+ )
88
+
89
+ return model
90
+
91
+
92
+ # 检测信息
93
+ def export_json(results, model, img_size):
94
+
95
+ return [
96
+ [
97
+ {
98
+ "id": int(i),
99
+ "class": int(result[i][5]),
100
+ "class_name": model.model.names[int(result[i][5])],
101
+ "normalized_box": {
102
+ "x0": round(result[i][:4].tolist()[0], 6),
103
+ "y0": round(result[i][:4].tolist()[1], 6),
104
+ "x1": round(result[i][:4].tolist()[2], 6),
105
+ "y1": round(result[i][:4].tolist()[3], 6),
106
+ },
107
+ "confidence": round(float(result[i][4]), 2),
108
+ "fps": round(1000 / float(results.t[1]), 2),
109
+ "width": img_size[0],
110
+ "height": img_size[1],
111
+ }
112
+ for i in range(len(result))
113
+ ]
114
+ for result in results.xyxyn
115
+ ]
116
+
117
+
118
+ # YOLOv5图片检测函数
119
+ def yolo_det(img, device, model_name, conf, iou, label_opt, model_cls):
120
+
121
+ global model, model_name_tmp, device_tmp
122
+
123
+ if model_name_tmp != model_name:
124
+ # 模型判断,避免反复加载
125
+ model_name_tmp = model_name
126
+ model = model_loading(model_name_tmp, device)
127
+ elif device_tmp != device:
128
+ device_tmp = device
129
+ model = model_loading(model_name_tmp, device)
130
+
131
+ # -----------模型调参-----------
132
+ model.conf = conf # NMS 置信度阈值
133
+ model.iou = iou # NMS IOU阈值
134
+ model.max_det = 1000 # 最大检测框数
135
+ model.classes = model_cls # 模型类别
136
+
137
+ results = model(img) # 检测
138
+ results.render(labels=label_opt) # 渲染
139
+
140
+ det_img = Image.fromarray(results.imgs[0]) # 检测图片
141
+
142
+ det_json = export_json(results, model, img.size)[0] # 检测信息
143
+
144
+ return det_img, det_json
145
+
146
+
147
+ # yaml文件解析
148
+ def yaml_parse(file_path):
149
+ return yaml.load(
150
+ open(file_path, "r", encoding="utf-8").read(), Loader=yaml.FullLoader
151
+ )
152
+
153
+
154
+ def main(args):
155
+ global model
156
+
157
+ slider_step = 0.05 # 滑动步长
158
+
159
+ nms_conf = args.nms_conf
160
+ nms_iou = args.nms_iou
161
+ label_opt = args.label_dnt_show
162
+ model_name = args.model_name
163
+ model_cfg = args.model_cfg
164
+ cls_name = args.cls_name
165
+ device = args.device
166
+
167
+ # 模型加载
168
+ model = model_loading(model_name, device)
169
+ # 模型名称
170
+ # model_names = [i[0] for i in list(csv.reader(open(model_cfg)))] # csv版
171
+ model_names = yaml_parse(model_cfg).get("model_names") # yaml版
172
+
173
+ # 类别名称
174
+ # model_cls_name = [i[0] for i in list(csv.reader(open(cls_name)))] # csv版
175
+ model_cls_name = yaml_parse(cls_name).get("model_cls_name") # yaml版
176
+
177
+ # -------------------输入组件-------------------
178
+ inputs_img = gr.inputs.Image(type="pil", label="原始图片")
179
+ device = gr.inputs.Dropdown(
180
+ choices=["0", "cpu"], default=device, type="value", label="设备"
181
+ )
182
+ inputs_model = gr.inputs.Dropdown(
183
+ choices=model_names, default=model_name, type="value", label="模型"
184
+ )
185
+ input_conf = gr.inputs.Slider(
186
+ 0, 1, step=slider_step, default=nms_conf, label="置信度阈值"
187
+ )
188
+ inputs_iou = gr.inputs.Slider(
189
+ 0, 1, step=slider_step, default=nms_iou, label="IoU 阈值"
190
+ )
191
+ inputs_label = gr.inputs.Checkbox(default=label_opt, label="标签显示")
192
+ inputs_clsName = gr.inputs.CheckboxGroup(
193
+ choices=model_cls_name, default=model_cls_name, type="index", label="类别"
194
+ )
195
+
196
+ # 输入参数
197
+ inputs = [
198
+ inputs_img, # 输入图片
199
+ device, # 设备
200
+ inputs_model, # 模型
201
+ input_conf, # 置信度阈值
202
+ inputs_iou, # IoU阈值
203
+ inputs_label, # 标签显示
204
+ inputs_clsName, # 类别
205
+ ]
206
+ # 输出参数
207
+ outputs = gr.outputs.Image(type="pil", label="检测图片")
208
+ outputs02 = gr.outputs.JSON(label="检测信息")
209
+
210
+ # 标题
211
+ title = "基于Gradio的YOLOv5通用目标检测系统"
212
+ # 描述
213
+ description = "<div align='center'>可自定义目标检测模型、安装简单、使用方便</div>"
214
+
215
+ gr.close_all()
216
+
217
+ # 接口
218
+ gr.Interface(
219
+ fn=yolo_det,
220
+ inputs=inputs,
221
+ outputs=[outputs, outputs02],
222
+ title=title,
223
+ description=description,
224
+ theme="seafoam",
225
+ # live=True, # 实时变更输出
226
+ flagging_dir="run" # 输出目录
227
+ # ).launch(inbrowser=True, auth=['admin', 'admin'])
228
+ ).launch(
229
+ inbrowser=True, # 自动打开默认浏览器
230
+ show_tips=True, # 自动显示gradio最新功能
231
+ favicon_path="./icon/logo.ico",
232
+ )
233
+
234
+
235
+ if __name__ == "__main__":
236
+ args = parse_args()
237
+ main(args)
cls_name/cls_name.csv ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ 自行车
3
+ 汽车
4
+ 摩托车
5
+ 飞机
6
+ 公交车
7
+ 火车
8
+ 卡车
9
+
10
+ 红绿灯
11
+ 消防栓
12
+ 停止标志
13
+ 停车收费表
14
+ 长凳
15
+
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+ 斑马
24
+ 长颈鹿
25
+ 背包
26
+ 雨伞
27
+ 手提包
28
+ 领带
29
+ 手提箱
30
+ 飞盘
31
+ 滑雪板
32
+ 单板滑雪
33
+ 运动球
34
+ 风筝
35
+ 棒球棒
36
+ 棒球手套
37
+ 滑板
38
+ 冲浪板
39
+ 网球拍
40
+ 瓶子
41
+ 红酒杯
42
+ 杯子
43
+ 叉子
44
+
45
+
46
+
47
+ 香蕉
48
+ 苹果
49
+ 三明治
50
+ 橙子
51
+ 西兰花
52
+ 胡萝卜
53
+ 热狗
54
+ 比萨
55
+ 甜甜圈
56
+ 蛋糕
57
+ 椅子
58
+ 长椅
59
+ 盆栽
60
+
61
+ 餐桌
62
+ 马桶
63
+ 电视
64
+ 笔记本电脑
65
+ 鼠标
66
+ 遥控器
67
+ 键盘
68
+ 手机
69
+ 微波炉
70
+ 烤箱
71
+ 烤面包机
72
+ 洗碗槽
73
+ 冰箱
74
+
75
+ 时钟
76
+ 花瓶
77
+ 剪刀
78
+ 泰迪熊
79
+ 吹风机
80
+ 牙刷
cls_name/cls_name.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ model_cls_name: ['人', '自行车', '汽车', '摩托车', '飞机', '公交车', '火车', '卡车', '船', '红绿灯', '消防栓', '停止标志',
2
+ '停车收费表', '长凳', '鸟', '猫', '狗', '马', '羊', '牛', '象', '熊', '斑马', '长颈鹿', '背包', '雨伞', '手提包', '领带',
3
+ '手提箱', '飞盘', '滑雪板', '单板滑雪', '运动球', '风筝', '棒球棒', '棒球手套', '滑板', '冲浪板', '网球拍', '瓶子', '红酒杯',
4
+ '杯子', '叉子', '刀', '勺', '碗', '香蕉', '苹果', '三明治', '橙子', '西兰花', '胡萝卜', '热狗', '比萨', '甜甜圈', '蛋糕',
5
+ '椅子', '长椅', '盆栽', '床', '餐桌', '马桶', '电视', '笔记本电脑', '鼠标', '遥控器', '键盘', '手机', '微波炉', '烤箱',
6
+ '烤面包机', '洗碗槽', '冰箱', '书', '时钟', '花瓶', '剪刀', '泰迪熊', '吹风机', '牙刷'
7
+ ]
icon/logo.ico ADDED
model_config/model_name_p5_all.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ yolov5n
2
+ yolov5s
3
+ yolov5m
4
+ yolov5l
5
+ yolov5x
model_config/model_name_p5_all.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ model_names: ["yolov5n", "yolov5s", "yolov5m", "yolov5l", "yolov5x"]
model_config/model_name_p5_n.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ yolov5n
model_config/model_name_p5_n.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ model_names: ["yolov5n"]
model_config/model_name_p6_all.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ yolov5n6
2
+ yolov5s6
3
+ yolov5m6
4
+ yolov5l6
5
+ yolov5x6
model_config/model_name_p6_all.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ model_names: ["yolov5n6", "yolov5s6", "yolov5m6", "yolov5l6", "yolov5x6"]
model_download/yolov5_model_p5_all.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ cd ./yolov5
2
+
3
+ # 下载YOLOv5模型
4
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5n.pt
5
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5s.pt
6
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5m.pt
7
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5l.pt
8
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5x.pt
model_download/yolov5_model_p5_n.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ cd ./yolov5
2
+
3
+ # 下载YOLOv5模型
4
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5n.pt
model_download/yolov5_model_p6_all.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ cd ./yolov5
2
+
3
+ # 下载YOLOv5模型
4
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5n6.pt
5
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5s6.pt
6
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5m6.pt
7
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5l6.pt
8
+ wget -c -t 0 https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5x6.pt