sanghan commited on
Commit
11dbf82
·
1 Parent(s): e7b6415

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -37
app.py CHANGED
@@ -1,37 +1,45 @@
 
1
  import torch
2
  import gradio as gr
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def get_free_memory_gb():
6
  gpu_index = torch.cuda.current_device()
7
- # Get the GPU's properties
8
  gpu_properties = torch.cuda.get_device_properties(gpu_index)
9
 
10
- # Get the total and allocated memory
11
  total_memory = gpu_properties.total_memory
12
  allocated_memory = torch.cuda.memory_allocated(gpu_index)
13
 
14
- # Calculate the free memory
15
  free_memory = total_memory - allocated_memory
16
  return free_memory / 1024**3
17
 
18
 
19
- model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
20
-
21
- if torch.cuda.is_available():
22
- free_memory = get_free_memory_gb()
23
- concurrency_count = int(free_memory // 7)
24
- model = model.cuda()
25
- print(f"Using GPU with concurrency: {concurrency_count}")
26
- print(f"Available video memory: {free_memory} GB")
27
- else:
28
- print("Using CPU")
29
- concurrency_count = 1
30
-
31
- convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
32
-
33
-
34
  def inference(video):
 
 
 
 
 
35
  convert_video(
36
  model, # The loaded model, can be on any device (cpu or cuda).
37
  input_source=video, # A video file or an image sequence directory.
@@ -48,23 +56,40 @@ def inference(video):
48
  return "com.mp4"
49
 
50
 
51
- with gr.Blocks(title="Robust Video Matting") as block:
52
- gr.Markdown("# Robust Video Matting")
53
- gr.Markdown(
54
- "Gradio demo for Robust Video Matting. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
55
- )
56
- with gr.Row():
57
- inp = gr.Video(label="Input Video")
58
- out = gr.Video(label="Output Video")
59
- btn = gr.Button("Run")
60
- btn.click(inference, inputs=inp, outputs=out)
61
-
62
- gr.Examples(
63
- examples=[["example.mp4"]],
64
- inputs=[inp],
65
- )
66
- gr.HTML(
67
- "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.11515'>Robust High-Resolution Video Matting with Temporal Guidance</a> | <a href='https://github.com/PeterL1n/RobustVideoMatting'>Github Repo</a></p>"
68
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- block.queue(api_open=False, max_size=5, concurrency_count=concurrency_count).launch()
 
 
 
1
+ import av
2
  import torch
3
  import gradio as gr
4
 
5
 
6
+ def get_video_length_av(video_path):
7
+ with av.open(video_path) as container:
8
+ stream = container.streams.video[0]
9
+ if container.duration is not None:
10
+ duration_in_seconds = float(container.duration) / av.time_base
11
+ else:
12
+ duration_in_seconds = stream.duration * stream.time_base
13
+
14
+ return duration_in_seconds
15
+
16
+
17
+ def get_video_dimensions(video_path):
18
+ with av.open(video_path) as container:
19
+ video_stream = container.streams.video[0]
20
+ width = video_stream.width
21
+ height = video_stream.height
22
+
23
+ return width, height
24
+
25
+
26
  def get_free_memory_gb():
27
  gpu_index = torch.cuda.current_device()
 
28
  gpu_properties = torch.cuda.get_device_properties(gpu_index)
29
 
 
30
  total_memory = gpu_properties.total_memory
31
  allocated_memory = torch.cuda.memory_allocated(gpu_index)
32
 
 
33
  free_memory = total_memory - allocated_memory
34
  return free_memory / 1024**3
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def inference(video):
38
+ if get_video_length_av(video.name) > 30:
39
+ raise gr.Error("Length of video cannot be over 30 seconds")
40
+ if get_video_dimensions(video.name) > (1920, 1920):
41
+ raise gr.Error("Video resolution must not be higher than 1920x1080")
42
+
43
  convert_video(
44
  model, # The loaded model, can be on any device (cpu or cuda).
45
  input_source=video, # A video file or an image sequence directory.
 
56
  return "com.mp4"
57
 
58
 
59
+ if __name__ == "__main__":
60
+ model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
61
+
62
+ if torch.cuda.is_available():
63
+ free_memory = get_free_memory_gb()
64
+ concurrency_count = int(free_memory // 7)
65
+ model = model.cuda()
66
+ print(f"Using GPU with concurrency: {concurrency_count}")
67
+ print(f"Available video memory: {free_memory} GB")
68
+ else:
69
+ print("Using CPU")
70
+ concurrency_count = 1
71
+
72
+ convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
73
+
74
+ with gr.Blocks(title="Robust Video Matting") as block:
75
+ gr.Markdown("# Robust Video Matting")
76
+ gr.Markdown(
77
+ "Gradio demo for Robust Video Matting. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
78
+ )
79
+ with gr.Row():
80
+ inp = gr.Video(label="Input Video")
81
+ out = gr.Video(label="Output Video")
82
+ btn = gr.Button("Run")
83
+ btn.click(inference, inputs=inp, outputs=out)
84
+
85
+ gr.Examples(
86
+ examples=[["example.mp4"]],
87
+ inputs=[inp],
88
+ )
89
+ gr.HTML(
90
+ "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.11515'>Robust High-Resolution Video Matting with Temporal Guidance</a> | <a href='https://github.com/PeterL1n/RobustVideoMatting'>Github Repo</a></p>"
91
+ )
92
 
93
+ block.queue(
94
+ api_open=False, max_size=5, concurrency_count=concurrency_count
95
+ ).launch()