import gradio as gr
from threading import Thread
import time
from PIL import Image
import torch
import spaces
import cv2
import numpy as np
from transformers import (
AutoProcessor,
AutoModelForVision2Seq,
TextIteratorStreamer
)
# Helper function to return a progress bar HTML snippet.
def progress_bar_html(label: str) -> str:
return f'''
'''
#adding examples
examples=[
[{"text": "Explain the Image", "files": ["examples/3.jpg"]}],
[{"text": "Transcription of the letter", "files": ["examples/222.png"]}],
[{"text": "@video-infer Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}],
[{"text": "@video-infer Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}],
[{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
[{"text": "@video-infer Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
[{"text": "@video-infer Summarize the events in this video", "files": ["examples/sky.mp4"]}],
[{"text": "@video-infer What is in the video ?", "files": ["examples/redlight.mp4"]}],
]
# Helper: Downsample video to extract a fixed number of frames.
def downsample_video(video_path, num_frames=10):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
# Calculate evenly spaced frame indices.
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
# Convert BGR to RGB and then to a PIL image.
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
cap.release()
return frames
# Load processor and model.
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
model = AutoModelForVision2Seq.from_pretrained(
"HuggingFaceTB/SmolVLM-Instruct",
torch_dtype=torch.bfloat16,
).to("cuda")
@spaces.GPU
def model_inference(
input_dict, history, decoding_strategy, temperature, max_new_tokens,
repetition_penalty, top_p
):
text = input_dict["text"]
# --- Video Inference Branch ---
if text.lower().startswith("@video-infer"):
# Remove the command prefix to get the prompt.
prompt_text = text[len("@video-infer"):].strip()
if not input_dict["files"]:
yield "Error: Please provide a video file for @video-infer."
return
# Assume the first file is a video.
video_file = input_dict["files"][0]
frames = downsample_video(video_file)
if not frames:
yield "Error: Could not extract frames from the video."
return
# Build a chat content: include the user prompt and then each frame labeled.
content = [{"type": "text", "text": prompt_text}]
for idx, frame in enumerate(frames):
content.append({"type": "text", "text": f"Frame {idx+1}:"})
content.append({"type": "image", "image": frame})
resulting_messages = [{
"role": "user",
"content": content
}]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
# Process the extracted frames as images.
inputs = processor(text=prompt, images=[frames], return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Setup generation parameters.
generation_args = {
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
}
assert decoding_strategy in ["Greedy", "Top P Sampling"]
if decoding_strategy == "Greedy":
generation_args["do_sample"] = False
elif decoding_strategy == "Top P Sampling":
generation_args["temperature"] = temperature
generation_args["do_sample"] = True
generation_args["top_p"] = top_p
generation_args.update(inputs)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
buffer = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
yield progress_bar_html("Processing Video with SmolVLM")
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
return
# --- Default Image Inference Branch ---
# Process input images if provided.
if len(input_dict["files"]) > 1:
images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
elif len(input_dict["files"]) == 1:
images = [Image.open(input_dict["files"][0]).convert("RGB")]
else:
images = []
# Validate input.
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
if text == "" and images:
gr.Error("Please input a text query along with the image(s).")
resulting_messages = [{
"role": "user",
"content": [{"type": "image"} for _ in range(len(images))] + [
{"type": "text", "text": text}
]
}]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[images], return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
generation_args = {
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
}
assert decoding_strategy in ["Greedy", "Top P Sampling"]
if decoding_strategy == "Greedy":
generation_args["do_sample"] = False
elif decoding_strategy == "Top P Sampling":
generation_args["temperature"] = temperature
generation_args["do_sample"] = True
generation_args["top_p"] = top_p
generation_args.update(inputs)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
buffer = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
yield progress_bar_html("Processing Video with SmolVLM")
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
# Gradio ChatInterface: Allow both image and video file types.
demo = gr.ChatInterface(
fn=model_inference,
description="# **SmolVLM Video Infer `@video-infer for video understanding`**",
examples=examples,
textbox=gr.MultimodalTextbox(
label="Query Input",
file_types=["image", "video"],
file_count="multiple"
),
stop_btn="Stop Generation",
multimodal=True,
additional_inputs=[
gr.Radio(
["Top P Sampling", "Greedy"],
value="Greedy",
label="Decoding strategy",
info="Higher values is equivalent to sampling more low-probability tokens.",
),
gr.Slider(
minimum=0.0,
maximum=5.0,
value=0.4,
step=0.1,
interactive=True,
label="Sampling temperature",
info="Higher values will produce more diverse outputs.",
),
gr.Slider(
minimum=8,
maximum=1024,
value=512,
step=1,
interactive=True,
label="Maximum number of new tokens to generate",
),
gr.Slider(
minimum=0.01,
maximum=5.0,
value=1.2,
step=0.01,
interactive=True,
label="Repetition penalty",
info="1.0 is equivalent to no penalty",
),
gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.8,
step=0.01,
interactive=True,
label="Top P",
info="Higher values is equivalent to sampling more low-probability tokens.",
)
],
cache_examples=False
)
demo.launch(debug=True)