Spaces:
Runtime error
Runtime error
# MIT License | |
# Copyright (c) 2022 Intelligent Systems Lab Org | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# File author: Zhenyu Li | |
import os | |
import cv2 | |
import torch | |
import gradio as gr | |
from estimator.models.patchfusion import PatchFusion | |
from gradio_imageslider import ImageSlider | |
from torchvision import transforms | |
import numpy as np | |
import torch.nn.functional as F | |
from PIL import Image | |
import tempfile | |
# import spaces | |
import matplotlib | |
from huggingface_hub import hf_hub_download | |
import copy | |
from zoedepth.models.zoedepth import ZoeDepth | |
css = """ | |
#img-display-container { | |
max-height: 100vh; | |
} | |
#img-display-input { | |
max-height: 80vh; | |
} | |
#img-display-output { | |
max-height: 80vh; | |
} | |
""" | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model_name = 'Zhyever/patchfusion_depth_anything_vitl14' | |
# 'Zhyever/patchfusion_depth_anything_vits14', | |
# 'Zhyever/patchfusion_depth_anything_vitb14', | |
# 'Zhyever/patchfusion_depth_anything_vitl14', | |
# 'Zhyever/patchfusion_zoedepth' | |
model = PatchFusion.from_pretrained(model_name).to(DEVICE).eval() | |
pf_ckp = hf_hub_download(repo_id="Zhyever/PatchFusion", filename="DepthAnything_vitl.pt") | |
depth_anything_base = ZoeDepth.build(**model.config.coarse_branch) | |
depth_anything_base.load_state_dict(torch.load(pf_ckp, map_location='cpu')['model']) | |
depth_anything_base = depth_anything_base.to(DEVICE).eval() | |
# @spaces.GPU | |
def predict_depth(model, image_lr, image_hr, mode, patch_number, resolution_h, resolution_w, patch_split_number_h, patch_split_number_w): | |
process_num = 4 | |
tile_cfg = dict() | |
tile_cfg['image_raw_shape'] = (resolution_h, resolution_w) | |
tile_cfg['patch_split_num'] = (patch_split_number_h, patch_split_number_w) | |
if mode == 'r': | |
mode = mode + str(patch_number) | |
result, log_dict = model(mode='infer', cai_mode=mode, process_num=process_num, image_lr=image_lr, image_hr=image_hr, tile_cfg=tile_cfg) | |
return result | |
# def colorize_depth(depth): | |
# depth = (depth - depth.min()) / (depth.max() - depth.min()) | |
# depth = (1 - depth) * 255.0 | |
# depth = depth.astype(np.uint8) | |
# colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1] | |
# return colored_depth | |
def colorize_depth(depth_map, min_depth=0, max_depth=0, cmap='Spectral'): | |
""" | |
Colorize depth maps. | |
""" | |
percentile = 0.03 | |
# percentile = 5 | |
min_depth = np.percentile(depth_map, percentile) | |
max_depth = np.percentile(depth_map, 100 - percentile) | |
assert len(depth_map.shape) >= 2, "Invalid dimension" | |
if isinstance(depth_map, torch.Tensor): | |
depth = depth_map.detach().clone().squeeze().numpy() | |
elif isinstance(depth_map, np.ndarray): | |
depth = depth_map.copy().squeeze() | |
# reshape to [ (B,) H, W ] | |
if depth.ndim < 3: | |
depth = depth[np.newaxis, :, :] | |
# colorize | |
cm = matplotlib.colormaps[cmap] | |
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) | |
img_colored_np = cm(depth, bytes=False)[:,:,:,0:3] # value from 0 to 1 | |
# img_colored_np = np.rollaxis(img_colored_np, 3, 1) | |
img_colored_np = img_colored_np * 255.0 | |
img_colored_np = img_colored_np.astype(np.uint8) | |
img_colored_np = img_colored_np[0, :, :, :] | |
return img_colored_np | |
def gallery_fn(image, coarse, fine, **kwargs): | |
return [(coarse, fine), image] | |
title = "# PatchFusion" | |
description = """Official demo for **PatchFusion: An End-to-End Tile-Based Framework for High-Resolution Monocular Metric Depth Estimation**. | |
PatchFusion is a deep learning model for high-resolution metric depth estimation from a single image. | |
Please refer to our [project webpage](https://zhyever.github.io/patchfusion), [paper](https://arxiv.org/abs/2312.02284) or [github](https://github.com/zhyever/PatchFusion) for more details. | |
We use [Depth-Anything vitl14](https://github.com/LiheYoung/Depth-Anything) as the default model in this demo. | |
You can slide the output to compare the depth prediction from the base low-resolution model and PatchFusion. | |
""" | |
markdown = """ | |
PatchFusion works on default 4k (2160x3840) resolution. All input images will be resized to 4k before passing through PatchFusion as default. Users can increase the processing resolution in the advanced option (resolution_h, resolution_w). | |
PatchFusion has three modes: m1, m2, and rn. They are corresponding to the different tiling strategies P16, P49, and Rn in our paper. We rename it for general usage. Please refer to the documentation provided in [github inference instruction](https://github.com/zhyever/PatchFusion/docs/user_infer.md). | |
We provide customized tiling variations for m1 and m2. Users can adjust the patch split number in the advanced option (patch_split_number_h, patch_split_number_w) (default: 4x4 splitting). | |
The output depth map is resized to the original image resolution. Download for better visualization quality. 16-Bit Raw Depth = (pred_depth * 256).to(uint16). | |
""" | |
image_resizer = model.resizer | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.Markdown("### Depth Prediction demo") | |
gr.Markdown(markdown) | |
prediction_coarse = gr.File( | |
visible=False, | |
) | |
prediction_fine = gr.File( | |
visible=False, | |
) | |
with gr.Row(): | |
with gr.Accordion("Advanced options", open=False): | |
resolution_h = gr.Slider(label="Proccessing resolution height (Default 4K, 2160)", minimum=256, maximum=2700, value=2160, step=10) | |
resolution_w = gr.Slider(label="Proccessing resolution width (Default 4K, 3840)", minimum=256, maximum=4800, value=3840, step=10) | |
mode = gr.Radio(["m2", "r"], label="Tiling mode", info="We recommand using M2 for fast evaluation and R with 256 patches for best visualization results, respectively", elem_id='mode', value='m2') | |
# mode = gr.Radio(["m1", "m2", "r"], label="Tiling mode", info="We recommand using M2 for fast evaluation and R with 256 patches for best visualization results, respectively", elem_id='mode', value='m1') | |
patch_number = gr.Slider(1, 256, label="Please decide the number of random patches (Only useful in mode=R)", step=1, value=256) | |
patch_split_number_h = gr.Slider(label="Patch split number of height dimension (Default 4)", minimum=2, maximum=8, value=4, step=1) | |
patch_split_number_w = gr.Slider(label="Patch split number of width dimension (Default 4)", minimum=2, maximum=8, value=4, step=1) | |
color_map = gr.Radio(["turbo_r", "magma_r", "Spectral", "gray"], label="Colormap used to render depth map", elem_id='mode', value='turbo_r') | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') | |
with gr.Row(): | |
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5) | |
raw_file = gr.File(label="16-Bit Raw Depth, Multiplier:65535/80") | |
submit = gr.Button("Submit") | |
def on_submit(image, mode, patch_number, resolution_h, resolution_w, patch_split_number_h, patch_split_number_w, color_map): | |
h, w = image.shape[:2] | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 | |
image = transforms.ToTensor()(np.asarray(image)) # raw image | |
image_lr = image_resizer(image.unsqueeze(dim=0)).float().to(DEVICE) | |
# make sure the resolution and patch number are valid | |
# adjust the relationship | |
if resolution_h % (2 * patch_split_number_h) != 0: | |
resolution_h = resolution_h // (2 * patch_split_number_h) * (2 * patch_split_number_h) | |
if resolution_w % (2 * patch_split_number_w) != 0: | |
resolution_w = resolution_w // (2 * patch_split_number_w) * (2 * patch_split_number_w) | |
image_hr = F.interpolate(image.unsqueeze(dim=0), (resolution_h, resolution_w), mode='bicubic', align_corners=True).float().to(DEVICE) | |
pf_prediction = predict_depth(model, image_lr, image_hr, mode, patch_number, resolution_h, resolution_w, patch_split_number_h, patch_split_number_w) | |
coarse_prediction = depth_anything_base(image_lr)['metric_depth'] | |
pf_prediction = F.interpolate(pf_prediction, (h, w))[0, 0].detach().cpu().numpy() | |
coarse_prediction = F.interpolate(coarse_prediction, (h, w))[0, 0].detach().cpu().numpy() | |
pf_prediction_colored = colorize_depth(pf_prediction, cmap=color_map) | |
coarse_predictioncolored = colorize_depth(coarse_prediction, cmap=color_map) | |
max_depth = 80 | |
raw_depth = Image.fromarray((pf_prediction*(65535 / max_depth)).astype('uint16')) | |
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
raw_depth.save(tmp.name) | |
return [(coarse_predictioncolored, pf_prediction_colored), tmp.name] | |
ips = [input_image, mode, patch_number, resolution_h, resolution_w, patch_split_number_h, patch_split_number_w, color_map] | |
submit.click(on_submit, inputs=ips, outputs=[depth_image_slider, raw_file]) | |
inputs = [input_image, prediction_coarse, prediction_fine] | |
outputs = [depth_image_slider, input_image,] | |
examples = gr.Examples( | |
inputs=inputs, | |
outputs=outputs, | |
examples=[ | |
[ | |
"examples/example_1.jpeg", | |
"examples/example_1_coarse.png", | |
"examples/example_1_fine.png", | |
], | |
[ | |
"examples/example_2.jpeg", | |
"examples/example_2_coarse.png", | |
"examples/example_2_fine.png", | |
], | |
[ | |
"examples/example_3.jpeg", | |
"examples/example_3_coarse.png", | |
"examples/example_3_fine.png", | |
]], | |
cache_examples=True, | |
fn=gallery_fn,) | |
if __name__ == '__main__': | |
demo.queue().launch(share=True, show_error=True) |