Spaces:
Running
Running
import gradio as gr | |
import torch | |
from augmentations import get_videomae_transform | |
from models import load_model | |
from utils import create_plot, get_frames, get_videomae_outputs, prepare_frames_masks | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
transform = get_videomae_transform() | |
def get_visualisations(mask_ratio, video_path): | |
frames, ids = get_frames(path=video_path, transform=transform) | |
model, masks, patch_size = load_model( | |
path="assets/checkpoint.pth", | |
mask_ratio=mask_ratio, | |
device=device, | |
) | |
with torch.no_grad(): | |
frames, masks = prepare_frames_masks(frames, masks, device) | |
outputs = model(frames, masks) | |
visualisations = get_videomae_outputs( | |
frames=frames, | |
masks=masks, | |
outputs=outputs, | |
ids=ids, | |
patch_size=patch_size, | |
device=device, | |
) | |
return create_plot(visualisations) | |
with gr.Blocks() as app: | |
video = gr.Video( | |
value="assets/example.mp4", | |
) | |
mask_ratio_slider = gr.Slider( | |
minimum=0.25, maximum=0.95, step=0.05, value=0.75, label="masking ratio" | |
) | |
btn = gr.Button("Run") | |
btn.click( | |
get_visualisations, | |
inputs=[mask_ratio_slider, video], | |
outputs=gr.Plot(label="VideoMAE Outputs", format="png"), | |
) | |
if __name__ == "__main__": | |
app.launch() | |