File size: 2,523 Bytes
ca389f6
cc017e3
ca389f6
bab98df
d61c863
 
aef8896
 
d61c863
 
 
 
 
3117ce4
 
7da106d
f38d7e4
 
 
 
d61c863
 
8469132
9c9ead3
2189a33
d61c863
 
16cff9a
 
d61c863
66964b3
d61c863
 
 
16cff9a
d61c863
30f8b65
f38d7e4
30f8b65
 
 
d61c863
 
 
16cff9a
d61c863
 
 
83f3811
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
os.system('pip install gradio==2.3.0a0')
os.system('pip freeze')
os.system('nvidia-smi')
import torch
import gradio as gr
from moviepy.editor import *

model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"

convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

def inference(video):
  #clip = VideoFileClip(video).subclip(0, 5)
  #clip.write_videofile("output.mp4")
  #os.system('ffmpeg -ss 00:00:00 -i '+ video +' -to 00:00:05 -c copy -y output.mp4')
  clip = VideoFileClip(video)
  print(clip.duration)
  if clip.duration > 10:
      return 'trim.mp4',"trim.mp4","trim.mp4"
  convert_video(
      model,                           # The loaded model, can be on any device (cpu or cuda).
      input_source=video,        # A video file or an image sequence directory.
      input_resize=(512,512),       # [Optional] Resize the input (also the output).
      downsample_ratio=None,           # [Optional] If None, make downsampled max size be 512px.
      output_type='video',             # Choose "video" or "png_sequence"
      output_composition='com.mp4',    # File path if video; directory path if png sequence.
      output_alpha="pha.mp4",          # [Optional] Output the raw alpha prediction.
      output_foreground="fgr.mp4",     # [Optional] Output the raw foreground prediction.
      output_video_mbps=4,             # Output video mbps. Not needed for png sequence.
      seq_chunk=8,                    # Process n frames at once for better parallelism.
      num_workers=1,                   # Only for image sequence input. Reader threads.
      progress=True                    # Print conversion progress.
  )
  return 'com.mp4',"pha.mp4","fgr.mp4"
  
title = "Robust Video Matting"
description = "Gradio demo for Robust Video Matting. To use it, simply upload your video, currently only mp4 and ogg formats are supported. Please trim video to 10 seconds or less. Read more at the links below."

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.11515'>Robust High-Resolution Video Matting with Temporal Guidance</a> | <a href='https://github.com/PeterL1n/RobustVideoMatting'>Github Repo</a></p>"

gr.Interface(
    inference, 
    gr.inputs.Video(label="Input"), 
    [gr.outputs.Video(label="Output Composition"),gr.outputs.Video(label="Output Alpha"),gr.outputs.Video(label="Output Foreground")],
    title=title,
    description=description,
    article=article,
    enable_queue=True).launch(debug=True)