Spaces:
Running
on
Zero
Running
on
Zero
guangkaixu
commited on
Commit
•
562fd4c
1
Parent(s):
cdba047
upload
Browse files- README.md +15 -4
- app.py +289 -0
- images/depth/.DS_Store +0 -0
- images/depth/anime_1.jpg +0 -0
- images/depth/anime_2.jpg +0 -0
- images/depth/anime_3.jpg +0 -0
- images/depth/anime_4.jpg +0 -0
- images/depth/anime_5.jpg +0 -0
- images/depth/anime_6.jpg +0 -0
- images/depth/anime_7.jpg +0 -0
- images/depth/line_1.jpg +0 -0
- images/depth/line_2.jpg +0 -0
- images/depth/line_3.jpg +0 -0
- images/depth/line_4.jpg +0 -0
- images/depth/line_5.jpg +0 -0
- images/depth/line_6.jpg +0 -0
- images/depth/real_1.jpg +0 -0
- images/depth/real_10.jpg +0 -0
- images/depth/real_11.jpg +0 -0
- images/depth/real_12.jpg +0 -0
- images/depth/real_13.jpg +0 -0
- images/depth/real_14.jpg +0 -0
- images/depth/real_15.jpg +0 -0
- images/depth/real_16.jpg +0 -0
- images/depth/real_17.jpg +0 -0
- images/depth/real_18.jpg +0 -0
- images/depth/real_19.jpg +0 -0
- images/depth/real_2.jpg +0 -0
- images/depth/real_20.jpg +0 -0
- images/depth/real_21.jpg +0 -0
- images/depth/real_22.jpg +0 -0
- images/depth/real_23.jpg +0 -0
- images/depth/real_24.jpg +0 -0
- images/depth/real_3.jpg +0 -0
- images/depth/real_4.jpg +0 -0
- images/depth/real_5.jpg +0 -0
- images/depth/real_6.jpg +0 -0
- images/depth/real_7.jpg +0 -0
- images/depth/real_8.jpg +0 -0
- images/depth/real_9.jpg +0 -0
- pipeline_genpercept.py +355 -0
README.md
CHANGED
@@ -1,13 +1,24 @@
|
|
1 |
---
|
2 |
-
title: GenPercept
|
3 |
emoji: ⚡
|
4 |
colorFrom: indigo
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.25.0
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
|
|
|
|
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: GenPercept: Diffusion Models Trained with Large Data Are Transferable Visual Models
|
3 |
emoji: ⚡
|
4 |
colorFrom: indigo
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.25.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
models:
|
11 |
+
- guangkaixu/GenPercept
|
12 |
+
license: cc0-1.0
|
13 |
---
|
14 |
|
15 |
+
If you find it useful, please cite our paper:
|
16 |
+
|
17 |
+
```
|
18 |
+
@article{xu2024diffusion,
|
19 |
+
title={Diffusion Models Trained with Large Data Are Transferable Visual Models},
|
20 |
+
author={Xu, Guangkai and Ge, Yongtao and Liu, Mingyu and Fan, Chengxiang and Xie, Kangyang and Zhao, Zhiyue and Chen, Hao and Shen, Chunhua},
|
21 |
+
journal={arXiv preprint arXiv:2403.06090},
|
22 |
+
year={2024}
|
23 |
+
}
|
24 |
+
```
|
app.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Guangkai Xu, Zhejiang University. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the CC0-1.0 license;
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://github.com/aim-uofa/GenPercept/blob/main/LICENSE
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# This code is based on Marigold and diffusers codebases
|
16 |
+
# https://github.com/prs-eth/marigold
|
17 |
+
# https://github.com/huggingface/diffusers
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
20 |
+
# Please find bibtex at: https://github.com/aim-uofa/GenPercept#%EF%B8%8F-citation
|
21 |
+
# More information about the method can be found at https://github.com/aim-uofa/GenPercept
|
22 |
+
# --------------------------------------------------------------------------
|
23 |
+
|
24 |
+
from __future__ import annotations
|
25 |
+
|
26 |
+
import functools
|
27 |
+
import os
|
28 |
+
import tempfile
|
29 |
+
import warnings
|
30 |
+
|
31 |
+
import gradio as gr
|
32 |
+
import numpy as np
|
33 |
+
import spaces
|
34 |
+
import torch as torch
|
35 |
+
from PIL import Image
|
36 |
+
from gradio_imageslider import ImageSlider
|
37 |
+
|
38 |
+
from gradio_patches.examples import Examples
|
39 |
+
from pipeline_genpercept import GenPerceptPipeline
|
40 |
+
|
41 |
+
warnings.filterwarnings(
|
42 |
+
"ignore", message=".*LoginButton created outside of a Blocks context.*"
|
43 |
+
)
|
44 |
+
|
45 |
+
default_image_processing_res = 768
|
46 |
+
default_image_reproducuble = True
|
47 |
+
|
48 |
+
def process_image_check(path_input):
|
49 |
+
if path_input is None:
|
50 |
+
raise gr.Error(
|
51 |
+
"Missing image in the first pane: upload a file or use one from the gallery below."
|
52 |
+
)
|
53 |
+
|
54 |
+
def process_image(
|
55 |
+
pipe,
|
56 |
+
path_input,
|
57 |
+
processing_res=default_image_processing_res,
|
58 |
+
):
|
59 |
+
name_base, name_ext = os.path.splitext(os.path.basename(path_input))
|
60 |
+
print(f"Processing image {name_base}{name_ext}")
|
61 |
+
|
62 |
+
path_output_dir = tempfile.mkdtemp()
|
63 |
+
path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy")
|
64 |
+
path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png")
|
65 |
+
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png")
|
66 |
+
|
67 |
+
input_image = Image.open(path_input)
|
68 |
+
|
69 |
+
pipe_out = pipe(
|
70 |
+
input_image,
|
71 |
+
processing_res=processing_res,
|
72 |
+
batch_size=1 if processing_res == 0 else 0,
|
73 |
+
show_progress_bar=False,
|
74 |
+
)
|
75 |
+
|
76 |
+
depth_pred = pipe_out.depth_np
|
77 |
+
depth_colored = pipe_out.depth_colored
|
78 |
+
depth_16bit = (depth_pred * 65535.0).astype(np.uint16)
|
79 |
+
|
80 |
+
np.save(path_out_fp32, depth_pred)
|
81 |
+
Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16")
|
82 |
+
depth_colored.save(path_out_vis)
|
83 |
+
|
84 |
+
return (
|
85 |
+
[path_out_16bit, path_out_vis],
|
86 |
+
[path_out_16bit, path_out_fp32, path_out_vis],
|
87 |
+
)
|
88 |
+
|
89 |
+
def run_demo_server(pipe):
|
90 |
+
process_pipe_image = spaces.GPU(functools.partial(process_image, pipe))
|
91 |
+
process_pipe_video = spaces.GPU(
|
92 |
+
functools.partial(process_video, pipe), duration=120
|
93 |
+
)
|
94 |
+
process_pipe_bas = spaces.GPU(functools.partial(process_bas, pipe))
|
95 |
+
|
96 |
+
gradio_theme = gr.themes.Default()
|
97 |
+
|
98 |
+
with gr.Blocks(
|
99 |
+
theme=gradio_theme,
|
100 |
+
title="GenPercept",
|
101 |
+
css="""
|
102 |
+
#download {
|
103 |
+
height: 118px;
|
104 |
+
}
|
105 |
+
.slider .inner {
|
106 |
+
width: 5px;
|
107 |
+
background: #FFF;
|
108 |
+
}
|
109 |
+
.viewport {
|
110 |
+
aspect-ratio: 4/3;
|
111 |
+
}
|
112 |
+
.tabs button.selected {
|
113 |
+
font-size: 20px !important;
|
114 |
+
color: crimson !important;
|
115 |
+
}
|
116 |
+
h1 {
|
117 |
+
text-align: center;
|
118 |
+
display: block;
|
119 |
+
}
|
120 |
+
h2 {
|
121 |
+
text-align: center;
|
122 |
+
display: block;
|
123 |
+
}
|
124 |
+
h3 {
|
125 |
+
text-align: center;
|
126 |
+
display: block;
|
127 |
+
}
|
128 |
+
.md_feedback li {
|
129 |
+
margin-bottom: 0px !important;
|
130 |
+
}
|
131 |
+
""",
|
132 |
+
head="""
|
133 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
|
134 |
+
<script>
|
135 |
+
window.dataLayer = window.dataLayer || [];
|
136 |
+
function gtag() {dataLayer.push(arguments);}
|
137 |
+
gtag('js', new Date());
|
138 |
+
gtag('config', 'G-1FWSVCGZTG');
|
139 |
+
</script>
|
140 |
+
""",
|
141 |
+
) as demo:
|
142 |
+
|
143 |
+
gr.Markdown(
|
144 |
+
"""
|
145 |
+
# GenPercept: Diffusion Models Trained with Large Data Are Transferable Visual Models
|
146 |
+
<p align="center">
|
147 |
+
<a title="arXiv" href="https://arxiv.org/abs/2403.06090" target="_blank" rel="noopener noreferrer"
|
148 |
+
style="display: inline-block;">
|
149 |
+
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
|
150 |
+
</a>
|
151 |
+
<a title="Github" href="https://github.com/aim-uofa/GenPercept" target="_blank" rel="noopener noreferrer"
|
152 |
+
style="display: inline-block;">
|
153 |
+
<img src="https://img.shields.io/github/stars/aim-uofa/GenPercept?label=GitHub%20%E2%98%85&logo=github&color=C8C"
|
154 |
+
alt="badge-github-stars">
|
155 |
+
</a>
|
156 |
+
</p>
|
157 |
+
<p align="justify">
|
158 |
+
GenPercept leverages the prior knowledge of stable diffusion models to estimate detailed visual perception results.
|
159 |
+
It achieve remarkable transferable performance on fundamental vision perception tasks using a moderate amount of target data
|
160 |
+
(even synthetic data only). Compared to previous methods, our inference process only requires one step and therefore runs faster.
|
161 |
+
</p>
|
162 |
+
"""
|
163 |
+
)
|
164 |
+
|
165 |
+
with gr.Tabs(elem_classes=["tabs"]):
|
166 |
+
with gr.Tab("Depth Estimation"):
|
167 |
+
with gr.Row():
|
168 |
+
with gr.Column():
|
169 |
+
image_input = gr.Image(
|
170 |
+
label="Input Image",
|
171 |
+
type="filepath",
|
172 |
+
)
|
173 |
+
with gr.Row():
|
174 |
+
image_submit_btn = gr.Button(
|
175 |
+
value="Estimate Depth", variant="primary"
|
176 |
+
)
|
177 |
+
image_reset_btn = gr.Button(value="Reset")
|
178 |
+
with gr.Accordion("Advanced options", open=False):
|
179 |
+
image_processing_res = gr.Radio(
|
180 |
+
[
|
181 |
+
("Native", 0),
|
182 |
+
("Recommended", 768),
|
183 |
+
],
|
184 |
+
label="Processing resolution",
|
185 |
+
value=default_image_processing_res,
|
186 |
+
)
|
187 |
+
with gr.Column():
|
188 |
+
image_output_slider = ImageSlider(
|
189 |
+
label="Predicted depth of gray / color (red-near, blue-far)",
|
190 |
+
type="filepath",
|
191 |
+
show_download_button=True,
|
192 |
+
show_share_button=True,
|
193 |
+
interactive=False,
|
194 |
+
elem_classes="slider",
|
195 |
+
position=0.25,
|
196 |
+
)
|
197 |
+
image_output_files = gr.Files(
|
198 |
+
label="Depth outputs",
|
199 |
+
elem_id="download",
|
200 |
+
interactive=False,
|
201 |
+
)
|
202 |
+
|
203 |
+
filenames = []
|
204 |
+
filenames.extend(["anime_%d.jpg" %i+1 for i in range(7)])
|
205 |
+
filenames.extend(["line_%d.jpg" %i+1 for i in range(6)])
|
206 |
+
filenames.extend(["real_%d.jpg" %i+1 for i in range(24)])
|
207 |
+
Examples(
|
208 |
+
fn=process_pipe_image,
|
209 |
+
examples=[
|
210 |
+
os.path.join("images", "depth", name)
|
211 |
+
for name in filenames
|
212 |
+
],
|
213 |
+
inputs=[image_input],
|
214 |
+
outputs=[image_output_slider, image_output_files],
|
215 |
+
cache_examples=True,
|
216 |
+
directory_name="examples_image",
|
217 |
+
)
|
218 |
+
|
219 |
+
### Image tab
|
220 |
+
image_submit_btn.click(
|
221 |
+
fn=process_image_check,
|
222 |
+
inputs=image_input,
|
223 |
+
outputs=None,
|
224 |
+
preprocess=False,
|
225 |
+
queue=False,
|
226 |
+
).success(
|
227 |
+
fn=process_pipe_image,
|
228 |
+
inputs=[
|
229 |
+
image_input,
|
230 |
+
image_processing_res,
|
231 |
+
],
|
232 |
+
outputs=[image_output_slider, image_output_files],
|
233 |
+
concurrency_limit=1,
|
234 |
+
)
|
235 |
+
|
236 |
+
image_reset_btn.click(
|
237 |
+
fn=lambda: (
|
238 |
+
None,
|
239 |
+
None,
|
240 |
+
None,
|
241 |
+
default_image_processing_res,
|
242 |
+
),
|
243 |
+
inputs=[],
|
244 |
+
outputs=[
|
245 |
+
image_input,
|
246 |
+
image_output_slider,
|
247 |
+
image_output_files,
|
248 |
+
image_processing_res,
|
249 |
+
],
|
250 |
+
queue=False,
|
251 |
+
)
|
252 |
+
|
253 |
+
### Server launch
|
254 |
+
|
255 |
+
demo.queue(
|
256 |
+
api_open=False,
|
257 |
+
).launch(
|
258 |
+
server_name="0.0.0.0",
|
259 |
+
server_port=7860,
|
260 |
+
)
|
261 |
+
|
262 |
+
|
263 |
+
def main():
|
264 |
+
os.system("pip freeze")
|
265 |
+
|
266 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
267 |
+
|
268 |
+
vae = AutoencoderKL.from_pretrained("./", subfolder='vae')
|
269 |
+
unet = UNet2DConditionModel.from_pretrained('./', subfolder="unet")
|
270 |
+
empty_text_embed = torch.from_numpy(np.load("./empty_text_embed.npy")).to(device, dtype)[None] # [1, 77, 1024]
|
271 |
+
|
272 |
+
pipe = GenPerceptPipeline(vae=vae,
|
273 |
+
unet=unet,
|
274 |
+
empty_text_embed=empty_text_embed)
|
275 |
+
try:
|
276 |
+
import xformers
|
277 |
+
pipe.enable_xformers_memory_efficient_attention()
|
278 |
+
except:
|
279 |
+
pass # run without xformers
|
280 |
+
|
281 |
+
pipe = pipe.to(device)
|
282 |
+
run_demo_server(pipe)
|
283 |
+
|
284 |
+
|
285 |
+
if __name__ == "__main__":
|
286 |
+
main()
|
287 |
+
|
288 |
+
|
289 |
+
|
images/depth/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
images/depth/anime_1.jpg
ADDED
images/depth/anime_2.jpg
ADDED
images/depth/anime_3.jpg
ADDED
images/depth/anime_4.jpg
ADDED
images/depth/anime_5.jpg
ADDED
images/depth/anime_6.jpg
ADDED
images/depth/anime_7.jpg
ADDED
images/depth/line_1.jpg
ADDED
images/depth/line_2.jpg
ADDED
images/depth/line_3.jpg
ADDED
images/depth/line_4.jpg
ADDED
images/depth/line_5.jpg
ADDED
images/depth/line_6.jpg
ADDED
images/depth/real_1.jpg
ADDED
images/depth/real_10.jpg
ADDED
images/depth/real_11.jpg
ADDED
images/depth/real_12.jpg
ADDED
images/depth/real_13.jpg
ADDED
images/depth/real_14.jpg
ADDED
images/depth/real_15.jpg
ADDED
images/depth/real_16.jpg
ADDED
images/depth/real_17.jpg
ADDED
images/depth/real_18.jpg
ADDED
images/depth/real_19.jpg
ADDED
images/depth/real_2.jpg
ADDED
images/depth/real_20.jpg
ADDED
images/depth/real_21.jpg
ADDED
images/depth/real_22.jpg
ADDED
images/depth/real_23.jpg
ADDED
images/depth/real_24.jpg
ADDED
images/depth/real_3.jpg
ADDED
images/depth/real_4.jpg
ADDED
images/depth/real_5.jpg
ADDED
images/depth/real_6.jpg
ADDED
images/depth/real_7.jpg
ADDED
images/depth/real_8.jpg
ADDED
images/depth/real_9.jpg
ADDED
pipeline_genpercept.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Diffusion Models Trained with Large Data Are Transferable Visual Models (https://arxiv.org/abs/2403.06090)
|
3 |
+
# Github source: https://github.com/aim-uofa/GenPercept
|
4 |
+
# Copyright (c) 2024 Zhejiang University
|
5 |
+
# Licensed under The CC0 1.0 License [see LICENSE for details]
|
6 |
+
# By Guangkai Xu
|
7 |
+
# Based on Marigold, diffusers codebases
|
8 |
+
# https://github.com/prs-eth/marigold
|
9 |
+
# https://github.com/huggingface/diffusers
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
|
17 |
+
from tqdm.auto import tqdm
|
18 |
+
from PIL import Image
|
19 |
+
from typing import List, Dict, Union
|
20 |
+
from torch.utils.data import DataLoader, TensorDataset
|
21 |
+
|
22 |
+
from diffusers import (
|
23 |
+
DiffusionPipeline,
|
24 |
+
UNet2DConditionModel,
|
25 |
+
AutoencoderKL,
|
26 |
+
)
|
27 |
+
from diffusers.utils import BaseOutput
|
28 |
+
|
29 |
+
from .util.image_util import chw2hwc, colorize_depth_maps, resize_max_res, norm_to_rgb, resize_res
|
30 |
+
from .util.batchsize import find_batch_size
|
31 |
+
|
32 |
+
class GenPerceptOutput(BaseOutput):
|
33 |
+
|
34 |
+
pred_np: np.ndarray
|
35 |
+
pred_colored: Image.Image
|
36 |
+
|
37 |
+
class GenPerceptPipeline(DiffusionPipeline):
|
38 |
+
|
39 |
+
vae_scale_factor = 0.18215
|
40 |
+
task_infos = {
|
41 |
+
'depth': dict(task_channel_num=1, interpolate='bilinear', ),
|
42 |
+
'seg': dict(task_channel_num=3, interpolate='nearest', ),
|
43 |
+
'sr': dict(task_channel_num=3, interpolate='nearest', ),
|
44 |
+
'normal': dict(task_channel_num=3, interpolate='bilinear', ),
|
45 |
+
}
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
unet: UNet2DConditionModel,
|
50 |
+
vae: AutoencoderKL,
|
51 |
+
customized_head=None,
|
52 |
+
empty_text_embed=None,
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.empty_text_embed = empty_text_embed
|
57 |
+
|
58 |
+
# register
|
59 |
+
register_dict = dict(
|
60 |
+
unet=unet,
|
61 |
+
vae=vae,
|
62 |
+
customized_head=customized_head,
|
63 |
+
)
|
64 |
+
self.register_modules(**register_dict)
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def __call__(
|
68 |
+
self,
|
69 |
+
input_image: Union[Image.Image, torch.Tensor],
|
70 |
+
mode: str = 'depth',
|
71 |
+
resize_hard = False,
|
72 |
+
processing_res: int = 768,
|
73 |
+
match_input_res: bool = True,
|
74 |
+
batch_size: int = 0,
|
75 |
+
color_map: str = "Spectral",
|
76 |
+
show_progress_bar: bool = True,
|
77 |
+
) -> GenPerceptOutput:
|
78 |
+
"""
|
79 |
+
Function invoked when calling the pipeline.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
input_image (Image):
|
83 |
+
Input RGB (or gray-scale) image.
|
84 |
+
processing_res (int, optional):
|
85 |
+
Maximum resolution of processing.
|
86 |
+
If set to 0: will not resize at all.
|
87 |
+
Defaults to 768.
|
88 |
+
match_input_res (bool, optional):
|
89 |
+
Resize depth prediction to match input resolution.
|
90 |
+
Only valid if `limit_input_res` is not None.
|
91 |
+
Defaults to True.
|
92 |
+
batch_size (int, optional):
|
93 |
+
Inference batch size.
|
94 |
+
If set to 0, the script will automatically decide the proper batch size.
|
95 |
+
Defaults to 0.
|
96 |
+
show_progress_bar (bool, optional):
|
97 |
+
Display a progress bar of diffusion denoising.
|
98 |
+
Defaults to True.
|
99 |
+
color_map (str, optional):
|
100 |
+
Colormap used to colorize the depth map.
|
101 |
+
Defaults to "Spectral".
|
102 |
+
Returns:
|
103 |
+
`GenPerceptOutput`
|
104 |
+
"""
|
105 |
+
|
106 |
+
device = self.device
|
107 |
+
|
108 |
+
task_channel_num = self.task_infos[mode]['task_channel_num']
|
109 |
+
|
110 |
+
if not match_input_res:
|
111 |
+
assert (
|
112 |
+
processing_res is not None
|
113 |
+
), "Value error: `resize_output_back` is only valid with "
|
114 |
+
assert processing_res >= 0
|
115 |
+
|
116 |
+
# ----------------- Image Preprocess -----------------
|
117 |
+
|
118 |
+
if type(input_image) == torch.Tensor: # [B, 3, H, W]
|
119 |
+
rgb_norm = input_image.to(device)
|
120 |
+
input_size = input_image.shape[2:]
|
121 |
+
bs_imgs = rgb_norm.shape[0]
|
122 |
+
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
|
123 |
+
rgb_norm = rgb_norm.to(self.dtype)
|
124 |
+
else:
|
125 |
+
# if len(rgb_paths) > 0 and 'kitti' in rgb_paths[0]:
|
126 |
+
# # kb crop
|
127 |
+
# height = input_image.size[1]
|
128 |
+
# width = input_image.size[0]
|
129 |
+
# top_margin = int(height - 352)
|
130 |
+
# left_margin = int((width - 1216) / 2)
|
131 |
+
# input_image = input_image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
132 |
+
|
133 |
+
# TODO: check the kitti evaluation resolution here.
|
134 |
+
input_size = (input_image.size[1], input_image.size[0])
|
135 |
+
# Resize image
|
136 |
+
if processing_res > 0:
|
137 |
+
if resize_hard:
|
138 |
+
input_image = resize_res(
|
139 |
+
input_image, max_edge_resolution=processing_res
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
input_image = resize_max_res(
|
143 |
+
input_image, max_edge_resolution=processing_res
|
144 |
+
)
|
145 |
+
input_image = input_image.convert("RGB")
|
146 |
+
image = np.asarray(input_image)
|
147 |
+
|
148 |
+
# Normalize rgb values
|
149 |
+
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
|
150 |
+
rgb_norm = rgb / 255.0 * 2.0 - 1.0
|
151 |
+
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
|
152 |
+
rgb_norm = rgb_norm[None].to(device)
|
153 |
+
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
|
154 |
+
bs_imgs = 1
|
155 |
+
|
156 |
+
# ----------------- Predicting depth -----------------
|
157 |
+
|
158 |
+
single_rgb_dataset = TensorDataset(rgb_norm)
|
159 |
+
if batch_size > 0:
|
160 |
+
_bs = batch_size
|
161 |
+
else:
|
162 |
+
_bs = find_batch_size(
|
163 |
+
ensemble_size=1,
|
164 |
+
input_res=max(rgb_norm.shape[1:]),
|
165 |
+
dtype=self.dtype,
|
166 |
+
)
|
167 |
+
|
168 |
+
single_rgb_loader = DataLoader(
|
169 |
+
single_rgb_dataset, batch_size=_bs, shuffle=False
|
170 |
+
)
|
171 |
+
|
172 |
+
# Predict depth maps (batched)
|
173 |
+
pred_list = []
|
174 |
+
if show_progress_bar:
|
175 |
+
iterable = tqdm(
|
176 |
+
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
|
177 |
+
)
|
178 |
+
else:
|
179 |
+
iterable = single_rgb_loader
|
180 |
+
|
181 |
+
for batch in iterable:
|
182 |
+
(batched_img, ) = batch
|
183 |
+
pred = self.single_infer(
|
184 |
+
rgb_in=batched_img,
|
185 |
+
mode=mode,
|
186 |
+
)
|
187 |
+
pred_list.append(pred.detach().clone())
|
188 |
+
preds = torch.concat(pred_list, axis=0).squeeze() # [bs_imgs, task_channel_num, H, W]
|
189 |
+
preds = preds.view(bs_imgs, task_channel_num, preds.shape[-2], preds.shape[-1])
|
190 |
+
|
191 |
+
if match_input_res:
|
192 |
+
preds = F.interpolate(preds, input_size, mode=self.task_infos[mode]['interpolate'])
|
193 |
+
|
194 |
+
# ----------------- Post processing -----------------
|
195 |
+
if mode == 'depth':
|
196 |
+
if len(preds.shape) == 4:
|
197 |
+
preds = preds[:, 0] # [bs_imgs, H, W]
|
198 |
+
# Scale prediction to [0, 1]
|
199 |
+
min_d = preds.view(bs_imgs, -1).min(dim=1)[0]
|
200 |
+
max_d = preds.view(bs_imgs, -1).max(dim=1)[0]
|
201 |
+
preds = (preds - min_d[:, None, None]) / (max_d[:, None, None] - min_d[:, None, None])
|
202 |
+
preds = preds.cpu().numpy().astype(np.float32)
|
203 |
+
# Colorize
|
204 |
+
pred_colored_img_list = []
|
205 |
+
for i in range(bs_imgs):
|
206 |
+
pred_colored_chw = colorize_depth_maps(
|
207 |
+
preds[i], 0, 1, cmap=color_map
|
208 |
+
).squeeze() # [3, H, W], value in (0, 1)
|
209 |
+
pred_colored_chw = (pred_colored_chw * 255).astype(np.uint8)
|
210 |
+
pred_colored_hwc = chw2hwc(pred_colored_chw)
|
211 |
+
pred_colored_img = Image.fromarray(pred_colored_hwc)
|
212 |
+
pred_colored_img_list.append(pred_colored_img)
|
213 |
+
|
214 |
+
return GenPerceptOutput(
|
215 |
+
pred_np=np.squeeze(preds),
|
216 |
+
pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
|
217 |
+
)
|
218 |
+
|
219 |
+
elif mode == 'seg' or mode == 'sr':
|
220 |
+
if not self.customized_head:
|
221 |
+
# shift to [0, 1]
|
222 |
+
preds = (preds + 1.0) / 2.0
|
223 |
+
# shift to [0, 255]
|
224 |
+
preds = preds * 255
|
225 |
+
# Clip output range
|
226 |
+
preds = preds.clip(0, 255).cpu().numpy().astype(np.uint8)
|
227 |
+
else:
|
228 |
+
raise NotImplementedError
|
229 |
+
|
230 |
+
pred_colored_img_list = []
|
231 |
+
for i in range(preds.shape[0]):
|
232 |
+
pred_colored_hwc = chw2hwc(preds[i])
|
233 |
+
pred_colored_img = Image.fromarray(pred_colored_hwc)
|
234 |
+
pred_colored_img_list.append(pred_colored_img)
|
235 |
+
|
236 |
+
return GenPerceptOutput(
|
237 |
+
pred_np=np.squeeze(preds),
|
238 |
+
pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
|
239 |
+
)
|
240 |
+
|
241 |
+
elif mode == 'normal':
|
242 |
+
if not self.customized_head:
|
243 |
+
preds = preds.clip(-1, 1).cpu().numpy() # [-1, 1]
|
244 |
+
else:
|
245 |
+
raise NotImplementedError
|
246 |
+
|
247 |
+
pred_colored_img_list = []
|
248 |
+
for i in range(preds.shape[0]):
|
249 |
+
pred_colored_chw = norm_to_rgb(preds[i])
|
250 |
+
pred_colored_hwc = chw2hwc(pred_colored_chw)
|
251 |
+
normal_colored_img_i = Image.fromarray(pred_colored_hwc)
|
252 |
+
pred_colored_img_list.append(normal_colored_img_i)
|
253 |
+
|
254 |
+
return GenPerceptOutput(
|
255 |
+
pred_np=np.squeeze(preds),
|
256 |
+
pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
|
257 |
+
)
|
258 |
+
|
259 |
+
else:
|
260 |
+
raise NotImplementedError
|
261 |
+
|
262 |
+
@torch.no_grad()
|
263 |
+
def single_infer(
|
264 |
+
self,
|
265 |
+
rgb_in: torch.Tensor,
|
266 |
+
mode: str = 'depth',
|
267 |
+
) -> torch.Tensor:
|
268 |
+
"""
|
269 |
+
Perform an individual depth prediction without ensembling.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
rgb_in (torch.Tensor):
|
273 |
+
Input RGB image.
|
274 |
+
num_inference_steps (int):
|
275 |
+
Number of diffusion denoising steps (DDIM) during inference.
|
276 |
+
show_pbar (bool):
|
277 |
+
Display a progress bar of diffusion denoising.
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
torch.Tensor: Predicted depth map.
|
281 |
+
"""
|
282 |
+
device = rgb_in.device
|
283 |
+
bs_imgs = rgb_in.shape[0]
|
284 |
+
timesteps = torch.tensor([1]).long().repeat(bs_imgs).to(device)
|
285 |
+
|
286 |
+
# Encode image
|
287 |
+
rgb_latent = self.encode_rgb(rgb_in)
|
288 |
+
|
289 |
+
batch_embed = self.empty_text_embed
|
290 |
+
batch_embed = batch_embed.repeat((rgb_latent.shape[0], 1, 1)).to(device) # [bs_imgs, 77, 1024]
|
291 |
+
|
292 |
+
# Forward!
|
293 |
+
if self.customized_head:
|
294 |
+
unet_features = self.unet(rgb_latent, timesteps, encoder_hidden_states=batch_embed, return_feature_only=True)[0][::-1]
|
295 |
+
pred = self.customized_head(unet_features)
|
296 |
+
else:
|
297 |
+
unet_output = self.unet(
|
298 |
+
rgb_latent, timesteps, encoder_hidden_states=batch_embed
|
299 |
+
) # [bs_imgs, 4, h, w]
|
300 |
+
unet_pred = unet_output.sample
|
301 |
+
pred_latent = - unet_pred
|
302 |
+
pred_latent.to(device)
|
303 |
+
pred = self.decode_pred(pred_latent)
|
304 |
+
if mode == 'depth':
|
305 |
+
# mean of output channels
|
306 |
+
pred = pred.mean(dim=1, keepdim=True)
|
307 |
+
# clip prediction
|
308 |
+
pred = torch.clip(pred, -1.0, 1.0)
|
309 |
+
return pred
|
310 |
+
|
311 |
+
|
312 |
+
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
313 |
+
"""
|
314 |
+
Encode RGB image into latent.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
rgb_in (torch.Tensor):
|
318 |
+
Input RGB image to be encoded.
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
torch.Tensor: Image latent
|
322 |
+
"""
|
323 |
+
try:
|
324 |
+
# encode
|
325 |
+
h_temp = self.vae.encoder(rgb_in)
|
326 |
+
moments = self.vae.quant_conv(h_temp)
|
327 |
+
except:
|
328 |
+
# encode
|
329 |
+
h_temp = self.vae.encoder(rgb_in.float())
|
330 |
+
moments = self.vae.quant_conv(h_temp.float())
|
331 |
+
|
332 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
333 |
+
# scale latent
|
334 |
+
rgb_latent = mean * self.vae_scale_factor
|
335 |
+
return rgb_latent
|
336 |
+
|
337 |
+
def decode_pred(self, pred_latent: torch.Tensor) -> torch.Tensor:
|
338 |
+
"""
|
339 |
+
Decode pred latent into pred label.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
pred_latent (torch.Tensor):
|
343 |
+
prediction latent to be decoded.
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
torch.Tensor: Decoded prediction label.
|
347 |
+
"""
|
348 |
+
# scale latent
|
349 |
+
pred_latent = pred_latent / self.vae_scale_factor
|
350 |
+
# decode
|
351 |
+
z = self.vae.post_quant_conv(pred_latent)
|
352 |
+
pred = self.vae.decoder(z)
|
353 |
+
|
354 |
+
return pred
|
355 |
+
|