ADD GDOWN
Browse files- app.py +165 -45
- pth2onnx.py +1 -1
- requirements.txt +1 -0
app.py
CHANGED
@@ -3,18 +3,31 @@ import requests
|
|
3 |
import os
|
4 |
import subprocess
|
5 |
from typing import Union
|
6 |
-
from datetime import datetime
|
7 |
from pth2onnx import convert_pth_to_onnx
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
from typing import Optional
|
10 |
|
11 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
log_to_terminal = True
|
13 |
-
|
14 |
-
# 新增全局任务计数器
|
15 |
task_counter = 0
|
|
|
16 |
|
17 |
-
#
|
18 |
def print_log(task_id, stage, status):
|
19 |
if log_to_terminal:
|
20 |
print(f"任务{task_id}: [{status}] {stage}")
|
@@ -31,28 +44,90 @@ def download_file(url: str, save_path: str):
|
|
31 |
with open(save_path, 'wb') as f:
|
32 |
f.write(response.content)
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
"""
|
36 |
-
从
|
37 |
-
|
38 |
Args:
|
39 |
-
url:
|
40 |
-
folder:
|
41 |
-
filesize_max
|
42 |
-
filesize_min
|
43 |
-
|
44 |
Returns:
|
45 |
-
|
46 |
"""
|
47 |
-
|
48 |
-
parsed_url = urlparse(url)
|
49 |
-
filename = os.path.basename(parsed_url.path)
|
50 |
-
if not filename:
|
51 |
-
return None # 无法从URL获取文件名
|
52 |
|
53 |
-
#
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
try:
|
58 |
# 发送HTTP请求,流式下载
|
@@ -64,6 +139,32 @@ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min:
|
|
64 |
if total_size > filesize_max:
|
65 |
return None # 文件大小超过最大值,不下载
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
downloaded_size = 0
|
68 |
with open(save_path, 'wb') as file:
|
69 |
for chunk in response.iter_content(chunk_size=8192):
|
@@ -81,11 +182,14 @@ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min:
|
|
81 |
os.remove(save_path)
|
82 |
return None
|
83 |
|
84 |
-
|
|
|
|
|
|
|
85 |
|
86 |
except Exception as e:
|
87 |
# 发生异常时清理文件
|
88 |
-
if os.path.exists(save_path):
|
89 |
os.remove(save_path)
|
90 |
return None
|
91 |
|
@@ -154,14 +258,15 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
|
|
154 |
|
155 |
with gr.Blocks() as demo:
|
156 |
gr.Markdown("# 模型转换工具")
|
|
|
157 |
with gr.Row():
|
158 |
with gr.Column():
|
159 |
-
input_type = gr.Radio(
|
160 |
url_input = gr.Textbox(label='模型链接')
|
161 |
-
file_input = gr.File(label='
|
162 |
|
163 |
def show_input(input_type):
|
164 |
-
if input_type ==
|
165 |
return gr.update(visible=True), gr.update(visible=False)
|
166 |
else:
|
167 |
return gr.update(visible=False), gr.update(visible=True)
|
@@ -186,30 +291,43 @@ with gr.Blocks() as demo:
|
|
186 |
os.makedirs(output_dir, exist_ok=True)
|
187 |
log=""
|
188 |
|
189 |
-
if input_type ==
|
190 |
-
# 新增:下载模型文件到 output_dir
|
191 |
log = f'正在下载模型文件: {url_input}\n'
|
192 |
print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
|
193 |
yield None, None, log
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
log += f'\n��型文件已下载到: {model_input}\n'
|
203 |
print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
|
204 |
yield None, None, log
|
205 |
-
elif input_type ==
|
206 |
model_input = file_input
|
207 |
else:
|
208 |
# 改为通过yield返回错误日志
|
209 |
log = '\n请选择输入类型并提供有效的输入!'
|
210 |
yield None, None, log
|
211 |
return
|
212 |
-
|
213 |
|
214 |
onnx_path = None
|
215 |
mnn_path = None
|
@@ -235,11 +353,13 @@ with gr.Blocks() as demo:
|
|
235 |
examples_column = gr.Column(visible=True)
|
236 |
with examples_column:
|
237 |
examples = [
|
238 |
-
[
|
239 |
-
[
|
240 |
-
[
|
241 |
-
[
|
242 |
-
[
|
|
|
|
|
243 |
]
|
244 |
example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
|
245 |
|
|
|
3 |
import os
|
4 |
import subprocess
|
5 |
from typing import Union
|
|
|
6 |
from pth2onnx import convert_pth_to_onnx
|
7 |
+
import re
|
8 |
+
import time
|
9 |
+
from urllib.parse import urlparse, unquote
|
10 |
+
import gdown
|
11 |
+
import sys
|
12 |
from typing import Optional
|
13 |
|
14 |
+
# def format_bytes(size: int) -> str:
|
15 |
+
# """将字节数格式化为更易读的单位 (KB, MB, GB)"""
|
16 |
+
# if size is None:
|
17 |
+
# return "未知大小"
|
18 |
+
# power = 1024
|
19 |
+
# n = 0
|
20 |
+
# power_labels = {0: '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'}
|
21 |
+
# while size > power and n < len(power_labels) -1 :
|
22 |
+
# size /= power
|
23 |
+
# n += 1
|
24 |
+
# return f"{size:.2f} {power_labels[n]}B"
|
25 |
+
# 日志开关
|
26 |
log_to_terminal = True
|
|
|
|
|
27 |
task_counter = 0
|
28 |
+
download_cache = {} # 格式: {url: 文件路径}
|
29 |
|
30 |
+
# 日志函数
|
31 |
def print_log(task_id, stage, status):
|
32 |
if log_to_terminal:
|
33 |
print(f"任务{task_id}: [{status}] {stage}")
|
|
|
44 |
with open(save_path, 'wb') as f:
|
45 |
f.write(response.content)
|
46 |
|
47 |
+
|
48 |
+
|
49 |
+
def download_gdrive_file(
|
50 |
+
url: str,
|
51 |
+
folder: str,
|
52 |
+
filesize_max: int,
|
53 |
+
filesize_min: int = 0
|
54 |
+
) -> Optional[str]:
|
55 |
"""
|
56 |
+
从 Google Drive 链接获取文件信息,检查大小后下载文件。
|
57 |
+
|
58 |
Args:
|
59 |
+
url (str): Google Drive 的分享链接。
|
60 |
+
folder (str): 用于保存文件的目标文件夹路径。
|
61 |
+
filesize_max (int): 允许下载的最大文件大小(单位:字节)。
|
62 |
+
filesize_min (int): 允许下载的最小文件大小(单位:字节),默认为 0。
|
63 |
+
|
64 |
Returns:
|
65 |
+
Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
|
66 |
"""
|
67 |
+
print(f"--- 开始处理链接: {url} ---")
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
# 1. 获取文件信息
|
70 |
+
try:
|
71 |
+
info = gdown.get_file_info(url)
|
72 |
+
if not info:
|
73 |
+
print(f"错误:无法从链接获取文件信息。请检查链接是否有效且文件是否已公开分享。", file=sys.stderr)
|
74 |
+
return None
|
75 |
+
|
76 |
+
filename = info.get('name')
|
77 |
+
filesize = info.get('size')
|
78 |
+
|
79 |
+
if not filename or filesize is None:
|
80 |
+
print(f"错误:无法获取完整的文件名或文件大小。元数据: {info}", file=sys.stderr)
|
81 |
+
return None
|
82 |
+
|
83 |
+
except Exception as e:
|
84 |
+
print(f"错误:获取文件信息时发生异常: {e}", file=sys.stderr)
|
85 |
+
return None
|
86 |
+
|
87 |
+
# 2. 验证文件大小
|
88 |
+
if filesize < filesize_min or filesize > filesize_max:
|
89 |
+
return None
|
90 |
+
|
91 |
+
|
92 |
+
# 3. 准备下载路径并创建文件夹
|
93 |
+
try:
|
94 |
+
os.makedirs(folder, exist_ok=True)
|
95 |
+
output_path = os.path.join(folder, filename)
|
96 |
+
except OSError as e:
|
97 |
+
print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr)
|
98 |
+
return None
|
99 |
+
|
100 |
+
# 4. 下载文件
|
101 |
+
try:
|
102 |
+
print(f"开始下载google drive文件到: {output_path}")
|
103 |
+
downloaded_path = gdown.download(url, output_path, quiet=True)
|
104 |
+
if downloaded_path and os.path.exists(downloaded_path):
|
105 |
+
print(f"下载成功!文件已保存至: {downloaded_path}")
|
106 |
+
return downloaded_path
|
107 |
+
else:
|
108 |
+
print("错误:gdown 下载过程未返回有效路径或文件下载后不存在。", file=sys.stderr)
|
109 |
+
return None
|
110 |
+
|
111 |
+
except Exception as e:
|
112 |
+
print(f"错误:下载过程中发生异常: {e}", file=sys.stderr)
|
113 |
+
# 如果下载中断,清理可能不完整的文件
|
114 |
+
if os.path.exists(output_path):
|
115 |
+
os.remove(output_path)
|
116 |
+
return None
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
|
121 |
+
global download_cache
|
122 |
+
|
123 |
+
# 检查缓存是否存在且文件有效
|
124 |
+
if url in download_cache:
|
125 |
+
cached_path = download_cache[url]
|
126 |
+
if os.path.exists(cached_path) and os.path.getsize(cached_path) >= filesize_min:
|
127 |
+
print(f" 使用缓存文件: {cached_path}")
|
128 |
+
return cached_path
|
129 |
+
|
130 |
+
save_path = None # 初始化save_path为None,避免except块中引用错误
|
131 |
|
132 |
try:
|
133 |
# 发送HTTP请求,流式下载
|
|
|
139 |
if total_size > filesize_max:
|
140 |
return None # 文件大小超过最大值,不下载
|
141 |
|
142 |
+
# 提取文件名
|
143 |
+
filename = None
|
144 |
+
content_disposition = response.headers.get('Content-Disposition')
|
145 |
+
if content_disposition:
|
146 |
+
# 使用正则表达式提取filename
|
147 |
+
match = re.search(r'filename="?([^"]+)"?', content_disposition)
|
148 |
+
if match:
|
149 |
+
filename = match.group(1)
|
150 |
+
# 处理可能的URL编码文件名
|
151 |
+
filename = unquote(filename)
|
152 |
+
|
153 |
+
# 如果响应头没有,从URL解析
|
154 |
+
if not filename:
|
155 |
+
parsed_url = urlparse(url)
|
156 |
+
filename = os.path.basename(parsed_url.path)
|
157 |
+
# 处理URL编码的文件名
|
158 |
+
filename = unquote(filename)
|
159 |
+
|
160 |
+
# 如果仍然没有文件名,生成默认文件名
|
161 |
+
if not filename:
|
162 |
+
filename = f"download_{int(time.time())}.bin"
|
163 |
+
|
164 |
+
# 确保目标文件夹存在
|
165 |
+
os.makedirs(folder, exist_ok=True)
|
166 |
+
save_path = os.path.join(folder, filename)
|
167 |
+
|
168 |
downloaded_size = 0
|
169 |
with open(save_path, 'wb') as file:
|
170 |
for chunk in response.iter_content(chunk_size=8192):
|
|
|
182 |
os.remove(save_path)
|
183 |
return None
|
184 |
|
185 |
+
# 下载成功后更新缓存
|
186 |
+
if save_path:
|
187 |
+
download_cache[url] = save_path
|
188 |
+
return save_path
|
189 |
|
190 |
except Exception as e:
|
191 |
# 发生异常时清理文件
|
192 |
+
if save_path and os.path.exists(save_path):
|
193 |
os.remove(save_path)
|
194 |
return None
|
195 |
|
|
|
258 |
|
259 |
with gr.Blocks() as demo:
|
260 |
gr.Markdown("# 模型转换工具")
|
261 |
+
model_type_opt = ['从链接下载', '直接上传文件']
|
262 |
with gr.Row():
|
263 |
with gr.Column():
|
264 |
+
input_type = gr.Radio(model_type_opt, label='模型文件来源')
|
265 |
url_input = gr.Textbox(label='模型链接')
|
266 |
+
file_input = gr.File(label='模型文件', visible=False)
|
267 |
|
268 |
def show_input(input_type):
|
269 |
+
if input_type == model_type_opt[0]:
|
270 |
return gr.update(visible=True), gr.update(visible=False)
|
271 |
else:
|
272 |
return gr.update(visible=False), gr.update(visible=True)
|
|
|
291 |
os.makedirs(output_dir, exist_ok=True)
|
292 |
log=""
|
293 |
|
294 |
+
if input_type == model_type_opt[0] and url_input:
|
|
|
295 |
log = f'正在下载模型文件: {url_input}\n'
|
296 |
print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
|
297 |
yield None, None, log
|
298 |
+
|
299 |
+
if url_input.startswith("https://drive.google.com/file/"):
|
300 |
+
model_input = download_gdrive_file(
|
301 |
+
url=url_input,
|
302 |
+
folder=output_dir,
|
303 |
+
filesize_max=200*1024*1024, # 200MB
|
304 |
+
filesize_min=1024 # 1KB
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
model_input = download_file2folder(
|
308 |
+
url=url_input,
|
309 |
+
folder=output_dir,
|
310 |
+
filesize_max=200*1024*1024, # 200MB
|
311 |
+
filesize_min=1024 # 1KB
|
312 |
+
)
|
313 |
+
|
314 |
+
|
315 |
+
if not model_input:
|
316 |
+
log += f'\n模型文件下载失败\n'
|
317 |
+
print_log(task_counter, f'模型文件载', '失败')
|
318 |
+
yield None, None, log
|
319 |
+
return
|
320 |
+
|
321 |
log += f'\n��型文件已下载到: {model_input}\n'
|
322 |
print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
|
323 |
yield None, None, log
|
324 |
+
elif input_type == model_type_opt[1] and file_input:
|
325 |
model_input = file_input
|
326 |
else:
|
327 |
# 改为通过yield返回错误日志
|
328 |
log = '\n请选择输入类型并提供有效的输入!'
|
329 |
yield None, None, log
|
330 |
return
|
|
|
331 |
|
332 |
onnx_path = None
|
333 |
mnn_path = None
|
|
|
353 |
examples_column = gr.Column(visible=True)
|
354 |
with examples_column:
|
355 |
examples = [
|
356 |
+
[model_type_opt[0], "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"],
|
357 |
+
[model_type_opt[0], "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"],
|
358 |
+
[model_type_opt[0], "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"],
|
359 |
+
[model_type_opt[0], "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"],
|
360 |
+
[model_type_opt[0], "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"],
|
361 |
+
[model_type_opt[0], "https://drive.google.com/uc?export=download&confirm=1&id=1PeqL1ikJbBJbVzvlqvtb4d7QdSW7BzrQ"],
|
362 |
+
[model_type_opt[0], "https://drive.google.com/file/d/1maYmC5yyzWCC42X5O0HeDuepsLFh7AV4/view?usp=drive_link"],
|
363 |
]
|
364 |
example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
|
365 |
|
pth2onnx.py
CHANGED
@@ -73,7 +73,7 @@ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tiles
|
|
73 |
print(f'文件名 {filename} 包含匹配模式。')
|
74 |
else:
|
75 |
base_path = f"{base_path}-x{scale}"
|
76 |
-
print("final use_fp16", str(use_fp16) )
|
77 |
onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
|
78 |
|
79 |
# 处理相对路径情况
|
|
|
73 |
print(f'文件名 {filename} 包含匹配模式。')
|
74 |
else:
|
75 |
base_path = f"{base_path}-x{scale}"
|
76 |
+
# print("final use_fp16", str(use_fp16) )
|
77 |
onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
|
78 |
|
79 |
# 处理相对路径情况
|
requirements.txt
CHANGED
@@ -5,3 +5,4 @@ onnx
|
|
5 |
onnxsim
|
6 |
mnn
|
7 |
gradio
|
|
|
|
5 |
onnxsim
|
6 |
mnn
|
7 |
gradio
|
8 |
+
gdown
|