sjtu-deepvision commited on
Commit
1cedc13
·
verified ·
1 Parent(s): 6ddc0ca

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -261
app.py CHANGED
@@ -1,287 +1,99 @@
1
- # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # If you find this code useful, we kindly ask you to cite our paper in your work.
16
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldmonodepth.github.io
18
- # --------------------------------------------------------------------------
19
- from __future__ import annotations
20
-
21
- import functools
22
  import os
23
- import tempfile
24
-
25
- import gradio as gr
26
- import imageio as imageio
27
  import numpy as np
28
- import spaces
29
- import torch as torch
30
- torch.backends.cuda.matmul.allow_tf32 = True
31
  from PIL import Image
32
- from gradio_imageslider import ImageSlider
33
- from tqdm import tqdm
34
-
35
- from pathlib import Path
36
- import gradio
37
- from gradio.utils import get_cache_folder
38
-
39
  from DAI.pipeline_all import DAIPipeline
40
-
41
  from DAI.controlnetvae import ControlNetVAEModel
42
-
43
  from DAI.decoder import CustomAutoencoderKL
44
-
45
- from diffusers import (
46
- AutoencoderKL,
47
- UNet2DConditionModel,
48
- )
49
-
50
  from transformers import CLIPTextModel, AutoTokenizer
51
 
52
-
53
- class Examples(gradio.helpers.Examples):
54
- def __init__(self, *args, directory_name=None, **kwargs):
55
- super().__init__(*args, **kwargs, _initiated_directly=False)
56
- if directory_name is not None:
57
- self.cached_folder = get_cache_folder() / directory_name
58
- self.cached_file = Path(self.cached_folder) / "log.csv"
59
- self.create()
60
-
61
-
62
- def process_image_check(path_input):
63
- if path_input is None:
64
- raise gr.Error(
65
- "Missing image in the first pane: upload a file or use one from the gallery below."
66
- )
67
-
68
- def process_image(
69
- pipe,
70
- vae_2,
71
- path_input,
72
- ):
73
- name_base, name_ext = os.path.splitext(os.path.basename(path_input))
74
- print(f"Processing image {name_base}{name_ext}")
75
-
76
- path_output_dir = tempfile.mkdtemp()
77
- path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
78
- input_image = Image.open(path_input)
79
- # pipe_out = pipe(
80
- # input_image,
81
- # match_input_resolution=False,
82
- # processing_resolution=default_image_processing_resolution
83
- # )
84
-
85
- # resolution = 0
86
- # if max(input_image.size) < 768:
87
- # resolution = None
88
- resolution = None
89
-
90
  pipe_out = pipe(
91
  image=input_image,
92
  prompt="remove glass reflection",
93
  vae_2=vae_2,
94
- processing_resolution=resolution,
95
  )
96
 
 
97
  processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
98
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
99
  processed_frame = Image.fromarray(processed_frame)
100
- processed_frame.save(path_out_png)
101
- yield [input_image, path_out_png]
102
-
103
 
104
- def run_demo_server(pipe, vae_2):
105
- process_pipe_image = spaces.GPU(functools.partial(process_image, pipe, vae_2))
106
-
107
- gradio_theme = gr.themes.Default()
108
-
109
- with gr.Blocks(
110
- theme=gradio_theme,
111
- title="DAI",
112
- css="""
113
- #download {
114
- height: 118px;
115
- }
116
- .slider .inner {
117
- width: 5px;
118
- background: #FFF;
119
- }
120
- .viewport {
121
- aspect-ratio: 4/3;
122
- }
123
- .tabs button.selected {
124
- font-size: 20px !important;
125
- color: crimson !important;
126
- }
127
- h1 {
128
- text-align: center;
129
- display: block;
130
- }
131
- h2 {
132
- text-align: center;
133
- display: block;
134
- }
135
- h3 {
136
- text-align: center;
137
- display: block;
138
- }
139
- .md_feedback li {
140
- margin-bottom: 0px !important;
141
- }
142
- """,
143
- head="""
144
- <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
145
- <script>
146
- window.dataLayer = window.dataLayer || [];
147
- function gtag() {dataLayer.push(arguments);}
148
- gtag('js', new Date());
149
- gtag('config', 'G-1FWSVCGZTG');
150
- </script>
151
- """,
152
- ) as demo:
153
- gr.Markdown(
154
- """
155
- # Dereflection Any Image
156
- <p align="center">
157
- """
158
  )
159
 
160
- with gr.Tabs(elem_classes=["tabs"]):
161
- with gr.Tab("Image"):
162
- with gr.Row():
163
- with gr.Column():
164
- image_input = gr.Image(
165
- label="Input Image",
166
- type="filepath",
167
- )
168
- with gr.Row():
169
- image_submit_btn = gr.Button(
170
- value="Dereflection", variant="primary"
171
- )
172
- image_reset_btn = gr.Button(value="Reset")
173
- with gr.Column():
174
- image_output_slider = ImageSlider(
175
- label="outputs",
176
- type="filepath",
177
- show_download_button=True,
178
- show_share_button=True,
179
- interactive=False,
180
- elem_classes="slider",
181
- # position=0.25,
182
- )
183
-
184
- Examples(
185
- fn=process_pipe_image,
186
- examples=sorted([
187
- os.path.join("files", "image", name)
188
- for name in os.listdir(os.path.join("files", "image"))
189
- ]),
190
- inputs=[image_input],
191
- outputs=[image_output_slider],
192
- cache_examples=False,
193
- directory_name="examples_image",
194
- )
195
-
196
- ### Image tab
197
- image_submit_btn.click(
198
- fn=process_image_check,
199
- inputs=image_input,
200
- outputs=None,
201
- preprocess=False,
202
- queue=False,
203
- ).success(
204
- fn=process_pipe_image,
205
- inputs=[
206
- image_input,
207
- ],
208
- outputs=[image_output_slider],
209
- concurrency_limit=1,
210
- )
211
-
212
- image_reset_btn.click(
213
- fn=lambda: (
214
- None,
215
- None,
216
- None,
217
- ),
218
- inputs=[],
219
- outputs=[
220
- image_input,
221
- image_output_slider,
222
- ],
223
- queue=False,
224
- )
225
-
226
- ### Server launch
227
-
228
- demo.queue(
229
- api_open=False,
230
- ).launch(
231
- server_name="0.0.0.0",
232
- server_port=7860,
233
  )
234
 
 
235
 
 
236
  def main():
237
- os.system("pip freeze")
238
-
239
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
240
-
241
- weight_dtype = torch.float32
242
- pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
243
- pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
244
- revision = None
245
- variant = None
246
-
247
- # Load the model
248
- controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
249
- unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
250
- vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
251
-
252
- vae = AutoencoderKL.from_pretrained(
253
- pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
254
- ).to(device)
255
-
256
- text_encoder = CLIPTextModel.from_pretrained(
257
- pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
258
- ).to(device)
259
- tokenizer = AutoTokenizer.from_pretrained(
260
- pretrained_model_name_or_path2,
261
- subfolder="tokenizer",
262
- revision=revision,
263
- use_fast=False,
264
- )
265
- pipe = DAIPipeline(
266
- vae=vae,
267
- text_encoder=text_encoder,
268
- tokenizer=tokenizer,
269
- unet=unet,
270
- controlnet=controlnet,
271
- safety_checker=None,
272
- scheduler=None,
273
- feature_extractor=None,
274
- t_start=0,
275
- ).to(device)
276
-
277
- try:
278
- import xformers
279
- pipe.enable_xformers_memory_efficient_attention()
280
- except:
281
- pass # run without xformers
282
-
283
- run_demo_server(pipe, vae_2)
284
-
285
 
286
  if __name__ == "__main__":
287
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
2
  import numpy as np
3
+ import torch
 
 
4
  from PIL import Image
5
+ import gradio as gr
 
 
 
 
 
 
6
  from DAI.pipeline_all import DAIPipeline
 
7
  from DAI.controlnetvae import ControlNetVAEModel
 
8
  from DAI.decoder import CustomAutoencoderKL
9
+ from diffusers import AutoencoderKL, UNet2DConditionModel
 
 
 
 
 
10
  from transformers import CLIPTextModel, AutoTokenizer
11
 
12
+ # Initialize device and model paths
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ weight_dtype = torch.float32
15
+ pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
16
+ pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
17
+
18
+ # Load the model components
19
+ controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
20
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
21
+ vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
22
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path2, subfolder="vae").to(device)
23
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path2, subfolder="text_encoder").to(device)
24
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path2, subfolder="tokenizer", use_fast=False)
25
+
26
+ # Create the pipeline
27
+ pipe = DAIPipeline(
28
+ vae=vae,
29
+ text_encoder=text_encoder,
30
+ tokenizer=tokenizer,
31
+ unet=unet,
32
+ controlnet=controlnet,
33
+ safety_checker=None,
34
+ scheduler=None,
35
+ feature_extractor=None,
36
+ t_start=0,
37
+ ).to(device)
38
+
39
+ # Function to process the image
40
+ def process_image(input_image):
41
+ # Convert Gradio input to PIL Image
42
+ input_image = Image.fromarray(input_image)
43
+
44
+ # Process the image
 
 
 
 
 
45
  pipe_out = pipe(
46
  image=input_image,
47
  prompt="remove glass reflection",
48
  vae_2=vae_2,
49
+ processing_resolution=None,
50
  )
51
 
52
+ # Convert the output to an image
53
  processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
54
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
55
  processed_frame = Image.fromarray(processed_frame)
 
 
 
56
 
57
+ return processed_frame
58
+
59
+ # Gradio interface
60
+ def create_gradio_interface():
61
+ # Example images
62
+ example_images = [
63
+ os.path.join("files", "image", f"{i}.png") for i in range(1, 9)
64
+ ]
65
+
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown("# Dereflection Any Image")
68
+ with gr.Row():
69
+ with gr.Column():
70
+ input_image = gr.Image(label="Input Image", type="numpy")
71
+ submit_btn = gr.Button("Remove Reflection", variant="primary")
72
+ with gr.Column():
73
+ output_image = gr.Image(label="Processed Image")
74
+
75
+ # Add examples
76
+ gr.Examples(
77
+ examples=example_images,
78
+ inputs=input_image,
79
+ outputs=output_image,
80
+ fn=process_image,
81
+ cache_examples=False, # Cache results for faster loading
82
+ label="Example Images",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
 
85
+ submit_btn.click(
86
+ fn=process_image,
87
+ inputs=input_image,
88
+ outputs=output_image,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
90
 
91
+ return demo
92
 
93
+ # Main function to launch the Gradio app
94
  def main():
95
+ demo = create_gradio_interface()
96
+ demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  if __name__ == "__main__":
99
+ main()