sjtu-deepvision commited on
Commit
c9620cb
·
verified ·
1 Parent(s): a41990b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -41,25 +41,22 @@ pipe = DAIPipeline(
41
  ).to(device)
42
 
43
  @spaces.GPU
44
- def process_image(input_image):
45
  # 将 Gradio 输入转换为 PIL 图像
46
  input_image = Image.fromarray(input_image)
47
 
48
- # 根据输入图像的大小设置分辨率
49
- max_side = max(input_image.size)
50
- if max_side < 768:
51
- resolution = 768
52
- elif max_side > 2560:
53
- resolution = 2560
54
  else:
55
- resolution = 0
56
 
57
  # 处理图像
58
  pipe_out = pipe(
59
  image=input_image,
60
  prompt="remove glass reflection",
61
  vae_2=vae_2,
62
- processing_resolution=resolution,
63
  )
64
 
65
  # 将输出转换为图像
@@ -73,19 +70,26 @@ def process_image(input_image):
73
  def create_gradio_interface():
74
  # 示例图像
75
  example_images = [
76
- os.path.join("files", "image", f"{i}.png") for i in range(1, 14)
77
  ]
78
  title = "# Dereflection Any Image"
79
  description = """Official demo for **Dereflection Any Image**.
80
  Please refer to our [paper](), [project page](https://abuuu122.github.io/DAI.github.io/), and [github](https://github.com/Abuuu122/Dereflection-Any-Image) for more details."""
81
 
82
-
83
  with gr.Blocks() as demo:
84
  gr.Markdown(title)
85
  gr.Markdown(description)
86
  with gr.Row():
87
  with gr.Column():
88
  input_image = gr.Image(label="Input Image", type="numpy")
 
 
 
 
 
 
 
 
89
  submit_btn = gr.Button("Remove Reflection", variant="primary")
90
  with gr.Column():
91
  output_image = gr.Image(label="Processed Image")
@@ -93,7 +97,7 @@ def create_gradio_interface():
93
  # 添加示例
94
  gr.Examples(
95
  examples=example_images,
96
- inputs=input_image,
97
  outputs=output_image,
98
  fn=process_image,
99
  cache_examples=False, # 缓存结果以加快加载速度
@@ -103,7 +107,7 @@ def create_gradio_interface():
103
  # 绑定按钮点击事件
104
  submit_btn.click(
105
  fn=process_image,
106
- inputs=input_image,
107
  outputs=output_image,
108
  )
109
 
 
41
  ).to(device)
42
 
43
  @spaces.GPU
44
+ def process_image(input_image, resolution_choice):
45
  # 将 Gradio 输入转换为 PIL 图像
46
  input_image = Image.fromarray(input_image)
47
 
48
+ # 根据用户选择设置处理分辨率
49
+ if resolution_choice == "768":
50
+ processing_resolution = None
 
 
 
51
  else:
52
+ processing_resolution = 0 # 使用原始分辨率
53
 
54
  # 处理图像
55
  pipe_out = pipe(
56
  image=input_image,
57
  prompt="remove glass reflection",
58
  vae_2=vae_2,
59
+ processing_resolution=processing_resolution,
60
  )
61
 
62
  # 将输出转换为图像
 
70
  def create_gradio_interface():
71
  # 示例图像
72
  example_images = [
73
+ [os.path.join("files", "image", f"{i}.png"), "768"] for i in range(1, 14)
74
  ]
75
  title = "# Dereflection Any Image"
76
  description = """Official demo for **Dereflection Any Image**.
77
  Please refer to our [paper](), [project page](https://abuuu122.github.io/DAI.github.io/), and [github](https://github.com/Abuuu122/Dereflection-Any-Image) for more details."""
78
 
 
79
  with gr.Blocks() as demo:
80
  gr.Markdown(title)
81
  gr.Markdown(description)
82
  with gr.Row():
83
  with gr.Column():
84
  input_image = gr.Image(label="Input Image", type="numpy")
85
+ resolution_choice = gr.Radio(
86
+ choices=["768", "Original Resolution"],
87
+ label="Processing Resolution",
88
+ value="768", # 默认选择原始分辨率
89
+ )
90
+ gr.Markdown(
91
+ "Select the resolution for processing the image. Higher resolution may take longer to process."
92
+ )
93
  submit_btn = gr.Button("Remove Reflection", variant="primary")
94
  with gr.Column():
95
  output_image = gr.Image(label="Processed Image")
 
97
  # 添加示例
98
  gr.Examples(
99
  examples=example_images,
100
+ inputs=[input_image, resolution_choice], # 输入组件列表
101
  outputs=output_image,
102
  fn=process_image,
103
  cache_examples=False, # 缓存结果以加快加载速度
 
107
  # 绑定按钮点击事件
108
  submit_btn.click(
109
  fn=process_image,
110
+ inputs=[input_image, resolution_choice], # 输入组件列表
111
  outputs=output_image,
112
  )
113