Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import uuid | |
import json | |
import time | |
import random | |
import asyncio | |
import cv2 | |
from datetime import datetime, timedelta | |
from threading import Thread | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from huggingface_hub import hf_hub_download | |
from vllm import LLM | |
from vllm.sampling_params import SamplingParams | |
# ----------------------------------------------------------------------------- | |
# Helper functions | |
# ----------------------------------------------------------------------------- | |
def progress_bar_html(label: str) -> str: | |
"""Return an HTML snippet for a progress bar.""" | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: #00FF00; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
def downsample_video(video_path: str, num_frames: int = 10): | |
""" | |
Downsample a video to extract a set number of evenly spaced frames. | |
Returns a list of tuples (PIL.Image, timestamp in seconds). | |
""" | |
vidcap = cv2.VideoCapture(video_path) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
frames = [] | |
if total_frames <= 0 or fps <= 0: | |
vidcap.release() | |
return frames | |
# Get evenly spaced frame indices. | |
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
for i in frame_indices: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
success, image = vidcap.read() | |
if success: | |
# Convert BGR to RGB and then to a PIL Image. | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(image) | |
timestamp = round(i / fps, 2) | |
frames.append((pil_image, timestamp)) | |
vidcap.release() | |
return frames | |
def load_system_prompt(repo_id: str, filename: str) -> str: | |
""" | |
Load the system prompt from the given Hugging Face Hub repo file, | |
and format it with the model name and current dates. | |
""" | |
file_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
with open(file_path, "r") as file: | |
system_prompt = file.read() | |
today = datetime.today().strftime("%Y-%m-%d") | |
yesterday = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d") | |
model_name = repo_id.split("/")[-1] | |
return system_prompt.format(name=model_name, today=today, yesterday=yesterday) | |
# ----------------------------------------------------------------------------- | |
# Global Settings and Model Initialization | |
# ----------------------------------------------------------------------------- | |
# Model details (adjust as needed) | |
MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" | |
# Load the system prompt from HF Hub (make sure SYSTEM_PROMPT.txt exists in the repo) | |
SYSTEM_PROMPT = load_system_prompt(MODEL_ID, "SYSTEM_PROMPT.txt") | |
# If you prefer a hardcoded system prompt, you can use: | |
# SYSTEM_PROMPT = "You are a conversational agent that always answers straight to the point, and ends with an ASCII cat." | |
# Initialize the Mistral LLM via vllm. | |
# Note: Running this model on GPU may require very high VRAM. | |
llm = LLM(model=MODEL_ID, tokenizer_mode="mistral") | |
# ----------------------------------------------------------------------------- | |
# Main Generation Function | |
# ----------------------------------------------------------------------------- | |
def generate( | |
input_dict: dict, | |
chat_history: list, | |
max_new_tokens: int = 512, | |
temperature: float = 0.15, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
): | |
""" | |
The main generation function for the Mistral chatbot. | |
It supports: | |
- Text-only inference. | |
- Image inference (attaches image file paths). | |
- Video inference (extracts and attaches sampled video frames). | |
""" | |
text = input_dict["text"] | |
files = input_dict.get("files", []) | |
# Prepare the conversation with a system prompt. | |
messages = [ | |
{"role": "system", "content": SYSTEM_PROMPT} | |
] | |
# Check if any file is provided | |
video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm") | |
if files: | |
# If any file is a video, use video inference branch. | |
if any(str(f).lower().endswith(video_extensions) for f in files): | |
# Remove any @video-infer tag if present. | |
prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"') | |
video_path = files[0] # currently process the first video file | |
frames = downsample_video(video_path) | |
# Build a list that contains the prompt plus each frame information. | |
user_content = [{"type": "text", "text": prompt_clean}] | |
for frame in frames: | |
image, timestamp = frame | |
# Save the frame to a temporary file. | |
image_path = f"video_frame_{uuid.uuid4().hex}.png" | |
image.save(image_path) | |
user_content.append({"type": "text", "text": f"Frame at {timestamp} seconds:"}) | |
user_content.append({"type": "image_path", "image_path": image_path}) | |
messages.append({"role": "user", "content": user_content}) | |
else: | |
# Assume provided files are images. | |
prompt_clean = re.sub(r"@mistral", "", text, flags=re.IGNORECASE).strip().strip('"') | |
user_content = [{"type": "text", "text": prompt_clean}] | |
for file in files: | |
try: | |
image = Image.open(file) | |
image_path = f"image_{uuid.uuid4().hex}.png" | |
image.save(image_path) | |
user_content.append({"type": "image_path", "image_path": image_path}) | |
except Exception as e: | |
user_content.append({"type": "text", "text": f"Could not open file {file}"}) | |
messages.append({"role": "user", "content": user_content}) | |
else: | |
# Text-only branch. | |
messages.append({"role": "user", "content": [{"type": "text", "text": text}]}) | |
# Show a progress bar before generating. | |
yield progress_bar_html("Processing with Mistral") | |
# Set up sampling parameters. | |
sampling_params = SamplingParams( | |
max_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k | |
) | |
# Run the chat (synchronously) using vllm. | |
outputs = llm.chat(messages, sampling_params=sampling_params) | |
final_response = outputs[0].outputs[0].text | |
# Simulate streaming output by chunking the result. | |
buffer = "" | |
chunk_size = 20 # number of characters per chunk | |
for i in range(0, len(final_response), chunk_size): | |
buffer = final_response[: i + chunk_size] | |
yield buffer | |
time.sleep(0.05) | |
return | |
# ----------------------------------------------------------------------------- | |
# Gradio Interface Setup | |
# ----------------------------------------------------------------------------- | |
demo = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Slider(label="Max new tokens", minimum=1, maximum=1024, step=1, value=512), | |
gr.Slider(label="Temperature", minimum=0.05, maximum=2.0, step=0.05, value=0.15), | |
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9), | |
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50), | |
], | |
examples=[ | |
# Example with text only. | |
["Explain the significance of today in the context of current events."], | |
# Example with image files (ensure you have valid image paths). | |
[{ | |
"text": "Describe what you see in the image.", | |
"files": ["examples/3.jpg"] | |
}], | |
# Example with video file (ensure you have a valid video file). | |
[{ | |
"text": "@video-infer Summarize the events shown in the video.", | |
"files": ["examples/sample_video.mp4"] | |
}], | |
], | |
cache_examples=False, | |
type="messages", | |
description="# **Mistral Multimodal Chatbot** \nSupports text, image (by reference) and video inference. Use @video-infer in your query when providing a video.", | |
fill_height=True, | |
textbox=gr.MultimodalTextbox( | |
label="Query Input", | |
file_types=["image", "video"], | |
file_count="multiple", | |
placeholder="Enter your query here. Tag with @video-infer if using a video file." | |
), | |
stop_btn="Stop Generation", | |
examples_per_page=3, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(share=True) |