File size: 3,357 Bytes
843bd97
19c4b4d
 
8a2c153
19c4b4d
d2ba2b6
843bd97
19c4b4d
843bd97
 
21401ba
843bd97
 
 
 
 
 
036e46e
843bd97
 
036e46e
843bd97
 
036e46e
843bd97
 
 
 
 
 
 
dd76a42
843bd97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2ba2b6
dd76a42
 
 
 
 
843bd97
 
 
 
dd76a42
843bd97
dd76a42
21401ba
843bd97
 
 
 
21401ba
 
843bd97
dd76a42
843bd97
 
 
dd76a42
 
19c4b4d
dd76a42
843bd97
 
 
dd76a42
843bd97
 
 
dd76a42
21401ba
 
 
dd76a42
8a2c153
 
 
392f705
843bd97
 
 
 
21401ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import cv2
import numpy as np
import os
from PIL import Image
import spaces
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
import tempfile
from gradio_imageslider import ImageSlider

from depth_anything.dpt import DPT_DINOv2
from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet

css = """
#img-display-container {
    max-height: 100vh;
    }
#img-display-input {
    max-height: 80vh;
    }
#img-display-output {
    max-height: 80vh;
    }
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DPT_DINOv2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).to(DEVICE).eval()
model.load_state_dict(torch.load('checkpoints/depth_anything_vitl14.pth'))

title = "# Depth Anything"
description = """Official demo for **Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data**.
Please refer to our [paper](), [project page](https://depth-anything.github.io), or [github](https://github.com/LiheYoung/Depth-Anything) for more details."""

transform = Compose([
        Resize(
            width=518,
            height=518,
            resize_target=False,
            keep_aspect_ratio=True,
            ensure_multiple_of=14,
            resize_method='lower_bound',
            image_interpolation_method=cv2.INTER_CUBIC,
        ),
        NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        PrepareForNet(),
])

@spaces.GPU
@torch.no_grad()
def predict_depth(model, image):
    return model(image)


with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown("### Depth Prediction demo")

    with gr.Row():
        input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
        depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0)
    raw_file = gr.File(label="16-bit raw depth (can be considered as disparity)")
    submit = gr.Button("Submit")

    def on_submit(image):
        original_image = image.copy()

        h, w = image.shape[:2]

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
        image = transform({'image': image})['image']
        image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)

        depth = predict_depth(model, image)
        depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]

        raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
        tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
        raw_depth.save(tmp.name)

        depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
        depth = depth.cpu().numpy().astype(np.uint8)
        colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]

        return [(original_image, colored_depth), tmp.name]

    submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, raw_file])

    example_files = os.listdir('examples')
    example_files.sort()
    example_files = [os.path.join('examples', filename) for filename in example_files]
    examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, raw_file], fn=on_submit, cache_examples=True)


if __name__ == '__main__':
    demo.queue().launch()