tumuyan2 commited on
Commit
da006e5
·
1 Parent(s): 20ad880
Files changed (2) hide show
  1. app.py +12 -10
  2. pth2onnx.py +16 -9
app.py CHANGED
@@ -197,7 +197,8 @@ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min:
197
  os.remove(save_path)
198
  return None
199
 
200
- async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool,onnxsim:bool,opset:int):
 
201
 
202
  log = ('初始化日志记录...\n')
203
  print_log(task_id, '初始化日志记录', '开始')
@@ -226,7 +227,7 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
226
  yield [],log
227
  else:
228
  print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
229
- onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16, simplify_model=onnxsim, opset=opset)
230
  if onnx_path:
231
  log += ( f'成功生成ONNX模型: {onnx_path}\n')
232
  print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
@@ -277,12 +278,11 @@ with gr.Blocks() as demo:
277
  return gr.update(visible=False), gr.update(visible=True)
278
 
279
  input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
280
-
281
- tilesize = gr.Number(label="Tilesize", value=0, precision=0)
282
- # 添加fp16和try_run复选框
283
  fp16 = gr.Checkbox(label="FP16", value=False)
284
  onnxsim = gr.Checkbox(label="ONNX export simplify model", value=False)
285
- opset = gr.Number(label="ONNX export opset version, suggest 9/11/13/16/17/18", value=13, precision=0)
286
  try_run = gr.Checkbox(label="MNNSR test", value=False)
287
  convert_btn = gr.Button("Run")
288
  with gr.Column():
@@ -298,7 +298,8 @@ with gr.Blocks() as demo:
298
  return gr.update(visible=False)
299
  try_run.change(show_try_run, inputs=try_run, outputs=img_output)
300
 
301
- async def process_model(input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, try_run):
 
302
 
303
  global task_counter
304
  task_counter += 1
@@ -347,7 +348,8 @@ with gr.Blocks() as demo:
347
  onnx_path = None
348
  mnn_path = None
349
  # 调用重命名后的函数
350
- async for result in _process_model(model_input, tilesize if tilesize>0 else 64, output_dir, task_counter, fp16, onnxsim, opset):
 
351
  if isinstance(result, tuple) and len(result) == 3:
352
  onnx_path, mnn_path, process_log = result
353
  yield onnx_path, mnn_path, log+process_log, None
@@ -360,7 +362,7 @@ with gr.Blocks() as demo:
360
  if mnn_path:
361
  if try_run:
362
  print_log(task_counter, f'测试模型: {mnn_path}', '开始')
363
- processed_image_np, load_time, infer_time = modelTest_for_gradio(mnn_path, "./sample.jpg", int(tilesize), 0)
364
  processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
365
  # processed_image_pil = Image.fromarray(processed_image_np)
366
  yield onnx_path, mnn_path, log+process_log+f"MNNSR 加载模型用时 {load_time:.4f} 秒, 推理({tilesize} px)用时 {infer_time:.4f} 秒", processed_image_pil
@@ -370,7 +372,7 @@ with gr.Blocks() as demo:
370
 
371
  convert_btn.click(
372
  process_model,
373
- inputs=[input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, try_run],
374
  outputs=[onnx_output, mnn_output, log_box, img_output],
375
  api_name="convert_nmm_model"
376
  )
 
197
  os.remove(save_path)
198
  return None
199
 
200
+ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool,onnxsim:bool,opset:int,dynamic_axes:bool):
201
+
202
 
203
  log = ('初始化日志记录...\n')
204
  print_log(task_id, '初始化日志记录', '开始')
 
227
  yield [],log
228
  else:
229
  print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
230
+ onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16, simplify_model=onnxsim, opset=opset, dynamic_axes=dynamic_axes)
231
  if onnx_path:
232
  log += ( f'成功生成ONNX模型: {onnx_path}\n')
233
  print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
 
278
  return gr.update(visible=False), gr.update(visible=True)
279
 
280
  input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
281
+ tilesize = gr.Number(label="Dummy input width/height, default 64", value=64, precision=0)
282
+ opset = gr.Number(label="ONNX export opset version, suggest 9/11/13/16/17/18", value=13, precision=0)
 
283
  fp16 = gr.Checkbox(label="FP16", value=False)
284
  onnxsim = gr.Checkbox(label="ONNX export simplify model", value=False)
285
+ dynamic_axes = gr.Checkbox(label="ONNX input apply dynamic axes", value=True)
286
  try_run = gr.Checkbox(label="MNNSR test", value=False)
287
  convert_btn = gr.Button("Run")
288
  with gr.Column():
 
298
  return gr.update(visible=False)
299
  try_run.change(show_try_run, inputs=try_run, outputs=img_output)
300
 
301
+ async def process_model(input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, dynamic_axes, try_run):
302
+
303
 
304
  global task_counter
305
  task_counter += 1
 
348
  onnx_path = None
349
  mnn_path = None
350
  # 调用重命名后的函数
351
+ async for result in _process_model(model_input, tilesize if tilesize>0 else 64, output_dir, task_counter, fp16, onnxsim, opset, dynamic_axes):
352
+
353
  if isinstance(result, tuple) and len(result) == 3:
354
  onnx_path, mnn_path, process_log = result
355
  yield onnx_path, mnn_path, log+process_log, None
 
362
  if mnn_path:
363
  if try_run:
364
  print_log(task_counter, f'测试模型: {mnn_path}', '开始')
365
+ processed_image_np, load_time, infer_time = modelTest_for_gradio(mnn_path, "./sample.jpg", tilesize if tilesize>0 and dynamic_axes else 0, 0)
366
  processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
367
  # processed_image_pil = Image.fromarray(processed_image_np)
368
  yield onnx_path, mnn_path, log+process_log+f"MNNSR 加载模型用时 {load_time:.4f} 秒, 推理({tilesize} px)用时 {infer_time:.4f} 秒", processed_image_pil
 
372
 
373
  convert_btn.click(
374
  process_model,
375
+ inputs=[input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, dynamic_axes, try_run],
376
  outputs=[onnx_output, mnn_output, log_box, img_output],
377
  api_name="convert_nmm_model"
378
  )
pth2onnx.py CHANGED
@@ -6,7 +6,7 @@ import onnx
6
  from spandrel import ImageModelDescriptor, ModelLoader
7
  from onnxsim import simplify
8
 
9
- def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11):
10
  """
11
  Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
12
 
@@ -86,6 +86,14 @@ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tiles
86
  print(f"ONNX model exporting...")
87
  try:
88
  # Export the model
 
 
 
 
 
 
 
 
89
  torch.onnx.export(
90
  torch_model, # The model instance
91
  example_input, # An example input tensor
@@ -95,10 +103,7 @@ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tiles
95
  do_constant_folding=True, # Whether to execute constant folding for optimization
96
  input_names=['input'], # The model's input names
97
  output_names=['output'], # The model's output names
98
- dynamic_axes={ # Allow variable input/output dimensions
99
- "input": {0: "batch_size", 2: "height", 3: "width"}, # Batch, H, W can vary
100
- "output": {0: "batch_size", 2: "height", 3: "width"},# Batch, H, W can vary
101
- }
102
  )
103
  print(f"ONNX model export successful: {onnx_path}")
104
 
@@ -133,16 +138,18 @@ if __name__ == "__main__":
133
  parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
134
  parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
135
  parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
 
136
  args = parser.parse_args()
137
 
138
  success = convert_pth_to_onnx(
139
- pth_path=args.pth_path,
140
- onnx_path=args.onnx_path,
141
  channel=args.channel,
142
  tilesize=args.tilesize,
143
- use_fp16=args.use_fp16,
144
- simplify_model=args.simplify_model,
145
  opset=args.opset,
 
146
  )
147
 
148
  if success:
 
6
  from spandrel import ImageModelDescriptor, ModelLoader
7
  from onnxsim import simplify
8
 
9
+ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11, dynamic_axes: bool = True):
10
  """
11
  Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
12
 
 
86
  print(f"ONNX model exporting...")
87
  try:
88
  # Export the model
89
+ if dynamic_axes:
90
+ axes = {
91
+ "input": {2: "height", 3: "width"},
92
+ "output": {2: "height", 3: "width"},
93
+ }
94
+ else:
95
+ axes = {}
96
+
97
  torch.onnx.export(
98
  torch_model, # The model instance
99
  example_input, # An example input tensor
 
103
  do_constant_folding=True, # Whether to execute constant folding for optimization
104
  input_names=['input'], # The model's input names
105
  output_names=['output'], # The model's output names
106
+ dynamic_axes=axes
 
 
 
107
  )
108
  print(f"ONNX model export successful: {onnx_path}")
109
 
 
138
  parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
139
  parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
140
  parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
141
+ parser.add_argument('--fixed_axes', action='store_true', help='Use dynamic axes.')
142
  args = parser.parse_args()
143
 
144
  success = convert_pth_to_onnx(
145
+ pth_path=args.pthpath,
146
+ onnx_path=args.onnxpath,
147
  channel=args.channel,
148
  tilesize=args.tilesize,
149
+ use_fp16=args.fp16,
150
+ simplify_model=args.simplify,
151
  opset=args.opset,
152
+ dynamic_axes= not args.fixed_axes,
153
  )
154
 
155
  if success: