sjtu-deepvision commited on
Commit
59880dc
·
verified ·
1 Parent(s): 9a81057

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -12
app.py CHANGED
@@ -1,10 +1,9 @@
1
- import spaces
2
  import os
3
  import numpy as np
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
7
- from gradio_imageslider import ImageSlider # 导入 ImageSlider 组件
8
 
9
  # 延迟 CUDA 初始化
10
  weight_dtype = torch.float32
@@ -59,8 +58,7 @@ def process_image(input_image):
59
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
60
  processed_frame = Image.fromarray(processed_frame)
61
 
62
- # 返回输入图像和处理后的图像
63
- return input_image, processed_frame
64
 
65
  # 创建 Gradio 界面
66
  def create_gradio_interface():
@@ -76,18 +74,13 @@ def create_gradio_interface():
76
  input_image = gr.Image(label="Input Image", type="numpy")
77
  submit_btn = gr.Button("Remove Reflection", variant="primary")
78
  with gr.Column():
79
- # 使用 ImageSlider 显示前后对比
80
- output_slider = ImageSlider(
81
- label="Before & After",
82
- show_download_button=True,
83
- show_share_button=True,
84
- )
85
 
86
  # 添加示例
87
  gr.Examples(
88
  examples=example_images,
89
  inputs=input_image,
90
- outputs=output_slider,
91
  fn=process_image,
92
  cache_examples=False, # 缓存结果以加快加载速度
93
  label="Example Images",
@@ -97,7 +90,7 @@ def create_gradio_interface():
97
  submit_btn.click(
98
  fn=process_image,
99
  inputs=input_image,
100
- outputs=output_slider,
101
  )
102
 
103
  return demo
 
1
+ import spaces # 必须放在最前面
2
  import os
3
  import numpy as np
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
 
7
 
8
  # 延迟 CUDA 初始化
9
  weight_dtype = torch.float32
 
58
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
59
  processed_frame = Image.fromarray(processed_frame)
60
 
61
+ return processed_frame
 
62
 
63
  # 创建 Gradio 界面
64
  def create_gradio_interface():
 
74
  input_image = gr.Image(label="Input Image", type="numpy")
75
  submit_btn = gr.Button("Remove Reflection", variant="primary")
76
  with gr.Column():
77
+ output_image = gr.Image(label="Processed Image")
 
 
 
 
 
78
 
79
  # 添加示例
80
  gr.Examples(
81
  examples=example_images,
82
  inputs=input_image,
83
+ outputs=output_image,
84
  fn=process_image,
85
  cache_examples=False, # 缓存结果以加快加载速度
86
  label="Example Images",
 
90
  submit_btn.click(
91
  fn=process_image,
92
  inputs=input_image,
93
+ outputs=output_image,
94
  )
95
 
96
  return demo