Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer | |
from threading import Thread | |
import time | |
from PIL import Image | |
import torch | |
import spaces | |
import cv2 | |
import numpy as np | |
# Helper function to return a progress bar HTML snippet. | |
def progress_bar_html(label: str) -> str: | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: #FFB6C1; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
#adding examples | |
examples=[ | |
[{"text": "Explain the Image", "files": ["examples/3.jpg"]}], | |
[{"text": "Transcription of the letter", "files": ["examples/222.png"]}], | |
[{"text": "@video-infer Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}], | |
[{"text": "@video-infer Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}], | |
[{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}], | |
[{"text": "@video-infer Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}], | |
[{"text": "@video-infer Summarize the events in this video", "files": ["examples/sky.mp4"]}], | |
[{"text": "@video-infer What is in the video ?", "files": ["examples/redlight.mp4"]}], | |
] | |
# Helper: Downsample video to extract a fixed number of frames. | |
def downsample_video(video_path, num_frames=10): | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
# Calculate evenly spaced frame indices. | |
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
frames = [] | |
for idx in frame_indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
ret, frame = cap.read() | |
if ret: | |
# Convert BGR to RGB and then to a PIL image. | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame = Image.fromarray(frame) | |
frames.append(frame) | |
cap.release() | |
return frames | |
# Load processor and model. | |
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct") | |
model = AutoModelForVision2Seq.from_pretrained( | |
"HuggingFaceTB/SmolVLM-Instruct", | |
torch_dtype=torch.bfloat16, | |
).to("cuda") | |
def model_inference( | |
input_dict, history, decoding_strategy, temperature, max_new_tokens, | |
repetition_penalty, top_p | |
): | |
text = input_dict["text"] | |
# --- Video Inference Branch --- | |
if text.lower().startswith("@video-infer"): | |
# Remove the command prefix to get the prompt. | |
prompt_text = text[len("@video-infer"):].strip() | |
if not input_dict["files"]: | |
yield "Error: Please provide a video file for @video-infer." | |
return | |
# Assume the first file is a video. | |
video_file = input_dict["files"][0] | |
frames = downsample_video(video_file) | |
if not frames: | |
yield "Error: Could not extract frames from the video." | |
return | |
# Build a chat content: include the user prompt and then each frame labeled. | |
content = [{"type": "text", "text": prompt_text}] | |
for idx, frame in enumerate(frames): | |
content.append({"type": "text", "text": f"Frame {idx+1}:"}) | |
content.append({"type": "image", "image": frame}) | |
resulting_messages = [{ | |
"role": "user", | |
"content": content | |
}] | |
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) | |
# Process the extracted frames as images. | |
inputs = processor(text=prompt, images=[frames], return_tensors="pt") | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
# Setup generation parameters. | |
generation_args = { | |
"max_new_tokens": max_new_tokens, | |
"repetition_penalty": repetition_penalty, | |
} | |
assert decoding_strategy in ["Greedy", "Top P Sampling"] | |
if decoding_strategy == "Greedy": | |
generation_args["do_sample"] = False | |
elif decoding_strategy == "Top P Sampling": | |
generation_args["temperature"] = temperature | |
generation_args["do_sample"] = True | |
generation_args["top_p"] = top_p | |
generation_args.update(inputs) | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) | |
buffer = "" | |
thread = Thread(target=model.generate, kwargs=generation_args) | |
thread.start() | |
yield progress_bar_html("Processing Video with SmolVLM") | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
return | |
# --- Default Image Inference Branch --- | |
# Process input images if provided. | |
if len(input_dict["files"]) > 1: | |
images = [Image.open(image).convert("RGB") for image in input_dict["files"]] | |
elif len(input_dict["files"]) == 1: | |
images = [Image.open(input_dict["files"][0]).convert("RGB")] | |
else: | |
images = [] | |
# Validate input. | |
if text == "" and not images: | |
gr.Error("Please input a query and optionally image(s).") | |
if text == "" and images: | |
gr.Error("Please input a text query along with the image(s).") | |
resulting_messages = [{ | |
"role": "user", | |
"content": [{"type": "image"} for _ in range(len(images))] + [ | |
{"type": "text", "text": text} | |
] | |
}] | |
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) | |
inputs = processor(text=prompt, images=[images], return_tensors="pt") | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
generation_args = { | |
"max_new_tokens": max_new_tokens, | |
"repetition_penalty": repetition_penalty, | |
} | |
assert decoding_strategy in ["Greedy", "Top P Sampling"] | |
if decoding_strategy == "Greedy": | |
generation_args["do_sample"] = False | |
elif decoding_strategy == "Top P Sampling": | |
generation_args["temperature"] = temperature | |
generation_args["do_sample"] = True | |
generation_args["top_p"] = top_p | |
generation_args.update(inputs) | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) | |
buffer = "" | |
thread = Thread(target=model.generate, kwargs=generation_args) | |
thread.start() | |
yield progress_bar_html("Processing Video with SmolVLM") | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
# Gradio ChatInterface: Allow both image and video file types. | |
demo = gr.ChatInterface( | |
fn=model_inference, | |
description="# **SmolVLM Video Infer `@video-infer for video understanding`**", | |
examples=examples, | |
textbox=gr.MultimodalTextbox( | |
label="Query Input", | |
file_types=["image", "video"], | |
file_count="multiple" | |
), | |
stop_btn="Stop Generation", | |
multimodal=True, | |
additional_inputs=[ | |
gr.Radio( | |
["Top P Sampling", "Greedy"], | |
value="Greedy", | |
label="Decoding strategy", | |
info="Higher values is equivalent to sampling more low-probability tokens.", | |
), | |
gr.Slider( | |
minimum=0.0, | |
maximum=5.0, | |
value=0.4, | |
step=0.1, | |
interactive=True, | |
label="Sampling temperature", | |
info="Higher values will produce more diverse outputs.", | |
), | |
gr.Slider( | |
minimum=8, | |
maximum=1024, | |
value=512, | |
step=1, | |
interactive=True, | |
label="Maximum number of new tokens to generate", | |
), | |
gr.Slider( | |
minimum=0.01, | |
maximum=5.0, | |
value=1.2, | |
step=0.01, | |
interactive=True, | |
label="Repetition penalty", | |
info="1.0 is equivalent to no penalty", | |
), | |
gr.Slider( | |
minimum=0.01, | |
maximum=0.99, | |
value=0.8, | |
step=0.01, | |
interactive=True, | |
label="Top P", | |
info="Higher values is equivalent to sampling more low-probability tokens.", | |
) | |
], | |
cache_examples=False | |
) | |
demo.launch(debug=True) |