import gradio as gr from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer from threading import Thread import time from PIL import Image import torch import spaces import cv2 import numpy as np # Helper function to return a progress bar HTML snippet. def progress_bar_html(label: str) -> str: return f'''
{label}
''' #adding examples examples=[ [{"text": "Explain the Image", "files": ["examples/3.jpg"]}], [{"text": "Transcription of the letter", "files": ["examples/222.png"]}], [{"text": "@video-infer Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}], [{"text": "@video-infer Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}], [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}], [{"text": "@video-infer Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}], [{"text": "@video-infer Summarize the events in this video", "files": ["examples/sky.mp4"]}], [{"text": "@video-infer What is in the video ?", "files": ["examples/redlight.mp4"]}], ] # Helper: Downsample video to extract a fixed number of frames. def downsample_video(video_path, num_frames=10): cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) # Calculate evenly spaced frame indices. frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) frames = [] for idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: # Convert BGR to RGB and then to a PIL image. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = Image.fromarray(frame) frames.append(frame) cap.release() return frames # Load processor and model. processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct") model = AutoModelForVision2Seq.from_pretrained( "HuggingFaceTB/SmolVLM-Instruct", torch_dtype=torch.bfloat16, ).to("cuda") @spaces.GPU def model_inference( input_dict, history, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p ): text = input_dict["text"] # --- Video Inference Branch --- if text.lower().startswith("@video-infer"): # Remove the command prefix to get the prompt. prompt_text = text[len("@video-infer"):].strip() if not input_dict["files"]: yield "Error: Please provide a video file for @video-infer." return # Assume the first file is a video. video_file = input_dict["files"][0] frames = downsample_video(video_file) if not frames: yield "Error: Could not extract frames from the video." return # Build a chat content: include the user prompt and then each frame labeled. content = [{"type": "text", "text": prompt_text}] for idx, frame in enumerate(frames): content.append({"type": "text", "text": f"Frame {idx+1}:"}) content.append({"type": "image", "image": frame}) resulting_messages = [{ "role": "user", "content": content }] prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) # Process the extracted frames as images. inputs = processor(text=prompt, images=[frames], return_tensors="pt") inputs = {k: v.to("cuda") for k, v in inputs.items()} # Setup generation parameters. generation_args = { "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, } assert decoding_strategy in ["Greedy", "Top P Sampling"] if decoding_strategy == "Greedy": generation_args["do_sample"] = False elif decoding_strategy == "Top P Sampling": generation_args["temperature"] = temperature generation_args["do_sample"] = True generation_args["top_p"] = top_p generation_args.update(inputs) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) buffer = "" thread = Thread(target=model.generate, kwargs=generation_args) thread.start() yield progress_bar_html("Processing Video with SmolVLM") for new_text in streamer: buffer += new_text time.sleep(0.01) yield buffer return # --- Default Image Inference Branch --- # Process input images if provided. if len(input_dict["files"]) > 1: images = [Image.open(image).convert("RGB") for image in input_dict["files"]] elif len(input_dict["files"]) == 1: images = [Image.open(input_dict["files"][0]).convert("RGB")] else: images = [] # Validate input. if text == "" and not images: gr.Error("Please input a query and optionally image(s).") if text == "" and images: gr.Error("Please input a text query along with the image(s).") resulting_messages = [{ "role": "user", "content": [{"type": "image"} for _ in range(len(images))] + [ {"type": "text", "text": text} ] }] prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) inputs = processor(text=prompt, images=[images], return_tensors="pt") inputs = {k: v.to("cuda") for k, v in inputs.items()} generation_args = { "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, } assert decoding_strategy in ["Greedy", "Top P Sampling"] if decoding_strategy == "Greedy": generation_args["do_sample"] = False elif decoding_strategy == "Top P Sampling": generation_args["temperature"] = temperature generation_args["do_sample"] = True generation_args["top_p"] = top_p generation_args.update(inputs) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) buffer = "" thread = Thread(target=model.generate, kwargs=generation_args) thread.start() yield progress_bar_html("Processing Video with SmolVLM") for new_text in streamer: buffer += new_text time.sleep(0.01) yield buffer # Gradio ChatInterface: Allow both image and video file types. demo = gr.ChatInterface( fn=model_inference, description="# **SmolVLM Video Infer `@video-infer for video understanding`**", examples=examples, textbox=gr.MultimodalTextbox( label="Query Input", file_types=["image", "video"], file_count="multiple" ), stop_btn="Stop Generation", multimodal=True, additional_inputs=[ gr.Radio( ["Top P Sampling", "Greedy"], value="Greedy", label="Decoding strategy", info="Higher values is equivalent to sampling more low-probability tokens.", ), gr.Slider( minimum=0.0, maximum=5.0, value=0.4, step=0.1, interactive=True, label="Sampling temperature", info="Higher values will produce more diverse outputs.", ), gr.Slider( minimum=8, maximum=1024, value=512, step=1, interactive=True, label="Maximum number of new tokens to generate", ), gr.Slider( minimum=0.01, maximum=5.0, value=1.2, step=0.01, interactive=True, label="Repetition penalty", info="1.0 is equivalent to no penalty", ), gr.Slider( minimum=0.01, maximum=0.99, value=0.8, step=0.01, interactive=True, label="Top P", info="Higher values is equivalent to sampling more low-probability tokens.", ) ], cache_examples=False ) demo.launch(debug=True)