prithivMLmods's picture
Update app.py
280b089 verified
raw
history blame
8.84 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
from threading import Thread
import time
from PIL import Image
import torch
import spaces
import cv2
import numpy as np
# Helper function to return a progress bar HTML snippet.
def progress_bar_html(label: str) -> str:
return f'''
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: #FFB6C1; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
'''
#adding examples
examples=[
[{"text": "Explain the Image", "files": ["examples/3.jpg"]}],
[{"text": "Transcription of the letter", "files": ["examples/222.png"]}],
[{"text": "@video-infer Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}],
[{"text": "@video-infer Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}],
[{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
[{"text": "@video-infer Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
[{"text": "@video-infer Summarize the events in this video", "files": ["examples/sky.mp4"]}],
[{"text": "@video-infer What is in the video ?", "files": ["examples/redlight.mp4"]}],
]
# Helper: Downsample video to extract a fixed number of frames.
def downsample_video(video_path, num_frames=10):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
# Calculate evenly spaced frame indices.
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
# Convert BGR to RGB and then to a PIL image.
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
cap.release()
return frames
# Load processor and model.
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
model = AutoModelForVision2Seq.from_pretrained(
"HuggingFaceTB/SmolVLM-Instruct",
torch_dtype=torch.bfloat16,
).to("cuda")
@spaces.GPU
def model_inference(
input_dict, history, decoding_strategy, temperature, max_new_tokens,
repetition_penalty, top_p
):
text = input_dict["text"]
# --- Video Inference Branch ---
if text.lower().startswith("@video-infer"):
# Remove the command prefix to get the prompt.
prompt_text = text[len("@video-infer"):].strip()
if not input_dict["files"]:
yield "Error: Please provide a video file for @video-infer."
return
# Assume the first file is a video.
video_file = input_dict["files"][0]
frames = downsample_video(video_file)
if not frames:
yield "Error: Could not extract frames from the video."
return
# Build a chat content: include the user prompt and then each frame labeled.
content = [{"type": "text", "text": prompt_text}]
for idx, frame in enumerate(frames):
content.append({"type": "text", "text": f"Frame {idx+1}:"})
content.append({"type": "image", "image": frame})
resulting_messages = [{
"role": "user",
"content": content
}]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
# Process the extracted frames as images.
inputs = processor(text=prompt, images=[frames], return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Setup generation parameters.
generation_args = {
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
}
assert decoding_strategy in ["Greedy", "Top P Sampling"]
if decoding_strategy == "Greedy":
generation_args["do_sample"] = False
elif decoding_strategy == "Top P Sampling":
generation_args["temperature"] = temperature
generation_args["do_sample"] = True
generation_args["top_p"] = top_p
generation_args.update(inputs)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
buffer = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
yield progress_bar_html("Processing Video with SmolVLM")
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
return
# --- Default Image Inference Branch ---
# Process input images if provided.
if len(input_dict["files"]) > 1:
images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
elif len(input_dict["files"]) == 1:
images = [Image.open(input_dict["files"][0]).convert("RGB")]
else:
images = []
# Validate input.
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
if text == "" and images:
gr.Error("Please input a text query along with the image(s).")
resulting_messages = [{
"role": "user",
"content": [{"type": "image"} for _ in range(len(images))] + [
{"type": "text", "text": text}
]
}]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[images], return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
generation_args = {
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
}
assert decoding_strategy in ["Greedy", "Top P Sampling"]
if decoding_strategy == "Greedy":
generation_args["do_sample"] = False
elif decoding_strategy == "Top P Sampling":
generation_args["temperature"] = temperature
generation_args["do_sample"] = True
generation_args["top_p"] = top_p
generation_args.update(inputs)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
buffer = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
yield progress_bar_html("Processing Video with SmolVLM")
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
# Gradio ChatInterface: Allow both image and video file types.
demo = gr.ChatInterface(
fn=model_inference,
description="# **SmolVLM Video Infer `@video-infer for video understanding`**",
examples=examples,
textbox=gr.MultimodalTextbox(
label="Query Input",
file_types=["image", "video"],
file_count="multiple"
),
stop_btn="Stop Generation",
multimodal=True,
additional_inputs=[
gr.Radio(
["Top P Sampling", "Greedy"],
value="Greedy",
label="Decoding strategy",
info="Higher values is equivalent to sampling more low-probability tokens.",
),
gr.Slider(
minimum=0.0,
maximum=5.0,
value=0.4,
step=0.1,
interactive=True,
label="Sampling temperature",
info="Higher values will produce more diverse outputs.",
),
gr.Slider(
minimum=8,
maximum=1024,
value=512,
step=1,
interactive=True,
label="Maximum number of new tokens to generate",
),
gr.Slider(
minimum=0.01,
maximum=5.0,
value=1.2,
step=0.01,
interactive=True,
label="Repetition penalty",
info="1.0 is equivalent to no penalty",
),
gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.8,
step=0.01,
interactive=True,
label="Top P",
info="Higher values is equivalent to sampling more low-probability tokens.",
)
],
cache_examples=False
)
demo.launch(debug=True)