import gradio as gr import torch from src.augmentations import get_videomae_transform from src.models import load_model from src.utils import ( create_plot, get_frames, get_videomae_outputs, prepare_frames_masks, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_visualisations(mask_ratio, video_path): transform = get_videomae_transform() frames, ids = get_frames(path=video_path, transform=transform) model, masks, patch_size = load_model( path="assets/checkpoint.pth", mask_ratio=mask_ratio, device=device, ) with torch.no_grad(): frames, masks = prepare_frames_masks(frames, masks, device) outputs = model(frames, masks) visualisations = get_videomae_outputs( frames=frames, masks=masks, outputs=outputs, ids=ids, patch_size=patch_size, device=device, ) return create_plot(visualisations) with gr.Blocks() as app: gr.Markdown( """ # VideoMAE Reconstruction Demo To read more about the Self-Supervised Learning techniques for video please refer to the [Lightly AI blogpost on Self-Supervised Learning for Videos](www.lightly.ai/post/self-supervised-learning-for-videos). """ # noqa: E501 ) video = gr.Video( value="assets/example.mp4", ) mask_ratio_slider = gr.Slider( minimum=0.25, maximum=0.95, step=0.05, value=0.75, label="masking ratio" ) btn = gr.Button("Run") btn.click( get_visualisations, inputs=[mask_ratio_slider, video], outputs=gr.Plot(label="VideoMAE Outputs", format="png"), ) if __name__ == "__main__": app.launch()