hysts's picture
hysts HF staff
Update
7a98e98
#!/usr/bin/env python
import pathlib
import gradio as gr
import matplotlib as mpl
import numpy as np
import PIL.Image
import spaces
import torch
from gradio_imageslider import ImageSlider
from transformers import DepthProForDepthEstimation, DepthProImageProcessorFast
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
cmap = mpl.colormaps.get_cmap("Spectral_r")
@spaces.GPU(duration=20)
@torch.inference_mode()
def run(image: PIL.Image.Image) -> tuple[tuple[PIL.Image.Image, PIL.Image.Image], str, str, str, str]:
inputs = image_processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
post_processed_output = image_processor.post_process_depth_estimation(
outputs,
target_sizes=[(image.height, image.width)],
)
depth_raw = post_processed_output[0]["predicted_depth"]
depth_min = depth_raw.min().item()
depth_max = depth_raw.max().item()
inverse_depth = 1 / depth_raw
normalized_inverse_depth = (inverse_depth - inverse_depth.min()) / (inverse_depth.max() - inverse_depth.min())
normalized_inverse_depth = normalized_inverse_depth * 255.0
normalized_inverse_depth = normalized_inverse_depth.detach().cpu().numpy()
normalized_inverse_depth = PIL.Image.fromarray(normalized_inverse_depth.astype("uint8"))
colored_inverse_depth = PIL.Image.fromarray(
(cmap(np.array(normalized_inverse_depth))[:, :, :3] * 255).astype(np.uint8)
)
field_of_view = post_processed_output[0]["field_of_view"].item()
focal_length = post_processed_output[0]["focal_length"].item()
return (
(image, colored_inverse_depth),
f"{field_of_view:.2f}",
f"{focal_length:.2f}",
f"{depth_min:.2f}",
f"{depth_max:.2f}",
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown("# DepthPro")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
run_button = gr.Button()
with gr.Column():
output_image = ImageSlider()
with gr.Row():
output_field_of_view = gr.Textbox(label="Field of View")
output_focal_length = gr.Textbox(label="Focal Length")
output_depth_min = gr.Textbox(label="Depth Min")
output_depth_max = gr.Textbox(label="Depth Max")
gr.Examples(
examples=sorted(pathlib.Path("images").glob("*.jpg")),
inputs=input_image,
fn=run,
outputs=[
output_image,
output_field_of_view,
output_focal_length,
output_depth_min,
output_depth_max,
],
)
run_button.click(
fn=run,
inputs=input_image,
outputs=[
output_image,
output_field_of_view,
output_focal_length,
output_depth_min,
output_depth_max,
],
)
if __name__ == "__main__":
demo.queue().launch()