tumuyan2 commited on
Commit
a4a3999
·
1 Parent(s): b96054d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -31,17 +31,12 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
31
  log = "转换过程非常慢,请耐心等待。显示文件列表不代表转换完成。如果未发生错误,转换结束会显示”任务完成“\n"
32
  yield [], log
33
  if input2 == None or input2.strip() == "":
34
- if isinstance(input1, str):
35
- input2 = os.path.splitext(os.path.basename(input1))[0]
36
- else:
37
- input2 = os.path.splitext(os.path.basename(input1.name))[0]
38
-
39
  if input2 == "":
40
  input2 = str(task_id)
41
  log += f"未提供文件名,使用{input2}\n"
42
  print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正")
43
  yield [], log
44
- input2 = "output"
45
 
46
  try:
47
  # 判断 input1 是地址还是文件,增加对 ftp 和 webdav 协议的支持
@@ -153,6 +148,7 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
153
  log += "生成张量…\n"
154
  yield [], log
155
  pt_path = output_folder + "/" + input2 + ".pt"
 
156
  input_tensor0 = torch.rand(shape0) if any(shape0) else None
157
  input_tensor1 = torch.rand(shape1) if any(shape1) else None
158
  if input_tensor0 is not None and input_tensor1 is not None:
@@ -190,6 +186,22 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
190
  torch_model = model.model
191
  print_log(task_id, input2, "获得模型对象", "完成")
192
  yield [], log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  if os.path.exists(pt_path):
194
  print_log(task_id, input2, "转换为TorchScript模型", "跳过")
195
  log += "跳过转换为TorchScript模型\n"
@@ -222,14 +234,14 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
222
  yield [], log
223
  returncode = process.poll()
224
  if returncode != 0:
225
- log += f"执行命令: {command} 失败,返回码: {returncode}\n"
226
  else:
227
- log += f"执行命令: {command} 成功\n"
228
  except Exception as e:
229
  log += f"执行命令: {command} 失败,错误信息: {str(e)}\n"
230
 
231
  output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))]
232
- log += f"任务完成,输出文件: {output_files}\n"
233
  print_log(task_id, input2, "执行命令", "完成")
234
  yield output_files, log
235
  except Exception as e:
@@ -243,8 +255,11 @@ with gr.Blocks() as demo:
243
  with gr.Row():
244
  # 左侧列,包含输入组件和按钮
245
  with gr.Column():
 
 
246
  with gr.Row():
247
  input1 = gr.Textbox(label="粘贴地址")
 
248
  input1_file = gr.File(label="上传文件")
249
  input2 = gr.Textbox(label="自定义文件名")
250
  # 修改为字符串输入控件
@@ -275,13 +290,13 @@ with gr.Blocks() as demo:
275
 
276
  # 添加范例
277
  examples = [
278
- ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", "", "1,3,128,128", "0,0,0,0"],
279
- ["https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth", "", "1,3,128,128", "0,0,0,0"],
280
- ["https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth", "", "1,3,128,128", "0,0,0,0"]
281
  ]
282
  gr.Examples(
283
  examples=examples,
284
- inputs=[input1, input2, shape0_str, shape1_str],
285
  outputs=[output, log_textbox],
286
  fn=start_process
287
  )
 
31
  log = "转换过程非常慢,请耐心等待。显示文件列表不代表转换完成。如果未发生错误,转换结束会显示”任务完成“\n"
32
  yield [], log
33
  if input2 == None or input2.strip() == "":
34
+ input2 = os.path.splitext(os.path.basename(input1))[0]
 
 
 
 
35
  if input2 == "":
36
  input2 = str(task_id)
37
  log += f"未提供文件名,使用{input2}\n"
38
  print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正")
39
  yield [], log
 
40
 
41
  try:
42
  # 判断 input1 是地址还是文件,增加对 ftp 和 webdav 协议的支持
 
148
  log += "生成张量…\n"
149
  yield [], log
150
  pt_path = output_folder + "/" + input2 + ".pt"
151
+ onnx_path = output_folder + "/" + input2 + ".onnx"
152
  input_tensor0 = torch.rand(shape0) if any(shape0) else None
153
  input_tensor1 = torch.rand(shape1) if any(shape1) else None
154
  if input_tensor0 is not None and input_tensor1 is not None:
 
186
  torch_model = model.model
187
  print_log(task_id, input2, "获得模型对象", "完成")
188
  yield [], log
189
+
190
+
191
+ if os.path.exists(onnx_path):
192
+ print_log(task_id, input2, "转换为ONNX模型", "跳过")
193
+ log += "跳过转换为ONNX模型\n"
194
+ yield [], log
195
+ else:
196
+ print_log(task_id, input2, "转换为ONNX模型", "开始")
197
+ log += "转换为ONNX模型…\n"
198
+ yield [], log
199
+ # 使用 torch.onnx.export 进行模型转换
200
+ # 将列表转换为元组
201
+ shape_tuple = tuple(shape0)
202
+ torch.onnx.export(torch_model, torch.rand(shape_tuple), onnx_path, verbose=True, input_names=["data"], output_names=["output"])
203
+
204
+
205
  if os.path.exists(pt_path):
206
  print_log(task_id, input2, "转换为TorchScript模型", "跳过")
207
  log += "跳过转换为TorchScript模型\n"
 
234
  yield [], log
235
  returncode = process.poll()
236
  if returncode != 0:
237
+ log += f"执行命令失败,返回码: {returncode},命令: {command} \n"
238
  else:
239
+ log += f"执行命令成功: {command} \n"
240
  except Exception as e:
241
  log += f"执行命令: {command} 失败,错误信息: {str(e)}\n"
242
 
243
  output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))]
244
+ log += f"任务完成\n"
245
  print_log(task_id, input2, "执行命令", "完成")
246
  yield output_files, log
247
  except Exception as e:
 
255
  with gr.Row():
256
  # 左侧列,包含输入组件和按钮
257
  with gr.Column():
258
+ # 添加文本提示
259
+ gr.Markdown("请输入的url,或者上传一个文件。限制文件为小于100M的*.pth模型")
260
  with gr.Row():
261
  input1 = gr.Textbox(label="粘贴地址")
262
+ # 新增文件上传组件
263
  input1_file = gr.File(label="上传文件")
264
  input2 = gr.Textbox(label="自定义文件名")
265
  # 修改为字符串输入控件
 
290
 
291
  # 添加范例
292
  examples = [
293
+ ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", None, "", "1,3,128,128", "0,0,0,0"],
294
+ ["https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth", None, "", "1,3,128,128", "0,0,0,0"],
295
+ ["https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth", None, "", "1,3,128,128", "0,0,0,0"]
296
  ]
297
  gr.Examples(
298
  examples=examples,
299
+ inputs=[input1, input1_file, input2, shape0_str, shape1_str],
300
  outputs=[output, log_textbox],
301
  fn=start_process
302
  )