tumuyan2 commited on
Commit
78ca519
·
1 Parent(s): 3ba1191

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -24,7 +24,7 @@ def print_log(task_id, filename, stage, status):
24
  print(f"任务{task_id}: {filename}, [{status}] {stage}")
25
 
26
  # 修改 start_process 函数,处理新增输入
27
- def start_process(input_file, input_url, input2, shape0_str, shape1_str, output_type="+ Dynamic MNN & NCNN", input_suffix=".pth"):
28
  global task_counter
29
  task_counter += 1
30
  task_id = task_counter
@@ -221,7 +221,7 @@ def start_process(input_file, input_url, input2, shape0_str, shape1_str, output_
221
  print_log(task_id, input2, "转换为TorchScript模型", "跳过")
222
  log += "跳过转换为TorchScript模型\n"
223
  yield [], log
224
- else:
225
  print_log(task_id, input2, "转换为TorchScript模型", "开始")
226
  log+= "转换为TorchScript模型…\n"
227
  yield [], log
@@ -239,19 +239,20 @@ def start_process(input_file, input_url, input2, shape0_str, shape1_str, output_
239
 
240
 
241
  # 转换为 ONNX 模型
242
- if str(width_ratio) in input2:
243
- onnx_path = output_base + ".onnx"
244
- else:
245
- onnx_path = output_base + "-x" + str(width_ratio) + ".onnx"
246
- if os.path.exists(onnx_path):
247
- print_log(task_id, input2, "转换为ONNX模型", "跳过")
248
- log += "跳过转换为ONNX模型\n"
249
- yield [], log
250
- else:
251
- print_log(task_id, input2, "转换为ONNX模型", "开始")
252
- log += "转换为ONNX模型…\n"
253
- yield [], log
254
- torch.onnx.export(torch_model, example_input, onnx_path, opset_version=17, input_names=["input"], output_names=["output"])
 
255
 
256
 
257
  # 转换为 mnn 模型
@@ -390,8 +391,8 @@ with gr.Blocks() as demo:
390
 
391
  with gr.Row():
392
  input2 = gr.Textbox(label="自定义文件名")
393
- output_type = gr.Dropdown(choices=["TorchScript & ONNX", "+ Fixed MNN & NCNN", "+ Dynamic MNN & NCNN", "Fixed NCNN", "Dynamic NCNN", "Fixed MNN", "Dynamic MNN" ], value="+ Dynamic MNN & NCNN", label="模型类型")
394
- # 修改为字符串输入控件
395
  shape0_str = gr.Textbox(label="shape0 (逗号分隔的整数)", value="1,3,128,128")
396
  shape1_str = gr.Textbox(label="shape1 (逗号分隔的整数)", value="0,0,0,0")
397
  with gr.Row():
@@ -424,11 +425,11 @@ with gr.Blocks() as demo:
424
  ["","https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth", "", "1,3,128,128", "0,0,0,0"],
425
  ["","https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth", "", "1,3,128,128", "0,0,0,0"],
426
  ]
427
- # gr.Examples(
428
- # examples=examples,
429
- # inputs=[input1_file, input1, input2, shape0_str, shape1_str],
430
- # outputs=[output, log_textbox],
431
- # fn=start_process
432
- # )
433
 
434
  demo.launch()
 
24
  print(f"任务{task_id}: {filename}, [{status}] {stage}")
25
 
26
  # 修改 start_process 函数,处理新增输入
27
+ def start_process(input_file, input_url, input2, shape0_str, shape1_str, output_type=["TorchScript", "ONNX", "MNN", "ONNX"], input_suffix=".pth"):
28
  global task_counter
29
  task_counter += 1
30
  task_id = task_counter
 
221
  print_log(task_id, input2, "转换为TorchScript模型", "跳过")
222
  log += "跳过转换为TorchScript模型\n"
223
  yield [], log
224
+ elif "TorchScript" in output_type:
225
  print_log(task_id, input2, "转换为TorchScript模型", "开始")
226
  log+= "转换为TorchScript模型…\n"
227
  yield [], log
 
239
 
240
 
241
  # 转换为 ONNX 模型
242
+ if "ONNX" in output_type or "NCNN" in output_type or "MNN" in output_type:
243
+ if str(width_ratio) in input2:
244
+ onnx_path = output_base + ".onnx"
245
+ else:
246
+ onnx_path = output_base + "-x" + str(width_ratio) + ".onnx"
247
+ if os.path.exists(onnx_path):
248
+ print_log(task_id, input2, "转换为ONNX模型", "跳过")
249
+ log += "跳过转换为ONNX模型\n"
250
+ yield [], log
251
+ else:
252
+ print_log(task_id, input2, "转换为ONNX模型", "开始")
253
+ log += "转换为ONNX模型…\n"
254
+ yield [], log
255
+ torch.onnx.export(torch_model, example_input, onnx_path, opset_version=17, input_names=["input"], output_names=["output"])
256
 
257
 
258
  # 转换为 mnn 模型
 
391
 
392
  with gr.Row():
393
  input2 = gr.Textbox(label="自定义文件名")
394
+ output_type = gr.Dropdown( ["TorchScript", "ONNX", "Fixed", "MNN", "NCNN"], value=["TorchScript", "ONNX", "MNN", "ONNX"], multiselect=True, label="模型类型", info="1. 生成mnn和ncnn模型必须先生成onnx模型;2.如果选项中包含了Fixed,那么输出的onnx和mnn模型都使用固定shape的input。"
395
+ ),
396
  shape0_str = gr.Textbox(label="shape0 (逗号分隔的整数)", value="1,3,128,128")
397
  shape1_str = gr.Textbox(label="shape1 (逗号分隔的整数)", value="0,0,0,0")
398
  with gr.Row():
 
425
  ["","https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth", "", "1,3,128,128", "0,0,0,0"],
426
  ["","https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth", "", "1,3,128,128", "0,0,0,0"],
427
  ]
428
+ gr.Examples(
429
+ examples=examples,
430
+ inputs=[input1_file, input1, input2, shape0_str, shape1_str],
431
+ outputs=[output, log_textbox],
432
+ fn=start_process
433
+ )
434
 
435
  demo.launch()