Zengyf-CVer commited on
Commit
f246420
1 Parent(s): b3457b5
Files changed (3) hide show
  1. .gitignore +3 -1
  2. app.py +99 -59
  3. util/fonts_opt.py +64 -0
.gitignore CHANGED
@@ -40,4 +40,6 @@
40
  !requirements.txt
41
  !cls_name/*
42
  !model_config/*
43
- !img_example/*
 
 
 
40
  !requirements.txt
41
  !cls_name/*
42
  !model_config/*
43
+ !img_example/*
44
+
45
+ app copy.py
app.py CHANGED
@@ -1,6 +1,6 @@
1
- # Gradio YOLOv5 Det v0.1
2
  # 创建人:曾逸夫
3
- # 创建时间:2022-04-03
4
  # email:[email protected]
5
  # 项目主页:https://gitee.com/CV_Lab/gradio_yolov5_det
6
 
@@ -12,14 +12,15 @@ from pathlib import Path
12
  import gradio as gr
13
  import torch
14
  import yaml
15
- from PIL import Image
 
 
16
 
17
  ROOT_PATH = sys.path[0] # 根目录
18
 
19
  # 模型路径
20
  model_path = "ultralytics/yolov5"
21
 
22
-
23
  # 模型名称临时变量
24
  model_name_tmp = ""
25
 
@@ -29,12 +30,13 @@ device_tmp = ""
29
  # 文件后缀
30
  suffix_list = [".csv", ".yaml"]
31
 
 
 
 
32
 
33
  def parse_args(known=False):
34
- parser = argparse.ArgumentParser(description="Gradio YOLOv5 Det v0.1")
35
- parser.add_argument(
36
- "--model_name", "-mn", default="yolov5s", type=str, help="model name"
37
- )
38
  parser.add_argument(
39
  "--model_cfg",
40
  "-mc",
@@ -56,15 +58,13 @@ def parse_args(known=False):
56
  type=float,
57
  help="model NMS confidence threshold",
58
  )
59
- parser.add_argument(
60
- "--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold"
61
- )
62
 
63
  parser.add_argument(
64
  "--label_dnt_show",
65
  "-lds",
66
- action="store_false",
67
- default=True,
68
  help="label show",
69
  )
70
  parser.add_argument(
@@ -72,11 +72,9 @@ def parse_args(known=False):
72
  "-dev",
73
  default="cpu",
74
  type=str,
75
- help="cuda or cpu, hugging face only cpu",
76
- )
77
- parser.add_argument(
78
- "--inference_size", "-isz", default=640, type=int, help="model inference size"
79
  )
 
80
 
81
  args = parser.parse_known_args()[0] if known else parser.parse_args()
82
  return args
@@ -99,24 +97,44 @@ def export_json(results, model, img_size):
99
  return [
100
  [
101
  {
102
- "id": int(i),
103
  "class": int(result[i][5]),
104
- "class_name": model.model.names[int(result[i][5])],
 
105
  "normalized_box": {
106
  "x0": round(result[i][:4].tolist()[0], 6),
107
  "y0": round(result[i][:4].tolist()[1], 6),
108
  "x1": round(result[i][:4].tolist()[2], 6),
109
- "y1": round(result[i][:4].tolist()[3], 6),
110
- },
111
  "confidence": round(float(result[i][4]), 2),
112
  "fps": round(1000 / float(results.t[1]), 2),
113
  "width": img_size[0],
114
- "height": img_size[1],
115
- }
116
- for i in range(len(result))
117
- ]
118
- for result in results.xyxyn
119
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
 
122
  # YOLOv5图片检测函数
@@ -139,9 +157,43 @@ def yolo_det(img, device, model_name, inference_size, conf, iou, label_opt, mode
139
  model.classes = model_cls # 模型类别
140
 
141
  results = model(img, size=inference_size) # 检测
142
- results.render(labels=label_opt) # 渲染
143
 
144
- det_img = Image.fromarray(results.imgs[0]) # 检测图片
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  det_json = export_json(results, model, img.size)[0] # 检测信息
147
 
@@ -150,7 +202,7 @@ def yolo_det(img, device, model_name, inference_size, conf, iou, label_opt, mode
150
 
151
  # yaml文件解析
152
  def yaml_parse(file_path):
153
- return yaml.safe_load(open(file_path, "r", encoding="utf-8").read())
154
 
155
 
156
  # yaml csv 文件解析
@@ -172,7 +224,7 @@ def yaml_csv(file_path, file_tag):
172
  def main(args):
173
  gr.close_all()
174
 
175
- global model
176
 
177
  slider_step = 0.05 # 滑动步长
178
 
@@ -185,38 +237,30 @@ def main(args):
185
  device = args.device
186
  inference_size = args.inference_size
187
 
 
 
188
  # 模型加载
189
  model = model_loading(model_name, device)
190
 
191
  model_names = yaml_csv(model_cfg, "model_names")
192
  model_cls_name = yaml_csv(cls_name, "model_cls_name")
193
 
 
 
194
  # -------------------输入组件-------------------
195
  inputs_img = gr.inputs.Image(type="pil", label="原始图片")
196
- device = gr.inputs.Dropdown(
197
- choices=["cpu"], default=device, type="value", label="设备"
198
- )
199
- inputs_model = gr.inputs.Dropdown(
200
- choices=model_names, default=model_name, type="value", label="模型"
201
- )
202
- inputs_size = gr.inputs.Radio(
203
- choices=[320, 640], default=inference_size, label="推理尺寸"
204
- )
205
- input_conf = gr.inputs.Slider(
206
- 0, 1, step=slider_step, default=nms_conf, label="置信度阈值"
207
- )
208
- inputs_iou = gr.inputs.Slider(
209
- 0, 1, step=slider_step, default=nms_iou, label="IoU 阈值"
210
- )
211
- inputs_label = gr.inputs.Checkbox(default=label_opt, label="标签显示")
212
- inputs_clsName = gr.inputs.CheckboxGroup(
213
- choices=model_cls_name, default=model_cls_name, type="index", label="类别"
214
- )
215
 
216
  # 输入参数
217
  inputs = [
218
  inputs_img, # 输入图片
219
- device, # 设备
220
  inputs_model, # 模型
221
  inputs_size, # 推理尺寸
222
  input_conf, # 置信度阈值
@@ -243,8 +287,7 @@ def main(args):
243
  0.6,
244
  0.5,
245
  True,
246
- ["人", "公交车"],
247
- ],
248
  [
249
  "./img_example/Millenial-at-work.jpg",
250
  "cpu",
@@ -253,8 +296,7 @@ def main(args):
253
  0.5,
254
  0.45,
255
  True,
256
- ["人", "椅子", "杯子", "笔记本电脑"],
257
- ],
258
  [
259
  "./img_example/zidane.jpg",
260
  "cpu",
@@ -263,9 +305,7 @@ def main(args):
263
  0.25,
264
  0.5,
265
  False,
266
- ["人", "领带"],
267
- ],
268
- ]
269
 
270
  # 接口
271
  gr.Interface(
@@ -282,7 +322,7 @@ def main(args):
282
  ).launch(
283
  inbrowser=True, # 自动打开默认浏览器
284
  show_tips=True, # 自动显示gradio最新功能
285
- favicon_path="./icon/logo.ico",
286
  )
287
 
288
 
 
1
+ # Gradio YOLOv5 Det v0.2
2
  # 创建人:曾逸夫
3
+ # 创建时间:2022-05-01
4
  # email:[email protected]
5
  # 项目主页:https://gitee.com/CV_Lab/gradio_yolov5_det
6
 
 
12
  import gradio as gr
13
  import torch
14
  import yaml
15
+ from PIL import Image, ImageDraw, ImageFont
16
+
17
+ from util.fonts_opt import is_fonts
18
 
19
  ROOT_PATH = sys.path[0] # 根目录
20
 
21
  # 模型路径
22
  model_path = "ultralytics/yolov5"
23
 
 
24
  # 模型名称临时变量
25
  model_name_tmp = ""
26
 
 
30
  # 文件后缀
31
  suffix_list = [".csv", ".yaml"]
32
 
33
+ # 字体大小
34
+ FONTSIZE = 25
35
+
36
 
37
  def parse_args(known=False):
38
+ parser = argparse.ArgumentParser(description="Gradio YOLOv5 Det v0.2")
39
+ parser.add_argument("--model_name", "-mn", default="yolov5s", type=str, help="model name")
 
 
40
  parser.add_argument(
41
  "--model_cfg",
42
  "-mc",
 
58
  type=float,
59
  help="model NMS confidence threshold",
60
  )
61
+ parser.add_argument("--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold")
 
 
62
 
63
  parser.add_argument(
64
  "--label_dnt_show",
65
  "-lds",
66
+ action="store_true",
67
+ default=False,
68
  help="label show",
69
  )
70
  parser.add_argument(
 
72
  "-dev",
73
  default="cpu",
74
  type=str,
75
+ help="cuda or cpu",
 
 
 
76
  )
77
+ parser.add_argument("--inference_size", "-isz", default=640, type=int, help="model inference size")
78
 
79
  args = parser.parse_known_args()[0] if known else parser.parse_args()
80
  return args
 
97
  return [
98
  [
99
  {
100
+ "id": i,
101
  "class": int(result[i][5]),
102
+ # "class_name": model.model.names[int(result[i][5])],
103
+ "class_name": model_cls_name_cp[int(result[i][5])],
104
  "normalized_box": {
105
  "x0": round(result[i][:4].tolist()[0], 6),
106
  "y0": round(result[i][:4].tolist()[1], 6),
107
  "x1": round(result[i][:4].tolist()[2], 6),
108
+ "y1": round(result[i][:4].tolist()[3], 6),},
 
109
  "confidence": round(float(result[i][4]), 2),
110
  "fps": round(1000 / float(results.t[1]), 2),
111
  "width": img_size[0],
112
+ "height": img_size[1],} for i in range(len(result))] for result in results.xyxyn]
113
+
114
+
115
+ # 帧转换
116
+ def pil_draw(img, countdown_msg, textFont, xyxy, font_size, label_opt):
117
+
118
+ img_pil = ImageDraw.Draw(img)
119
+
120
+ img_pil.rectangle(xyxy, fill=None, outline="green") # 边界框
121
+
122
+ if label_opt:
123
+ text_w, text_h = textFont.getsize(countdown_msg) # 标签尺寸
124
+ img_pil.rectangle(
125
+ (xyxy[0], xyxy[1], xyxy[0] + text_w, xyxy[1] + text_h),
126
+ fill="green",
127
+ outline="green",
128
+ ) # 标签背景
129
+ img_pil.multiline_text(
130
+ (xyxy[0], xyxy[1]),
131
+ countdown_msg,
132
+ fill=(205, 250, 255),
133
+ font=textFont,
134
+ align="center",
135
+ )
136
+
137
+ return img
138
 
139
 
140
  # YOLOv5图片检测函数
 
157
  model.classes = model_cls # 模型类别
158
 
159
  results = model(img, size=inference_size) # 检测
 
160
 
161
+ img_size = img.size # 帧尺寸
162
+
163
+ # 加载字体
164
+ textFont = ImageFont.truetype(str(f"{ROOT_PATH}/fonts/SimSun.ttc"), size=FONTSIZE)
165
+
166
+ det_img = img.copy()
167
+
168
+ for result in results.xyxyn:
169
+ for i in range(len(result)):
170
+ id = int(i) # 实例ID
171
+ obj_cls_index = int(result[i][5]) # 类别索引
172
+ obj_cls = model_cls_name_cp[obj_cls_index] # 类别
173
+
174
+ # ------------边框坐标------------
175
+ x0 = float(result[i][:4].tolist()[0])
176
+ y0 = float(result[i][:4].tolist()[1])
177
+ x1 = float(result[i][:4].tolist()[2])
178
+ y1 = float(result[i][:4].tolist()[3])
179
+
180
+ # ------------边框实际坐标------------
181
+ x0 = int(img_size[0] * x0)
182
+ y0 = int(img_size[1] * y0)
183
+ x1 = int(img_size[0] * x1)
184
+ y1 = int(img_size[1] * y1)
185
+
186
+ conf = float(result[i][4]) # 置信度
187
+ # fps = f"{(1000 / float(results.t[1])):.2f}" # FPS
188
+
189
+ det_img = pil_draw(
190
+ img,
191
+ f"{id}-{obj_cls}:{conf:.2f}",
192
+ textFont,
193
+ [x0, y0, x1, y1],
194
+ FONTSIZE,
195
+ label_opt,
196
+ )
197
 
198
  det_json = export_json(results, model, img.size)[0] # 检测信息
199
 
 
202
 
203
  # yaml文件解析
204
  def yaml_parse(file_path):
205
+ return yaml.safe_load(open(file_path, encoding="utf-8").read())
206
 
207
 
208
  # yaml csv 文件解析
 
224
  def main(args):
225
  gr.close_all()
226
 
227
+ global model, model_cls_name_cp
228
 
229
  slider_step = 0.05 # 滑动步长
230
 
 
237
  device = args.device
238
  inference_size = args.inference_size
239
 
240
+ is_fonts(f"{ROOT_PATH}/fonts") # 检查字体文件
241
+
242
  # 模型加载
243
  model = model_loading(model_name, device)
244
 
245
  model_names = yaml_csv(model_cfg, "model_names")
246
  model_cls_name = yaml_csv(cls_name, "model_cls_name")
247
 
248
+ model_cls_name_cp = model_cls_name.copy() # 类别名称
249
+
250
  # -------------------输入组件-------------------
251
  inputs_img = gr.inputs.Image(type="pil", label="原始图片")
252
+ inputs_device = gr.inputs.Dropdown(choices=["0", "cpu"], default=device, type="value", label="设备")
253
+ inputs_model = gr.inputs.Dropdown(choices=model_names, default=model_name, type="value", label="模型")
254
+ inputs_size = gr.inputs.Radio(choices=[320, 640], default=inference_size, label="推理尺寸")
255
+ input_conf = gr.inputs.Slider(0, 1, step=slider_step, default=nms_conf, label="置信度阈值")
256
+ inputs_iou = gr.inputs.Slider(0, 1, step=slider_step, default=nms_iou, label="IoU 阈值")
257
+ inputs_label = gr.inputs.Checkbox(default=(not label_opt), label="标签显示")
258
+ inputs_clsName = gr.inputs.CheckboxGroup(choices=model_cls_name, default=model_cls_name, type="index", label="类别")
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  # 输入参数
261
  inputs = [
262
  inputs_img, # 输入图片
263
+ inputs_device, # 设备
264
  inputs_model, # 模型
265
  inputs_size, # 推理尺寸
266
  input_conf, # 置信度阈值
 
287
  0.6,
288
  0.5,
289
  True,
290
+ ["人", "公交车"],],
 
291
  [
292
  "./img_example/Millenial-at-work.jpg",
293
  "cpu",
 
296
  0.5,
297
  0.45,
298
  True,
299
+ ["人", "椅子", "杯子", "笔记本电脑"],],
 
300
  [
301
  "./img_example/zidane.jpg",
302
  "cpu",
 
305
  0.25,
306
  0.5,
307
  False,
308
+ ["人", "领带"],],]
 
 
309
 
310
  # 接口
311
  gr.Interface(
 
322
  ).launch(
323
  inbrowser=True, # 自动打开默认浏览器
324
  show_tips=True, # 自动显示gradio最新功能
325
+ # favicon_path="./icon/logo.ico",
326
  )
327
 
328
 
util/fonts_opt.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 字体管理
2
+ # 创建人:曾逸夫
3
+ # 创建时间:2022-05-01
4
+
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import wget
10
+ from rich.console import Console
11
+
12
+ ROOT_PATH = sys.path[0] # 项目根目录
13
+
14
+ fonts_list = ["SimSun.ttc"] # 字体列表
15
+ fonts_suffix = ["ttc", "ttf", "otf"] # 字体后缀
16
+
17
+ data_url_dict = {"SimSun.ttc": "https://gitee.com/CV_Lab/opencv_webcam/attach_files/959173/download/SimSun.ttc"}
18
+
19
+ console = Console()
20
+
21
+
22
+ # 创建字体库
23
+ def add_fronts(font_diff):
24
+
25
+ global font_name
26
+
27
+ for k, v in data_url_dict.items():
28
+ if k in font_diff:
29
+ font_name = v.split("/")[-1] # 字体名称
30
+ Path(f"{ROOT_PATH}/fonts").mkdir(parents=True, exist_ok=True) # 创建目录
31
+
32
+ file_path = f"{ROOT_PATH}/fonts/{font_name}" # 字体路径
33
+
34
+ try:
35
+ # 下载字体文件
36
+ wget.download(v, file_path)
37
+ except Exception as e:
38
+ print("路径错误!程序结束!")
39
+ print(e)
40
+ sys.exit()
41
+ else:
42
+ print()
43
+ console.print(f"{font_name} [bold green]字体文件下载完成![/bold green] 已保存至:{file_path}")
44
+
45
+
46
+ # 判断字体文件
47
+ def is_fonts(fonts_dir):
48
+ if os.path.isdir(fonts_dir):
49
+ # 如果字体库存在
50
+ f_list = os.listdir(fonts_dir) # 本地字体库
51
+
52
+ font_diff = list(set(fonts_list).difference(set(f_list)))
53
+
54
+ if font_diff != []:
55
+ # 字体不存在
56
+ console.print("[bold red]字体不存在,正在加载。。。[/bold red]")
57
+ add_fronts(font_diff) # 创建字体库
58
+ else:
59
+ console.print(f"{fonts_list}[bold green]字体已存在![/bold green]")
60
+ else:
61
+ # 字体库不存在,创建字体库
62
+ console.print("[bold red]字体库不存在,正在创建。。。[/bold red]")
63
+ add_fronts(fonts_list) # 创建字体库
64
+