File size: 1,908 Bytes
37840e7
 
c23facb
 
 
5514789
 
c283f36
5514789
 
5212158
 
c07c44f
 
 
 
5212158
5514789
212d34e
07dbe58
d2eaa46
5514789
 
 
36f850c
5514789
212d34e
 
5514789
 
 
 
 
c283f36
212d34e
5514789
 
212d34e
 
 
5514789
10674e9
5514789
 
 
 
 
 
c5c5a80
5212158
 
5514789
 
5212158
5514789
8f8d235
 
5514789
c283f36
5212158
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
torch.jit.script = lambda f: f
import gradio as gr
import spaces

from zoedepth.utils.misc import colorize, save_raw_16bit
from zoedepth.utils.geometry import depth_to_points, create_triangles

from PIL import Image
import numpy as np

css = """
img {
    max-height: 500px;
    object-fit: contain;
}
"""

# DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL = torch.hub.load('isl-org/ZoeDepth', "ZoeD_N", pretrained=True).eval()

# ----------- Depth functions
def save_raw_16bit(depth, fpath="raw.png"):
    if isinstance(depth, torch.Tensor):
        depth = depth.squeeze().cpu().numpy()
    
    # assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
    # assert depth.ndim == 2, "Depth must be 2D"
    depth = depth * 256  # scale for 16-bit png
    depth = depth.astype(np.uint16)
    return depth

@spaces.GPU(enable_queue=True)
def process_image(image: Image.Image):
    global MODEL
    image = image.convert("RGB")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    MODEL.to(device)
    depth = MODEL.infer_pil(image)

    processed_array = save_raw_16bit(colorize(depth)[:, :, 0])
    return Image.fromarray(processed_array)

# ----------- Depth functions


title = "# ZoeDepth"
description = """Unofficial demo for **ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth**."""

with gr.Blocks(css=css) as API:
    gr.Markdown(title)
    gr.Markdown(description)
    with gr.Tab("Depth Prediction"):
        with gr.Row():
            inputs=gr.Image(label="Input Image", type='pil', height=500)  # Input is an image
            outputs=gr.Image(label="Depth Map", type='pil', height=500)  # Output is also an image
        generate_btn = gr.Button(value="Generate")
        generate_btn.click(process_image, inputs=inputs, outputs=outputs, api_name="generate_depth")

if __name__ == '__main__':
    API.launch()