# Copyright 2024 Guangkai Xu, Zhejiang University. All rights reserved. # # Licensed under the CC0-1.0 license; # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://github.com/aim-uofa/GenPercept/blob/main/LICENSE # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # This code is based on Marigold and diffusers codebases # https://github.com/prs-eth/marigold # https://github.com/huggingface/diffusers # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/aim-uofa/GenPercept#%EF%B8%8F-citation # More information about the method can be found at https://github.com/aim-uofa/GenPercept # -------------------------------------------------------------------------- from __future__ import annotations import functools import os import tempfile import warnings import sys sys.path.append("../") import gradio as gr import numpy as np import spaces import torch as torch from PIL import Image from gradio_imageslider import ImageSlider from gradio_patches.examples import Examples from pipeline_genpercept import GenPerceptPipeline from diffusers import ( DiffusionPipeline, UNet2DConditionModel, AutoencoderKL, ) warnings.filterwarnings( "ignore", message=".*LoginButton created outside of a Blocks context.*" ) default_image_processing_res = 768 default_image_reproducuble = True def process_image_check(path_input): if path_input is None: raise gr.Error( "Missing image in the first pane: upload a file or use one from the gallery below." ) def process_image( pipe, path_input, processing_res=default_image_processing_res, ): name_base, name_ext = os.path.splitext(os.path.basename(path_input)) print(f"Processing image {name_base}{name_ext}") path_output_dir = tempfile.mkdtemp() path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy") path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png") path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png") input_image = Image.open(path_input) pipe_out = pipe( input_image, processing_res=processing_res, batch_size=1 if processing_res == 0 else 0, show_progress_bar=False, ) depth_pred = pipe_out.depth_np depth_colored = pipe_out.depth_colored depth_16bit = (depth_pred * 65535.0).astype(np.uint16) np.save(path_out_fp32, depth_pred) Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16") depth_colored.save(path_out_vis) return ( [path_out_16bit, path_out_vis], [path_out_16bit, path_out_fp32, path_out_vis], ) def run_demo_server(pipe): process_pipe_image = spaces.GPU(functools.partial(process_image, pipe)) process_pipe_video = spaces.GPU( functools.partial(process_video, pipe), duration=120 ) process_pipe_bas = spaces.GPU(functools.partial(process_bas, pipe)) gradio_theme = gr.themes.Default() with gr.Blocks( theme=gradio_theme, title="GenPercept", css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """, head=""" """, ) as demo: gr.Markdown( """ # GenPercept: Diffusion Models Trained with Large Data Are Transferable Visual Models

badge-github-stars

GenPercept leverages the prior knowledge of stable diffusion models to estimate detailed visual perception results. It achieve remarkable transferable performance on fundamental vision perception tasks using a moderate amount of target data (even synthetic data only). Compared to previous methods, our inference process only requires one step and therefore runs faster.

""" ) with gr.Tabs(elem_classes=["tabs"]): with gr.Tab("Depth Estimation"): with gr.Row(): with gr.Column(): image_input = gr.Image( label="Input Image", type="filepath", ) with gr.Row(): image_submit_btn = gr.Button( value="Estimate Depth", variant="primary" ) image_reset_btn = gr.Button(value="Reset") with gr.Accordion("Advanced options", open=False): image_processing_res = gr.Radio( [ ("Native", 0), ("Recommended", 768), ], label="Processing resolution", value=default_image_processing_res, ) with gr.Column(): image_output_slider = ImageSlider( label="Predicted depth of gray / color (red-near, blue-far)", type="filepath", show_download_button=True, show_share_button=True, interactive=False, elem_classes="slider", position=0.25, ) image_output_files = gr.Files( label="Depth outputs", elem_id="download", interactive=False, ) filenames = [] filenames.extend(["anime_%d.jpg" %i+1 for i in range(7)]) filenames.extend(["line_%d.jpg" %i+1 for i in range(6)]) filenames.extend(["real_%d.jpg" %i+1 for i in range(24)]) Examples( fn=process_pipe_image, examples=[ os.path.join("images", "depth", name) for name in filenames ], inputs=[image_input], outputs=[image_output_slider, image_output_files], cache_examples=True, directory_name="examples_image", ) ### Image tab image_submit_btn.click( fn=process_image_check, inputs=image_input, outputs=None, preprocess=False, queue=False, ).success( fn=process_pipe_image, inputs=[ image_input, image_processing_res, ], outputs=[image_output_slider, image_output_files], concurrency_limit=1, ) image_reset_btn.click( fn=lambda: ( None, None, None, default_image_processing_res, ), inputs=[], outputs=[ image_input, image_output_slider, image_output_files, image_processing_res, ], queue=False, ) ### Server launch demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, ) def main(): os.system("pip freeze") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae = AutoencoderKL.from_pretrained("./", subfolder='vae') unet = UNet2DConditionModel.from_pretrained('./', subfolder="unet") empty_text_embed = torch.from_numpy(np.load("./empty_text_embed.npy")).to(device, dtype)[None] # [1, 77, 1024] pipe = GenPerceptPipeline(vae=vae, unet=unet, empty_text_embed=empty_text_embed) try: import xformers pipe.enable_xformers_memory_efficient_attention() except: pass # run without xformers pipe = pipe.to(device) run_demo_server(pipe) if __name__ == "__main__": main()