Zengyf-CVer commited on
Commit
0c35b7f
·
1 Parent(s): 9e12f9c

fix yaml load

Browse files
Files changed (1) hide show
  1. app.py +6 -20
app.py CHANGED
@@ -10,12 +10,11 @@ import gradio as gr
10
  import torch
11
  import yaml
12
  from PIL import Image
13
- from zmq import device
14
 
15
  ROOT_PATH = sys.path[0] # 根目录
16
 
17
- # 本地模型路径
18
- local_model_path = f"{ROOT_PATH}/yolov5"
19
 
20
 
21
  # 模型名称临时变量
@@ -67,7 +66,7 @@ def parse_args(known=False):
67
  "-dev",
68
  default="cpu",
69
  type=str,
70
- help="cuda or cpu",
71
  )
72
 
73
  args = parser.parse_known_args()[0] if known else parser.parse_args()
@@ -78,18 +77,7 @@ def parse_args(known=False):
78
  def model_loading(model_name, device):
79
 
80
  # 加载本地模型
81
- # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', device=device)
82
- model = torch.hub.load(
83
- "ultralytics/yolov5", model_name, force_reload=True, device=device
84
- )
85
-
86
- # model = torch.hub.load(
87
- # local_model_path,
88
- # "custom",
89
- # path=f"{local_model_path}/{model_name}",
90
- # source="local",
91
- # device=device,
92
- # )
93
 
94
  return model
95
 
@@ -151,9 +139,7 @@ def yolo_det(img, device, model_name, conf, iou, label_opt, model_cls):
151
 
152
  # yaml文件解析
153
  def yaml_parse(file_path):
154
- return yaml.load(
155
- open(file_path, "r", encoding="utf-8").read(), Loader=yaml.FullLoader
156
- )
157
 
158
 
159
  def main(args):
@@ -182,7 +168,7 @@ def main(args):
182
  # -------------------输入组件-------------------
183
  inputs_img = gr.inputs.Image(type="pil", label="原始图片")
184
  device = gr.inputs.Dropdown(
185
- choices=["0", "cpu"], default=device, type="value", label="设备"
186
  )
187
  inputs_model = gr.inputs.Dropdown(
188
  choices=model_names, default=model_name, type="value", label="模型"
 
10
  import torch
11
  import yaml
12
  from PIL import Image
 
13
 
14
  ROOT_PATH = sys.path[0] # 根目录
15
 
16
+ # 模型路径
17
+ model_path = "ultralytics/yolov5"
18
 
19
 
20
  # 模型名称临时变量
 
66
  "-dev",
67
  default="cpu",
68
  type=str,
69
+ help="cuda or cpu, hugging face only cpu",
70
  )
71
 
72
  args = parser.parse_known_args()[0] if known else parser.parse_args()
 
77
  def model_loading(model_name, device):
78
 
79
  # 加载本地模型
80
+ model = torch.hub.load(model_path, model_name, force_reload=True, device=device)
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  return model
83
 
 
139
 
140
  # yaml文件解析
141
  def yaml_parse(file_path):
142
+ return yaml.safe_load(open(file_path, "r", encoding="utf-8").read())
 
 
143
 
144
 
145
  def main(args):
 
168
  # -------------------输入组件-------------------
169
  inputs_img = gr.inputs.Image(type="pil", label="原始图片")
170
  device = gr.inputs.Dropdown(
171
+ choices=["cpu"], default=device, type="value", label="设备"
172
  )
173
  inputs_model = gr.inputs.Dropdown(
174
  choices=model_names, default=model_name, type="value", label="模型"