Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,37 +1,45 @@
|
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def get_free_memory_gb():
|
6 |
gpu_index = torch.cuda.current_device()
|
7 |
-
# Get the GPU's properties
|
8 |
gpu_properties = torch.cuda.get_device_properties(gpu_index)
|
9 |
|
10 |
-
# Get the total and allocated memory
|
11 |
total_memory = gpu_properties.total_memory
|
12 |
allocated_memory = torch.cuda.memory_allocated(gpu_index)
|
13 |
|
14 |
-
# Calculate the free memory
|
15 |
free_memory = total_memory - allocated_memory
|
16 |
return free_memory / 1024**3
|
17 |
|
18 |
|
19 |
-
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
|
20 |
-
|
21 |
-
if torch.cuda.is_available():
|
22 |
-
free_memory = get_free_memory_gb()
|
23 |
-
concurrency_count = int(free_memory // 7)
|
24 |
-
model = model.cuda()
|
25 |
-
print(f"Using GPU with concurrency: {concurrency_count}")
|
26 |
-
print(f"Available video memory: {free_memory} GB")
|
27 |
-
else:
|
28 |
-
print("Using CPU")
|
29 |
-
concurrency_count = 1
|
30 |
-
|
31 |
-
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
|
32 |
-
|
33 |
-
|
34 |
def inference(video):
|
|
|
|
|
|
|
|
|
|
|
35 |
convert_video(
|
36 |
model, # The loaded model, can be on any device (cpu or cuda).
|
37 |
input_source=video, # A video file or an image sequence directory.
|
@@ -48,23 +56,40 @@ def inference(video):
|
|
48 |
return "com.mp4"
|
49 |
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
gr.
|
67 |
-
"
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
block.queue(
|
|
|
|
|
|
1 |
+
import av
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
|
5 |
|
6 |
+
def get_video_length_av(video_path):
|
7 |
+
with av.open(video_path) as container:
|
8 |
+
stream = container.streams.video[0]
|
9 |
+
if container.duration is not None:
|
10 |
+
duration_in_seconds = float(container.duration) / av.time_base
|
11 |
+
else:
|
12 |
+
duration_in_seconds = stream.duration * stream.time_base
|
13 |
+
|
14 |
+
return duration_in_seconds
|
15 |
+
|
16 |
+
|
17 |
+
def get_video_dimensions(video_path):
|
18 |
+
with av.open(video_path) as container:
|
19 |
+
video_stream = container.streams.video[0]
|
20 |
+
width = video_stream.width
|
21 |
+
height = video_stream.height
|
22 |
+
|
23 |
+
return width, height
|
24 |
+
|
25 |
+
|
26 |
def get_free_memory_gb():
|
27 |
gpu_index = torch.cuda.current_device()
|
|
|
28 |
gpu_properties = torch.cuda.get_device_properties(gpu_index)
|
29 |
|
|
|
30 |
total_memory = gpu_properties.total_memory
|
31 |
allocated_memory = torch.cuda.memory_allocated(gpu_index)
|
32 |
|
|
|
33 |
free_memory = total_memory - allocated_memory
|
34 |
return free_memory / 1024**3
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def inference(video):
|
38 |
+
if get_video_length_av(video.name) > 30:
|
39 |
+
raise gr.Error("Length of video cannot be over 30 seconds")
|
40 |
+
if get_video_dimensions(video.name) > (1920, 1920):
|
41 |
+
raise gr.Error("Video resolution must not be higher than 1920x1080")
|
42 |
+
|
43 |
convert_video(
|
44 |
model, # The loaded model, can be on any device (cpu or cuda).
|
45 |
input_source=video, # A video file or an image sequence directory.
|
|
|
56 |
return "com.mp4"
|
57 |
|
58 |
|
59 |
+
if __name__ == "__main__":
|
60 |
+
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
|
61 |
+
|
62 |
+
if torch.cuda.is_available():
|
63 |
+
free_memory = get_free_memory_gb()
|
64 |
+
concurrency_count = int(free_memory // 7)
|
65 |
+
model = model.cuda()
|
66 |
+
print(f"Using GPU with concurrency: {concurrency_count}")
|
67 |
+
print(f"Available video memory: {free_memory} GB")
|
68 |
+
else:
|
69 |
+
print("Using CPU")
|
70 |
+
concurrency_count = 1
|
71 |
+
|
72 |
+
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
|
73 |
+
|
74 |
+
with gr.Blocks(title="Robust Video Matting") as block:
|
75 |
+
gr.Markdown("# Robust Video Matting")
|
76 |
+
gr.Markdown(
|
77 |
+
"Gradio demo for Robust Video Matting. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
|
78 |
+
)
|
79 |
+
with gr.Row():
|
80 |
+
inp = gr.Video(label="Input Video")
|
81 |
+
out = gr.Video(label="Output Video")
|
82 |
+
btn = gr.Button("Run")
|
83 |
+
btn.click(inference, inputs=inp, outputs=out)
|
84 |
+
|
85 |
+
gr.Examples(
|
86 |
+
examples=[["example.mp4"]],
|
87 |
+
inputs=[inp],
|
88 |
+
)
|
89 |
+
gr.HTML(
|
90 |
+
"<p style='text-align: center'><a href='https://arxiv.org/abs/2108.11515'>Robust High-Resolution Video Matting with Temporal Guidance</a> | <a href='https://github.com/PeterL1n/RobustVideoMatting'>Github Repo</a></p>"
|
91 |
+
)
|
92 |
|
93 |
+
block.queue(
|
94 |
+
api_open=False, max_size=5, concurrency_count=concurrency_count
|
95 |
+
).launch()
|