Spaces:
Sleeping
Sleeping
import gradio as gr | |
from gradio_rerun import Rerun | |
import rerun as rr | |
import rerun.blueprint as rrb | |
import mmcv | |
from timeit import default_timer as timer | |
from typing import Literal | |
from mini_dpvo.config import cfg as base_cfg | |
from mini_dpvo.api.inference import ( | |
log_trajectory, | |
calib_from_dust3r, | |
create_reader, | |
calculate_num_frames, | |
) | |
import torch | |
import numpy as np | |
from pathlib import Path | |
from multiprocessing import Process, Queue | |
from mini_dpvo.dpvo import DPVO | |
from jaxtyping import UInt8, Float64, Float32 | |
from mini_dust3r.model import AsymmetricCroCo3DStereo | |
from tqdm import tqdm | |
import tyro | |
from dataclasses import dataclass | |
if gr.NO_RELOAD: | |
NETWORK_PATH = "checkpoints/dpvo.pth" | |
DEVICE = ( | |
"mps" | |
if torch.backends.mps.is_available() | |
else "cuda" | |
if torch.cuda.is_available() | |
else "cpu" | |
) | |
MODEL = AsymmetricCroCo3DStereo.from_pretrained( | |
"naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt" | |
).to(DEVICE) | |
class GradioDPVOConfig: | |
share: bool = False | |
def run_dpvo( | |
video_file_path: str, | |
jpg_quality: str, | |
stride: int = 1, | |
skip: int = 0, | |
config_type: Literal["accurate", "fast"] = "accurate", | |
progress=gr.Progress(), | |
): | |
# create a stream to send data back to the rerun viewer | |
stream = rr.binary_stream() | |
parent_log_path = Path("world") | |
rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True) | |
blueprint = rrb.Blueprint( | |
collapse_panels=True, | |
) | |
rr.send_blueprint(blueprint) | |
if config_type == "accurate": | |
base_cfg.merge_from_file("config/default.yaml") | |
elif config_type == "fast": | |
base_cfg.merge_from_file("config/fast.yaml") | |
else: | |
raise ValueError("Invalid config type") | |
base_cfg.BUFFER_SIZE = 2048 | |
slam = None | |
start_time = timer() | |
queue = Queue(maxsize=8) | |
reader: Process = create_reader(video_file_path, None, stride, skip, queue) | |
reader.start() | |
# get the first frame | |
progress(progress=0.1, desc="Estimating Camera Intrinsics") | |
_, bgr_hw3, _ = queue.get() | |
K_33_pred = calib_from_dust3r(bgr_hw3, MODEL, DEVICE) | |
intri_np: Float64[np.ndarray, "4"] = np.array( | |
[K_33_pred[0, 0], K_33_pred[1, 1], K_33_pred[0, 2], K_33_pred[1, 2]] | |
) | |
num_frames = calculate_num_frames(video_file_path, stride, skip) | |
path_list = [] | |
with tqdm(total=num_frames, desc="Processing Frames") as pbar: | |
while True: | |
timestep: int | |
bgr_hw3: UInt8[np.ndarray, "h w 3"] | |
intri_np: Float64[np.ndarray, "4"] | |
(timestep, bgr_hw3, _) = queue.get() | |
# queue will have a (-1, image, intrinsics) tuple when the reader is done | |
if timestep < 0: | |
break | |
rr.set_time_sequence(timeline="timestep", sequence=timestep) | |
bgr_3hw: UInt8[torch.Tensor, "h w 3"] = ( | |
torch.from_numpy(bgr_hw3).permute(2, 0, 1).cuda() | |
) | |
intri_torch: Float64[torch.Tensor, "4"] = torch.from_numpy(intri_np).cuda() | |
if slam is None: | |
_, h, w = bgr_3hw.shape | |
slam = DPVO(base_cfg, NETWORK_PATH, ht=h, wd=w) | |
slam(timestep, bgr_3hw, intri_torch) | |
pbar.update(1) | |
if slam.is_initialized: | |
poses: Float32[torch.Tensor, "buffer_size 7"] = slam.poses_ | |
points: Float32[torch.Tensor, "buffer_size*num_patches 3"] = ( | |
slam.points_ | |
) | |
colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] = slam.colors_ | |
path_list = log_trajectory( | |
parent_log_path, | |
poses, | |
points, | |
colors, | |
intri_np, | |
bgr_hw3, | |
path_list, | |
jpg_quality, | |
) | |
yield stream.read(), timer() - start_time | |
def on_file_upload(video_file_path: str) -> None: | |
video_reader = mmcv.VideoReader(video_file_path) | |
video_info = f""" | |
**Video Info:** | |
- Number of Frames: {video_reader.frame_cnt} | |
- FPS: {round(video_reader.fps)} | |
""" | |
return video_info | |
def main(gradio_config: GradioDPVOConfig): | |
with gr.Blocks( | |
css=""".gradio-container {margin: 0 !important; min-width: 100%};""", | |
title="Mini-DPVO Demo", | |
) as demo: | |
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference | |
gr.HTML('<h2 style="text-align: center;">Mini-DPVO Demo</h2>') | |
gr.HTML( | |
'<p style="text-align: center;">Unofficial DPVO demo using the mini-dpvo. Learn more about mini-dpvo <a href="https://github.com/pablovela5620/mini-dpvo">here</a>.</p>' | |
) | |
with gr.Column(): | |
with gr.Row(): | |
video_input = gr.File( | |
height=100, | |
file_count="single", | |
file_types=[".mp4", ".mov", ".MOV", ".webm"], | |
label="Video File", | |
) | |
with gr.Column(): | |
video_info = gr.Markdown( | |
value=""" | |
**Video Info:** | |
""" | |
) | |
time_taken = gr.Number( | |
label="Time Taken (s)", precision=2, interactive=False | |
) | |
with gr.Accordion(label="Advanced", open=False): | |
with gr.Row(): | |
jpg_quality = gr.Radio( | |
label="JPEG Quality %: Lower quality means faster streaming", | |
choices=[10, 50, 90], | |
value=90, | |
type="value", | |
) | |
stride = gr.Slider( | |
label="Stride: How many frames to sample between each prediction", | |
minimum=1, | |
maximum=5, | |
step=1, | |
value=5, | |
) | |
skip = gr.Number( | |
label="Skip: How many frames to skip at the beginning", | |
value=0, | |
precision=0, | |
) | |
config_type = gr.Dropdown( | |
label="Config Type: Choose between accurate and fast", | |
value="fast", | |
choices=["accurate", "fast"], | |
max_choices=1, | |
) | |
with gr.Row(): | |
start_btn = gr.Button("Run") | |
stop_btn = gr.Button("Stop") | |
rr_viewer = Rerun(height=600, streaming=True) | |
# Example videos | |
base_example_params = [50, 4, 0, "fast"] | |
example_dpvo_dir = Path("data/movies") | |
example_iphone_dir = Path("data/iphone") | |
example_video_paths = sorted(example_iphone_dir.glob("*.MOV")) + sorted( | |
example_dpvo_dir.glob("*.MOV") | |
) | |
example_video_paths = [str(path) for path in example_video_paths] | |
gr.Examples( | |
examples=[[path, *base_example_params] for path in example_video_paths], | |
inputs=[video_input, jpg_quality, stride, skip, config_type], | |
outputs=[rr_viewer], | |
fn=run_dpvo, | |
cache_examples=False, | |
) | |
click_event = start_btn.click( | |
fn=run_dpvo, | |
inputs=[video_input, jpg_quality, stride, skip, config_type], | |
outputs=[rr_viewer, time_taken], | |
) | |
stop_btn.click( | |
fn=None, | |
inputs=[], | |
outputs=[], | |
cancels=[click_event], | |
) | |
video_input.upload( | |
fn=on_file_upload, inputs=[video_input], outputs=[video_info] | |
) | |
demo.launch(share=gradio_config.share) | |
if __name__ == "__main__": | |
main(tyro.cli(GradioDPVOConfig)) | |