sjtu-deepvision commited on
Commit
a390bd6
·
verified ·
1 Parent(s): e64df08

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -253
app.py CHANGED
@@ -1,40 +1,12 @@
1
- # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # If you find this code useful, we kindly ask you to cite our paper in your work.
16
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldmonodepth.github.io
18
- # --------------------------------------------------------------------------
19
- from __future__ import annotations
20
-
21
- import functools
22
  import os
23
  import tempfile
24
-
25
- import gradio as gr
26
  import numpy as np
27
- import spaces
28
  import torch as torch
29
  torch.backends.cuda.matmul.allow_tf32 = True
30
- from PIL import Image
31
- from gradio_imageslider import ImageSlider
32
- from tqdm import tqdm
33
 
34
- from pathlib import Path
35
- import gradio
36
- from gradio.utils import get_cache_folder
37
- from DAI.pipeline_all import DAIPipeline
38
 
39
  from diffusers import (
40
  AutoencoderKL,
@@ -47,67 +19,20 @@ from DAI.controlnetvae import ControlNetVAEModel
47
 
48
  from DAI.decoder import CustomAutoencoderKL
49
 
 
 
 
 
50
 
51
- class Examples(gradio.helpers.Examples):
52
- def __init__(self, *args, directory_name=None, **kwargs):
53
- super().__init__(*args, **kwargs, _initiated_directly=False)
54
- if directory_name is not None:
55
- self.cached_folder = get_cache_folder() / directory_name
56
- self.cached_file = Path(self.cached_folder) / "log.csv"
57
- self.create()
58
-
59
-
60
- default_seed = 2024
61
- default_batch_size = 1
62
-
63
- def process_image_check(path_input):
64
- if path_input is None:
65
- raise gr.Error(
66
- "Missing image in the first pane: upload a file or use one from the gallery below."
67
- )
68
-
69
- def resize_image(input_image, resolution):
70
- # Ensure input_image is a PIL Image object
71
- if not isinstance(input_image, Image.Image):
72
- raise ValueError("input_image should be a PIL Image object")
73
-
74
- # Convert image to numpy array
75
- input_image_np = np.asarray(input_image)
76
-
77
- # Get image dimensions
78
- H, W, C = input_image_np.shape
79
- H = float(H)
80
- W = float(W)
81
-
82
- # Calculate the scaling factor
83
- k = float(resolution) / min(H, W)
84
-
85
- # Determine new dimensions
86
- H *= k
87
- W *= k
88
- H = int(np.round(H / 64.0)) * 64
89
- W = int(np.round(W / 64.0)) * 64
90
-
91
- # Resize the image using PIL's resize method
92
- img = input_image.resize((W, H), Image.Resampling.LANCZOS)
93
-
94
- return img
95
-
96
- def process_image(
97
- pipe,
98
- vae_2,
99
- path_input,
100
- ):
101
- name_base, name_ext = os.path.splitext(os.path.basename(path_input))
102
  print(f"Processing image {name_base}{name_ext}")
103
 
104
  path_output_dir = tempfile.mkdtemp()
105
  path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
106
- input_image = Image.open(path_input)
107
  resolution = None
108
 
109
  pipe_out = pipe(
110
- image=input_image,
111
  prompt="remove glass reflection",
112
  vae_2=vae_2,
113
  processing_resolution=resolution,
@@ -117,192 +42,64 @@ def process_image(
117
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
118
  processed_frame = Image.fromarray(processed_frame)
119
  processed_frame.save(path_out_png)
120
- yield [input_image, path_out_png]
121
-
122
- def run_demo_server(pipe, vae_2):
123
- process_pipe_image = spaces.GPU(functools.partial(process_image, pipe, vae_2))
124
-
125
- gradio_theme = gr.themes.Default()
126
-
127
- with gr.Blocks(
128
- theme=gradio_theme,
129
- title="Dereflection Any Image",
130
- css="""
131
- #download {
132
- height: 118px;
133
- }
134
- .slider .inner {
135
- width: 5px;
136
- background: #FFF;
137
- }
138
- .viewport {
139
- aspect-ratio: 4/3;
140
- }
141
- .tabs button.selected {
142
- font-size: 20px !important;
143
- color: crimson !important;
144
- }
145
- h1 {
146
- text-align: center;
147
- display: block;
148
- }
149
- h2 {
150
- text-align: center;
151
- display: block;
152
- }
153
- h3 {
154
- text-align: center;
155
- display: block;
156
- }
157
- .md_feedback li {
158
- margin-bottom: 0px !important;
159
- }
160
- """,
161
- head="""
162
- <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
163
- <script>
164
- window.dataLayer = window.dataLayer || [];
165
- function gtag() {dataLayer.push(arguments);}
166
- gtag('js', new Date());
167
- gtag('config', 'G-1FWSVCGZTG');
168
- </script>
169
- """,
170
- ) as demo:
171
- gr.Markdown(
172
- """
173
- # Dereflection Any Image
174
- <p align="center">
175
- """
176
- )
177
 
178
- with gr.Tabs(elem_classes=["tabs"]):
179
- with gr.Tab("Image"):
180
- with gr.Row():
181
- with gr.Column():
182
- image_input = gr.Image(
183
- label="Input Image",
184
- type="filepath",
185
- )
186
- with gr.Row():
187
- image_submit_btn = gr.Button(
188
- value="remove reflection", variant="primary"
189
- )
190
- image_reset_btn = gr.Button(value="Reset")
191
- with gr.Column():
192
- image_output_slider = ImageSlider(
193
- label="outputs",
194
- type="filepath",
195
- show_download_button=True,
196
- show_share_button=True,
197
- interactive=False,
198
- elem_classes="slider",
199
- # position=0.25,
200
- )
201
-
202
- Examples(
203
- fn=process_pipe_image,
204
- examples=sorted([
205
- os.path.join("files", "image", name)
206
- for name in os.listdir(os.path.join("files", "image"))
207
- ]),
208
- inputs=[image_input],
209
- outputs=[image_output_slider],
210
- cache_examples=False,
211
- directory_name="examples_image",
212
- )
213
-
214
-
215
- ### Image tab
216
- image_submit_btn.click(
217
- fn=process_image_check,
218
- inputs=image_input,
219
- outputs=None,
220
- preprocess=False,
221
- queue=False,
222
- ).success(
223
- fn=process_pipe_image,
224
- inputs=[
225
- image_input,
226
- ],
227
- outputs=[image_output_slider],
228
- concurrency_limit=1,
229
- )
230
-
231
- image_reset_btn.click(
232
- fn=lambda: (
233
- None,
234
- None,
235
- None,
236
- ),
237
- inputs=[],
238
- outputs=[
239
- image_input,
240
- image_output_slider,
241
- ],
242
- queue=False,
243
- )
244
-
245
-
246
- ### Server launch
247
-
248
- demo.queue(
249
- api_open=False,
250
- ).launch(
251
- server_name="0.0.0.0",
252
- server_port=7860,
253
- )
254
-
255
-
256
- def main():
257
- os.system("pip freeze")
258
 
 
259
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
260
-
261
  weight_dtype = torch.float32
262
- model_dir = "./weights"
263
  pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
264
  pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
265
  revision = None
266
  variant = None
 
267
  # Load the model
268
  controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
269
  unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
270
  vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
271
 
272
- # Load other components of the pipeline
273
  vae = AutoencoderKL.from_pretrained(
274
- pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
275
- ).to(device)
276
 
277
  text_encoder = CLIPTextModel.from_pretrained(
278
- pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
279
- ).to(device)
280
  tokenizer = AutoTokenizer.from_pretrained(
281
- pretrained_model_name_or_path2,
282
- subfolder="tokenizer",
283
- revision=revision,
284
- use_fast=False,
285
- )
286
  pipe = DAIPipeline(
287
- vae=vae,
288
- text_encoder=text_encoder,
289
- tokenizer=tokenizer,
290
- unet=unet,
291
- controlnet=controlnet,
292
- safety_checker=None,
293
- scheduler=None,
294
- feature_extractor=None,
295
- t_start=0,
296
- ).to(device)
297
-
298
- try:
299
- import xformers
300
- pipe.enable_xformers_memory_efficient_attention()
301
- except:
302
- pass # run without xformers
303
-
304
- run_demo_server(pipe, vae_2)
 
 
 
 
 
 
 
 
 
 
305
 
 
306
 
307
- if __name__ == "__main__":
308
- main()
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from DAI.pipeline_all import DAIPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import os
5
  import tempfile
 
 
6
  import numpy as np
 
7
  import torch as torch
8
  torch.backends.cuda.matmul.allow_tf32 = True
 
 
 
9
 
 
 
 
 
10
 
11
  from diffusers import (
12
  AutoencoderKL,
 
19
 
20
  from DAI.decoder import CustomAutoencoderKL
21
 
22
+ def process_image(pipe, vae_2, image):
23
+ # Save the input image to a temporary file
24
+ temp_input_path = tempfile.mktemp(suffix=".png")
25
+ image.save(temp_input_path)
26
 
27
+ name_base, name_ext = os.path.splitext(os.path.basename(temp_input_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  print(f"Processing image {name_base}{name_ext}")
29
 
30
  path_output_dir = tempfile.mkdtemp()
31
  path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
 
32
  resolution = None
33
 
34
  pipe_out = pipe(
35
+ image=image,
36
  prompt="remove glass reflection",
37
  vae_2=vae_2,
38
  processing_resolution=resolution,
 
42
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
43
  processed_frame = Image.fromarray(processed_frame)
44
  processed_frame.save(path_out_png)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ return processed_frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ if __name__ == "__main__":
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
50
  weight_dtype = torch.float32
 
51
  pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
52
  pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
53
  revision = None
54
  variant = None
55
+
56
  # Load the model
57
  controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
58
  unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
59
  vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
60
 
 
61
  vae = AutoencoderKL.from_pretrained(
62
+ pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
63
+ ).to(device)
64
 
65
  text_encoder = CLIPTextModel.from_pretrained(
66
+ pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
67
+ ).to(device)
68
  tokenizer = AutoTokenizer.from_pretrained(
69
+ pretrained_model_name_or_path2,
70
+ subfolder="tokenizer",
71
+ revision=revision,
72
+ use_fast=False,
73
+ )
74
  pipe = DAIPipeline(
75
+ vae=vae,
76
+ text_encoder=text_encoder,
77
+ tokenizer=tokenizer,
78
+ unet=unet,
79
+ controlnet=controlnet,
80
+ safety_checker=None,
81
+ scheduler=None,
82
+ feature_extractor=None,
83
+ t_start=0,
84
+ ).to(device)
85
+
86
+ # Cache example images in memory
87
+ example_images_dir = "files/image"
88
+ example_images = []
89
+ for i in range(1, 9):
90
+ image_path = os.path.join(example_images_dir, f"{i}.png")
91
+ if os.path.exists(image_path):
92
+ example_images.append([Image.open(image_path)])
93
+
94
+ # Create a Gradio interface
95
+ interface = gr.Interface(
96
+ fn=lambda image: process_image(pipe, vae_2, image),
97
+ inputs=gr.Image(type="pil"),
98
+ outputs=gr.Image(type="pil"),
99
+ title="Dereflection Any Image",
100
+ description="Upload an image to remove glass reflections.",
101
+ examples=example_images,
102
+ )
103
 
104
+ interface.launch()
105