mtwohey2 commited on
Commit
a612409
·
verified ·
1 Parent(s): 9126d65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -83
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import gc
3
- import torch
4
  import cv2
5
  import gradio as gr
6
  import numpy as np
@@ -10,62 +9,17 @@ import subprocess
10
  import sys
11
  import spaces
12
 
13
- from video_depth_anything.video_depth import VideoDepthAnything
14
  from utils.dc_utils import read_video_frames, save_video
15
- from huggingface_hub import hf_hub_download
16
 
17
- # Examples for the Gradio Demo.
18
- # Each example now contains 8 parameters:
19
- # [video_path, max_len, target_fps, max_res, stitch, grayscale, convert_from_color, blur]
20
- examples = [
21
- ['assets/example_videos/octopus_01.mp4', -1, -1, 1280, True, True, True, 0.3],
22
- ['assets/example_videos/chicken_01.mp4', -1, -1, 1280, True, True, True, 0.3],
23
- ['assets/example_videos/gorilla_01.mp4', -1, -1, 1280, True, True, True, 0.3],
24
- ['assets/example_videos/davis_rollercoaster.mp4', -1, -1, 1280, True, True, True, 0.3],
25
- ['assets/example_videos/Tokyo-Walk_rgb.mp4', -1, -1, 1280, True, True, True, 0.3],
26
- ['assets/example_videos/4158877-uhd_3840_2160_30fps_rgb.mp4', -1, -1, 1280, True, True, True, 0.3],
27
- ['assets/example_videos/4511004-uhd_3840_2160_24fps_rgb.mp4', -1, -1, 1280, True, True, True, 0.3],
28
- ['assets/example_videos/1753029-hd_1920_1080_30fps.mp4', -1, -1, 1280, True, True, True, 0.3],
29
- ['assets/example_videos/davis_burnout.mp4', -1, -1, 1280, True, True, True, 0.3],
30
- ['assets/example_videos/example_5473765-l.mp4', -1, -1, 1280, True, True, True, 0.3],
31
- ['assets/example_videos/Istanbul-26920.mp4', -1, -1, 1280, True, True, True, 0.3],
32
- ['assets/example_videos/obj_1.mp4', -1, -1, 1280, True, True, True, 0.3],
33
- ['assets/example_videos/sheep_cut1.mp4', -1, -1, 1280, True, True, True, 0.3],
34
- ]
35
-
36
- # Use GPU if available; otherwise, use CPU.
37
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
38
-
39
- # Model configuration for different encoder variants.
40
- model_configs = {
41
- 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
42
- 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
43
- }
44
- encoder2name = {
45
- 'vits': 'Small',
46
- 'vitl': 'Large',
47
- }
48
- encoder = 'vitl'
49
- model_name = encoder2name[encoder]
50
-
51
- # Initialize the model.
52
- video_depth_anything = VideoDepthAnything(**model_configs[encoder])
53
- filepath = hf_hub_download(
54
- repo_id=f"depth-anything/Video-Depth-Anything-{model_name}",
55
- filename=f"video_depth_anything_{encoder}.pth",
56
- repo_type="model"
57
- )
58
- video_depth_anything.load_state_dict(torch.load(filepath, map_location='cpu'))
59
- video_depth_anything = video_depth_anything.to(DEVICE).eval()
60
-
61
- title = "# Video Depth Anything + RGBD sbs output"
62
  description = """**Video Depth Anything** + RGBD sbs output for viewing with Looking Glass Factory displays.
63
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
64
 
65
  @spaces.GPU(enable_queue=True)
66
 
67
- def infer_video_depth(
68
- input_video: str,
 
69
  max_len: int = -1,
70
  target_fps: int = -1,
71
  max_res: int = 1280,
@@ -77,24 +31,18 @@ def infer_video_depth(
77
  input_size: int = 518,
78
  ):
79
  # 1. Read input video frames for inference (downscaled to max_res).
80
- frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
81
- # 2. Perform depth inference using the model.
82
- depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
83
-
84
- video_name = os.path.basename(input_video)
85
  if not os.path.exists(output_dir):
86
  os.makedirs(output_dir)
87
-
88
- # Save the preprocessed (RGB) video and the generated depth visualization.
89
- processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_src.mp4')
90
- depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0] + '_vis.mp4')
91
- save_video(frames, processed_video_path, fps=fps)
92
- save_video(depths, depth_vis_path, fps=fps, is_depths=True)
93
-
94
  stitched_video_path = None
95
  if stitch:
96
  # For stitching: read the original video in full resolution (without downscaling).
97
- full_frames, _ = read_video_frames(input_video, max_len, target_fps, max_res=-1)
 
 
98
  # For each frame, create a visual depth image from the inferenced depths.
99
  d_min, d_max = depths.min(), depths.max()
100
  stitched_frames = []
@@ -134,7 +82,7 @@ def infer_video_depth(
134
  base_name = os.path.splitext(video_name)[0]
135
  short_name = base_name[:20]
136
  stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4')
137
- save_video(stitched_frames, stitched_video_path, fps=fps)
138
 
139
  # Merge audio from the input video into the stitched video using ffmpeg.
140
  temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
@@ -142,7 +90,7 @@ def infer_video_depth(
142
  "ffmpeg",
143
  "-y",
144
  "-i", stitched_video_path,
145
- "-i", input_video,
146
  "-c:v", "copy",
147
  "-c:a", "aac",
148
  "-map", "0:v:0",
@@ -154,10 +102,9 @@ def infer_video_depth(
154
  os.replace(temp_audio_path, stitched_video_path)
155
 
156
  gc.collect()
157
- torch.cuda.empty_cache()
158
 
159
- # Return the preprocessed RGB video, depth visualization, and (if created) the stitched video.
160
- return [processed_video_path, depth_vis_path, stitched_video_path]
161
 
162
  def construct_demo():
163
  with gr.Blocks(analytics_enabled=False) as demo:
@@ -168,11 +115,10 @@ def construct_demo():
168
  with gr.Row(equal_height=True):
169
  with gr.Column(scale=1):
170
  # Video input component for file upload.
171
- input_video = gr.Video(label="Input Video")
 
172
  with gr.Column(scale=2):
173
  with gr.Row(equal_height=True):
174
- processed_video = gr.Video(label="Preprocessed Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
175
- depth_vis_video = gr.Video(label="Generated Depth Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
176
  stitched_video = gr.Video(label="Stitched RGBD Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
177
 
178
  with gr.Row(equal_height=True):
@@ -189,19 +135,10 @@ def construct_demo():
189
  with gr.Column(scale=2):
190
  pass
191
 
192
- gr.Examples(
193
- examples=examples,
194
- inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, convert_from_color_option, blur_slider],
195
- outputs=[processed_video, depth_vis_video, stitched_video],
196
- fn=infer_video_depth,
197
- cache_examples=False,
198
- cache_mode="lazy",
199
- )
200
-
201
  generate_btn.click(
202
- fn=infer_video_depth,
203
- inputs=[input_video, max_len, target_fps, max_res, stitch_option, grayscale_option, convert_from_color_option, blur_slider],
204
- outputs=[processed_video, depth_vis_video, stitched_video],
205
  )
206
 
207
  return demo
 
1
  import os
2
  import gc
 
3
  import cv2
4
  import gradio as gr
5
  import numpy as np
 
9
  import sys
10
  import spaces
11
 
 
12
  from utils.dc_utils import read_video_frames, save_video
 
13
 
14
+ title = "#RGBD sbs output"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  description = """**Video Depth Anything** + RGBD sbs output for viewing with Looking Glass Factory displays.
16
  Please refer to our [paper](https://arxiv.org/abs/2501.12375), [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
17
 
18
  @spaces.GPU(enable_queue=True)
19
 
20
+ def stitch_rgbd_videos(
21
+ processed_video: str,
22
+ depth_vis_video: str,
23
  max_len: int = -1,
24
  target_fps: int = -1,
25
  max_res: int = 1280,
 
31
  input_size: int = 518,
32
  ):
33
  # 1. Read input video frames for inference (downscaled to max_res).
34
+ frames, target_fps = read_video_frames(processed_video, max_len, target_fps, max_res)
35
+
36
+ video_name = os.path.basename(processed_video)
 
 
37
  if not os.path.exists(output_dir):
38
  os.makedirs(output_dir)
39
+
 
 
 
 
 
 
40
  stitched_video_path = None
41
  if stitch:
42
  # For stitching: read the original video in full resolution (without downscaling).
43
+ full_frames, _ = read_video_frames(processed_video, max_len, target_fps, max_res=-1)
44
+ depths, _ = read_video_frames(depth_vis_video, max_len, target_fps, max_res=-1)
45
+
46
  # For each frame, create a visual depth image from the inferenced depths.
47
  d_min, d_max = depths.min(), depths.max()
48
  stitched_frames = []
 
82
  base_name = os.path.splitext(video_name)[0]
83
  short_name = base_name[:20]
84
  stitched_video_path = os.path.join(output_dir, short_name + '_RGBD.mp4')
85
+ save_video(stitched_frames, stitched_video_path, fps=target_fps)
86
 
87
  # Merge audio from the input video into the stitched video using ffmpeg.
88
  temp_audio_path = stitched_video_path.replace('_RGBD.mp4', '_RGBD_audio.mp4')
 
90
  "ffmpeg",
91
  "-y",
92
  "-i", stitched_video_path,
93
+ "-i", processed_video,
94
  "-c:v", "copy",
95
  "-c:a", "aac",
96
  "-map", "0:v:0",
 
102
  os.replace(temp_audio_path, stitched_video_path)
103
 
104
  gc.collect()
 
105
 
106
+ # Return stitched video.
107
+ return [stitched_video_path]
108
 
109
  def construct_demo():
110
  with gr.Blocks(analytics_enabled=False) as demo:
 
115
  with gr.Row(equal_height=True):
116
  with gr.Column(scale=1):
117
  # Video input component for file upload.
118
+ processed_video = gr.Video(label="Input Video")
119
+ depth_vis_video = gr.Video(label="Generated Depth Video")
120
  with gr.Column(scale=2):
121
  with gr.Row(equal_height=True):
 
 
122
  stitched_video = gr.Video(label="Stitched RGBD Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5)
123
 
124
  with gr.Row(equal_height=True):
 
135
  with gr.Column(scale=2):
136
  pass
137
 
 
 
 
 
 
 
 
 
 
138
  generate_btn.click(
139
+ fn=stitch_rgbd_videos,
140
+ inputs=[processed_video, depth_vis_video, max_len, target_fps, max_res, stitch_option, grayscale_option, convert_from_color_option, blur_slider],
141
+ outputs=[stitched_video],
142
  )
143
 
144
  return demo