sjtu-deepvision commited on
Commit
651dfe7
·
verified ·
1 Parent(s): 86d66c4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -33
app.py CHANGED
@@ -1,12 +1,46 @@
1
- import gradio as gr
2
- from PIL import Image
3
- import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import functools
5
  import os
6
  import tempfile
 
 
 
7
  import numpy as np
 
8
  import torch as torch
9
  torch.backends.cuda.matmul.allow_tf32 = True
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  from diffusers import (
12
  AutoencoderKL,
@@ -15,26 +49,46 @@ from diffusers import (
15
 
16
  from transformers import CLIPTextModel, AutoTokenizer
17
 
18
- from DAI.pipeline_all import DAIPipeline
19
 
20
- from DAI.controlnetvae import ControlNetVAEModel
 
 
 
 
 
 
21
 
22
- from DAI.decoder import CustomAutoencoderKL
23
 
24
- def process_image(pipe, vae_2, image):
25
- # Save the input image to a temporary file
26
- temp_input_path = tempfile.mktemp(suffix=".png")
27
- image.save(temp_input_path)
 
28
 
29
- name_base, name_ext = os.path.splitext(os.path.basename(temp_input_path))
 
 
 
 
 
30
  print(f"Processing image {name_base}{name_ext}")
31
 
32
  path_output_dir = tempfile.mkdtemp()
33
  path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
 
 
 
 
 
 
 
 
 
 
34
  resolution = None
35
 
36
  pipe_out = pipe(
37
- image=image,
38
  prompt="remove glass reflection",
39
  vae_2=vae_2,
40
  processing_resolution=resolution,
@@ -44,13 +98,148 @@ def process_image(pipe, vae_2, image):
44
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
45
  processed_frame = Image.fromarray(processed_frame)
46
  processed_frame.save(path_out_png)
 
47
 
48
- return processed_frame
49
 
50
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
52
  weight_dtype = torch.float32
53
- pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
54
  pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
55
  revision = None
56
  variant = None
@@ -91,23 +280,8 @@ if __name__ == "__main__":
91
  except:
92
  pass # run without xformers
93
 
94
- # Cache example images in memory
95
- example_images_dir = "files/image"
96
- example_images = []
97
- for i in range(1, 9):
98
- image_path = os.path.join(example_images_dir, f"{i}.png")
99
- if os.path.exists(image_path):
100
- example_images.append([Image.open(image_path)])
101
-
102
- # Create a Gradio interface
103
- interface = gr.Interface(
104
- fn=spaces.GPU(functools.partial(process_image, pipe, vae_2)),
105
- inputs=gr.Image(type="pil"),
106
- outputs=gr.Image(type="pil"),
107
- title="Dereflection Any Image",
108
- description="Upload an image to remove glass reflections.",
109
- examples=example_images,
110
- )
111
 
112
- interface.launch()
113
 
 
 
 
1
+ # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+ from __future__ import annotations
20
+
21
  import functools
22
  import os
23
  import tempfile
24
+
25
+ import gradio as gr
26
+ import imageio as imageio
27
  import numpy as np
28
+ import spaces
29
  import torch as torch
30
  torch.backends.cuda.matmul.allow_tf32 = True
31
+ from PIL import Image
32
+ from gradio_imageslider import ImageSlider
33
+ from tqdm import tqdm
34
+
35
+ from pathlib import Path
36
+ import gradio
37
+ from gradio.utils import get_cache_folder
38
+
39
+ from DAI.pipeline_all import DAIPipeline
40
+
41
+ from DAI.controlnetvae import ControlNetVAEModel
42
+
43
+ from DAI.decoder import CustomAutoencoderKL
44
 
45
  from diffusers import (
46
  AutoencoderKL,
 
49
 
50
  from transformers import CLIPTextModel, AutoTokenizer
51
 
 
52
 
53
+ class Examples(gradio.helpers.Examples):
54
+ def __init__(self, *args, directory_name=None, **kwargs):
55
+ super().__init__(*args, **kwargs, _initiated_directly=False)
56
+ if directory_name is not None:
57
+ self.cached_folder = get_cache_folder() / directory_name
58
+ self.cached_file = Path(self.cached_folder) / "log.csv"
59
+ self.create()
60
 
 
61
 
62
+ def process_image_check(path_input):
63
+ if path_input is None:
64
+ raise gr.Error(
65
+ "Missing image in the first pane: upload a file or use one from the gallery below."
66
+ )
67
 
68
+ def process_image(
69
+ pipe,
70
+ vae_2,
71
+ path_input,
72
+ ):
73
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
74
  print(f"Processing image {name_base}{name_ext}")
75
 
76
  path_output_dir = tempfile.mkdtemp()
77
  path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
78
+ input_image = Image.open(path_input)
79
+ # pipe_out = pipe(
80
+ # input_image,
81
+ # match_input_resolution=False,
82
+ # processing_resolution=default_image_processing_resolution
83
+ # )
84
+
85
+ # resolution = 0
86
+ # if max(input_image.size) < 768:
87
+ # resolution = None
88
  resolution = None
89
 
90
  pipe_out = pipe(
91
+ image=input_image,
92
  prompt="remove glass reflection",
93
  vae_2=vae_2,
94
  processing_resolution=resolution,
 
98
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
99
  processed_frame = Image.fromarray(processed_frame)
100
  processed_frame.save(path_out_png)
101
+ yield [input_image, path_out_png]
102
 
 
103
 
104
+ def run_demo_server(pipe, vae_2):
105
+ process_pipe_image = spaces.GPU(functools.partial(process_image, pipe, vae_2))
106
+
107
+ gradio_theme = gr.themes.Default()
108
+
109
+ with gr.Blocks(
110
+ theme=gradio_theme,
111
+ title="DAI",
112
+ css="""
113
+ #download {
114
+ height: 118px;
115
+ }
116
+ .slider .inner {
117
+ width: 5px;
118
+ background: #FFF;
119
+ }
120
+ .viewport {
121
+ aspect-ratio: 4/3;
122
+ }
123
+ .tabs button.selected {
124
+ font-size: 20px !important;
125
+ color: crimson !important;
126
+ }
127
+ h1 {
128
+ text-align: center;
129
+ display: block;
130
+ }
131
+ h2 {
132
+ text-align: center;
133
+ display: block;
134
+ }
135
+ h3 {
136
+ text-align: center;
137
+ display: block;
138
+ }
139
+ .md_feedback li {
140
+ margin-bottom: 0px !important;
141
+ }
142
+ """,
143
+ head="""
144
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
145
+ <script>
146
+ window.dataLayer = window.dataLayer || [];
147
+ function gtag() {dataLayer.push(arguments);}
148
+ gtag('js', new Date());
149
+ gtag('config', 'G-1FWSVCGZTG');
150
+ </script>
151
+ """,
152
+ ) as demo:
153
+ gr.Markdown(
154
+ """
155
+ # Dereflection Any Image
156
+ <p align="center">
157
+ """
158
+ )
159
+
160
+ with gr.Tabs(elem_classes=["tabs"]):
161
+ with gr.Tab("Image"):
162
+ with gr.Row():
163
+ with gr.Column():
164
+ image_input = gr.Image(
165
+ label="Input Image",
166
+ type="filepath",
167
+ )
168
+ with gr.Row():
169
+ image_submit_btn = gr.Button(
170
+ value="Dereflection", variant="primary"
171
+ )
172
+ image_reset_btn = gr.Button(value="Reset")
173
+ with gr.Column():
174
+ image_output_slider = ImageSlider(
175
+ label="outputs",
176
+ type="filepath",
177
+ show_download_button=True,
178
+ show_share_button=True,
179
+ interactive=False,
180
+ elem_classes="slider",
181
+ # position=0.25,
182
+ )
183
+
184
+ Examples(
185
+ fn=process_pipe_image,
186
+ examples=sorted([
187
+ os.path.join("files", "image", name)
188
+ for name in os.listdir(os.path.join("files", "image"))
189
+ ]),
190
+ inputs=[image_input],
191
+ outputs=[image_output_slider],
192
+ cache_examples=False,
193
+ directory_name="examples_image",
194
+ )
195
+
196
+ ### Image tab
197
+ image_submit_btn.click(
198
+ fn=process_image_check,
199
+ inputs=image_input,
200
+ outputs=None,
201
+ preprocess=False,
202
+ queue=False,
203
+ ).success(
204
+ fn=process_pipe_image,
205
+ inputs=[
206
+ image_input,
207
+ ],
208
+ outputs=[image_output_slider],
209
+ concurrency_limit=1,
210
+ )
211
+
212
+ image_reset_btn.click(
213
+ fn=lambda: (
214
+ None,
215
+ None,
216
+ None,
217
+ ),
218
+ inputs=[],
219
+ outputs=[
220
+ image_input,
221
+ image_output_slider,
222
+ ],
223
+ queue=False,
224
+ )
225
+
226
+ ### Server launch
227
+
228
+ demo.queue(
229
+ api_open=False,
230
+ ).launch(
231
+ server_name="0.0.0.0",
232
+ server_port=7860,
233
+ )
234
+
235
+
236
+ def main():
237
+ os.system("pip freeze")
238
+
239
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
240
+
241
  weight_dtype = torch.float32
242
+ pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
243
  pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
244
  revision = None
245
  variant = None
 
280
  except:
281
  pass # run without xformers
282
 
283
+ run_demo_server(pipe, vae_2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
 
285
 
286
+ if __name__ == "__main__":
287
+ main()