update
Browse files- app.py +12 -10
- pth2onnx.py +16 -9
app.py
CHANGED
@@ -197,7 +197,8 @@ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min:
|
|
197 |
os.remove(save_path)
|
198 |
return None
|
199 |
|
200 |
-
async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool,onnxsim:bool,opset:int):
|
|
|
201 |
|
202 |
log = ('初始化日志记录...\n')
|
203 |
print_log(task_id, '初始化日志记录', '开始')
|
@@ -226,7 +227,7 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
|
|
226 |
yield [],log
|
227 |
else:
|
228 |
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
|
229 |
-
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16, simplify_model=onnxsim, opset=opset)
|
230 |
if onnx_path:
|
231 |
log += ( f'成功生成ONNX模型: {onnx_path}\n')
|
232 |
print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
|
@@ -277,12 +278,11 @@ with gr.Blocks() as demo:
|
|
277 |
return gr.update(visible=False), gr.update(visible=True)
|
278 |
|
279 |
input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
|
280 |
-
|
281 |
-
|
282 |
-
# 添加fp16和try_run复选框
|
283 |
fp16 = gr.Checkbox(label="FP16", value=False)
|
284 |
onnxsim = gr.Checkbox(label="ONNX export simplify model", value=False)
|
285 |
-
|
286 |
try_run = gr.Checkbox(label="MNNSR test", value=False)
|
287 |
convert_btn = gr.Button("Run")
|
288 |
with gr.Column():
|
@@ -298,7 +298,8 @@ with gr.Blocks() as demo:
|
|
298 |
return gr.update(visible=False)
|
299 |
try_run.change(show_try_run, inputs=try_run, outputs=img_output)
|
300 |
|
301 |
-
async def process_model(input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, try_run):
|
|
|
302 |
|
303 |
global task_counter
|
304 |
task_counter += 1
|
@@ -347,7 +348,8 @@ with gr.Blocks() as demo:
|
|
347 |
onnx_path = None
|
348 |
mnn_path = None
|
349 |
# 调用重命名后的函数
|
350 |
-
async for result in _process_model(model_input, tilesize if tilesize>0 else 64, output_dir, task_counter, fp16, onnxsim, opset):
|
|
|
351 |
if isinstance(result, tuple) and len(result) == 3:
|
352 |
onnx_path, mnn_path, process_log = result
|
353 |
yield onnx_path, mnn_path, log+process_log, None
|
@@ -360,7 +362,7 @@ with gr.Blocks() as demo:
|
|
360 |
if mnn_path:
|
361 |
if try_run:
|
362 |
print_log(task_counter, f'测试模型: {mnn_path}', '开始')
|
363 |
-
processed_image_np, load_time, infer_time = modelTest_for_gradio(mnn_path, "./sample.jpg",
|
364 |
processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
|
365 |
# processed_image_pil = Image.fromarray(processed_image_np)
|
366 |
yield onnx_path, mnn_path, log+process_log+f"MNNSR 加载模型用时 {load_time:.4f} 秒, 推理({tilesize} px)用时 {infer_time:.4f} 秒", processed_image_pil
|
@@ -370,7 +372,7 @@ with gr.Blocks() as demo:
|
|
370 |
|
371 |
convert_btn.click(
|
372 |
process_model,
|
373 |
-
inputs=[input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, try_run],
|
374 |
outputs=[onnx_output, mnn_output, log_box, img_output],
|
375 |
api_name="convert_nmm_model"
|
376 |
)
|
|
|
197 |
os.remove(save_path)
|
198 |
return None
|
199 |
|
200 |
+
async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool,onnxsim:bool,opset:int,dynamic_axes:bool):
|
201 |
+
|
202 |
|
203 |
log = ('初始化日志记录...\n')
|
204 |
print_log(task_id, '初始化日志记录', '开始')
|
|
|
227 |
yield [],log
|
228 |
else:
|
229 |
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
|
230 |
+
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16, simplify_model=onnxsim, opset=opset, dynamic_axes=dynamic_axes)
|
231 |
if onnx_path:
|
232 |
log += ( f'成功生成ONNX模型: {onnx_path}\n')
|
233 |
print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
|
|
|
278 |
return gr.update(visible=False), gr.update(visible=True)
|
279 |
|
280 |
input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
|
281 |
+
tilesize = gr.Number(label="Dummy input width/height, default 64", value=64, precision=0)
|
282 |
+
opset = gr.Number(label="ONNX export opset version, suggest 9/11/13/16/17/18", value=13, precision=0)
|
|
|
283 |
fp16 = gr.Checkbox(label="FP16", value=False)
|
284 |
onnxsim = gr.Checkbox(label="ONNX export simplify model", value=False)
|
285 |
+
dynamic_axes = gr.Checkbox(label="ONNX input apply dynamic axes", value=True)
|
286 |
try_run = gr.Checkbox(label="MNNSR test", value=False)
|
287 |
convert_btn = gr.Button("Run")
|
288 |
with gr.Column():
|
|
|
298 |
return gr.update(visible=False)
|
299 |
try_run.change(show_try_run, inputs=try_run, outputs=img_output)
|
300 |
|
301 |
+
async def process_model(input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, dynamic_axes, try_run):
|
302 |
+
|
303 |
|
304 |
global task_counter
|
305 |
task_counter += 1
|
|
|
348 |
onnx_path = None
|
349 |
mnn_path = None
|
350 |
# 调用重命名后的函数
|
351 |
+
async for result in _process_model(model_input, tilesize if tilesize>0 else 64, output_dir, task_counter, fp16, onnxsim, opset, dynamic_axes):
|
352 |
+
|
353 |
if isinstance(result, tuple) and len(result) == 3:
|
354 |
onnx_path, mnn_path, process_log = result
|
355 |
yield onnx_path, mnn_path, log+process_log, None
|
|
|
362 |
if mnn_path:
|
363 |
if try_run:
|
364 |
print_log(task_counter, f'测试模型: {mnn_path}', '开始')
|
365 |
+
processed_image_np, load_time, infer_time = modelTest_for_gradio(mnn_path, "./sample.jpg", tilesize if tilesize>0 and dynamic_axes else 0, 0)
|
366 |
processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
|
367 |
# processed_image_pil = Image.fromarray(processed_image_np)
|
368 |
yield onnx_path, mnn_path, log+process_log+f"MNNSR 加载模型用时 {load_time:.4f} 秒, 推理({tilesize} px)用时 {infer_time:.4f} 秒", processed_image_pil
|
|
|
372 |
|
373 |
convert_btn.click(
|
374 |
process_model,
|
375 |
+
inputs=[input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, dynamic_axes, try_run],
|
376 |
outputs=[onnx_output, mnn_output, log_box, img_output],
|
377 |
api_name="convert_nmm_model"
|
378 |
)
|
pth2onnx.py
CHANGED
@@ -6,7 +6,7 @@ import onnx
|
|
6 |
from spandrel import ImageModelDescriptor, ModelLoader
|
7 |
from onnxsim import simplify
|
8 |
|
9 |
-
def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11):
|
10 |
"""
|
11 |
Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
|
12 |
|
@@ -86,6 +86,14 @@ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tiles
|
|
86 |
print(f"ONNX model exporting...")
|
87 |
try:
|
88 |
# Export the model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
torch.onnx.export(
|
90 |
torch_model, # The model instance
|
91 |
example_input, # An example input tensor
|
@@ -95,10 +103,7 @@ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tiles
|
|
95 |
do_constant_folding=True, # Whether to execute constant folding for optimization
|
96 |
input_names=['input'], # The model's input names
|
97 |
output_names=['output'], # The model's output names
|
98 |
-
dynamic_axes=
|
99 |
-
"input": {0: "batch_size", 2: "height", 3: "width"}, # Batch, H, W can vary
|
100 |
-
"output": {0: "batch_size", 2: "height", 3: "width"},# Batch, H, W can vary
|
101 |
-
}
|
102 |
)
|
103 |
print(f"ONNX model export successful: {onnx_path}")
|
104 |
|
@@ -133,16 +138,18 @@ if __name__ == "__main__":
|
|
133 |
parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
|
134 |
parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
|
135 |
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
|
|
|
136 |
args = parser.parse_args()
|
137 |
|
138 |
success = convert_pth_to_onnx(
|
139 |
-
pth_path=args.
|
140 |
-
onnx_path=args.
|
141 |
channel=args.channel,
|
142 |
tilesize=args.tilesize,
|
143 |
-
use_fp16=args.
|
144 |
-
simplify_model=args.
|
145 |
opset=args.opset,
|
|
|
146 |
)
|
147 |
|
148 |
if success:
|
|
|
6 |
from spandrel import ImageModelDescriptor, ModelLoader
|
7 |
from onnxsim import simplify
|
8 |
|
9 |
+
def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11, dynamic_axes: bool = True):
|
10 |
"""
|
11 |
Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
|
12 |
|
|
|
86 |
print(f"ONNX model exporting...")
|
87 |
try:
|
88 |
# Export the model
|
89 |
+
if dynamic_axes:
|
90 |
+
axes = {
|
91 |
+
"input": {2: "height", 3: "width"},
|
92 |
+
"output": {2: "height", 3: "width"},
|
93 |
+
}
|
94 |
+
else:
|
95 |
+
axes = {}
|
96 |
+
|
97 |
torch.onnx.export(
|
98 |
torch_model, # The model instance
|
99 |
example_input, # An example input tensor
|
|
|
103 |
do_constant_folding=True, # Whether to execute constant folding for optimization
|
104 |
input_names=['input'], # The model's input names
|
105 |
output_names=['output'], # The model's output names
|
106 |
+
dynamic_axes=axes
|
|
|
|
|
|
|
107 |
)
|
108 |
print(f"ONNX model export successful: {onnx_path}")
|
109 |
|
|
|
138 |
parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
|
139 |
parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
|
140 |
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
|
141 |
+
parser.add_argument('--fixed_axes', action='store_true', help='Use dynamic axes.')
|
142 |
args = parser.parse_args()
|
143 |
|
144 |
success = convert_pth_to_onnx(
|
145 |
+
pth_path=args.pthpath,
|
146 |
+
onnx_path=args.onnxpath,
|
147 |
channel=args.channel,
|
148 |
tilesize=args.tilesize,
|
149 |
+
use_fp16=args.fp16,
|
150 |
+
simplify_model=args.simplify,
|
151 |
opset=args.opset,
|
152 |
+
dynamic_axes= not args.fixed_axes,
|
153 |
)
|
154 |
|
155 |
if success:
|