tumuyan2 commited on
Commit
62053b6
·
1 Parent(s): 55f5f0f
Files changed (3) hide show
  1. app.py +165 -45
  2. pth2onnx.py +1 -1
  3. requirements.txt +1 -0
app.py CHANGED
@@ -3,18 +3,31 @@ import requests
3
  import os
4
  import subprocess
5
  from typing import Union
6
- from datetime import datetime
7
  from pth2onnx import convert_pth_to_onnx
8
- from urllib.parse import urlparse
 
 
 
 
9
  from typing import Optional
10
 
11
- # 新增日志开关
 
 
 
 
 
 
 
 
 
 
 
12
  log_to_terminal = True
13
-
14
- # 新增全局任务计数器
15
  task_counter = 0
 
16
 
17
- # 新增日志函数
18
  def print_log(task_id, stage, status):
19
  if log_to_terminal:
20
  print(f"任务{task_id}: [{status}] {stage}")
@@ -31,28 +44,90 @@ def download_file(url: str, save_path: str):
31
  with open(save_path, 'wb') as f:
32
  f.write(response.content)
33
 
34
- def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
 
 
 
 
 
 
 
35
  """
36
- URL下载文件到指定文件夹,并进行文件大小检查
37
-
38
  Args:
39
- url: 要下载的文件URL
40
- folder: 保存文件的目标文件夹路径
41
- filesize_max: 最大允许文件大小(字节),超过此值将中断下载并删除文件
42
- filesize_min: 最小允许文件大小(字节),小于此值将删除文件并返回None
43
-
44
  Returns:
45
- 成功下载的文件名,如果下载失败或文件大小不符合要求则返回None
46
  """
47
- # 解析URL获取文件名
48
- parsed_url = urlparse(url)
49
- filename = os.path.basename(parsed_url.path)
50
- if not filename:
51
- return None # 无法从URL获取文件名
52
 
53
- # 确保目标文件夹存在
54
- os.makedirs(folder, exist_ok=True)
55
- save_path = os.path.join(folder, filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  try:
58
  # 发送HTTP请求,流式下载
@@ -64,6 +139,32 @@ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min:
64
  if total_size > filesize_max:
65
  return None # 文件大小超过最大值,不下载
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  downloaded_size = 0
68
  with open(save_path, 'wb') as file:
69
  for chunk in response.iter_content(chunk_size=8192):
@@ -81,11 +182,14 @@ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min:
81
  os.remove(save_path)
82
  return None
83
 
84
- return filename
 
 
 
85
 
86
  except Exception as e:
87
  # 发生异常时清理文件
88
- if os.path.exists(save_path):
89
  os.remove(save_path)
90
  return None
91
 
@@ -154,14 +258,15 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
154
 
155
  with gr.Blocks() as demo:
156
  gr.Markdown("# 模型转换工具")
 
157
  with gr.Row():
158
  with gr.Column():
159
- input_type = gr.Radio(['模型链接', '上传模型文件'], label='输入类型')
160
  url_input = gr.Textbox(label='模型链接')
161
- file_input = gr.File(label='上传模型文件', visible=False)
162
 
163
  def show_input(input_type):
164
- if input_type == '模型链接':
165
  return gr.update(visible=True), gr.update(visible=False)
166
  else:
167
  return gr.update(visible=False), gr.update(visible=True)
@@ -186,30 +291,43 @@ with gr.Blocks() as demo:
186
  os.makedirs(output_dir, exist_ok=True)
187
  log=""
188
 
189
- if input_type == '模型链接' and url_input:
190
- # 新增:下载模型文件到 output_dir
191
  log = f'正在下载模型文件: {url_input}\n'
192
  print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
193
  yield None, None, log
194
-
195
- filename = download_file2folder(
196
- url=url_input,
197
- folder=output_dir,
198
- filesize_max=200*1024*1024, # 200MB
199
- filesize_min=1024 # 1KB
200
- )
201
- model_input = os.path.join(output_dir, filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  log += f'\n��型文件已下载到: {model_input}\n'
203
  print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
204
  yield None, None, log
205
- elif input_type == '上传模型文件' and file_input:
206
  model_input = file_input
207
  else:
208
  # 改为通过yield返回错误日志
209
  log = '\n请选择输入类型并提供有效的输入!'
210
  yield None, None, log
211
  return
212
-
213
 
214
  onnx_path = None
215
  mnn_path = None
@@ -235,11 +353,13 @@ with gr.Blocks() as demo:
235
  examples_column = gr.Column(visible=True)
236
  with examples_column:
237
  examples = [
238
- ["模型链接", "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"],
239
- ["模型链接", "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"],
240
- ["模型链接", "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"],
241
- ["模型链接", "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"],
242
- ["模型链接", "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"]
 
 
243
  ]
244
  example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
245
 
 
3
  import os
4
  import subprocess
5
  from typing import Union
 
6
  from pth2onnx import convert_pth_to_onnx
7
+ import re
8
+ import time
9
+ from urllib.parse import urlparse, unquote
10
+ import gdown
11
+ import sys
12
  from typing import Optional
13
 
14
+ # def format_bytes(size: int) -> str:
15
+ # """将字节数格式化为更易读的单位 (KB, MB, GB)"""
16
+ # if size is None:
17
+ # return "未知大小"
18
+ # power = 1024
19
+ # n = 0
20
+ # power_labels = {0: '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'}
21
+ # while size > power and n < len(power_labels) -1 :
22
+ # size /= power
23
+ # n += 1
24
+ # return f"{size:.2f} {power_labels[n]}B"
25
+ # 日志开关
26
  log_to_terminal = True
 
 
27
  task_counter = 0
28
+ download_cache = {} # 格式: {url: 文件路径}
29
 
30
+ # 日志函数
31
  def print_log(task_id, stage, status):
32
  if log_to_terminal:
33
  print(f"任务{task_id}: [{status}] {stage}")
 
44
  with open(save_path, 'wb') as f:
45
  f.write(response.content)
46
 
47
+
48
+
49
+ def download_gdrive_file(
50
+ url: str,
51
+ folder: str,
52
+ filesize_max: int,
53
+ filesize_min: int = 0
54
+ ) -> Optional[str]:
55
  """
56
+ Google Drive 链接获取文件信息,检查大小后下载文件。
57
+
58
  Args:
59
+ url (str): Google Drive 的分享链接。
60
+ folder (str): 用于保存文件的目标文件夹路径。
61
+ filesize_max (int): 允许下载的最大文件大小(单位:字节)。
62
+ filesize_min (int): 允许下载的最小文件大小(单位:字节),默认为 0。
63
+
64
  Returns:
65
+ Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None
66
  """
67
+ print(f"--- 开始处理链接: {url} ---")
 
 
 
 
68
 
69
+ # 1. 获取文件信息
70
+ try:
71
+ info = gdown.get_file_info(url)
72
+ if not info:
73
+ print(f"错误:无法从链接获取文件信息。请检查链接是否有效且文件是否已公开分享。", file=sys.stderr)
74
+ return None
75
+
76
+ filename = info.get('name')
77
+ filesize = info.get('size')
78
+
79
+ if not filename or filesize is None:
80
+ print(f"错误:无法获取完整的文件名或文件大小。元数据: {info}", file=sys.stderr)
81
+ return None
82
+
83
+ except Exception as e:
84
+ print(f"错误:获取文件信息时发生异常: {e}", file=sys.stderr)
85
+ return None
86
+
87
+ # 2. 验证文件大小
88
+ if filesize < filesize_min or filesize > filesize_max:
89
+ return None
90
+
91
+
92
+ # 3. 准备下载路径并创建文件夹
93
+ try:
94
+ os.makedirs(folder, exist_ok=True)
95
+ output_path = os.path.join(folder, filename)
96
+ except OSError as e:
97
+ print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr)
98
+ return None
99
+
100
+ # 4. 下载文件
101
+ try:
102
+ print(f"开始下载google drive文件到: {output_path}")
103
+ downloaded_path = gdown.download(url, output_path, quiet=True)
104
+ if downloaded_path and os.path.exists(downloaded_path):
105
+ print(f"下载成功!文件已保存至: {downloaded_path}")
106
+ return downloaded_path
107
+ else:
108
+ print("错误:gdown 下载过程未返回有效路径或文件下载后不存在。", file=sys.stderr)
109
+ return None
110
+
111
+ except Exception as e:
112
+ print(f"错误:下载过程中发生异常: {e}", file=sys.stderr)
113
+ # 如果下载中断,清理可能不完整的文件
114
+ if os.path.exists(output_path):
115
+ os.remove(output_path)
116
+ return None
117
+
118
+
119
+
120
+ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
121
+ global download_cache
122
+
123
+ # 检查缓存是否存在且文件有效
124
+ if url in download_cache:
125
+ cached_path = download_cache[url]
126
+ if os.path.exists(cached_path) and os.path.getsize(cached_path) >= filesize_min:
127
+ print(f" 使用缓存文件: {cached_path}")
128
+ return cached_path
129
+
130
+ save_path = None # 初始化save_path为None,避免except块中引用错误
131
 
132
  try:
133
  # 发送HTTP请求,流式下载
 
139
  if total_size > filesize_max:
140
  return None # 文件大小超过最大值,不下载
141
 
142
+ # 提取文件名
143
+ filename = None
144
+ content_disposition = response.headers.get('Content-Disposition')
145
+ if content_disposition:
146
+ # 使用正则表达式提取filename
147
+ match = re.search(r'filename="?([^"]+)"?', content_disposition)
148
+ if match:
149
+ filename = match.group(1)
150
+ # 处理可能的URL编码文件名
151
+ filename = unquote(filename)
152
+
153
+ # 如果响应头没有,从URL解析
154
+ if not filename:
155
+ parsed_url = urlparse(url)
156
+ filename = os.path.basename(parsed_url.path)
157
+ # 处理URL编码的文件名
158
+ filename = unquote(filename)
159
+
160
+ # 如果仍然没有文件名,生成默认文件名
161
+ if not filename:
162
+ filename = f"download_{int(time.time())}.bin"
163
+
164
+ # 确保目标文件夹存在
165
+ os.makedirs(folder, exist_ok=True)
166
+ save_path = os.path.join(folder, filename)
167
+
168
  downloaded_size = 0
169
  with open(save_path, 'wb') as file:
170
  for chunk in response.iter_content(chunk_size=8192):
 
182
  os.remove(save_path)
183
  return None
184
 
185
+ # 下载成功后更新缓存
186
+ if save_path:
187
+ download_cache[url] = save_path
188
+ return save_path
189
 
190
  except Exception as e:
191
  # 发生异常时清理文件
192
+ if save_path and os.path.exists(save_path):
193
  os.remove(save_path)
194
  return None
195
 
 
258
 
259
  with gr.Blocks() as demo:
260
  gr.Markdown("# 模型转换工具")
261
+ model_type_opt = ['从链接下载', '直接上传文件']
262
  with gr.Row():
263
  with gr.Column():
264
+ input_type = gr.Radio(model_type_opt, label='模型文件来源')
265
  url_input = gr.Textbox(label='模型链接')
266
+ file_input = gr.File(label='模型文件', visible=False)
267
 
268
  def show_input(input_type):
269
+ if input_type == model_type_opt[0]:
270
  return gr.update(visible=True), gr.update(visible=False)
271
  else:
272
  return gr.update(visible=False), gr.update(visible=True)
 
291
  os.makedirs(output_dir, exist_ok=True)
292
  log=""
293
 
294
+ if input_type == model_type_opt[0] and url_input:
 
295
  log = f'正在下载模型文件: {url_input}\n'
296
  print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
297
  yield None, None, log
298
+
299
+ if url_input.startswith("https://drive.google.com/file/"):
300
+ model_input = download_gdrive_file(
301
+ url=url_input,
302
+ folder=output_dir,
303
+ filesize_max=200*1024*1024, # 200MB
304
+ filesize_min=1024 # 1KB
305
+ )
306
+ else:
307
+ model_input = download_file2folder(
308
+ url=url_input,
309
+ folder=output_dir,
310
+ filesize_max=200*1024*1024, # 200MB
311
+ filesize_min=1024 # 1KB
312
+ )
313
+
314
+
315
+ if not model_input:
316
+ log += f'\n模型文件下载失败\n'
317
+ print_log(task_counter, f'模型文件载', '失败')
318
+ yield None, None, log
319
+ return
320
+
321
  log += f'\n��型文件已下载到: {model_input}\n'
322
  print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
323
  yield None, None, log
324
+ elif input_type == model_type_opt[1] and file_input:
325
  model_input = file_input
326
  else:
327
  # 改为通过yield返回错误日志
328
  log = '\n请选择输入类型并提供有效的输入!'
329
  yield None, None, log
330
  return
 
331
 
332
  onnx_path = None
333
  mnn_path = None
 
353
  examples_column = gr.Column(visible=True)
354
  with examples_column:
355
  examples = [
356
+ [model_type_opt[0], "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"],
357
+ [model_type_opt[0], "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"],
358
+ [model_type_opt[0], "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"],
359
+ [model_type_opt[0], "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"],
360
+ [model_type_opt[0], "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"],
361
+ [model_type_opt[0], "https://drive.google.com/uc?export=download&confirm=1&id=1PeqL1ikJbBJbVzvlqvtb4d7QdSW7BzrQ"],
362
+ [model_type_opt[0], "https://drive.google.com/file/d/1maYmC5yyzWCC42X5O0HeDuepsLFh7AV4/view?usp=drive_link"],
363
  ]
364
  example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
365
 
pth2onnx.py CHANGED
@@ -73,7 +73,7 @@ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tiles
73
  print(f'文件名 {filename} 包含匹配模式。')
74
  else:
75
  base_path = f"{base_path}-x{scale}"
76
- print("final use_fp16", str(use_fp16) )
77
  onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
78
 
79
  # 处理相对路径情况
 
73
  print(f'文件名 {filename} 包含匹配模式。')
74
  else:
75
  base_path = f"{base_path}-x{scale}"
76
+ # print("final use_fp16", str(use_fp16) )
77
  onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
78
 
79
  # 处理相对路径情况
requirements.txt CHANGED
@@ -5,3 +5,4 @@ onnx
5
  onnxsim
6
  mnn
7
  gradio
 
 
5
  onnxsim
6
  mnn
7
  gradio
8
+ gdown