sjtu-deepvision commited on
Commit
8895d65
·
verified ·
1 Parent(s): 588230c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -58
app.py CHANGED
@@ -41,19 +41,16 @@ pipe = DAIPipeline(
41
  ).to(device)
42
 
43
  @spaces.GPU
44
- def process_image(input_image, resolution_option):
45
  # 将 Gradio 输入转换为 PIL 图像
46
  input_image = Image.fromarray(input_image)
47
 
48
- # 根据用户选择的分辨率选项设置 processing_resolution
49
- processing_resolution = 2560 if resolution_option == "2560" else None
50
-
51
  # 处理图像
52
  pipe_out = pipe(
53
  image=input_image,
54
  prompt="remove glass reflection",
55
  vae_2=vae_2,
56
- processing_resolution=processing_resolution,
57
  )
58
 
59
  # 将输出转换为图像
@@ -67,67 +64,41 @@ def process_image(input_image, resolution_option):
67
  def create_gradio_interface():
68
  # 示例图像
69
  example_images = [
70
- os.path.join("files", "image", f"{i}.png") for i in range(1, 10)
 
 
71
  ]
72
  title = "# Dereflection Any Image"
73
  description = """Official demo for **Dereflection Any Image**.
74
  Please refer to our [paper](), [project page](https://abuuu122.github.io/DAI.github.io/), and [github](https://github.com/Abuuu122/Dereflection-Any-Image) for more details."""
75
 
 
76
  with gr.Blocks() as demo:
77
  gr.Markdown(title)
78
  gr.Markdown(description)
79
-
80
- # 创建两个 Tab
81
- with gr.Tabs():
82
- with gr.Tab("Image"):
83
- with gr.Row():
84
- with gr.Column():
85
- input_image_default = gr.Image(label="Input Image", type="numpy")
86
- submit_btn_default = gr.Button("Remove Reflection", variant="primary")
87
- with gr.Column():
88
- output_image_default = gr.Image(label="Processed Image")
89
-
90
- # 添加示例
91
- gr.Examples(
92
- examples=example_images,
93
- inputs=input_image_default,
94
- outputs=output_image_default,
95
- fn=lambda x: process_image(x, "None"),
96
- cache_examples=False, # 缓存结果以加快加载速度
97
- label="Example Images",
98
- )
99
-
100
- # 绑定按钮点击事件
101
- submit_btn_default.click(
102
- fn=lambda x: process_image(x, "None"),
103
- inputs=input_image_default,
104
- outputs=output_image_default,
105
- )
106
-
107
- with gr.Tab("2K (2560)"):
108
- with gr.Row():
109
- with gr.Column():
110
- input_image_high = gr.Image(label="Input Image", type="numpy")
111
- submit_btn_high = gr.Button("Remove Reflection", variant="primary")
112
- with gr.Column():
113
- output_image_high = gr.Image(label="Processed Image")
114
-
115
- # 添加示例
116
- gr.Examples(
117
- examples=example_images,
118
- inputs=input_image_high,
119
- outputs=output_image_high,
120
- fn=lambda x: process_image(x, "2560"),
121
- cache_examples=False, # 缓存结果以加快加载速度
122
- label="Example Images",
123
- )
124
-
125
- # 绑定按钮点击事件
126
- submit_btn_high.click(
127
- fn=lambda x: process_image(x, "2560"),
128
- inputs=input_image_high,
129
- outputs=output_image_high,
130
- )
131
 
132
  return demo
133
 
 
41
  ).to(device)
42
 
43
  @spaces.GPU
44
+ def process_image(input_image):
45
  # 将 Gradio 输入转换为 PIL 图像
46
  input_image = Image.fromarray(input_image)
47
 
 
 
 
48
  # 处理图像
49
  pipe_out = pipe(
50
  image=input_image,
51
  prompt="remove glass reflection",
52
  vae_2=vae_2,
53
+ processing_resolution=None,
54
  )
55
 
56
  # 将输出转换为图像
 
64
  def create_gradio_interface():
65
  # 示例图像
66
  example_images = [
67
+ os.path.join("files", "image", filename)
68
+ for filename in os.listdir(os.path.join("files", "image"))
69
+ if filename.endswith((".png", ".jpg", ".jpeg"))
70
  ]
71
  title = "# Dereflection Any Image"
72
  description = """Official demo for **Dereflection Any Image**.
73
  Please refer to our [paper](), [project page](https://abuuu122.github.io/DAI.github.io/), and [github](https://github.com/Abuuu122/Dereflection-Any-Image) for more details."""
74
 
75
+
76
  with gr.Blocks() as demo:
77
  gr.Markdown(title)
78
  gr.Markdown(description)
79
+ with gr.Row():
80
+ with gr.Column():
81
+ input_image = gr.Image(label="Input Image", type="numpy")
82
+ submit_btn = gr.Button("Remove Reflection", variant="primary")
83
+ with gr.Column():
84
+ output_image = gr.Image(label="Processed Image")
85
+
86
+ # 添加示例
87
+ gr.Examples(
88
+ examples=example_images,
89
+ inputs=input_image,
90
+ outputs=output_image,
91
+ fn=process_image,
92
+ cache_examples=False, # 缓存结果以加快加载速度
93
+ label="Example Images",
94
+ )
95
+
96
+ # 绑定按钮点击事件
97
+ submit_btn.click(
98
+ fn=process_image,
99
+ inputs=input_image,
100
+ outputs=output_image,
101
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  return demo
104