import torch import gradio as gr from transformers import AutoProcessor, AutoModel from utils import ( convert_frames_to_gif, download_youtube_video, get_num_total_frames, sample_frames_from_video_file, ) FRAME_SAMPLING_RATE = 4 DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot" VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS = [ "microsoft/xclip-base-patch32", "microsoft/xclip-base-patch16-zero-shot", "microsoft/xclip-base-patch16-kinetics-600", "microsoft/xclip-large-patch14ft/xclip-base-patch32-16-frames", "microsoft/xclip-large-patch14", "microsoft/xclip-base-patch16-hmdb-4-shot", "microsoft/xclip-base-patch16-16-frames", "microsoft/xclip-base-patch16-hmdb-2-shot", "microsoft/xclip-base-patch16-ucf-2-shot", "microsoft/xclip-base-patch16-ucf-8-shot", "microsoft/xclip-base-patch16", "microsoft/xclip-base-patch16-hmdb-8-shot", "microsoft/xclip-base-patch16-hmdb-16-shot", "microsoft/xclip-base-patch16-ucf-16-shot", ] processor = AutoProcessor.from_pretrained(DEFAULT_MODEL) model = AutoModel.from_pretrained(DEFAULT_MODEL) # examples = [ # [ # "https://www.youtu.be/l1dBM8ZECao", # "sleeping dog,cat fight club,birds of prey", # ], # [ # "https://youtu.be/VMj-3S1tku0", # "programming course,eating spaghetti,playing football", # ], # [ # "https://youtu.be/BRw7rvLdGzU", # "game of thrones,the lord of the rings,vikings", # ], # ] def select_model(model_name): global processor, model processor = AutoProcessor.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name) def predict(youtube_url_or_file_path, labels_text): if youtube_url_or_file_path.startswith("http"): video_path = download_youtube_video(youtube_url_or_file_path) else: video_path = youtube_url_or_file_path # rearrange sampling rate based on video length and model input length num_total_frames = get_num_total_frames(video_path) num_model_input_frames = model.config.vision_config.num_frames if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames: frame_sampling_rate = num_total_frames // num_model_input_frames else: frame_sampling_rate = FRAME_SAMPLING_RATE labels = labels_text.split(",") frames = sample_frames_from_video_file( video_path, num_model_input_frames, frame_sampling_rate ) gif_path = convert_frames_to_gif(frames, save_path="video.gif") inputs = processor( text=labels, videos=list(frames), return_tensors="pt", padding=True ) # forward pass with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy() label_to_prob = {} for ind, label in enumerate(labels): label_to_prob[label] = float(probs[ind]) return label_to_prob, gif_path app = gr.Blocks() with app: gr.Markdown( "# **
Zero-shot Video Classification with Huggingface Transformers
**" ) with gr.Row(): with gr.Column(): model_names_dropdown = gr.Dropdown( choices=VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS, label="Model:", show_label=True, value=DEFAULT_MODEL, ) model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown) with gr.Tab(label="Youtube URL"): gr.Markdown( "### **Provide a Youtube video URL and a list of labels separated by commas**" ) youtube_url = gr.Textbox(label="Youtube URL:", show_label=True) youtube_url_labels_text = gr.Textbox( label="Labels Text:", show_label=True ) youtube_url_predict_btn = gr.Button(value="Predict") with gr.Tab(label="Local File"): gr.Markdown( "### **Upload a video file and provide a list of labels separated by commas**" ) video_file = gr.Video(label="Video File:", show_label=True) local_video_labels_text = gr.Textbox( label="Labels Text:", show_label=True ) local_video_predict_btn = gr.Button(value="Predict") with gr.Column(): video_gif = gr.Image( label="Input Clip", show_label=True, ) with gr.Column(): predictions = gr.Label(label="Predictions:", show_label=True) # gr.Markdown("**Examples:**") # gr.Examples( # examples, # [youtube_url, youtube_url_labels_text], # [predictions, video_gif], # fn=predict, # cache_examples=True, # ) youtube_url_predict_btn.click( predict, inputs=[youtube_url, youtube_url_labels_text], outputs=[predictions, video_gif], ) local_video_predict_btn.click( predict, inputs=[video_file, local_video_labels_text], outputs=[predictions, video_gif], ) app.launch()