tumuyan2 commited on
Commit
2152dfa
·
verified ·
1 Parent(s): 1d965a8

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +252 -0
  2. onnx2mnn2.bat +121 -0
  3. pth2onnx.bat +36 -0
  4. pth2onnx.py +149 -0
  5. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import os
4
+ import subprocess
5
+ from typing import Union
6
+ from datetime import datetime
7
+ from pth2onnx import convert_pth_to_onnx
8
+ from urllib.parse import urlparse
9
+ from typing import Optional
10
+
11
+ # 新增日志开关
12
+ log_to_terminal = True
13
+
14
+ # 新增全局任务计数器
15
+ task_counter = 0
16
+
17
+ # 新增日志函数
18
+ def print_log(task_id, stage, status):
19
+ if log_to_terminal:
20
+ print(f"任务{task_id}: [{status}] {stage}")
21
+
22
+ # 使用 MNN 库自带的转换工具
23
+ def convertmnn(onnx_path: str, mnn_path: str, fp16=False):
24
+ param = ['mnnconvert', '-f', 'ONNX', '--modelFile', onnx_path, '--MNNModel', mnn_path, '--bizCode', 'biz', '--info', '--detectSparseSpeedUp']
25
+ if fp16:
26
+ param.append('--fp16')
27
+ subprocess.run(param, check=True)
28
+
29
+ def download_file(url: str, save_path: str):
30
+ response = requests.get(url)
31
+ with open(save_path, 'wb') as f:
32
+ f.write(response.content)
33
+
34
+ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
35
+ """
36
+ 从URL下载文件到指定文件夹,并进行文件大小检查
37
+
38
+ Args:
39
+ url: 要下载的文件URL
40
+ folder: 保存文件的目标文件夹路径
41
+ filesize_max: 最大允许文件大小(字节),超过此值将中断下载并删除文件
42
+ filesize_min: 最小允许文件大小(字节),小于此值将删除文件并返回None
43
+
44
+ Returns:
45
+ 成功下载的文件名,如果下载失败或文件大小不符合要求则返回None
46
+ """
47
+ # 解析URL获取文件名
48
+ parsed_url = urlparse(url)
49
+ filename = os.path.basename(parsed_url.path)
50
+ if not filename:
51
+ return None # 无法从URL获取文件名
52
+
53
+ # 确保目标文件夹存在
54
+ os.makedirs(folder, exist_ok=True)
55
+ save_path = os.path.join(folder, filename)
56
+
57
+ try:
58
+ # 发送HTTP请求,流式下载
59
+ with requests.get(url, stream=True, timeout=10) as response:
60
+ response.raise_for_status() # 检查HTTP错误状态
61
+
62
+ # 获取文件总大小(如果服务器提供)
63
+ total_size = int(response.headers.get('content-length', 0))
64
+ if total_size > filesize_max:
65
+ return None # 文件大小超过最大值,不下载
66
+
67
+ downloaded_size = 0
68
+ with open(save_path, 'wb') as file:
69
+ for chunk in response.iter_content(chunk_size=8192):
70
+ if chunk: # 过滤空块
71
+ downloaded_size += len(chunk)
72
+ # 检查是否超过最大允许大小
73
+ if downloaded_size > filesize_max:
74
+ file.close()
75
+ os.remove(save_path)
76
+ return None
77
+ file.write(chunk)
78
+
79
+ # 下载完成后检查最小文件大小
80
+ if os.path.getsize(save_path) < filesize_min:
81
+ os.remove(save_path)
82
+ return None
83
+
84
+ return filename
85
+
86
+ except Exception as e:
87
+ # 发生异常时清理文件
88
+ if os.path.exists(save_path):
89
+ os.remove(save_path)
90
+ return None
91
+
92
+ # 原 process_model 函数重命名为 _process_model
93
+ async def _process_model(model_input: Union[str, gr.File], tilesize: int, log_box: gr.Textbox, output_dir: str):
94
+ global task_counter
95
+ task_id = task_counter
96
+ log = ('初始化日志记录...')
97
+ print_log(task_id, '初始化日志记录', '开始')
98
+ yield [],[], log
99
+
100
+ # 处理输入模型
101
+ if isinstance(model_input, str): # 处理链接
102
+ if model_input.startswith(('http://', 'https://')):
103
+ log += ( f'正在下载模型文件: {model_input}')
104
+ print_log(task_id, f'正在下载模型文件: {model_input}', '开始')
105
+ yield [],[], log
106
+
107
+ # 下载文件到output文件夹
108
+ filename = download_file2folder(
109
+ url=model_input,
110
+ folder=output_dir,
111
+ filesize_max=200*1024*1024, # 200MB
112
+ filesize_min=1024 # 1KB
113
+ )
114
+ input_path = os.path.join(output_dir, filename)
115
+ log += ( f'模型文件已下载到: {input_path}')
116
+ print_log(task_id, f'模型文件已下载到: {input_path}', '完成')
117
+ yield [],[], log
118
+ else:
119
+ input_path = model_input
120
+ log += ( f'使用本地文件: {input_path}')
121
+ print_log(task_id, f'使用本地文件: {input_path}', '开始')
122
+ yield [],[], log
123
+ else:
124
+ input_path = model_input.name
125
+ log += ( f'已上传模型文件: {input_path}')
126
+ print_log(task_id, f'已上传模型文件: {input_path}', '开始')
127
+ yield [],[], log
128
+
129
+
130
+ if not input_path:
131
+ log += ( f'未获得正确的模型文件')
132
+ print_log(task_id, f'未获得正确的模型文件', '错误')
133
+ yield [],[], log
134
+ return
135
+
136
+
137
+ if input_path.endswith('.onnx'):
138
+ onnx_path = input_path
139
+ log += ( '输入已是 ONNX 模型,直接使用...')
140
+ print_log(task_id, '输入已是 ONNX 模型,直接使用', '开始')
141
+ yield [],[], log
142
+ else:
143
+ print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
144
+ onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir)
145
+ if onnx_path:
146
+ log += ( f'成功生成ONNX模型: {onnx_path}')
147
+ print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
148
+ else:
149
+ log += ( '生成ONNX模型失败')
150
+ print_log(task_id, '生成ONNX模型', '错误')
151
+ yield [], [], log
152
+ return
153
+
154
+
155
+ # 转换为 MNN 模型
156
+ output_name= os.path.splitext(os.path.basename(onnx_path))[0]
157
+ mnn_path = os.path.join(output_dir, f'{output_name}.mnn')
158
+ try:
159
+ log += ( '正在将 ONNX 模型转换为 MNN 格式...')
160
+ print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
161
+ convertmnn(onnx_path, mnn_path)
162
+ yield onnx_path,[], log
163
+ except Exception as e:
164
+ log += ( f'转换 MNN 模型时出错: {str(e)}')
165
+ print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
166
+ yield onnx_path,[], log
167
+
168
+ print_log(task_id, '模型转换任务完成', '完成')
169
+
170
+ # 转换为 MNN 模型后对文件检查
171
+ if os.path.exists(mnn_path) and os.path.getsize(mnn_path) > 1024: # 1KB = 1024 bytes
172
+ log += ( f'MNN 模型已保存到: {mnn_path}')
173
+ else:
174
+ log += ( 'MNN 模型生成失败或文件大小不足1KB')
175
+ mnn_path = None
176
+
177
+ yield onnx_path, mnn_path, log
178
+
179
+ with gr.Blocks() as demo:
180
+ gr.Markdown("# 模型转换工具")
181
+ with gr.Row():
182
+ with gr.Column():
183
+ input_type = gr.Radio(['模型链接', '上传模型文件'], label='输入类型')
184
+ url_input = gr.Textbox(label='模型链接')
185
+ file_input = gr.File(label='上传模型文件', visible=False)
186
+
187
+
188
+ def show_input(input_type):
189
+ if input_type == '模型链接':
190
+ return gr.update(visible=True), gr.update(visible=False)
191
+ else:
192
+ return gr.update(visible=False), gr.update(visible=True)
193
+
194
+ input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
195
+
196
+ tilesize = gr.Number(label="Tilesize", value=0, precision=0)
197
+ convert_btn = gr.Button("开始转换")
198
+ with gr.Column():
199
+ log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
200
+ with gr.Row():
201
+ onnx_output = gr.File(label="ONNX 模型输出")
202
+ mnn_output = gr.File(label="MNN 模型输出")
203
+
204
+ async def process_model(input_type, url_input, file_input, tilesize, log_box):
205
+ if input_type == '模型链接' and url_input:
206
+ model_input = url_input
207
+ elif input_type == '上传模型文件' and file_input:
208
+ model_input = file_input
209
+ else:
210
+ # 改为通过yield返回错误日志
211
+ log = '\n请选择输入类型并提供有效的输入!'
212
+ yield None, None, log
213
+ return
214
+
215
+ # 创建不重名的输出目录
216
+ global task_counter
217
+ task_counter += 1
218
+ output_dir = os.path.join(os.getcwd(), f"output_{task_counter}")
219
+ os.makedirs(output_dir, exist_ok=True)
220
+
221
+ onnx_path = None
222
+ mnn_path = None
223
+ # 调用重命名后的函数
224
+ async for result in _process_model(model_input, int(tilesize), log_box, output_dir):
225
+ if isinstance(result, tuple) and len(result) == 3:
226
+ onnx_path, mnn_path, log_box = result
227
+ elif isinstance(result, tuple) and len(result) == 2:
228
+ # 处理纯日志yield
229
+ _, process_log = result
230
+ yield None, None, process_log
231
+ yield onnx_path, mnn_path, log_box
232
+
233
+ convert_btn.click(
234
+ process_model,
235
+ inputs=[input_type, url_input, file_input, tilesize, log_box],
236
+ outputs=[onnx_output, mnn_output, log_box],
237
+ api_name="convert_model"
238
+ )
239
+
240
+ # 将示例移至底部并包裹在列组件中
241
+ examples_column = gr.Column(visible=True)
242
+ with examples_column:
243
+ examples = [
244
+ ["模型链接", "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"],
245
+ ["模型链接", "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"],
246
+ ["模型链接", "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"],
247
+ ["模型链接", "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"],
248
+ ["模型链接", "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"]
249
+ ]
250
+ example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
251
+
252
+ demo.launch()
onnx2mnn2.bat ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ setlocal enabledelayedexpansion
3
+
4
+ rem ���̶�shape��fp16����ͼƬ����
5
+ rem ����Ƿ��в�������
6
+ if "%~1"=="" (
7
+ rem ������������ʾ�û������ļ�·��
8
+ set /p onnx_path="������ .onnx �ļ���·��: "
9
+ ) else (
10
+ rem ��������ʹ�õ�һ��������Ϊ�ļ�·��
11
+ set "onnx_path=%~1"
12
+ cd /d %~dp0
13
+
14
+ )
15
+
16
+ :main
17
+ rem ����ļ��Ƿ�������Ƿ��� .onnx ��β
18
+ if not exist "!onnx_path!" (
19
+ echo ����: �ļ������ڣ�����������·����
20
+ goto loop
21
+ )
22
+
23
+ set "onnx_ext=!onnx_path:~-5!"
24
+ if /i "!onnx_ext!" neq ".onnx" (
25
+ echo ����: �ļ����� .onnx ��ʽ������������·����
26
+ goto loop
27
+ )
28
+
29
+ rem ȥ��Ŀ¼·���ͺ�׺��
30
+ for %%f in ("!onnx_path!") do set "onnx_name=%%~nf"
31
+ echo ��ǰ����Ŀ¼��: %cd%
32
+ rem ִ�� onnx2ncnn.exe
33
+ MNNConvert -f ONNX --modelFile "!onnx_path!" --MNNModel "!onnx_name!.mnn" --bizCode biz --fp16 --info --detectSparseSpeedUp
34
+
35
+ rem ��� fp16 �ļ��Ƿ����
36
+ if exist "!onnx_name!.mnn" (
37
+ rem ����
38
+ ) else (
39
+ echo δ���ģ���ļ�������������·����
40
+ goto loop
41
+ )
42
+
43
+
44
+
45
+ :next_step
46
+ rem ��ʾ�û�ѡ����һ������
47
+ echo.
48
+ echo ��ѡ����һ������:
49
+ echo 1. ��һ��ѭ����Ĭ�ϣ�
50
+ echo 2. ����ģ�ͣ�ʹ�� test.png��
51
+ echo 3. �����Զ���ͼƬ·��������ģ��
52
+ set /p choice="������ѡ�� (1/2/3): "
53
+ if "%choice%"=="" set "choice=1"
54
+
55
+ if "%choice%"=="1" (
56
+ goto loop
57
+ ) else if "%choice%"=="2" (
58
+ set "test_image=test.png"
59
+ if exist "!test_image!" (
60
+ mnnsr-ncnn.exe -i "!test_image!" -o "!onnx_name!.png" -m "!onnx_name!.mnn" -s 0
61
+ ) else (
62
+ echo ����: �ļ� "!test_image!" �����ڣ�������ѡ��
63
+ goto next_step
64
+ )
65
+ ) else if "%choice%"=="3" (
66
+ set /p custom_image="�������Զ���ͼƬ·��: "
67
+ if exist "!custom_image!" (
68
+ mnnsr-ncnn.exe -i "!custom_image!" -o "!onnx_name!.png" -m "!onnx_name!.mnn" -s 0
69
+ ) else (
70
+ echo ����: �ļ� "!custom_image!" �����ڣ�������ѡ��
71
+ goto next_step
72
+ )
73
+ ) else (
74
+ echo ��Ч��ѡ�������ѡ��
75
+ goto next_step
76
+ )
77
+
78
+ rem �������� PNG �ļ��Ƿ����
79
+ if exist "!onnx_name!.png" (
80
+ echo ����ļ����ɳɹ�: "!onnx_name!.png"
81
+ start "" "!onnx_name!.png"
82
+ ) else (
83
+ echo ��֤ʧ��: ����ļ������ڣ���ѡ���Ƿ�ɾ��ģ���ļ���
84
+ set /p delete_model="�Ƿ���ģ���ļ� "!onnx_name!.mnn" ? (y/n, Ĭ��n): "
85
+ if /i "!delete_model!"=="y" (
86
+ echo ��������
87
+ ) else (
88
+ del "!onnx_name!.mnn"
89
+ echo ģ���ļ���ɾ����
90
+ )
91
+ goto loop
92
+ )
93
+
94
+ rem ѯ���û�ѡ����һ������
95
+ echo.
96
+ echo ��ѡ����һ������:
97
+ echo 1. ������ļ��У�Ĭ�ϣ�
98
+ echo 2. ɾ��ģ���ļ��Ͳ���ͼ
99
+ echo 3. ��һ��ѭ��
100
+ set /p next_choice="������ѡ�� (1/2/3): "
101
+ if "%next_choice%"=="" set "next_choice=1"
102
+
103
+ if "%next_choice%"=="1" (
104
+ start "" .
105
+ ) else if "%next_choice%"=="2" (
106
+ del "!onnx_name!.png"
107
+ del "!onnx_name!.mnn"
108
+ echo ģ���ļ��Ͳ���ͼ��ɾ����
109
+ ) else if "%next_choice%"=="3" (
110
+ goto loop
111
+ ) else (
112
+ echo ��Ч��ѡ�������ѡ��
113
+ goto next_choice
114
+ )
115
+
116
+
117
+
118
+ :loop
119
+ echo ====================================
120
+ set /p onnx_path="������ .onnx �ļ���·��: "
121
+ goto main
pth2onnx.bat ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ setlocal enabledelayedexpansion
3
+
4
+ rem ����Ƿ��в�������
5
+ if "%~1"=="" (
6
+ rem ������������ʾ�û������ļ�·��
7
+ set /p pth_path="������ .pth �ļ���·��: "
8
+ ) else (
9
+ rem ��������ʹ�õ�һ��������Ϊ�ļ�·��
10
+ set "pth_path=%~1"
11
+ cd /d %~dp0
12
+
13
+ )
14
+
15
+ :main
16
+ rem ����ļ��Ƿ�������Ƿ��� .pth ��β
17
+ if not exist "!pth_path!" (
18
+ echo ����: �ļ������ڣ�����������·����
19
+ goto loop
20
+ )
21
+
22
+ set "ext=!pth_path:~-4!"
23
+ if /i "!ext!" neq ".pth" (
24
+ echo ����: �ļ����� .pth ��ʽ������������·����
25
+ goto loop
26
+ )
27
+
28
+ echo ��ǰ����Ŀ¼��: %cd%
29
+ rem ִ�� pth2onnx.py
30
+ rem python .\pth2onnx.py "!pth_path!" --fp16 --simplify --channel !channel!
31
+ python .\pth2onnx.py "!pth_path!" --fp16
32
+
33
+ :loop
34
+ echo ====================================
35
+ set /p pth_path="������ .pth �ļ���·��: "
36
+ goto main
pth2onnx.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import re
5
+ import onnx
6
+ from spandrel import ImageModelDescriptor, ModelLoader
7
+ from onnxsim import simplify
8
+
9
+ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None):
10
+ """
11
+ Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
12
+
13
+ Args:
14
+ pth_path: Path to the input .pth model file.
15
+ onnx_path: Path to save the output .onnx file.
16
+ channel: Number of input channels for the model.
17
+ use_fp16: Boolean to determine if the model should be converted to half precision.
18
+ simplify_model: Boolean to determine if the ONNX model should be simplified.
19
+ """
20
+
21
+ print(f"Loading model from: {pth_path}")
22
+ try:
23
+ # Use Spandrel to load the model architecture and state dict
24
+ model_descriptor = ModelLoader().load_from_file(pth_path)
25
+
26
+ # Ensure it's the expected type from Spandrel
27
+ if not isinstance(model_descriptor, ImageModelDescriptor):
28
+ print(f"Error: Expected ImageModelDescriptor, but got {type(model_descriptor)}")
29
+ print("Please ensure the .pth file is compatible with Spandrel's loading mechanism.")
30
+ return False
31
+
32
+
33
+ # Get the underlying torch.nn.Module
34
+ torch_model = model_descriptor.model
35
+
36
+ # Set the model to evaluation mode (important for dropout, batchnorm layers)
37
+ torch_model.eval()
38
+
39
+ except Exception as e:
40
+ print(f"Error loading model: {e}")
41
+ return False
42
+
43
+ if channel == 0:
44
+ channel = model_descriptor.input_channels
45
+ if tilesize<1:
46
+ tilesize = 64
47
+ example_input = torch.randn(1, channel, tilesize, tilesize)
48
+ print("Model input channels:", channel, "tile size:", tilesize)
49
+
50
+ if use_fp16:
51
+ if torch.cuda.is_available():
52
+ torch_model.cuda()
53
+ example_input = example_input.cuda()
54
+ else:
55
+ print("Warning: no CUDA device")
56
+ torch_model.half()
57
+ example_input = example_input.half() # 转换为半精度输入
58
+ print(f"Model loaded successfully: {type(torch_model).__name__}")
59
+
60
+ if output_folder:
61
+ os.makedirs(output_folder, exist_ok=True)
62
+
63
+ if onnx_path is None:
64
+ base_path, _ = os.path.splitext(pth_path)
65
+ if output_folder:
66
+ base_path = os.path.join(output_folder, os.path.basename(base_path))
67
+
68
+ scale = model_descriptor.scale
69
+ # 判断 pth_path 的文件名是否包含 xs 或者 sx,x 为大小写字母 x,s 为 int scale
70
+ filename = os.path.basename(pth_path).upper()
71
+ pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
72
+ if re.search(pattern, filename):
73
+ print(f'文件名 {filename} 包含匹配模式。')
74
+ else:
75
+ base_path = f"{base_path}-x{scale}"
76
+
77
+ onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
78
+
79
+ # 处理相对路径情况
80
+ # elif output_folder and not os.path.isabs(onnx_path):
81
+ elif output_folder:
82
+ onnx_path = os.path.join(output_folder, onnx_path)
83
+
84
+ print(f"output_folder: {output_folder}, onnx_path: {onnx_path}")
85
+
86
+ try:
87
+ # Export the model
88
+ torch.onnx.export(
89
+ torch_model, # The model instance
90
+ example_input, # An example input tensor
91
+ onnx_path, # Where to save the model (file path)
92
+ export_params=True, # Store the trained parameter weights inside the model file
93
+ opset_version=11, # The ONNX version to export the model to (choose based on target runtime)
94
+ do_constant_folding=True, # Whether to execute constant folding for optimization
95
+ input_names=['input'], # The model's input names
96
+ output_names=['output'], # The model's output names
97
+ dynamic_axes={ # Allow variable input/output dimensions
98
+ "input": {0: "batch_size", 2: "height", 3: "width"}, # Batch, H, W can vary
99
+ "output": {0: "batch_size", 2: "height", 3: "width"},# Batch, H, W can vary
100
+ }
101
+ )
102
+ print(f"ONNX export successful: {onnx_path}")
103
+
104
+ # Optional: Simplify the ONNX model
105
+ if simplify_model:
106
+ model = onnx.load(onnx_path)
107
+ model_simplified, _ = simplify(model)
108
+ onnx.save(model_simplified, onnx_path)
109
+ print(f"ONNX model simplified successfully: {onnx_path}")
110
+
111
+ # 添加文件验证逻辑
112
+ if os.path.exists(onnx_path):
113
+ file_size = os.path.getsize(onnx_path)
114
+ if file_size > min_size:
115
+ return onnx_path
116
+
117
+ os.remove(onnx_path)
118
+ print(f"文件大小不足 {min_size} 字节,已删除无效文件")
119
+ return ""
120
+
121
+ except Exception as e:
122
+ print(f"导出失败: {e}")
123
+ return ""
124
+
125
+ if __name__ == "__main__":
126
+ import argparse
127
+ parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX model.')
128
+ parser.add_argument('--pthpath', type=str, required=True, help='Path to the PyTorch model file.')
129
+ parser.add_argument('--onnxpath', type=str, default=None, help='Path to save the ONNX model file.')
130
+ parser.add_argument('--channel', type=int, default=0, help='Channel parameter.')
131
+ parser.add_argument('--tilesize', type=int, default=0, help='Tilesize parameter.')
132
+ parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
133
+ parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
134
+ args = parser.parse_args()
135
+
136
+ success = convert_pth_to_onnx(
137
+ pth_path=args.pth_path,
138
+ onnx_path=args.onnx_path,
139
+ channel=args.channel,
140
+ tilesize=args.tilesize,
141
+ use_fp16=args.use_fp16,
142
+ simplify_model=args.simplify_model
143
+ )
144
+
145
+ if success:
146
+ print("Conversion process finished.")
147
+ else:
148
+ print("Conversion process failed.")
149
+ exit(1) # Exit with error code
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spandrel
2
+ torch
3
+ pnnx
4
+ onnx
5
+ mnn
6
+ gradio