Depth_Stitcher / app.py
mtwohey2's picture
Update app.py
e9985cb verified
import os
import gc
import cv2
import gradio as gr
import numpy as np
import matplotlib.cm as cm
import matplotlib # New import for the updated colormap API
import subprocess
import sys
from utils.dc_utils import read_video_frames, save_video
title = """**RGBD SBS output**"""
description = """**Video Depth Anything** + RGBD sbs output for viewing with Looking Glass Factory displays.
Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
def stitch_rgbd_videos(
processed_video: str,
depth_vis_video: str,
max_len: int = -1,
target_fps: int = -1,
max_res: int = 1280,
stitch: bool = True,
grayscale: bool = True,
convert_from_color: bool = True,
blur: float = 0.3,
output_dir: str = './outputs',
input_size: int = 518,
):
video_name = os.path.basename(processed_video)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
stitched_video_path = None
if stitch:
# Process videos frame by frame
cap_rgb = cv2.VideoCapture(processed_video)
cap_depth = cv2.VideoCapture(depth_vis_video)
if not cap_rgb.isOpened() or not cap_depth.isOpened():
print("Error: Could not open one or both videos")
return None
# Get video properties
original_fps = cap_rgb.get(cv2.CAP_PROP_FPS)
if target_fps <= 0:
target_fps = original_fps
# Calculate stride for frame skipping
stride = max(round(original_fps / target_fps), 1) if target_fps > 0 else 1
# Get frame counts for progress reporting
total_frames_rgb = int(cap_rgb.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Video fps: {original_fps}, target fps: {target_fps}, total frames: {total_frames_rgb}")
# Set up video writer
base_name = os.path.splitext(video_name)[0]
short_name = base_name[:20]
stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4')
# Get first frame to determine dimensions
ret_rgb, first_frame_rgb = cap_rgb.read()
ret_depth, first_frame_depth = cap_depth.read()
if not ret_rgb or not ret_depth:
print("Error: Could not read first frame from one or both videos")
return None
# Reset video captures
cap_rgb.set(cv2.CAP_PROP_POS_FRAMES, 0)
cap_depth.set(cv2.CAP_PROP_POS_FRAMES, 0)
# Get output dimensions
H_full, W_full = first_frame_rgb.shape[:2]
output_width = W_full * 2 # RGB and depth side by side
output_height = H_full
# Initialize video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(stitched_video_path, fourcc, target_fps, (output_width, output_height))
# Process frames one by one
frame_count = 0
processed_count = 0
while True:
# Read frames
ret_rgb, rgb_full = cap_rgb.read()
ret_depth, depth_frame = cap_depth.read()
# Break if either video ends
if not ret_rgb or not ret_depth:
break
# Skip frames based on stride
frame_count += 1
if frame_count % stride != 0:
continue
processed_count += 1
# Set max_len limit if specified
if max_len > 0 and processed_count > max_len:
break
# Process RGB frame - resize if max_res is specified
if max_res > 0:
h, w = rgb_full.shape[:2]
if max(h, w) > max_res:
scale = max_res / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
rgb_full = cv2.resize(rgb_full, (new_w, new_h))
# Process depth frame based on settings (assuming always 3-channel)
if grayscale:
if convert_from_color:
# Convert to grayscale if it's a color image
depth_gray = cv2.cvtColor(depth_frame, cv2.COLOR_RGB2GRAY)
depth_vis = np.stack([depth_gray] * 3, axis=-1)
else:
# Assume it's already the right format
depth_vis = depth_frame
else:
if np.max(depth_frame) > 0: # Ensure we have valid depth data
# Use the inferno colormap if requested
cmap = matplotlib.colormaps.get_cmap("inferno")
# Convert to single channel first
depth_gray = cv2.cvtColor(depth_frame, cv2.COLOR_RGB2GRAY)
# Normalize to 0-1 range for colormap
depth_norm = depth_gray / 255.0
# Apply colormap
depth_vis = (cmap(depth_norm)[..., :3] * 255).astype(np.uint8)
else:
# If zero depth, just use the original
depth_vis = depth_frame
# Apply Gaussian blur if requested
if blur > 0:
kernel_size = max(1, int(blur * 20) * 2 + 1) # Ensures an odd kernel size.
kernel_size = min(kernel_size, 31) # Cap kernel size at 31 (OpenCV limitation)
depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0)
# Resize the depth visualization to match the full-resolution RGB frame.
H_full, W_full = rgb_full.shape[:2]
depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full))
depth_vis_resized = depth_vis_resized.astype(np.uint8) # Ensure uint8
# Concatenate frames
stitched = cv2.hconcat([rgb_full, depth_vis_resized])
# Write frame
out.write(stitched)
# Free memory
del rgb_full, depth_vis, depth_vis_resized, stitched
# Progress report
if processed_count % 10 == 0:
print(f"Processed {processed_count} frames...")
# Force garbage collection periodically
if processed_count % 50 == 0:
gc.collect()
# Release resources
cap_rgb.release()
cap_depth.release()
out.release()
# Merge audio from the input video into the stitched video using ffmpeg.
temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
cmd = [
"ffmpeg",
"-y",
"-i", stitched_video_path,
"-i", processed_video,
"-c:v", "copy",
"-c:a", "aac",
"-map", "0:v:0",
"-map", "1:a:0?",
"-shortest",
temp_audio_path
]
subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
os.replace(temp_audio_path, stitched_video_path)
print(f"Completed processing {processed_count} frames")
# Return stitched video.
return stitched_video_path
def construct_demo():
with gr.Blocks(analytics_enabled=False) as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown("### If you find this work useful, please help ⭐ the [Github Repo](https://github.com/DepthAnything/Video-Depth-Anything). Thanks for your attention!")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
# Video input component for file upload.
processed_video = gr.Video(label="Input Video with Audio")
with gr.Column(scale=1):
depth_vis_video = gr.Video(label="Depth Video")
with gr.Column(scale=2):
with gr.Row(equal_height=True):
stitched_video = gr.Video(label="Stitched RGBD Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
with gr.Row(equal_height=True):
with gr.Column(scale=2):
with gr.Accordion("Advanced Settings", open=False):
max_len = gr.Slider(label="Max process length", minimum=-1, maximum=1000, value=-1, step=1)
target_fps = gr.Slider(label="Target FPS", minimum=-1, maximum=30, value=-1, step=1)
max_res = gr.Slider(label="Max side resolution", minimum=480, maximum=1920, value=1920, step=1)
stitch_option = gr.Checkbox(label="Stitch RGB & Depth Videos", value=True)
grayscale_option = gr.Checkbox(label="Output Depth as Grayscale", value=True)
convert_from_color_option = gr.Checkbox(label="Convert Grayscale from Color", value=True)
blur_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Depth Blur (can reduce edge artifacts on display)", value=0.3)
generate_btn = gr.Button("Generate")
with gr.Column(scale=1):
pass
generate_btn.click(
fn=stitch_rgbd_videos,
inputs=[processed_video, depth_vis_video, max_len, target_fps, max_res, stitch_option, grayscale_option, convert_from_color_option, blur_slider],
outputs=stitched_video,
)
return demo
if __name__ == "__main__":
demo = construct_demo()
demo.queue(max_size=4).launch()