Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 Guangkai Xu, Zhejiang University. All rights reserved. | |
# | |
# Licensed under the CC0-1.0 license; | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://github.com/aim-uofa/GenPercept/blob/main/LICENSE | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# -------------------------------------------------------------------------- | |
# This code is based on Marigold and diffusers codebases | |
# https://github.com/prs-eth/marigold | |
# https://github.com/huggingface/diffusers | |
# -------------------------------------------------------------------------- | |
# If you find this code useful, we kindly ask you to cite our paper in your work. | |
# Please find bibtex at: https://github.com/aim-uofa/GenPercept#%EF%B8%8F-citation | |
# More information about the method can be found at https://github.com/aim-uofa/GenPercept | |
# -------------------------------------------------------------------------- | |
from __future__ import annotations | |
import functools | |
import os | |
import tempfile | |
import warnings | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch as torch | |
from PIL import Image | |
from gradio_imageslider import ImageSlider | |
from gradio_patches.examples import Examples | |
from pipeline_genpercept import GenPerceptPipeline | |
from diffusers import ( | |
DiffusionPipeline, | |
UNet2DConditionModel, | |
AutoencoderKL, | |
) | |
warnings.filterwarnings( | |
"ignore", message=".*LoginButton created outside of a Blocks context.*" | |
) | |
default_image_processing_res = 768 | |
default_image_reproducuble = True | |
def process_image_check(path_input): | |
if path_input is None: | |
raise gr.Error( | |
"Missing image in the first pane: upload a file or use one from the gallery below." | |
) | |
def process_depth( | |
pipe, | |
path_input, | |
processing_res=default_image_processing_res, | |
): | |
print('line 65', path_input) | |
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
print(f"Processing image {name_base}{name_ext}") | |
path_output_dir = tempfile.mkdtemp() | |
path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy") | |
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png") | |
input_image = Image.open(path_input) | |
pipe_out = pipe( | |
input_image, | |
processing_res=processing_res, | |
batch_size=1 if processing_res == 0 else 0, | |
show_progress_bar=False, | |
mode='depth', | |
) | |
depth_pred = pipe_out.pred_np | |
depth_colored = pipe_out.pred_colored | |
np.save(path_out_fp32, depth_pred) | |
depth_colored.save(path_out_vis) | |
path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png") | |
depth_16bit = (depth_pred * 65535.0).astype(np.uint16) | |
Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16") | |
return ( | |
[path_out_16bit, path_out_vis], | |
[path_out_16bit, path_out_fp32, path_out_vis], | |
) | |
def process_normal( | |
pipe, | |
path_input, | |
processing_res=default_image_processing_res, | |
): | |
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
print(f"Processing image {name_base}{name_ext}") | |
path_output_dir = tempfile.mkdtemp() | |
path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_normal_fp32.npy") | |
path_out_vis = os.path.join(path_output_dir, f"{name_base}_normal_colored.png") | |
input_image = Image.open(path_input) | |
pipe_out = pipe( | |
input_image, | |
processing_res=processing_res, | |
batch_size=1 if processing_res == 0 else 0, | |
show_progress_bar=False, | |
mode='normal', | |
) | |
depth_pred = pipe_out.pred_np | |
depth_colored = pipe_out.pred_colored | |
np.save(path_out_fp32, depth_pred) | |
depth_colored.save(path_out_vis) | |
return ( | |
[path_out_vis], | |
[path_out_fp32, path_out_vis], | |
) | |
def process_dis( | |
pipe, | |
path_input, | |
processing_res=default_image_processing_res, | |
): | |
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
print(f"Processing image {name_base}{name_ext}") | |
path_output_dir = tempfile.mkdtemp() | |
path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_dis_fp32.npy") | |
path_out_vis = os.path.join(path_output_dir, f"{name_base}_dis_colored.png") | |
input_image = Image.open(path_input) | |
pipe_out = pipe( | |
input_image, | |
processing_res=processing_res, | |
batch_size=1 if processing_res == 0 else 0, | |
show_progress_bar=False, | |
mode='seg', | |
) | |
depth_pred = pipe_out.pred_np | |
depth_colored = pipe_out.pred_colored | |
np.save(path_out_fp32, depth_pred) | |
depth_colored.save(path_out_vis) | |
return ( | |
[path_out_vis], | |
[path_out_fp32, path_out_vis], | |
) | |
def run_demo_server(pipe_depth, pipe_normal, pipe_dis): | |
process_pipe_depth = spaces.GPU(functools.partial(process_depth, pipe_depth)) | |
process_pipe_normal = spaces.GPU(functools.partial(process_normal, pipe_normal)) | |
process_pipe_dis = spaces.GPU(functools.partial(process_dis, pipe_dis)) | |
gradio_theme = gr.themes.Default() | |
with gr.Blocks( | |
theme=gradio_theme, | |
title="GenPercept", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
.tabs button.selected { | |
font-size: 20px !important; | |
color: crimson !important; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
.md_feedback li { | |
margin-bottom: 0px !important; | |
} | |
""", | |
head=""" | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag() {dataLayer.push(arguments);} | |
gtag('js', new Date()); | |
gtag('config', 'G-1FWSVCGZTG'); | |
</script> | |
""", | |
) as demo: | |
gr.Markdown( | |
""" | |
# GenPercept: Diffusion Models Trained with Large Data Are Transferable Visual Models | |
<p align="center"> | |
<a title="arXiv" href="https://arxiv.org/abs/2403.06090" target="_blank" rel="noopener noreferrer" | |
style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | |
</a> | |
<a title="Github" href="https://github.com/aim-uofa/GenPercept" target="_blank" rel="noopener noreferrer" | |
style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/aim-uofa/GenPercept?label=GitHub%20%E2%98%85&logo=github&color=C8C" | |
alt="badge-github-stars"> | |
</a> | |
</p> | |
<p align="justify"> | |
GenPercept is a one-step image perception generalist, which leverages the pretrained prior from stable diffusion models to estimate depth/surface normal/matting/segmentation with impressive details. | |
It achieves extremely fast inference speed and remarkable generalization capability on these fundamental vision perception tasks. | |
</p> | |
""" | |
) | |
with gr.Tabs(elem_classes=["tabs"]): | |
with gr.Tab("Depth"): | |
with gr.Row(): | |
with gr.Column(): | |
depth_image_input = gr.Image( | |
label="Input Image", | |
type="filepath", | |
# type="pil", | |
) | |
with gr.Row(): | |
depth_image_submit_btn = gr.Button( | |
value="Estimate Depth", variant="primary" | |
) | |
depth_image_reset_btn = gr.Button(value="Reset") | |
with gr.Accordion("Advanced options", open=False): | |
image_processing_res = gr.Radio( | |
[ | |
("Native", 0), | |
("Recommended", 768), | |
], | |
label="Processing resolution", | |
value=default_image_processing_res, | |
) | |
with gr.Column(): | |
depth_image_output_slider = ImageSlider( | |
label="Predicted depth of gray / color (red-near, blue-far)", | |
type="filepath", | |
show_download_button=True, | |
show_share_button=True, | |
interactive=False, | |
elem_classes="slider", | |
position=0.25, | |
) | |
depth_image_output_files = gr.Files( | |
label="Depth outputs", | |
elem_id="download", | |
interactive=False, | |
) | |
filenames = [] | |
filenames.extend(["depth_anime_%d.jpg" %(i+1) for i in range(7)]) | |
filenames.extend(["depth_line_%d.jpg" %(i+1) for i in range(6)]) | |
filenames.extend(["depth_real_%d.jpg" %(i+1) for i in range(24)]) | |
example_folder = os.path.join(os.path.dirname(__file__), "./images") | |
Examples( | |
fn=process_pipe_depth, | |
examples=[ | |
os.path.join(example_folder, name) | |
for name in filenames | |
], | |
inputs=[depth_image_input], | |
outputs=[depth_image_output_slider, depth_image_output_files], | |
cache_examples=False, | |
# directory_name="examples_depth", | |
# cache_examples=False, | |
) | |
with gr.Tab("Normal"): | |
with gr.Row(): | |
with gr.Column(): | |
normal_image_input = gr.Image( | |
label="Input Image", | |
type="filepath", | |
) | |
with gr.Row(): | |
normal_image_submit_btn = gr.Button( | |
value="Estimate Normal", variant="primary" | |
) | |
normal_image_reset_btn = gr.Button(value="Reset") | |
with gr.Accordion("Advanced options", open=False): | |
image_processing_res = gr.Radio( | |
[ | |
("Native", 0), | |
("Recommended", 768), | |
], | |
label="Processing resolution", | |
value=default_image_processing_res, | |
) | |
with gr.Column(): | |
# normal_image_output_slider = ImageSlider( | |
# label="Predicted surface normal", | |
# type="filepath", | |
# show_download_button=True, | |
# show_share_button=True, | |
# interactive=False, | |
# elem_classes="slider", | |
# position=0.25, | |
# ) | |
normal_image_output = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto') | |
normal_image_output_files = gr.Files( | |
label="Normal outputs", | |
elem_id="download", | |
interactive=False, | |
) | |
filenames = [] | |
filenames.extend(["normal_%d.jpg" %(i+1) for i in range(10)]) | |
# example_folder = "images" | |
# print(os.path.join(example_folder, '1.jpg')) | |
# example_folder = os.path.join(os.path.dirname(__file__), "images") | |
example_folder = os.path.join(os.path.dirname(__file__), "normal_images") | |
Examples( | |
fn=process_pipe_normal, | |
examples=[ | |
os.path.join(example_folder, name) | |
for name in filenames | |
], | |
inputs=[normal_image_input], | |
outputs=[normal_image_output, normal_image_output_files], | |
# cache_examples=True, | |
# directory_name="examples_normal", | |
directory_name="images_cache", | |
cache_examples=False, | |
) | |
with gr.Tab("Dichotomous Segmentation"): | |
with gr.Row(): | |
with gr.Column(): | |
dis_image_input = gr.Image( | |
label="Input Image", | |
type="filepath", | |
) | |
with gr.Row(): | |
dis_image_submit_btn = gr.Button( | |
value="Estimate Dichotomous Segmentation", variant="primary" | |
) | |
dis_image_reset_btn = gr.Button(value="Reset") | |
with gr.Accordion("Advanced options", open=False): | |
image_processing_res = gr.Radio( | |
[ | |
("Native", 0), | |
("Recommended", 768), | |
], | |
label="Processing resolution", | |
value=default_image_processing_res, | |
) | |
with gr.Column(): | |
# dis_image_output_slider = ImageSlider( | |
# label="Predicted dichotomous image segmentation", | |
# type="filepath", | |
# show_download_button=True, | |
# show_share_button=True, | |
# interactive=False, | |
# elem_classes="slider", | |
# position=0.25, | |
# ) | |
dis_image_output = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto') | |
dis_image_output_files = gr.Files( | |
label="DIS outputs", | |
elem_id="download", | |
interactive=False, | |
) | |
filenames = [] | |
filenames.extend(["dis_%d.jpg" %(i+1) for i in range(10)]) | |
# example_folder = "images" | |
# print('line 396', __file__) | |
example_folder = os.path.join(os.path.dirname(__file__), "dis_images") | |
# print(example_folder) | |
Examples( | |
fn=process_pipe_dis, | |
examples=[ | |
os.path.join(example_folder, name) | |
for name in filenames | |
], | |
inputs=[dis_image_input], | |
outputs=[dis_image_output, dis_image_output_files], | |
# cache_examples=True, | |
directory_name="images_cache", | |
cache_examples=False, | |
) | |
### Image tab | |
depth_image_submit_btn.click( | |
fn=process_image_check, | |
inputs=depth_image_input, | |
outputs=None, | |
preprocess=False, | |
queue=False, | |
).success( | |
fn=process_pipe_depth, | |
inputs=[ | |
depth_image_input, | |
image_processing_res, | |
], | |
outputs=[depth_image_output_slider, depth_image_output_files], | |
concurrency_limit=1, | |
) | |
depth_image_reset_btn.click( | |
fn=lambda: ( | |
None, | |
None, | |
None, | |
default_image_processing_res, | |
), | |
inputs=[], | |
outputs=[ | |
depth_image_input, | |
depth_image_output_slider, | |
depth_image_output_files, | |
image_processing_res, | |
], | |
queue=False, | |
) | |
normal_image_submit_btn.click( | |
fn=process_image_check, | |
inputs=normal_image_input, | |
outputs=None, | |
preprocess=False, | |
queue=False, | |
).success( | |
fn=process_pipe_normal, | |
inputs=[ | |
normal_image_input, | |
image_processing_res, | |
], | |
outputs=[normal_image_output, normal_image_output_files], | |
concurrency_limit=1, | |
) | |
normal_image_reset_btn.click( | |
fn=lambda: ( | |
None, | |
None, | |
None, | |
default_image_processing_res, | |
), | |
inputs=[], | |
outputs=[ | |
normal_image_input, | |
normal_image_output, | |
normal_image_output_files, | |
image_processing_res, | |
], | |
queue=False, | |
) | |
dis_image_submit_btn.click( | |
fn=process_image_check, | |
inputs=dis_image_input, | |
outputs=None, | |
preprocess=False, | |
queue=False, | |
).success( | |
fn=process_pipe_dis, | |
inputs=[ | |
dis_image_input, | |
image_processing_res, | |
], | |
outputs=[dis_image_output, dis_image_output_files], | |
concurrency_limit=1, | |
) | |
dis_image_reset_btn.click( | |
fn=lambda: ( | |
None, | |
None, | |
None, | |
default_image_processing_res, | |
), | |
inputs=[], | |
outputs=[ | |
dis_image_input, | |
dis_image_output, | |
dis_image_output_files, | |
image_processing_res, | |
], | |
queue=False, | |
) | |
### Server launch | |
demo.queue( | |
api_open=False, | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
) | |
def main(): | |
os.system("pip freeze") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dtype = torch.float16 | |
vae = AutoencoderKL.from_pretrained("guangkaixu/GenPercept", subfolder='vae').to(dtype) | |
unet_depth_v1 = UNet2DConditionModel.from_pretrained( | |
'guangkaixu/genpercept-depth', | |
subfolder="unet", | |
use_safetensors=True).to(dtype) | |
unet_normal_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_normal_v1", use_safetensors=True).to(dtype) | |
unet_dis_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_dis_v1", use_safetensors=True).to(dtype) | |
empty_text_embed = torch.from_numpy(np.load("./empty_text_embed.npy")).to(device, dtype)[None] # [1, 77, 1024] | |
pipe_depth = GenPerceptPipeline(vae=vae, | |
unet=unet_depth_v1, | |
empty_text_embed=empty_text_embed) | |
pipe_normal = GenPerceptPipeline(vae=vae, | |
unet=unet_normal_v1, | |
empty_text_embed=empty_text_embed) | |
pipe_dis = GenPerceptPipeline(vae=vae, | |
unet=unet_dis_v1, | |
empty_text_embed=empty_text_embed) | |
try: | |
import xformers | |
pipe_depth.enable_xformers_memory_efficient_attention() | |
pipe_normal.enable_xformers_memory_efficient_attention() | |
pipe_dis.enable_xformers_memory_efficient_attention() | |
except: | |
pass # run without xformers | |
pipe_depth = pipe_depth.to(device) | |
pipe_normal = pipe_normal.to(device) | |
pipe_dis = pipe_dis.to(device) | |
run_demo_server(pipe_depth, pipe_normal, pipe_dis) | |
if __name__ == "__main__": | |
main() |