import os import gc import cv2 import gradio as gr import numpy as np import matplotlib.cm as cm import matplotlib # New import for the updated colormap API import subprocess import sys from utils.dc_utils import read_video_frames, save_video title = """**RGBD SBS output**""" description = """**Video Depth Anything** + RGBD sbs output for viewing with Looking Glass Factory displays. Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details.""" def stitch_rgbd_videos( processed_video: str, depth_vis_video: str, max_len: int = -1, target_fps: int = -1, max_res: int = 1280, stitch: bool = True, grayscale: bool = True, convert_from_color: bool = True, blur: float = 0.3, output_dir: str = './outputs', input_size: int = 518, ): video_name = os.path.basename(processed_video) if not os.path.exists(output_dir): os.makedirs(output_dir) stitched_video_path = None if stitch: # Process videos frame by frame cap_rgb = cv2.VideoCapture(processed_video) cap_depth = cv2.VideoCapture(depth_vis_video) if not cap_rgb.isOpened() or not cap_depth.isOpened(): print("Error: Could not open one or both videos") return None # Get video properties original_fps = cap_rgb.get(cv2.CAP_PROP_FPS) if target_fps <= 0: target_fps = original_fps # Calculate stride for frame skipping stride = max(round(original_fps / target_fps), 1) if target_fps > 0 else 1 # Get frame counts for progress reporting total_frames_rgb = int(cap_rgb.get(cv2.CAP_PROP_FRAME_COUNT)) print(f"Video fps: {original_fps}, target fps: {target_fps}, total frames: {total_frames_rgb}") # Set up video writer base_name = os.path.splitext(video_name)[0] short_name = base_name[:20] stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4') # Get first frame to determine dimensions ret_rgb, first_frame_rgb = cap_rgb.read() ret_depth, first_frame_depth = cap_depth.read() if not ret_rgb or not ret_depth: print("Error: Could not read first frame from one or both videos") return None # Reset video captures cap_rgb.set(cv2.CAP_PROP_POS_FRAMES, 0) cap_depth.set(cv2.CAP_PROP_POS_FRAMES, 0) # Get output dimensions H_full, W_full = first_frame_rgb.shape[:2] output_width = W_full * 2 # RGB and depth side by side output_height = H_full # Initialize video writer fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(stitched_video_path, fourcc, target_fps, (output_width, output_height)) # Process frames one by one frame_count = 0 processed_count = 0 while True: # Read frames ret_rgb, rgb_full = cap_rgb.read() ret_depth, depth_frame = cap_depth.read() # Break if either video ends if not ret_rgb or not ret_depth: break # Skip frames based on stride frame_count += 1 if frame_count % stride != 0: continue processed_count += 1 # Set max_len limit if specified if max_len > 0 and processed_count > max_len: break # Process RGB frame - resize if max_res is specified if max_res > 0: h, w = rgb_full.shape[:2] if max(h, w) > max_res: scale = max_res / max(h, w) new_h, new_w = int(h * scale), int(w * scale) rgb_full = cv2.resize(rgb_full, (new_w, new_h)) # Process depth frame based on settings (assuming always 3-channel) if grayscale: if convert_from_color: # Convert to grayscale if it's a color image depth_gray = cv2.cvtColor(depth_frame, cv2.COLOR_RGB2GRAY) depth_vis = np.stack([depth_gray] * 3, axis=-1) else: # Assume it's already the right format depth_vis = depth_frame else: if np.max(depth_frame) > 0: # Ensure we have valid depth data # Use the inferno colormap if requested cmap = matplotlib.colormaps.get_cmap("inferno") # Convert to single channel first depth_gray = cv2.cvtColor(depth_frame, cv2.COLOR_RGB2GRAY) # Normalize to 0-1 range for colormap depth_norm = depth_gray / 255.0 # Apply colormap depth_vis = (cmap(depth_norm)[..., :3] * 255).astype(np.uint8) else: # If zero depth, just use the original depth_vis = depth_frame # Apply Gaussian blur if requested if blur > 0: kernel_size = max(1, int(blur * 20) * 2 + 1) # Ensures an odd kernel size. kernel_size = min(kernel_size, 31) # Cap kernel size at 31 (OpenCV limitation) depth_vis = cv2.GaussianBlur(depth_vis, (kernel_size, kernel_size), 0) # Resize the depth visualization to match the full-resolution RGB frame. H_full, W_full = rgb_full.shape[:2] depth_vis_resized = cv2.resize(depth_vis, (W_full, H_full)) depth_vis_resized = depth_vis_resized.astype(np.uint8) # Ensure uint8 # Concatenate frames stitched = cv2.hconcat([rgb_full, depth_vis_resized]) # Write frame out.write(stitched) # Free memory del rgb_full, depth_vis, depth_vis_resized, stitched # Progress report if processed_count % 10 == 0: print(f"Processed {processed_count} frames...") # Force garbage collection periodically if processed_count % 50 == 0: gc.collect() # Release resources cap_rgb.release() cap_depth.release() out.release() # Merge audio from the input video into the stitched video using ffmpeg. temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4') cmd = [ "ffmpeg", "-y", "-i", stitched_video_path, "-i", processed_video, "-c:v", "copy", "-c:a", "aac", "-map", "0:v:0", "-map", "1:a:0?", "-shortest", temp_audio_path ] subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) os.replace(temp_audio_path, stitched_video_path) print(f"Completed processing {processed_count} frames") # Return stitched video. return stitched_video_path def construct_demo(): with gr.Blocks(analytics_enabled=False) as demo: gr.Markdown(title) gr.Markdown(description) gr.Markdown("### If you find this work useful, please help ⭐ the [Github Repo](https://github.com/DepthAnything/Video-Depth-Anything). Thanks for your attention!") with gr.Row(equal_height=True): with gr.Column(scale=1): # Video input component for file upload. processed_video = gr.Video(label="Input Video with Audio") with gr.Column(scale=1): depth_vis_video = gr.Video(label="Depth Video") with gr.Column(scale=2): with gr.Row(equal_height=True): stitched_video = gr.Video(label="Stitched RGBD Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5) with gr.Row(equal_height=True): with gr.Column(scale=2): with gr.Accordion("Advanced Settings", open=False): max_len = gr.Slider(label="Max process length", minimum=-1, maximum=1000, value=-1, step=1) target_fps = gr.Slider(label="Target FPS", minimum=-1, maximum=30, value=-1, step=1) max_res = gr.Slider(label="Max side resolution", minimum=480, maximum=1920, value=1920, step=1) stitch_option = gr.Checkbox(label="Stitch RGB & Depth Videos", value=True) grayscale_option = gr.Checkbox(label="Output Depth as Grayscale", value=True) convert_from_color_option = gr.Checkbox(label="Convert Grayscale from Color", value=True) blur_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Depth Blur (can reduce edge artifacts on display)", value=0.3) generate_btn = gr.Button("Generate") with gr.Column(scale=1): pass generate_btn.click( fn=stitch_rgbd_videos, inputs=[processed_video, depth_vis_video, max_len, target_fps, max_res, stitch_option, grayscale_option, convert_from_color_option, blur_slider], outputs=stitched_video, ) return demo if __name__ == "__main__": demo = construct_demo() demo.queue(max_size=4).launch()