prithivMLmods's picture
Update app.py
a9be97c verified
raw
history blame
7.25 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
from threading import Thread
import time
from PIL import Image
import torch
import spaces
import cv2
import numpy as np
# Helper: Downsample video to extract a fixed number of frames.
def downsample_video(video_path, num_frames=10):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
# Calculate evenly spaced frame indices.
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
# Convert BGR to RGB and then to a PIL image.
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
cap.release()
return frames
# Load processor and model.
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
model = AutoModelForVision2Seq.from_pretrained(
"HuggingFaceTB/SmolVLM-Instruct",
torch_dtype=torch.bfloat16,
).to("cuda")
@spaces.GPU
def model_inference(
input_dict, history, decoding_strategy, temperature, max_new_tokens,
repetition_penalty, top_p
):
text = input_dict["text"]
# --- Video Inference Branch ---
if text.lower().startswith("@video-infer"):
# Remove the command prefix to get the prompt.
prompt_text = text[len("@video-infer"):].strip()
if not input_dict["files"]:
yield "Error: Please provide a video file for @video-infer."
return
# Assume the first file is a video.
video_file = input_dict["files"][0]
frames = downsample_video(video_file)
if not frames:
yield "Error: Could not extract frames from the video."
return
# Build a chat content: include the user prompt and then each frame labeled.
content = [{"type": "text", "text": prompt_text}]
for idx, frame in enumerate(frames):
content.append({"type": "text", "text": f"Frame {idx+1}:"})
content.append({"type": "image", "image": frame})
resulting_messages = [{
"role": "user",
"content": content
}]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
# Process the extracted frames as images.
inputs = processor(text=prompt, images=[frames], return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Setup generation parameters.
generation_args = {
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
}
assert decoding_strategy in ["Greedy", "Top P Sampling"]
if decoding_strategy == "Greedy":
generation_args["do_sample"] = False
elif decoding_strategy == "Top P Sampling":
generation_args["temperature"] = temperature
generation_args["do_sample"] = True
generation_args["top_p"] = top_p
generation_args.update(inputs)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
buffer = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
yield "..."
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
return
# --- Default Image Inference Branch ---
# Process input images if provided.
if len(input_dict["files"]) > 1:
images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
elif len(input_dict["files"]) == 1:
images = [Image.open(input_dict["files"][0]).convert("RGB")]
else:
images = []
# Validate input.
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
if text == "" and images:
gr.Error("Please input a text query along with the image(s).")
resulting_messages = [{
"role": "user",
"content": [{"type": "image"} for _ in range(len(images))] + [
{"type": "text", "text": text}
]
}]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[images], return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
generation_args = {
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
}
assert decoding_strategy in ["Greedy", "Top P Sampling"]
if decoding_strategy == "Greedy":
generation_args["do_sample"] = False
elif decoding_strategy == "Top P Sampling":
generation_args["temperature"] = temperature
generation_args["do_sample"] = True
generation_args["top_p"] = top_p
generation_args.update(inputs)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
buffer = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
yield "..."
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
# Gradio ChatInterface: Allow both image and video file types.
demo = gr.ChatInterface(
fn=model_inference,
description="# **SmolVLM Video Infer**",
textbox=gr.MultimodalTextbox(
label="Query Input",
file_types=["image", "video"],
file_count="multiple"
),
stop_btn="Stop Generation",
multimodal=True,
additional_inputs=[
gr.Radio(
["Top P Sampling", "Greedy"],
value="Greedy",
label="Decoding strategy",
info="Higher values is equivalent to sampling more low-probability tokens.",
),
gr.Slider(
minimum=0.0,
maximum=5.0,
value=0.4,
step=0.1,
interactive=True,
label="Sampling temperature",
info="Higher values will produce more diverse outputs.",
),
gr.Slider(
minimum=8,
maximum=1024,
value=512,
step=1,
interactive=True,
label="Maximum number of new tokens to generate",
),
gr.Slider(
minimum=0.01,
maximum=5.0,
value=1.2,
step=0.01,
interactive=True,
label="Repetition penalty",
info="1.0 is equivalent to no penalty",
),
gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.8,
step=0.01,
interactive=True,
label="Top P",
info="Higher values is equivalent to sampling more low-probability tokens.",
)
],
cache_examples=False
)
demo.launch(debug=True)