Zengyf-CVer commited on
Commit
1b9e070
1 Parent(s): a4c3d05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -36
app.py CHANGED
@@ -1,7 +1,7 @@
1
- # Gradio YOLOv8 Det v1.0.0
2
  # 创建人:曾逸夫
3
- # 创建时间:2023-10-28
4
- # pip install gradio>=4.0.2
5
 
6
  import argparse
7
  import csv
@@ -34,7 +34,7 @@ from util.fonts_opt import is_fonts
34
  ROOT_PATH = sys.path[0] # 根目录
35
 
36
  # Gradio YOLOv8 Det版本
37
- GYD_VERSION = "Gradio YOLOv8 Det v1.0.0"
38
 
39
  # 文件后缀
40
  suffix_list = [".csv", ".yaml"]
@@ -47,9 +47,7 @@ obj_style = ["小目标", "中目标", "大目标"]
47
 
48
 
49
  def parse_args(known=False):
50
- parser = argparse.ArgumentParser(description="Gradio YOLOv8 Det v1.0.0")
51
- parser.add_argument("--model_type", "-mt", default="online", type=str, help="model type")
52
- parser.add_argument("--source", "-src", default="upload", type=str, help="image input source")
53
  parser.add_argument("--model_name", "-mn", default="yolov8s", type=str, help="model name")
54
  parser.add_argument(
55
  "--model_cfg",
@@ -96,7 +94,7 @@ def parse_args(known=False):
96
  default=False,
97
  help="is login",
98
  )
99
- parser.add_argument("--server_port", "-sp", default=7861, type=int, help="server port")
100
 
101
  args = parser.parse_known_args()[0] if known else parser.parse_args()
102
  return args
@@ -228,7 +226,7 @@ def seg_output(img_path, seg_mask_list, color_list, cls_list):
228
  return img_mask_merge
229
 
230
 
231
- # 模型加载
232
  def model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model="yolov8n.pt"):
233
  model = YOLO(yolo_model)
234
 
@@ -237,6 +235,15 @@ def model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_mod
237
  return results
238
 
239
 
 
 
 
 
 
 
 
 
 
240
  # YOLOv8图片检测函数
241
  def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_det, obj_size):
242
 
@@ -381,12 +388,35 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
381
  raise gr.Error("图片检测失败!")
382
 
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  def main(args):
385
  gr.close_all()
386
 
387
  global model_cls_name_cp, cls_name
388
 
389
- source = args.source
390
  nms_conf = args.nms_conf
391
  nms_iou = args.nms_iou
392
  model_name = args.model_name
@@ -406,8 +436,6 @@ def main(args):
406
  # ------------ Gradio Blocks ------------
407
  with gr.Blocks() as gyd:
408
  with gr.Row():
409
- # gr.HTML(value="<div align='center' style='font-size:25px'>Gradio YOLOv8 Det<div> \
410
- # <div align='center' style='font-size:20px'>基于 YOLOv8 的目标检测与图像分割系统<div>")
411
  gr.Markdown(value="<p align='center'><a href='https://gitee.com/CV_Lab/gradio-yolov8-det'>\
412
  <img src='https://pycver.gitee.io/ows-pics/imgs/gradio_yolov8_det_logo.png' alt='Simple Icons' ></a>\
413
  <p align='center'>基于 Gradio 的 YOLOv8 通用目标检测与图像分割演示系统</p><p align='center'>可自定义检测模型、演示便捷、安装简单</p>"
@@ -416,32 +444,51 @@ def main(args):
416
  gr.Markdown(value="作者:曾逸夫,Gitee:https://gitee.com/PyCVer ,Github:https://github.com/Zengyf-CVer")
417
  with gr.Row():
418
  with gr.Column(scale=1):
419
- with gr.Row():
420
- inputs_img = gr.Image(image_mode="RGB", sources=source, type="filepath", label="原始图片")
421
-
422
- with gr.Row():
423
- device_opt = gr.Radio(choices=["cpu", "0", "1", "2", "3"], value="cpu", label="设备")
424
- with gr.Row():
425
- inputs_model = gr.Dropdown(choices=model_names, value=model_name, type="value", label="模型")
426
- with gr.Row():
427
- inputs_size = gr.Slider(320, 1600, step=1, value=inference_size, label="推理尺寸")
428
- max_det = gr.Slider(1, 1000, step=1, value=max_detnum, label="最大检测数")
429
- with gr.Row():
430
- input_conf = gr.Slider(0, 1, step=slider_step, value=nms_conf, label="置信度阈值")
431
- inputs_iou = gr.Slider(0, 1, step=slider_step, value=nms_iou, label="IoU 阈值")
432
- with gr.Row():
433
- obj_size = gr.Radio(choices=["所有尺寸", "小目标", "中目标", "大目标"], value="所有尺寸", label="目标尺寸")
434
- with gr.Row():
435
- gr.ClearButton(inputs_img, value="清除")
436
- det_btn_img = gr.Button(value='检测', variant="primary")
 
 
 
 
 
 
 
 
 
 
437
 
438
  with gr.Column(scale=1):
439
- with gr.Row():
440
- outputs_img = gr.Image(type="pil", label="检测图片")
441
- with gr.Row():
442
- outputs_objSize = gr.Label(label="目标尺寸占比统计")
443
- with gr.Row():
444
- outputs_clsSize = gr.Label(label="类别检测占比统计")
 
 
 
 
 
 
 
 
 
445
 
446
  with gr.Row():
447
  example_list = [
@@ -463,6 +510,11 @@ def main(args):
463
  obj_size],
464
  outputs=[outputs_img, outputs_objSize, outputs_clsSize])
465
 
 
 
 
 
 
466
  return gyd
467
 
468
 
 
1
+ # Gradio YOLOv8 Det v1.1.0
2
  # 创建人:曾逸夫
3
+ # 创建时间:2023-11-04
4
+ # pip install gradio>=4.1.1
5
 
6
  import argparse
7
  import csv
 
34
  ROOT_PATH = sys.path[0] # 根目录
35
 
36
  # Gradio YOLOv8 Det版本
37
+ GYD_VERSION = "Gradio YOLOv8 Det v1.1.0"
38
 
39
  # 文件后缀
40
  suffix_list = [".csv", ".yaml"]
 
47
 
48
 
49
  def parse_args(known=False):
50
+ parser = argparse.ArgumentParser(description="Gradio YOLOv8 Det v1.1.0")
 
 
51
  parser.add_argument("--model_name", "-mn", default="yolov8s", type=str, help="model name")
52
  parser.add_argument(
53
  "--model_cfg",
 
94
  default=False,
95
  help="is login",
96
  )
97
+ parser.add_argument("--server_port", "-sp", default=7860, type=int, help="server port")
98
 
99
  args = parser.parse_known_args()[0] if known else parser.parse_args()
100
  return args
 
226
  return img_mask_merge
227
 
228
 
229
+ # 目标检测和图像分割模型加载
230
  def model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model="yolov8n.pt"):
231
  model = YOLO(yolo_model)
232
 
 
235
  return results
236
 
237
 
238
+ # 图像分类模型加载
239
+ def model_cls_loading(img_path, yolo_model="yolov8s-cls.pt"):
240
+ model = YOLO(yolo_model)
241
+
242
+ results = model(source=img_path)
243
+ results = list(results)[0]
244
+ return results
245
+
246
+
247
  # YOLOv8图片检测函数
248
  def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_det, obj_size):
249
 
 
388
  raise gr.Error("图片检测失败!")
389
 
390
 
391
+ # YOLOv8图片分类函数
392
+ def yolo_cls_img(img_path, model_name):
393
+
394
+ # 模型加载
395
+ predict_results = model_cls_loading(img_path, yolo_model=f"{model_name}.pt")
396
+
397
+ det_img = Image.open(img_path)
398
+ clas_ratio_list = predict_results.probs.top5conf.tolist()
399
+ clas_index_list = predict_results.probs.top5
400
+
401
+ clas_name_list = []
402
+ for i in clas_index_list:
403
+ clas_name_list.append(predict_results.names[i])
404
+
405
+ clsRatio_dict = {}
406
+ index_cls = 0
407
+ clsDet_dict = Counter(clas_name_list)
408
+ for k, v in clsDet_dict.items():
409
+ clsRatio_dict[k] = clas_ratio_list[index_cls]
410
+ index_cls+=1
411
+
412
+ return det_img, clsRatio_dict
413
+
414
+
415
  def main(args):
416
  gr.close_all()
417
 
418
  global model_cls_name_cp, cls_name
419
 
 
420
  nms_conf = args.nms_conf
421
  nms_iou = args.nms_iou
422
  model_name = args.model_name
 
436
  # ------------ Gradio Blocks ------------
437
  with gr.Blocks() as gyd:
438
  with gr.Row():
 
 
439
  gr.Markdown(value="<p align='center'><a href='https://gitee.com/CV_Lab/gradio-yolov8-det'>\
440
  <img src='https://pycver.gitee.io/ows-pics/imgs/gradio_yolov8_det_logo.png' alt='Simple Icons' ></a>\
441
  <p align='center'>基于 Gradio 的 YOLOv8 通用目标检测与图像分割演示系统</p><p align='center'>可自定义检测模型、演示便捷、安装简单</p>"
 
444
  gr.Markdown(value="作者:曾逸夫,Gitee:https://gitee.com/PyCVer ,Github:https://github.com/Zengyf-CVer")
445
  with gr.Row():
446
  with gr.Column(scale=1):
447
+ with gr.Tabs():
448
+ with gr.TabItem("目标检测与图像分割"):
449
+ with gr.Row():
450
+ inputs_img = gr.Image(image_mode="RGB", type="filepath", label="原始图片")
451
+ with gr.Row():
452
+ device_opt = gr.Radio(choices=["cpu", "0", "1", "2", "3"], value="cpu", label="设备")
453
+ with gr.Row():
454
+ inputs_model = gr.Dropdown(choices=model_names, value=model_name, type="value", label="模型")
455
+ with gr.Row():
456
+ inputs_size = gr.Slider(320, 1600, step=1, value=inference_size, label="推理尺寸")
457
+ max_det = gr.Slider(1, 1000, step=1, value=max_detnum, label="最大检测数")
458
+ with gr.Row():
459
+ input_conf = gr.Slider(0, 1, step=slider_step, value=nms_conf, label="置信度阈值")
460
+ inputs_iou = gr.Slider(0, 1, step=slider_step, value=nms_iou, label="IoU 阈值")
461
+ with gr.Row():
462
+ obj_size = gr.Radio(choices=["所有尺寸", "小目标", "中目标", "大目标"], value="所有尺寸", label="目标尺寸")
463
+ with gr.Row():
464
+ gr.ClearButton(inputs_img, value="清除")
465
+ det_btn_img = gr.Button(value='检测', variant="primary")
466
+
467
+ with gr.TabItem("图像分类"):
468
+ with gr.Row():
469
+ inputs_img_cls = gr.Image(image_mode="RGB", type="filepath", label="原始图片")
470
+ with gr.Row():
471
+ inputs_model_cls = gr.Dropdown(choices=["yolov8n-cls", "yolov8s-cls", "yolov8l-cls", "yolov8m-cls", "yolov8x-cls"], value="yolov8s-cls", type="value", label="模型")
472
+ with gr.Row():
473
+ gr.ClearButton(inputs_img, value="清除")
474
+ det_btn_img_cls = gr.Button(value='检测', variant="primary")
475
 
476
  with gr.Column(scale=1):
477
+ with gr.Tabs():
478
+ with gr.TabItem("目标检测与图像分割"):
479
+ with gr.Row():
480
+ outputs_img = gr.Image(type="pil", label="检测图片")
481
+ with gr.Row():
482
+ outputs_objSize = gr.Label(label="目标尺寸占比统计")
483
+ with gr.Row():
484
+ outputs_clsSize = gr.Label(label="类别检测占比统计")
485
+
486
+ with gr.TabItem("图像分类"):
487
+ with gr.Row():
488
+ outputs_img_cls = gr.Image(type="pil", label="检测图片")
489
+ with gr.Row():
490
+ outputs_ratio_cls = gr.Label(label="图像分类结果")
491
+
492
 
493
  with gr.Row():
494
  example_list = [
 
510
  obj_size],
511
  outputs=[outputs_img, outputs_objSize, outputs_clsSize])
512
 
513
+ det_btn_img_cls.click(fn=yolo_cls_img,
514
+ inputs=[
515
+ inputs_img_cls, inputs_model_cls],
516
+ outputs=[outputs_img_cls, outputs_ratio_cls])
517
+
518
  return gyd
519
 
520