Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import uuid | |
import json | |
import time | |
import re | |
from threading import Thread | |
from datetime import datetime, timedelta | |
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from huggingface_hub import hf_hub_download | |
# ----------------------------------------------------------------------------- | |
# Constants & Device Setup | |
# ----------------------------------------------------------------------------- | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# ----------------------------------------------------------------------------- | |
# Helper Functions | |
# ----------------------------------------------------------------------------- | |
def progress_bar_html(label: str) -> str: | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: #00FF00; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
def load_system_prompt(repo_id: str, filename: str) -> str: | |
""" | |
Download and load a system prompt template from the given Hugging Face repo. | |
The template may include placeholders (e.g. {name}, {today}, {yesterday}) that get formatted. | |
""" | |
file_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
with open(file_path, "r") as file: | |
system_prompt = file.read() | |
today = datetime.today().strftime("%Y-%m-%d") | |
yesterday = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d") | |
model_name = repo_id.split("/")[-1] | |
return system_prompt.format(name=model_name, today=today, yesterday=yesterday) | |
def downsample_video(video_path: str): | |
""" | |
Extracts 10 evenly spaced frames from the video. | |
Returns a list of tuples (PIL.Image, timestamp_in_seconds). | |
""" | |
vidcap = cv2.VideoCapture(video_path) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
frames = [] | |
if total_frames > 0 and fps > 0: | |
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) | |
for i in frame_indices: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
success, image = vidcap.read() | |
if success: | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(image) | |
timestamp = round(i / fps, 2) | |
frames.append((pil_image, timestamp)) | |
vidcap.release() | |
return frames | |
def build_prompt(chat_history, current_input_text, video_frames=None, image_files=None): | |
""" | |
Build a conversation prompt string. | |
The system prompt is added first, then previous chat history, and finally the current input. | |
If video_frames or image_files are provided, a note is added in the prompt. | |
""" | |
prompt = f"System: {SYSTEM_PROMPT}\n" | |
# Append chat history (if any) | |
for msg in chat_history: | |
role = msg.get("role", "").capitalize() | |
content = msg.get("content", "") | |
prompt += f"{role}: {content}\n" | |
prompt += f"User: {current_input_text}\n" | |
if video_frames: | |
for _, timestamp in video_frames: | |
prompt += f"[Video Frame at {timestamp} sec]\n" | |
if image_files: | |
for _ in image_files: | |
prompt += "[Image Input]\n" | |
prompt += "Assistant: " | |
return prompt | |
# ----------------------------------------------------------------------------- | |
# Load Mistral Model & System Prompt | |
# ----------------------------------------------------------------------------- | |
MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" | |
SYSTEM_PROMPT = load_system_prompt(MODEL_ID, "SYSTEM_PROMPT.txt") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
).to(device) | |
model.eval() | |
# ----------------------------------------------------------------------------- | |
# Main Generation Function | |
# ----------------------------------------------------------------------------- | |
def generate( | |
input_dict: dict, | |
chat_history: list, | |
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
): | |
text = input_dict.get("text", "") | |
files = input_dict.get("files", []) | |
# Separate video files from images based on file extension. | |
video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm") | |
video_files = [f for f in files if str(f).lower().endswith(video_extensions)] | |
image_files = [f for f in files if not str(f).lower().endswith(video_extensions)] | |
video_frames = None | |
if video_files: | |
# Process the first video file. | |
video_path = video_files[0] | |
video_frames = downsample_video(video_path) | |
# Build the full prompt from the system prompt, chat history, current text, and file inputs. | |
prompt = build_prompt(chat_history, text, video_frames, image_files) | |
# Tokenize the prompt. | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Set up a streamer for incremental output. | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=20.0) | |
generation_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"max_new_tokens": max_new_tokens, | |
"do_sample": True, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"repetition_penalty": repetition_penalty, | |
"streamer": streamer, | |
} | |
# Launch generation in a separate thread. | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
yield progress_bar_html("Processing with Mistral") | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
# ----------------------------------------------------------------------------- | |
# Gradio Interface | |
# ----------------------------------------------------------------------------- | |
demo = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS), | |
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6), | |
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9), | |
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50), | |
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2), | |
], | |
examples=[ | |
[{"text": "Describe the content of the video.", "files": ["examples/sample_video.mp4"]}], | |
[{"text": "Explain what is in this image.", "files": ["examples/sample_image.jpg"]}], | |
["Tell me a fun fact about space."], | |
], | |
cache_examples=False, | |
type="messages", | |
description="# **Mistral Chatbot with Video Inference**\nA chatbot built with Mistral (via Transformers) that supports text, image, and video (frame extraction) inputs.", | |
fill_height=True, | |
textbox=gr.MultimodalTextbox( | |
label="Query Input", | |
file_types=["image", "video"], | |
file_count="multiple", | |
placeholder="Type your message here. Optionally attach images or video." | |
), | |
stop_btn="Stop Generation", | |
multimodal=True, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(share=True) | |