Zengyf-CVer commited on
Commit
c828a61
1 Parent(s): c0abb63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -83
app.py CHANGED
@@ -14,6 +14,9 @@ from pathlib import Path
14
 
15
  import cv2
16
  import gradio as gr
 
 
 
17
  import numpy as np
18
  from matplotlib import font_manager
19
  from ultralytics import YOLO
@@ -63,9 +66,28 @@ EXAMPLES_DET = [
63
  ["./img_examples/bus.jpg", "yolov8s", "cpu", 640, 0.6, 0.5, 100, "所有尺寸"],
64
  ["./img_examples/giraffe.jpg", "yolov8l", "cpu", 320, 0.5, 0.45, 100, "所有尺寸"],
65
  ["./img_examples/zidane.jpg", "yolov8m", "cpu", 640, 0.6, 0.5, 100, "所有尺寸"],
66
- ["./img_examples/Millenial-at-work.jpg", "yolov8x", "cpu", 1280, 0.5, 0.5, 100, "所有尺寸"],
 
 
 
 
 
 
 
 
 
67
  ["./img_examples/bus.jpg", "yolov8s-seg", "cpu", 640, 0.6, 0.5, 100, "所有尺寸"],
68
- ["./img_examples/Millenial-at-work.jpg", "yolov8x-seg", "cpu", 1280, 0.5, 0.5, 100, "所有尺寸"],]
 
 
 
 
 
 
 
 
 
 
69
 
70
  EXAMPLES_CLAS = [
71
  ["./img_examples/img_clas/ILSVRC2012_val_00000008.JPEG", "yolov8s-cls"],
@@ -73,7 +95,8 @@ EXAMPLES_CLAS = [
73
  ["./img_examples/img_clas/ILSVRC2012_val_00000023.JPEG", "yolov8m-cls"],
74
  ["./img_examples/img_clas/ILSVRC2012_val_00000067.JPEG", "yolov8m-cls"],
75
  ["./img_examples/img_clas/ILSVRC2012_val_00000077.JPEG", "yolov8m-cls"],
76
- ["./img_examples/img_clas/ILSVRC2012_val_00000247.JPEG", "yolov8m-cls"],]
 
77
 
78
  GYD_CSS = """#disp_image {
79
  text-align: center; /* Horizontally center the content */
@@ -82,7 +105,9 @@ GYD_CSS = """#disp_image {
82
 
83
  def parse_args(known=False):
84
  parser = argparse.ArgumentParser(description=GYD_VERSION)
85
- parser.add_argument("--model_name", "-mn", default="yolov8s", type=str, help="model name")
 
 
86
  parser.add_argument(
87
  "--model_cfg",
88
  "-mc",
@@ -111,10 +136,18 @@ def parse_args(known=False):
111
  type=float,
112
  help="model NMS confidence threshold",
113
  )
114
- parser.add_argument("--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold")
115
- parser.add_argument("--inference_size", "-isz", default=640, type=int, help="model inference size")
116
- parser.add_argument("--max_detnum", "-mdn", default=50, type=float, help="model max det num")
117
- parser.add_argument("--slider_step", "-ss", default=0.05, type=float, help="slider step")
 
 
 
 
 
 
 
 
118
  parser.add_argument(
119
  "--is_login",
120
  "-isl",
@@ -122,12 +155,14 @@ def parse_args(known=False):
122
  default=False,
123
  help="is login",
124
  )
125
- parser.add_argument('--usr_pwd',
126
- "-up",
127
- nargs='+',
128
- type=str,
129
- default=["admin", "admin"],
130
- help="user & password for login")
 
 
131
  parser.add_argument(
132
  "--is_share",
133
  "-is",
@@ -135,7 +170,9 @@ def parse_args(known=False):
135
  default=False,
136
  help="is login",
137
  )
138
- parser.add_argument("--server_port", "-sp", default=7860, type=int, help="server port")
 
 
139
 
140
  args = parser.parse_known_args()[0] if known else parser.parse_args()
141
  return args
@@ -167,6 +204,7 @@ def check_online():
167
  # 参考:https://github.com/ultralytics/yolov5/blob/master/utils/general.py
168
  # Check internet connectivity
169
  import socket
 
170
  try:
171
  socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
172
  return True
@@ -203,9 +241,12 @@ def pil_draw(img, score_l, bbox_l, cls_l, cls_index_l, textFont, color_list):
203
  img_pil = ImageDraw.Draw(img)
204
  id = 0
205
 
206
- for score, (xmin, ymin, xmax, ymax), label, cls_index in zip(score_l, bbox_l, cls_l, cls_index_l):
207
-
208
- img_pil.rectangle([xmin, ymin, xmax, ymax], fill=None, outline=color_list[cls_index], width=2) # 边界框
 
 
 
209
  countdown_msg = f"{id}-{label} {score:.2f}"
210
  # text_w, text_h = textFont.getsize(countdown_msg) # 标签尺寸 pillow 9.5.0
211
  # left, top, left + width, top + height
@@ -214,7 +255,12 @@ def pil_draw(img, score_l, bbox_l, cls_l, cls_index_l, textFont, color_list):
214
  # 标签背景
215
  img_pil.rectangle(
216
  # (xmin, ymin, xmin + text_w, ymin + text_h), # pillow 9.5.0
217
- (xmin, ymin, xmin + text_xmax - text_xmin, ymin + text_ymax - text_ymin), # pillow 10.0.0
 
 
 
 
 
218
  fill=color_list[cls_index],
219
  outline=color_list[cls_index],
220
  )
@@ -268,10 +314,19 @@ def seg_output(img_path, seg_mask_list, color_list, cls_list):
268
 
269
 
270
  # 目标检测和图像分割模型加载
271
- def model_det_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model="yolov8n.pt"):
 
 
272
  model = YOLO(yolo_model)
273
 
274
- results = model(source=img_path, device=device_opt, imgsz=infer_size, conf=conf, iou=iou, max_det=max_det)
 
 
 
 
 
 
 
275
  results = list(results)[0]
276
  return results
277
 
@@ -286,8 +341,9 @@ def model_cls_loading(img_path, yolo_model="yolov8s-cls.pt"):
286
 
287
 
288
  # YOLOv8图片检测函数
289
- def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_det, obj_size):
290
-
 
291
  global model, model_name_tmp, device_tmp
292
 
293
  s_obj, m_obj, l_obj = 0, 0, 0
@@ -300,13 +356,15 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
300
  cls_index_det_stat = [] # 1
301
 
302
  # 模型加载
303
- predict_results = model_det_loading(img_path,
304
- device_opt,
305
- conf,
306
- iou,
307
- infer_size,
308
- max_det,
309
- yolo_model=f"{model_name}.pt")
 
 
310
  # 检测参数
311
  xyxy_list = predict_results.boxes.xyxy.cpu().numpy().tolist()
312
  conf_list = predict_results.boxes.conf.cpu().numpy().tolist()
@@ -315,34 +373,39 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
315
  # 颜色列表
316
  color_list = random_color(len(model_cls_name_cp), True)
317
 
 
 
 
318
  # 图像分割
319
- if (model_name[-3:] == "seg"):
320
  # masks_list = predict_results.masks.xyn
321
  masks_list = predict_results.masks.xy
322
  img_mask_merge = seg_output(img_path, masks_list, color_list, cls_list)
323
- img = Image.fromarray(cv2.cvtColor(img_mask_merge, cv2.COLOR_BGRA2RGBA))
324
- else:
325
- img = Image.open(img_path)
326
 
327
  # 判断检测对象是否为空
328
- if (xyxy_list != []):
329
-
330
  # ---------------- 加载字体 ----------------
331
  yaml_index = cls_name.index(".yaml")
332
- cls_name_lang = cls_name[yaml_index - 2:yaml_index]
333
 
334
  if cls_name_lang == "zh":
335
  # 中文
336
- textFont = ImageFont.truetype(str(f"{ROOT_PATH}/fonts/SimSun.ttf"), size=FONTSIZE)
 
 
337
  elif cls_name_lang in ["en", "ru", "es", "ar"]:
338
  # 英文、俄语、西班牙语、阿拉伯语
339
- textFont = ImageFont.truetype(str(f"{ROOT_PATH}/fonts/TimesNewRoman.ttf"), size=FONTSIZE)
 
 
340
  elif cls_name_lang == "ko":
341
  # 韩语
342
- textFont = ImageFont.truetype(str(f"{ROOT_PATH}/fonts/malgun.ttf"), size=FONTSIZE)
 
 
343
 
344
  for i in range(len(xyxy_list)):
345
-
346
  # ------------ 边框坐标 ------------
347
  x0 = int(xyxy_list[i][0])
348
  y0 = int(xyxy_list[i][1])
@@ -354,7 +417,7 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
354
  h_obj = y1 - y0
355
  area_obj = w_obj * h_obj # 目标尺寸
356
 
357
- if (obj_size == obj_style[0] and area_obj > 0 and area_obj <= 32 ** 2):
358
  obj_cls_index = int(cls_list[i]) # 类别索引
359
  cls_index_det_stat.append(obj_cls_index)
360
 
@@ -367,7 +430,9 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
367
  score_det_stat.append(conf)
368
 
369
  area_obj_all.append(area_obj)
370
- elif (obj_size == obj_style[1] and area_obj > 32 ** 2 and area_obj <= 96 ** 2):
 
 
371
  obj_cls_index = int(cls_list[i]) # 类别索引
372
  cls_index_det_stat.append(obj_cls_index)
373
 
@@ -380,7 +445,7 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
380
  score_det_stat.append(conf)
381
 
382
  area_obj_all.append(area_obj)
383
- elif (obj_size == obj_style[2] and area_obj > 96 ** 2):
384
  obj_cls_index = int(cls_list[i]) # 类别索引
385
  cls_index_det_stat.append(obj_cls_index)
386
 
@@ -393,7 +458,7 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
393
  score_det_stat.append(conf)
394
 
395
  area_obj_all.append(area_obj)
396
- elif (obj_size == "所有尺寸"):
397
  obj_cls_index = int(cls_list[i]) # 类别索引
398
  cls_index_det_stat.append(obj_cls_index)
399
 
@@ -407,20 +472,30 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
407
 
408
  area_obj_all.append(area_obj)
409
 
410
- det_img = pil_draw(img, score_det_stat, bbox_det_stat, cls_det_stat, cls_index_det_stat, textFont, color_list)
 
 
 
 
 
 
 
 
411
 
412
  # -------------- 目标尺寸计算 --------------
413
  for i in range(len(area_obj_all)):
414
- if (0 < area_obj_all[i] <= 32 ** 2):
415
  s_obj = s_obj + 1
416
- elif (32 ** 2 < area_obj_all[i] <= 96 ** 2):
417
  m_obj = m_obj + 1
418
- elif (area_obj_all[i] > 96 ** 2):
419
  l_obj = l_obj + 1
420
 
421
  sml_obj_total = s_obj + m_obj + l_obj
422
  objSize_dict = {}
423
- objSize_dict = {obj_style[i]: [s_obj, m_obj, l_obj][i] / sml_obj_total for i in range(3)}
 
 
424
 
425
  # ------------ 类别统计 ------------
426
  clsRatio_dict = {}
@@ -429,15 +504,23 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
429
  for k, v in clsDet_dict.items():
430
  clsRatio_dict[k] = v / clsDet_dict_sum
431
 
 
 
 
 
 
 
 
 
 
432
  gr.Info("图片检测成功!")
433
- return det_img, objSize_dict, clsRatio_dict
434
  else:
435
  raise gr.Error("图片检测失败!")
436
 
437
 
438
  # YOLOv8图片分类函数
439
  def yolo_cls_img(img_path, model_name):
440
-
441
  # 模型加载
442
  predict_results = model_cls_loading(img_path, yolo_model=f"{model_name}.pt")
443
 
@@ -484,8 +567,10 @@ def main(args):
484
  model_cls_name_cp = model_cls_name.copy() # 类别名称
485
  model_cls_imagenet_name_cp = model_cls_imagenet_name.copy() # 类别名称
486
 
487
- custom_theme = gr.themes.Soft(primary_hue="blue").set(button_secondary_background_fill="*neutral_100",
488
- button_secondary_background_fill_hover="*neutral_200")
 
 
489
 
490
  custom_css = GYD_CSS
491
 
@@ -500,58 +585,116 @@ def main(args):
500
  with gr.Tabs():
501
  with gr.TabItem("目标检测与图像分割"):
502
  with gr.Row():
503
- inputs_img = gr.Image(image_mode="RGB", type="filepath", label="原始图片")
 
 
504
  with gr.Row():
505
- device_opt = gr.Radio(choices=["cpu", "0", "1", "2", "3"], value="cpu", label="设备")
 
 
 
 
506
  with gr.Row():
507
- inputs_model = gr.Dropdown(choices=model_names, value=model_name, type="value", label="模型")
 
 
 
 
 
508
  with gr.Accordion("高级设置", open=True):
509
  with gr.Row():
510
- inputs_size = gr.Slider(320, 1600, step=1, value=inference_size, label="推理尺寸")
511
- max_det = gr.Slider(1, 1000, step=1, value=max_detnum, label="最大检测数")
 
 
 
 
 
 
 
 
512
  with gr.Row():
513
- input_conf = gr.Slider(0, 1, step=slider_step, value=nms_conf, label="置信度阈值")
514
- inputs_iou = gr.Slider(0, 1, step=slider_step, value=nms_iou, label="IoU 阈值")
 
 
 
 
 
 
 
 
 
 
 
 
515
  with gr.Row():
516
- obj_size = gr.Radio(choices=["所有尺寸", "小目标", "中目标", "大目标"], value="所有尺寸", label="目标尺寸")
 
 
 
 
517
  with gr.Row():
518
  gr.ClearButton(inputs_img, value="清除")
519
- det_btn_img = gr.Button(value='检测', variant="primary")
520
  with gr.Row():
521
  gr.Examples(
522
  examples=EXAMPLES_DET,
523
  fn=yolo_det_img,
524
  inputs=[
525
- inputs_img, inputs_model, device_opt, inputs_size, input_conf, inputs_iou, max_det,
526
- obj_size],
 
 
 
 
 
 
 
527
  # outputs=[outputs_img, outputs_objSize, outputs_clsSize],
528
- cache_examples=False)
 
529
 
530
  with gr.TabItem("图像分类"):
531
  with gr.Row():
532
- inputs_img_cls = gr.Image(image_mode="RGB", type="filepath", label="原始图片")
 
 
533
  with gr.Row():
534
- inputs_model_cls = gr.Dropdown(choices=[
535
- "yolov8n-cls", "yolov8s-cls", "yolov8l-cls", "yolov8m-cls", "yolov8x-cls"],
536
- value="yolov8s-cls",
537
- type="value",
538
- label="模型")
 
 
 
 
 
 
 
539
  with gr.Row():
540
  gr.ClearButton(inputs_img, value="清除")
541
- det_btn_img_cls = gr.Button(value='检测', variant="primary")
542
  with gr.Row():
543
  gr.Examples(
544
  examples=EXAMPLES_CLAS,
545
  fn=yolo_cls_img,
546
  inputs=[inputs_img_cls, inputs_model_cls],
547
  # outputs=[outputs_img_cls, outputs_ratio_cls],
548
- cache_examples=False)
 
549
 
 
550
  with gr.Column(scale=1):
551
  with gr.Tabs():
552
  with gr.TabItem("目标检测与图像分割"):
 
 
 
 
553
  with gr.Row():
554
- outputs_img = gr.Image(type="pil", label="检测图片")
555
  with gr.Row():
556
  outputs_objSize = gr.Label(label="目标尺寸占比统计")
557
  with gr.Row():
@@ -604,15 +747,31 @@ def main(args):
604
  """
605
  )
606
 
607
- det_btn_img.click(fn=yolo_det_img,
608
- inputs=[
609
- inputs_img, inputs_model, device_opt, inputs_size, input_conf, inputs_iou, max_det,
610
- obj_size],
611
- outputs=[outputs_img, outputs_objSize, outputs_clsSize])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
 
613
- det_btn_img_cls.click(fn=yolo_cls_img,
614
- inputs=[inputs_img_cls, inputs_model_cls],
615
- outputs=[outputs_img_cls, outputs_ratio_cls])
 
 
616
 
617
  return gyd
618
 
 
14
 
15
  import cv2
16
  import gradio as gr
17
+ from gradio_imageslider import ImageSlider
18
+ import tempfile
19
+ import uuid
20
  import numpy as np
21
  from matplotlib import font_manager
22
  from ultralytics import YOLO
 
66
  ["./img_examples/bus.jpg", "yolov8s", "cpu", 640, 0.6, 0.5, 100, "所有尺寸"],
67
  ["./img_examples/giraffe.jpg", "yolov8l", "cpu", 320, 0.5, 0.45, 100, "所有尺寸"],
68
  ["./img_examples/zidane.jpg", "yolov8m", "cpu", 640, 0.6, 0.5, 100, "所有尺寸"],
69
+ [
70
+ "./img_examples/Millenial-at-work.jpg",
71
+ "yolov8x",
72
+ "cpu",
73
+ 1280,
74
+ 0.5,
75
+ 0.5,
76
+ 100,
77
+ "所有尺寸",
78
+ ],
79
  ["./img_examples/bus.jpg", "yolov8s-seg", "cpu", 640, 0.6, 0.5, 100, "所有尺寸"],
80
+ [
81
+ "./img_examples/Millenial-at-work.jpg",
82
+ "yolov8x-seg",
83
+ "cpu",
84
+ 1280,
85
+ 0.5,
86
+ 0.5,
87
+ 100,
88
+ "所有尺寸",
89
+ ],
90
+ ]
91
 
92
  EXAMPLES_CLAS = [
93
  ["./img_examples/img_clas/ILSVRC2012_val_00000008.JPEG", "yolov8s-cls"],
 
95
  ["./img_examples/img_clas/ILSVRC2012_val_00000023.JPEG", "yolov8m-cls"],
96
  ["./img_examples/img_clas/ILSVRC2012_val_00000067.JPEG", "yolov8m-cls"],
97
  ["./img_examples/img_clas/ILSVRC2012_val_00000077.JPEG", "yolov8m-cls"],
98
+ ["./img_examples/img_clas/ILSVRC2012_val_00000247.JPEG", "yolov8m-cls"],
99
+ ]
100
 
101
  GYD_CSS = """#disp_image {
102
  text-align: center; /* Horizontally center the content */
 
105
 
106
  def parse_args(known=False):
107
  parser = argparse.ArgumentParser(description=GYD_VERSION)
108
+ parser.add_argument(
109
+ "--model_name", "-mn", default="yolov8s", type=str, help="model name"
110
+ )
111
  parser.add_argument(
112
  "--model_cfg",
113
  "-mc",
 
136
  type=float,
137
  help="model NMS confidence threshold",
138
  )
139
+ parser.add_argument(
140
+ "--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold"
141
+ )
142
+ parser.add_argument(
143
+ "--inference_size", "-isz", default=640, type=int, help="model inference size"
144
+ )
145
+ parser.add_argument(
146
+ "--max_detnum", "-mdn", default=50, type=float, help="model max det num"
147
+ )
148
+ parser.add_argument(
149
+ "--slider_step", "-ss", default=0.05, type=float, help="slider step"
150
+ )
151
  parser.add_argument(
152
  "--is_login",
153
  "-isl",
 
155
  default=False,
156
  help="is login",
157
  )
158
+ parser.add_argument(
159
+ "--usr_pwd",
160
+ "-up",
161
+ nargs="+",
162
+ type=str,
163
+ default=["admin", "admin"],
164
+ help="user & password for login",
165
+ )
166
  parser.add_argument(
167
  "--is_share",
168
  "-is",
 
170
  default=False,
171
  help="is login",
172
  )
173
+ parser.add_argument(
174
+ "--server_port", "-sp", default=7860, type=int, help="server port"
175
+ )
176
 
177
  args = parser.parse_known_args()[0] if known else parser.parse_args()
178
  return args
 
204
  # 参考:https://github.com/ultralytics/yolov5/blob/master/utils/general.py
205
  # Check internet connectivity
206
  import socket
207
+
208
  try:
209
  socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
210
  return True
 
241
  img_pil = ImageDraw.Draw(img)
242
  id = 0
243
 
244
+ for score, (xmin, ymin, xmax, ymax), label, cls_index in zip(
245
+ score_l, bbox_l, cls_l, cls_index_l
246
+ ):
247
+ img_pil.rectangle(
248
+ [xmin, ymin, xmax, ymax], fill=None, outline=color_list[cls_index], width=2
249
+ ) # 边界框
250
  countdown_msg = f"{id}-{label} {score:.2f}"
251
  # text_w, text_h = textFont.getsize(countdown_msg) # 标签尺寸 pillow 9.5.0
252
  # left, top, left + width, top + height
 
255
  # 标签背景
256
  img_pil.rectangle(
257
  # (xmin, ymin, xmin + text_w, ymin + text_h), # pillow 9.5.0
258
+ (
259
+ xmin,
260
+ ymin,
261
+ xmin + text_xmax - text_xmin,
262
+ ymin + text_ymax - text_ymin,
263
+ ), # pillow 10.0.0
264
  fill=color_list[cls_index],
265
  outline=color_list[cls_index],
266
  )
 
314
 
315
 
316
  # 目标检测和图像分割模型加载
317
+ def model_det_loading(
318
+ img_path, device_opt, conf, iou, infer_size, max_det, yolo_model="yolov8n.pt"
319
+ ):
320
  model = YOLO(yolo_model)
321
 
322
+ results = model(
323
+ source=img_path,
324
+ device=device_opt,
325
+ imgsz=infer_size,
326
+ conf=conf,
327
+ iou=iou,
328
+ max_det=max_det,
329
+ )
330
  results = list(results)[0]
331
  return results
332
 
 
341
 
342
 
343
  # YOLOv8图片检测函数
344
+ def yolo_det_img(
345
+ img_path, model_name, device_opt, infer_size, conf, iou, max_det, obj_size
346
+ ):
347
  global model, model_name_tmp, device_tmp
348
 
349
  s_obj, m_obj, l_obj = 0, 0, 0
 
356
  cls_index_det_stat = [] # 1
357
 
358
  # 模型加载
359
+ predict_results = model_det_loading(
360
+ img_path,
361
+ device_opt,
362
+ conf,
363
+ iou,
364
+ infer_size,
365
+ max_det,
366
+ yolo_model=f"{model_name}.pt",
367
+ )
368
  # 检测参数
369
  xyxy_list = predict_results.boxes.xyxy.cpu().numpy().tolist()
370
  conf_list = predict_results.boxes.conf.cpu().numpy().tolist()
 
373
  # 颜色列表
374
  color_list = random_color(len(model_cls_name_cp), True)
375
 
376
+ img = Image.open(img_path)
377
+ img_cp = img.copy()
378
+
379
  # 图像分割
380
+ if model_name[-3:] == "seg":
381
  # masks_list = predict_results.masks.xyn
382
  masks_list = predict_results.masks.xy
383
  img_mask_merge = seg_output(img_path, masks_list, color_list, cls_list)
384
+ img = Image.fromarray(cv2.cvtColor(img_mask_merge, cv2.COLOR_BGRA2RGB))
 
 
385
 
386
  # 判断检测对象是否为空
387
+ if xyxy_list != []:
 
388
  # ---------------- 加载字体 ----------------
389
  yaml_index = cls_name.index(".yaml")
390
+ cls_name_lang = cls_name[yaml_index - 2 : yaml_index]
391
 
392
  if cls_name_lang == "zh":
393
  # 中文
394
+ textFont = ImageFont.truetype(
395
+ str(f"{ROOT_PATH}/fonts/SimSun.ttf"), size=FONTSIZE
396
+ )
397
  elif cls_name_lang in ["en", "ru", "es", "ar"]:
398
  # 英文、俄语、西班牙语、阿拉伯语
399
+ textFont = ImageFont.truetype(
400
+ str(f"{ROOT_PATH}/fonts/TimesNewRoman.ttf"), size=FONTSIZE
401
+ )
402
  elif cls_name_lang == "ko":
403
  # 韩语
404
+ textFont = ImageFont.truetype(
405
+ str(f"{ROOT_PATH}/fonts/malgun.ttf"), size=FONTSIZE
406
+ )
407
 
408
  for i in range(len(xyxy_list)):
 
409
  # ------------ 边框坐标 ------------
410
  x0 = int(xyxy_list[i][0])
411
  y0 = int(xyxy_list[i][1])
 
417
  h_obj = y1 - y0
418
  area_obj = w_obj * h_obj # 目标尺寸
419
 
420
+ if obj_size == obj_style[0] and area_obj > 0 and area_obj <= 32**2:
421
  obj_cls_index = int(cls_list[i]) # 类别索引
422
  cls_index_det_stat.append(obj_cls_index)
423
 
 
430
  score_det_stat.append(conf)
431
 
432
  area_obj_all.append(area_obj)
433
+ elif (
434
+ obj_size == obj_style[1] and area_obj > 32**2 and area_obj <= 96**2
435
+ ):
436
  obj_cls_index = int(cls_list[i]) # 类别索引
437
  cls_index_det_stat.append(obj_cls_index)
438
 
 
445
  score_det_stat.append(conf)
446
 
447
  area_obj_all.append(area_obj)
448
+ elif obj_size == obj_style[2] and area_obj > 96**2:
449
  obj_cls_index = int(cls_list[i]) # 类别索引
450
  cls_index_det_stat.append(obj_cls_index)
451
 
 
458
  score_det_stat.append(conf)
459
 
460
  area_obj_all.append(area_obj)
461
+ elif obj_size == "所有尺寸":
462
  obj_cls_index = int(cls_list[i]) # 类别索引
463
  cls_index_det_stat.append(obj_cls_index)
464
 
 
472
 
473
  area_obj_all.append(area_obj)
474
 
475
+ det_img = pil_draw(
476
+ img,
477
+ score_det_stat,
478
+ bbox_det_stat,
479
+ cls_det_stat,
480
+ cls_index_det_stat,
481
+ textFont,
482
+ color_list,
483
+ )
484
 
485
  # -------------- 目标尺寸计算 --------------
486
  for i in range(len(area_obj_all)):
487
+ if 0 < area_obj_all[i] <= 32**2:
488
  s_obj = s_obj + 1
489
+ elif 32**2 < area_obj_all[i] <= 96**2:
490
  m_obj = m_obj + 1
491
+ elif area_obj_all[i] > 96**2:
492
  l_obj = l_obj + 1
493
 
494
  sml_obj_total = s_obj + m_obj + l_obj
495
  objSize_dict = {}
496
+ objSize_dict = {
497
+ obj_style[i]: [s_obj, m_obj, l_obj][i] / sml_obj_total for i in range(3)
498
+ }
499
 
500
  # ------------ 类别统计 ------------
501
  clsRatio_dict = {}
 
504
  for k, v in clsDet_dict.items():
505
  clsRatio_dict[k] = v / clsDet_dict_sum
506
 
507
+ images = (det_img, img_cp)
508
+ images_names = ("det", "raw")
509
+ images_path = tempfile.mkdtemp()
510
+ images_paths = []
511
+ uuid_name = uuid.uuid4()
512
+ for image, image_name in zip(images, images_names):
513
+ image.save(images_path + f"/img_{uuid_name}_{image_name}.jpg")
514
+ images_paths.append(images_path + f"/img_{uuid_name}_{image_name}.jpg")
515
+
516
  gr.Info("图片检测成功!")
517
+ return (det_img, img_cp), images_paths, objSize_dict, clsRatio_dict
518
  else:
519
  raise gr.Error("图片检测失败!")
520
 
521
 
522
  # YOLOv8图片分类函数
523
  def yolo_cls_img(img_path, model_name):
 
524
  # 模型加载
525
  predict_results = model_cls_loading(img_path, yolo_model=f"{model_name}.pt")
526
 
 
567
  model_cls_name_cp = model_cls_name.copy() # 类别名称
568
  model_cls_imagenet_name_cp = model_cls_imagenet_name.copy() # 类别名称
569
 
570
+ custom_theme = gr.themes.Soft(primary_hue="blue").set(
571
+ button_secondary_background_fill="*neutral_100",
572
+ button_secondary_background_fill_hover="*neutral_200",
573
+ )
574
 
575
  custom_css = GYD_CSS
576
 
 
585
  with gr.Tabs():
586
  with gr.TabItem("目标检测与图像分割"):
587
  with gr.Row():
588
+ inputs_img = gr.Image(
589
+ image_mode="RGB", type="filepath", label="原始图片"
590
+ )
591
  with gr.Row():
592
+ device_opt = gr.Radio(
593
+ choices=["cpu", "0", "1", "2", "3"],
594
+ value="cpu",
595
+ label="设备",
596
+ )
597
  with gr.Row():
598
+ inputs_model = gr.Dropdown(
599
+ choices=model_names,
600
+ value=model_name,
601
+ type="value",
602
+ label="模型",
603
+ )
604
  with gr.Accordion("高级设置", open=True):
605
  with gr.Row():
606
+ inputs_size = gr.Slider(
607
+ 320,
608
+ 1600,
609
+ step=1,
610
+ value=inference_size,
611
+ label="推理尺寸",
612
+ )
613
+ max_det = gr.Slider(
614
+ 1, 1000, step=1, value=max_detnum, label="最大检测数"
615
+ )
616
  with gr.Row():
617
+ input_conf = gr.Slider(
618
+ 0,
619
+ 1,
620
+ step=slider_step,
621
+ value=nms_conf,
622
+ label="置信度阈值",
623
+ )
624
+ inputs_iou = gr.Slider(
625
+ 0,
626
+ 1,
627
+ step=slider_step,
628
+ value=nms_iou,
629
+ label="IoU 阈值",
630
+ )
631
  with gr.Row():
632
+ obj_size = gr.Radio(
633
+ choices=["所有尺寸", "小目标", "中目标", "大目标"],
634
+ value="所有尺寸",
635
+ label="目标尺寸",
636
+ )
637
  with gr.Row():
638
  gr.ClearButton(inputs_img, value="清除")
639
+ det_btn_img = gr.Button(value="检测", variant="primary")
640
  with gr.Row():
641
  gr.Examples(
642
  examples=EXAMPLES_DET,
643
  fn=yolo_det_img,
644
  inputs=[
645
+ inputs_img,
646
+ inputs_model,
647
+ device_opt,
648
+ inputs_size,
649
+ input_conf,
650
+ inputs_iou,
651
+ max_det,
652
+ obj_size,
653
+ ],
654
  # outputs=[outputs_img, outputs_objSize, outputs_clsSize],
655
+ cache_examples=False,
656
+ )
657
 
658
  with gr.TabItem("图像分类"):
659
  with gr.Row():
660
+ inputs_img_cls = gr.Image(
661
+ image_mode="RGB", type="filepath", label="原始图片"
662
+ )
663
  with gr.Row():
664
+ inputs_model_cls = gr.Dropdown(
665
+ choices=[
666
+ "yolov8n-cls",
667
+ "yolov8s-cls",
668
+ "yolov8l-cls",
669
+ "yolov8m-cls",
670
+ "yolov8x-cls",
671
+ ],
672
+ value="yolov8s-cls",
673
+ type="value",
674
+ label="模型",
675
+ )
676
  with gr.Row():
677
  gr.ClearButton(inputs_img, value="清除")
678
+ det_btn_img_cls = gr.Button(value="检测", variant="primary")
679
  with gr.Row():
680
  gr.Examples(
681
  examples=EXAMPLES_CLAS,
682
  fn=yolo_cls_img,
683
  inputs=[inputs_img_cls, inputs_model_cls],
684
  # outputs=[outputs_img_cls, outputs_ratio_cls],
685
+ cache_examples=False,
686
+ )
687
 
688
+ # -------- 输出 --------
689
  with gr.Column(scale=1):
690
  with gr.Tabs():
691
  with gr.TabItem("目标检测与图像分割"):
692
+ # with gr.Row():
693
+ # outputs_img = gr.Image(type="pil", label="检测图片")
694
+ with gr.Row():
695
+ outputs_img_slider = ImageSlider(position=0.5, label="检测图片")
696
  with gr.Row():
697
+ outputs_imgfiles = gr.Files(label="图片下载")
698
  with gr.Row():
699
  outputs_objSize = gr.Label(label="目标尺寸占比统计")
700
  with gr.Row():
 
747
  """
748
  )
749
 
750
+ det_btn_img.click(
751
+ fn=yolo_det_img,
752
+ inputs=[
753
+ inputs_img,
754
+ inputs_model,
755
+ device_opt,
756
+ inputs_size,
757
+ input_conf,
758
+ inputs_iou,
759
+ max_det,
760
+ obj_size,
761
+ ],
762
+ outputs=[
763
+ outputs_img_slider,
764
+ outputs_imgfiles,
765
+ outputs_objSize,
766
+ outputs_clsSize,
767
+ ],
768
+ )
769
 
770
+ det_btn_img_cls.click(
771
+ fn=yolo_cls_img,
772
+ inputs=[inputs_img_cls, inputs_model_cls],
773
+ outputs=[outputs_img_cls, outputs_ratio_cls],
774
+ )
775
 
776
  return gyd
777