Upload 5 files
Browse files- app.py +252 -0
- onnx2mnn2.bat +121 -0
- pth2onnx.bat +36 -0
- pth2onnx.py +149 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import requests
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
from typing import Union
|
6 |
+
from datetime import datetime
|
7 |
+
from pth2onnx import convert_pth_to_onnx
|
8 |
+
from urllib.parse import urlparse
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
# 新增日志开关
|
12 |
+
log_to_terminal = True
|
13 |
+
|
14 |
+
# 新增全局任务计数器
|
15 |
+
task_counter = 0
|
16 |
+
|
17 |
+
# 新增日志函数
|
18 |
+
def print_log(task_id, stage, status):
|
19 |
+
if log_to_terminal:
|
20 |
+
print(f"任务{task_id}: [{status}] {stage}")
|
21 |
+
|
22 |
+
# 使用 MNN 库自带的转换工具
|
23 |
+
def convertmnn(onnx_path: str, mnn_path: str, fp16=False):
|
24 |
+
param = ['mnnconvert', '-f', 'ONNX', '--modelFile', onnx_path, '--MNNModel', mnn_path, '--bizCode', 'biz', '--info', '--detectSparseSpeedUp']
|
25 |
+
if fp16:
|
26 |
+
param.append('--fp16')
|
27 |
+
subprocess.run(param, check=True)
|
28 |
+
|
29 |
+
def download_file(url: str, save_path: str):
|
30 |
+
response = requests.get(url)
|
31 |
+
with open(save_path, 'wb') as f:
|
32 |
+
f.write(response.content)
|
33 |
+
|
34 |
+
def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
|
35 |
+
"""
|
36 |
+
从URL下载文件到指定文件夹,并进行文件大小检查
|
37 |
+
|
38 |
+
Args:
|
39 |
+
url: 要下载的文件URL
|
40 |
+
folder: 保存文件的目标文件夹路径
|
41 |
+
filesize_max: 最大允许文件大小(字节),超过此值将中断下载并删除文件
|
42 |
+
filesize_min: 最小允许文件大小(字节),小于此值将删除文件并返回None
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
成功下载的文件名,如果下载失败或文件大小不符合要求则返回None
|
46 |
+
"""
|
47 |
+
# 解析URL获取文件名
|
48 |
+
parsed_url = urlparse(url)
|
49 |
+
filename = os.path.basename(parsed_url.path)
|
50 |
+
if not filename:
|
51 |
+
return None # 无法从URL获取文件名
|
52 |
+
|
53 |
+
# 确保目标文件夹存在
|
54 |
+
os.makedirs(folder, exist_ok=True)
|
55 |
+
save_path = os.path.join(folder, filename)
|
56 |
+
|
57 |
+
try:
|
58 |
+
# 发送HTTP请求,流式下载
|
59 |
+
with requests.get(url, stream=True, timeout=10) as response:
|
60 |
+
response.raise_for_status() # 检查HTTP错误状态
|
61 |
+
|
62 |
+
# 获取文件总大小(如果服务器提供)
|
63 |
+
total_size = int(response.headers.get('content-length', 0))
|
64 |
+
if total_size > filesize_max:
|
65 |
+
return None # 文件大小超过最大值,不下载
|
66 |
+
|
67 |
+
downloaded_size = 0
|
68 |
+
with open(save_path, 'wb') as file:
|
69 |
+
for chunk in response.iter_content(chunk_size=8192):
|
70 |
+
if chunk: # 过滤空块
|
71 |
+
downloaded_size += len(chunk)
|
72 |
+
# 检查是否超过最大允许大小
|
73 |
+
if downloaded_size > filesize_max:
|
74 |
+
file.close()
|
75 |
+
os.remove(save_path)
|
76 |
+
return None
|
77 |
+
file.write(chunk)
|
78 |
+
|
79 |
+
# 下载完成后检查最小文件大小
|
80 |
+
if os.path.getsize(save_path) < filesize_min:
|
81 |
+
os.remove(save_path)
|
82 |
+
return None
|
83 |
+
|
84 |
+
return filename
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
# 发生异常时清理文件
|
88 |
+
if os.path.exists(save_path):
|
89 |
+
os.remove(save_path)
|
90 |
+
return None
|
91 |
+
|
92 |
+
# 原 process_model 函数重命名为 _process_model
|
93 |
+
async def _process_model(model_input: Union[str, gr.File], tilesize: int, log_box: gr.Textbox, output_dir: str):
|
94 |
+
global task_counter
|
95 |
+
task_id = task_counter
|
96 |
+
log = ('初始化日志记录...')
|
97 |
+
print_log(task_id, '初始化日志记录', '开始')
|
98 |
+
yield [],[], log
|
99 |
+
|
100 |
+
# 处理输入模型
|
101 |
+
if isinstance(model_input, str): # 处理链接
|
102 |
+
if model_input.startswith(('http://', 'https://')):
|
103 |
+
log += ( f'正在下载模型文件: {model_input}')
|
104 |
+
print_log(task_id, f'正在下载模型文件: {model_input}', '开始')
|
105 |
+
yield [],[], log
|
106 |
+
|
107 |
+
# 下载文件到output文件夹
|
108 |
+
filename = download_file2folder(
|
109 |
+
url=model_input,
|
110 |
+
folder=output_dir,
|
111 |
+
filesize_max=200*1024*1024, # 200MB
|
112 |
+
filesize_min=1024 # 1KB
|
113 |
+
)
|
114 |
+
input_path = os.path.join(output_dir, filename)
|
115 |
+
log += ( f'模型文件已下载到: {input_path}')
|
116 |
+
print_log(task_id, f'模型文件已下载到: {input_path}', '完成')
|
117 |
+
yield [],[], log
|
118 |
+
else:
|
119 |
+
input_path = model_input
|
120 |
+
log += ( f'使用本地文件: {input_path}')
|
121 |
+
print_log(task_id, f'使用本地文件: {input_path}', '开始')
|
122 |
+
yield [],[], log
|
123 |
+
else:
|
124 |
+
input_path = model_input.name
|
125 |
+
log += ( f'已上传模型文件: {input_path}')
|
126 |
+
print_log(task_id, f'已上传模型文件: {input_path}', '开始')
|
127 |
+
yield [],[], log
|
128 |
+
|
129 |
+
|
130 |
+
if not input_path:
|
131 |
+
log += ( f'未获得正确的模型文件')
|
132 |
+
print_log(task_id, f'未获得正确的模型文件', '错误')
|
133 |
+
yield [],[], log
|
134 |
+
return
|
135 |
+
|
136 |
+
|
137 |
+
if input_path.endswith('.onnx'):
|
138 |
+
onnx_path = input_path
|
139 |
+
log += ( '输入已是 ONNX 模型,直接使用...')
|
140 |
+
print_log(task_id, '输入已是 ONNX 模型,直接使用', '开始')
|
141 |
+
yield [],[], log
|
142 |
+
else:
|
143 |
+
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
|
144 |
+
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir)
|
145 |
+
if onnx_path:
|
146 |
+
log += ( f'成功生成ONNX模型: {onnx_path}')
|
147 |
+
print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
|
148 |
+
else:
|
149 |
+
log += ( '生成ONNX模型失败')
|
150 |
+
print_log(task_id, '生成ONNX模型', '错误')
|
151 |
+
yield [], [], log
|
152 |
+
return
|
153 |
+
|
154 |
+
|
155 |
+
# 转换为 MNN 模型
|
156 |
+
output_name= os.path.splitext(os.path.basename(onnx_path))[0]
|
157 |
+
mnn_path = os.path.join(output_dir, f'{output_name}.mnn')
|
158 |
+
try:
|
159 |
+
log += ( '正在将 ONNX 模型转换为 MNN 格式...')
|
160 |
+
print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
|
161 |
+
convertmnn(onnx_path, mnn_path)
|
162 |
+
yield onnx_path,[], log
|
163 |
+
except Exception as e:
|
164 |
+
log += ( f'转换 MNN 模型时出错: {str(e)}')
|
165 |
+
print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
|
166 |
+
yield onnx_path,[], log
|
167 |
+
|
168 |
+
print_log(task_id, '模型转换任务完成', '完成')
|
169 |
+
|
170 |
+
# 转换为 MNN 模型后对文件检查
|
171 |
+
if os.path.exists(mnn_path) and os.path.getsize(mnn_path) > 1024: # 1KB = 1024 bytes
|
172 |
+
log += ( f'MNN 模型已保存到: {mnn_path}')
|
173 |
+
else:
|
174 |
+
log += ( 'MNN 模型生成失败或文件大小不足1KB')
|
175 |
+
mnn_path = None
|
176 |
+
|
177 |
+
yield onnx_path, mnn_path, log
|
178 |
+
|
179 |
+
with gr.Blocks() as demo:
|
180 |
+
gr.Markdown("# 模型转换工具")
|
181 |
+
with gr.Row():
|
182 |
+
with gr.Column():
|
183 |
+
input_type = gr.Radio(['模型链接', '上传模型文件'], label='输入类型')
|
184 |
+
url_input = gr.Textbox(label='模型链接')
|
185 |
+
file_input = gr.File(label='上传模型文件', visible=False)
|
186 |
+
|
187 |
+
|
188 |
+
def show_input(input_type):
|
189 |
+
if input_type == '模型链接':
|
190 |
+
return gr.update(visible=True), gr.update(visible=False)
|
191 |
+
else:
|
192 |
+
return gr.update(visible=False), gr.update(visible=True)
|
193 |
+
|
194 |
+
input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
|
195 |
+
|
196 |
+
tilesize = gr.Number(label="Tilesize", value=0, precision=0)
|
197 |
+
convert_btn = gr.Button("开始转换")
|
198 |
+
with gr.Column():
|
199 |
+
log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
|
200 |
+
with gr.Row():
|
201 |
+
onnx_output = gr.File(label="ONNX 模型输出")
|
202 |
+
mnn_output = gr.File(label="MNN 模型输出")
|
203 |
+
|
204 |
+
async def process_model(input_type, url_input, file_input, tilesize, log_box):
|
205 |
+
if input_type == '模型链接' and url_input:
|
206 |
+
model_input = url_input
|
207 |
+
elif input_type == '上传模型文件' and file_input:
|
208 |
+
model_input = file_input
|
209 |
+
else:
|
210 |
+
# 改为通过yield返回错误日志
|
211 |
+
log = '\n请选择输入类型并提供有效的输入!'
|
212 |
+
yield None, None, log
|
213 |
+
return
|
214 |
+
|
215 |
+
# 创建不重名的输出目录
|
216 |
+
global task_counter
|
217 |
+
task_counter += 1
|
218 |
+
output_dir = os.path.join(os.getcwd(), f"output_{task_counter}")
|
219 |
+
os.makedirs(output_dir, exist_ok=True)
|
220 |
+
|
221 |
+
onnx_path = None
|
222 |
+
mnn_path = None
|
223 |
+
# 调用重命名后的函数
|
224 |
+
async for result in _process_model(model_input, int(tilesize), log_box, output_dir):
|
225 |
+
if isinstance(result, tuple) and len(result) == 3:
|
226 |
+
onnx_path, mnn_path, log_box = result
|
227 |
+
elif isinstance(result, tuple) and len(result) == 2:
|
228 |
+
# 处理纯日志yield
|
229 |
+
_, process_log = result
|
230 |
+
yield None, None, process_log
|
231 |
+
yield onnx_path, mnn_path, log_box
|
232 |
+
|
233 |
+
convert_btn.click(
|
234 |
+
process_model,
|
235 |
+
inputs=[input_type, url_input, file_input, tilesize, log_box],
|
236 |
+
outputs=[onnx_output, mnn_output, log_box],
|
237 |
+
api_name="convert_model"
|
238 |
+
)
|
239 |
+
|
240 |
+
# 将示例移至底部并包裹在列组件中
|
241 |
+
examples_column = gr.Column(visible=True)
|
242 |
+
with examples_column:
|
243 |
+
examples = [
|
244 |
+
["模型链接", "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"],
|
245 |
+
["模型链接", "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"],
|
246 |
+
["模型链接", "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"],
|
247 |
+
["模型链接", "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"],
|
248 |
+
["模型链接", "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"]
|
249 |
+
]
|
250 |
+
example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
|
251 |
+
|
252 |
+
demo.launch()
|
onnx2mnn2.bat
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
setlocal enabledelayedexpansion
|
3 |
+
|
4 |
+
rem ���̶�shape��fp16����ͼƬ����
|
5 |
+
rem ����Ƿ��в�������
|
6 |
+
if "%~1"=="" (
|
7 |
+
rem ������������ʾ�û������ļ�·��
|
8 |
+
set /p onnx_path="������ .onnx �ļ���·��: "
|
9 |
+
) else (
|
10 |
+
rem ��������ʹ�õ�һ��������Ϊ�ļ�·��
|
11 |
+
set "onnx_path=%~1"
|
12 |
+
cd /d %~dp0
|
13 |
+
|
14 |
+
)
|
15 |
+
|
16 |
+
:main
|
17 |
+
rem ����ļ��Ƿ�������Ƿ��� .onnx ��β
|
18 |
+
if not exist "!onnx_path!" (
|
19 |
+
echo ����: �ļ������ڣ�����������·����
|
20 |
+
goto loop
|
21 |
+
)
|
22 |
+
|
23 |
+
set "onnx_ext=!onnx_path:~-5!"
|
24 |
+
if /i "!onnx_ext!" neq ".onnx" (
|
25 |
+
echo ����: �ļ����� .onnx ��ʽ������������·����
|
26 |
+
goto loop
|
27 |
+
)
|
28 |
+
|
29 |
+
rem ȥ��Ŀ¼·���ͺ���
|
30 |
+
for %%f in ("!onnx_path!") do set "onnx_name=%%~nf"
|
31 |
+
echo ��ǰ����Ŀ¼��: %cd%
|
32 |
+
rem ִ�� onnx2ncnn.exe
|
33 |
+
MNNConvert -f ONNX --modelFile "!onnx_path!" --MNNModel "!onnx_name!.mnn" --bizCode biz --fp16 --info --detectSparseSpeedUp
|
34 |
+
|
35 |
+
rem ��� fp16 �ļ��Ƿ����
|
36 |
+
if exist "!onnx_name!.mnn" (
|
37 |
+
rem ����
|
38 |
+
) else (
|
39 |
+
echo δ���ģ���ļ�������������·����
|
40 |
+
goto loop
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
:next_step
|
46 |
+
rem ��ʾ�û�ѡ����һ������
|
47 |
+
echo.
|
48 |
+
echo ��ѡ����һ������:
|
49 |
+
echo 1. ��һ��ѭ����Ĭ�ϣ�
|
50 |
+
echo 2. ����ģ�ͣ�ʹ�� test.png��
|
51 |
+
echo 3. �����Զ���ͼƬ·��������ģ��
|
52 |
+
set /p choice="������ѡ�� (1/2/3): "
|
53 |
+
if "%choice%"=="" set "choice=1"
|
54 |
+
|
55 |
+
if "%choice%"=="1" (
|
56 |
+
goto loop
|
57 |
+
) else if "%choice%"=="2" (
|
58 |
+
set "test_image=test.png"
|
59 |
+
if exist "!test_image!" (
|
60 |
+
mnnsr-ncnn.exe -i "!test_image!" -o "!onnx_name!.png" -m "!onnx_name!.mnn" -s 0
|
61 |
+
) else (
|
62 |
+
echo ����: �ļ� "!test_image!" �����ڣ�������ѡ��
|
63 |
+
goto next_step
|
64 |
+
)
|
65 |
+
) else if "%choice%"=="3" (
|
66 |
+
set /p custom_image="�������Զ���ͼƬ·��: "
|
67 |
+
if exist "!custom_image!" (
|
68 |
+
mnnsr-ncnn.exe -i "!custom_image!" -o "!onnx_name!.png" -m "!onnx_name!.mnn" -s 0
|
69 |
+
) else (
|
70 |
+
echo ����: �ļ� "!custom_image!" �����ڣ�������ѡ��
|
71 |
+
goto next_step
|
72 |
+
)
|
73 |
+
) else (
|
74 |
+
echo ��Ч��ѡ�������ѡ��
|
75 |
+
goto next_step
|
76 |
+
)
|
77 |
+
|
78 |
+
rem �������� PNG �ļ��Ƿ����
|
79 |
+
if exist "!onnx_name!.png" (
|
80 |
+
echo ����ļ����ɳɹ�: "!onnx_name!.png"
|
81 |
+
start "" "!onnx_name!.png"
|
82 |
+
) else (
|
83 |
+
echo ��֤ʧ��: ����ļ������ڣ���ѡ���Ƿ�ɾ��ģ���ļ���
|
84 |
+
set /p delete_model="�Ƿ���ģ���ļ� "!onnx_name!.mnn" ? (y/n, Ĭ��n): "
|
85 |
+
if /i "!delete_model!"=="y" (
|
86 |
+
echo ��������
|
87 |
+
) else (
|
88 |
+
del "!onnx_name!.mnn"
|
89 |
+
echo ģ���ļ���ɾ����
|
90 |
+
)
|
91 |
+
goto loop
|
92 |
+
)
|
93 |
+
|
94 |
+
rem ѯ���û�ѡ����һ������
|
95 |
+
echo.
|
96 |
+
echo ��ѡ����һ������:
|
97 |
+
echo 1. ������ļ��У�Ĭ�ϣ�
|
98 |
+
echo 2. ɾ��ģ���ļ��Ͳ���ͼ
|
99 |
+
echo 3. ��һ��ѭ��
|
100 |
+
set /p next_choice="������ѡ�� (1/2/3): "
|
101 |
+
if "%next_choice%"=="" set "next_choice=1"
|
102 |
+
|
103 |
+
if "%next_choice%"=="1" (
|
104 |
+
start "" .
|
105 |
+
) else if "%next_choice%"=="2" (
|
106 |
+
del "!onnx_name!.png"
|
107 |
+
del "!onnx_name!.mnn"
|
108 |
+
echo ģ���ļ��Ͳ���ͼ��ɾ����
|
109 |
+
) else if "%next_choice%"=="3" (
|
110 |
+
goto loop
|
111 |
+
) else (
|
112 |
+
echo ��Ч��ѡ�������ѡ��
|
113 |
+
goto next_choice
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
:loop
|
119 |
+
echo ====================================
|
120 |
+
set /p onnx_path="������ .onnx �ļ���·��: "
|
121 |
+
goto main
|
pth2onnx.bat
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
setlocal enabledelayedexpansion
|
3 |
+
|
4 |
+
rem ����Ƿ��в�������
|
5 |
+
if "%~1"=="" (
|
6 |
+
rem ������������ʾ�û������ļ�·��
|
7 |
+
set /p pth_path="������ .pth �ļ���·��: "
|
8 |
+
) else (
|
9 |
+
rem ��������ʹ�õ�һ��������Ϊ�ļ�·��
|
10 |
+
set "pth_path=%~1"
|
11 |
+
cd /d %~dp0
|
12 |
+
|
13 |
+
)
|
14 |
+
|
15 |
+
:main
|
16 |
+
rem ����ļ��Ƿ�������Ƿ��� .pth ��β
|
17 |
+
if not exist "!pth_path!" (
|
18 |
+
echo ����: �ļ������ڣ�����������·����
|
19 |
+
goto loop
|
20 |
+
)
|
21 |
+
|
22 |
+
set "ext=!pth_path:~-4!"
|
23 |
+
if /i "!ext!" neq ".pth" (
|
24 |
+
echo ����: �ļ����� .pth ��ʽ������������·����
|
25 |
+
goto loop
|
26 |
+
)
|
27 |
+
|
28 |
+
echo ��ǰ����Ŀ¼��: %cd%
|
29 |
+
rem ִ�� pth2onnx.py
|
30 |
+
rem python .\pth2onnx.py "!pth_path!" --fp16 --simplify --channel !channel!
|
31 |
+
python .\pth2onnx.py "!pth_path!" --fp16
|
32 |
+
|
33 |
+
:loop
|
34 |
+
echo ====================================
|
35 |
+
set /p pth_path="������ .pth �ļ���·��: "
|
36 |
+
goto main
|
pth2onnx.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import onnx
|
6 |
+
from spandrel import ImageModelDescriptor, ModelLoader
|
7 |
+
from onnxsim import simplify
|
8 |
+
|
9 |
+
def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None):
|
10 |
+
"""
|
11 |
+
Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
pth_path: Path to the input .pth model file.
|
15 |
+
onnx_path: Path to save the output .onnx file.
|
16 |
+
channel: Number of input channels for the model.
|
17 |
+
use_fp16: Boolean to determine if the model should be converted to half precision.
|
18 |
+
simplify_model: Boolean to determine if the ONNX model should be simplified.
|
19 |
+
"""
|
20 |
+
|
21 |
+
print(f"Loading model from: {pth_path}")
|
22 |
+
try:
|
23 |
+
# Use Spandrel to load the model architecture and state dict
|
24 |
+
model_descriptor = ModelLoader().load_from_file(pth_path)
|
25 |
+
|
26 |
+
# Ensure it's the expected type from Spandrel
|
27 |
+
if not isinstance(model_descriptor, ImageModelDescriptor):
|
28 |
+
print(f"Error: Expected ImageModelDescriptor, but got {type(model_descriptor)}")
|
29 |
+
print("Please ensure the .pth file is compatible with Spandrel's loading mechanism.")
|
30 |
+
return False
|
31 |
+
|
32 |
+
|
33 |
+
# Get the underlying torch.nn.Module
|
34 |
+
torch_model = model_descriptor.model
|
35 |
+
|
36 |
+
# Set the model to evaluation mode (important for dropout, batchnorm layers)
|
37 |
+
torch_model.eval()
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
print(f"Error loading model: {e}")
|
41 |
+
return False
|
42 |
+
|
43 |
+
if channel == 0:
|
44 |
+
channel = model_descriptor.input_channels
|
45 |
+
if tilesize<1:
|
46 |
+
tilesize = 64
|
47 |
+
example_input = torch.randn(1, channel, tilesize, tilesize)
|
48 |
+
print("Model input channels:", channel, "tile size:", tilesize)
|
49 |
+
|
50 |
+
if use_fp16:
|
51 |
+
if torch.cuda.is_available():
|
52 |
+
torch_model.cuda()
|
53 |
+
example_input = example_input.cuda()
|
54 |
+
else:
|
55 |
+
print("Warning: no CUDA device")
|
56 |
+
torch_model.half()
|
57 |
+
example_input = example_input.half() # 转换为半精度输入
|
58 |
+
print(f"Model loaded successfully: {type(torch_model).__name__}")
|
59 |
+
|
60 |
+
if output_folder:
|
61 |
+
os.makedirs(output_folder, exist_ok=True)
|
62 |
+
|
63 |
+
if onnx_path is None:
|
64 |
+
base_path, _ = os.path.splitext(pth_path)
|
65 |
+
if output_folder:
|
66 |
+
base_path = os.path.join(output_folder, os.path.basename(base_path))
|
67 |
+
|
68 |
+
scale = model_descriptor.scale
|
69 |
+
# 判断 pth_path 的文件名是否包含 xs 或者 sx,x 为大小写字母 x,s 为 int scale
|
70 |
+
filename = os.path.basename(pth_path).upper()
|
71 |
+
pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
|
72 |
+
if re.search(pattern, filename):
|
73 |
+
print(f'文件名 {filename} 包含匹配模式。')
|
74 |
+
else:
|
75 |
+
base_path = f"{base_path}-x{scale}"
|
76 |
+
|
77 |
+
onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
|
78 |
+
|
79 |
+
# 处理相对路径情况
|
80 |
+
# elif output_folder and not os.path.isabs(onnx_path):
|
81 |
+
elif output_folder:
|
82 |
+
onnx_path = os.path.join(output_folder, onnx_path)
|
83 |
+
|
84 |
+
print(f"output_folder: {output_folder}, onnx_path: {onnx_path}")
|
85 |
+
|
86 |
+
try:
|
87 |
+
# Export the model
|
88 |
+
torch.onnx.export(
|
89 |
+
torch_model, # The model instance
|
90 |
+
example_input, # An example input tensor
|
91 |
+
onnx_path, # Where to save the model (file path)
|
92 |
+
export_params=True, # Store the trained parameter weights inside the model file
|
93 |
+
opset_version=11, # The ONNX version to export the model to (choose based on target runtime)
|
94 |
+
do_constant_folding=True, # Whether to execute constant folding for optimization
|
95 |
+
input_names=['input'], # The model's input names
|
96 |
+
output_names=['output'], # The model's output names
|
97 |
+
dynamic_axes={ # Allow variable input/output dimensions
|
98 |
+
"input": {0: "batch_size", 2: "height", 3: "width"}, # Batch, H, W can vary
|
99 |
+
"output": {0: "batch_size", 2: "height", 3: "width"},# Batch, H, W can vary
|
100 |
+
}
|
101 |
+
)
|
102 |
+
print(f"ONNX export successful: {onnx_path}")
|
103 |
+
|
104 |
+
# Optional: Simplify the ONNX model
|
105 |
+
if simplify_model:
|
106 |
+
model = onnx.load(onnx_path)
|
107 |
+
model_simplified, _ = simplify(model)
|
108 |
+
onnx.save(model_simplified, onnx_path)
|
109 |
+
print(f"ONNX model simplified successfully: {onnx_path}")
|
110 |
+
|
111 |
+
# 添加文件验证逻辑
|
112 |
+
if os.path.exists(onnx_path):
|
113 |
+
file_size = os.path.getsize(onnx_path)
|
114 |
+
if file_size > min_size:
|
115 |
+
return onnx_path
|
116 |
+
|
117 |
+
os.remove(onnx_path)
|
118 |
+
print(f"文件大小不足 {min_size} 字节,已删除无效文件")
|
119 |
+
return ""
|
120 |
+
|
121 |
+
except Exception as e:
|
122 |
+
print(f"导出失败: {e}")
|
123 |
+
return ""
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
import argparse
|
127 |
+
parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX model.')
|
128 |
+
parser.add_argument('--pthpath', type=str, required=True, help='Path to the PyTorch model file.')
|
129 |
+
parser.add_argument('--onnxpath', type=str, default=None, help='Path to save the ONNX model file.')
|
130 |
+
parser.add_argument('--channel', type=int, default=0, help='Channel parameter.')
|
131 |
+
parser.add_argument('--tilesize', type=int, default=0, help='Tilesize parameter.')
|
132 |
+
parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
|
133 |
+
parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
|
134 |
+
args = parser.parse_args()
|
135 |
+
|
136 |
+
success = convert_pth_to_onnx(
|
137 |
+
pth_path=args.pth_path,
|
138 |
+
onnx_path=args.onnx_path,
|
139 |
+
channel=args.channel,
|
140 |
+
tilesize=args.tilesize,
|
141 |
+
use_fp16=args.use_fp16,
|
142 |
+
simplify_model=args.simplify_model
|
143 |
+
)
|
144 |
+
|
145 |
+
if success:
|
146 |
+
print("Conversion process finished.")
|
147 |
+
else:
|
148 |
+
print("Conversion process failed.")
|
149 |
+
exit(1) # Exit with error code
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
spandrel
|
2 |
+
torch
|
3 |
+
pnnx
|
4 |
+
onnx
|
5 |
+
mnn
|
6 |
+
gradio
|