Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import uuid | |
import json | |
import time | |
import asyncio | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
Qwen2VLForConditionalGeneration, | |
AutoProcessor, | |
) | |
from transformers.image_utils import load_image | |
# Constants for text generation | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load text-only model and tokenizer | |
model_id = "prithivMLmods/FastThink-0.5B-Tiny" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
) | |
model.eval() | |
MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" | |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model_m = Qwen2VLForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).to("cuda").eval() | |
def clean_chat_history(chat_history): | |
""" | |
Filter out any chat entries whose "content" is not a string. | |
This helps prevent errors when concatenating previous messages. | |
""" | |
cleaned = [] | |
for msg in chat_history: | |
if isinstance(msg, dict) and isinstance(msg.get("content"), str): | |
cleaned.append(msg) | |
return cleaned | |
def downsample_video(video_path): | |
""" | |
Downsamples the video to 10 evenly spaced frames. | |
Each frame is returned as a PIL image along with its timestamp. | |
""" | |
vidcap = cv2.VideoCapture(video_path) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
frames = [] | |
# Sample 10 evenly spaced frames. | |
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) | |
for i in frame_indices: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
success, image = vidcap.read() | |
if success: | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB | |
pil_image = Image.fromarray(image) | |
timestamp = round(i / fps, 2) | |
frames.append((pil_image, timestamp)) | |
vidcap.release() | |
return frames | |
def progress_bar_html(label: str) -> str: | |
""" | |
Returns an HTML snippet for a thin progress bar with a label. | |
The progress bar is styled as a dark red animated bar. | |
""" | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
def generate(input_dict: dict, chat_history: list[dict], | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2): | |
""" | |
Generates chatbot responses with support for multimodal input and video processing. | |
Special command: | |
- "@video-infer": triggers video processing using Qwen2VL. | |
""" | |
text = input_dict["text"] | |
files = input_dict.get("files", []) | |
lower_text = text.strip().lower() | |
# Branch for video processing with Qwen2VL. | |
if lower_text.startswith("@video-infer"): | |
prompt = text[len("@video-infer"):].strip() | |
if files: | |
# Assume the first file is a video. | |
video_path = files[0] | |
frames = downsample_video(video_path) | |
messages = [ | |
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, | |
{"role": "user", "content": [{"type": "text", "text": prompt}]} | |
] | |
# Append each frame with its timestamp. | |
for frame in frames: | |
image, timestamp = frame | |
image_path = f"video_frame_{uuid.uuid4().hex}.png" | |
image.save(image_path) | |
messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"}) | |
messages[1]["content"].append({"type": "image", "url": image_path}) | |
else: | |
messages = [ | |
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, | |
{"role": "user", "content": [{"type": "text", "text": prompt}]} | |
] | |
inputs = processor.apply_chat_template( | |
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" | |
).to("cuda") | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = { | |
**inputs, | |
"streamer": streamer, | |
"max_new_tokens": max_new_tokens, | |
"do_sample": True, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"repetition_penalty": repetition_penalty, | |
} | |
thread = Thread(target=model_m.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
yield progress_bar_html("Processing video with Qwen2VL") | |
for new_text in streamer: | |
buffer += new_text | |
buffer = buffer.replace("<|im_end|>", "") | |
time.sleep(0.01) | |
yield buffer | |
return | |
# Normal text or multimodal conversation processing. | |
if files: | |
if len(files) > 1: | |
images = [load_image(image) for image in files] | |
elif len(files) == 1: | |
images = [load_image(files[0])] | |
else: | |
images = [] | |
messages = [{ | |
"role": "user", | |
"content": [ | |
*[{"type": "image", "image": image} for image in images], | |
{"type": "text", "text": text}, | |
] | |
}] | |
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda") | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} | |
thread = Thread(target=model_m.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
yield progress_bar_html("Thinking...") | |
for new_text in streamer: | |
buffer += new_text | |
buffer = buffer.replace("<|im_end|>", "") | |
time.sleep(0.01) | |
yield buffer | |
else: | |
conversation = clean_chat_history(chat_history) | |
conversation.append({"role": "user", "content": text}) | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = { | |
"input_ids": input_ids, | |
"streamer": streamer, | |
"max_new_tokens": max_new_tokens, | |
"do_sample": True, | |
"top_p": top_p, | |
"top_k": top_k, | |
"temperature": temperature, | |
"num_beams": 1, | |
"repetition_penalty": repetition_penalty, | |
} | |
t = Thread(target=model.generate, kwargs=generation_kwargs) | |
t.start() | |
outputs = [] | |
yield progress_bar_html("Processing...") | |
for new_text in streamer: | |
outputs.append(new_text) | |
yield "".join(outputs) | |
final_response = "".join(outputs) | |
yield final_response | |
# Custom CSS for the theme (combining font and global styles) | |
custom_css = """ | |
/* ------------------ START FONTS.CSS ------------------ */ | |
/* GT Maru Regular */ | |
@font-face { | |
font-family: 'GT-Maru'; | |
src: url('/fonts/GT Maru/GT-Maru-Regular.otf') format('opentype'); | |
font-weight: 400; | |
font-style: normal; | |
font-display: swap; | |
} | |
@font-face { | |
font-family: 'GT-Maru'; | |
src: url('/fonts/GT Maru/GT-Maru-Regular-Oblique.otf') format('opentype'); | |
font-weight: 400; | |
font-style: italic; | |
font-display: swap; | |
} | |
/* GT Maru Light */ | |
@font-face { | |
font-family: 'GT-Maru'; | |
src: url('/fonts/GT Maru/GT-Maru-Light.otf') format('opentype'); | |
font-weight: 300; | |
font-style: normal; | |
font-display: swap; | |
} | |
@font-face { | |
font-family: 'GT-Maru'; | |
src: url('/fonts/GT Maru/GT-Maru-Light-Oblique.otf') format('opentype'); | |
font-weight: 300; | |
font-style: italic; | |
font-display: swap; | |
} | |
/* GT Maru Medium */ | |
@font-face { | |
font-family: 'GT-Maru'; | |
src: url('/fonts/GT Maru/GT-Maru-Medium.otf') format('opentype'); | |
font-weight: 500; | |
font-style: normal; | |
font-display: swap; | |
} | |
@font-face { | |
font-family: 'GT-Maru'; | |
src: url('/fonts/GT Maru/GT-Maru-Medium-Oblique.otf') format('opentype'); | |
font-weight: 500; | |
font-style: italic; | |
font-display: swap; | |
} | |
/* GT Maru Bold */ | |
@font-face { | |
font-family: 'GT-Maru'; | |
src: url('/fonts/GT Maru/GT-Maru-Bold.otf') format('opentype'); | |
font-weight: 700; | |
font-style: normal; | |
font-display: swap; | |
} | |
@font-face { | |
font-family: 'GT-Maru'; | |
src: url('/fonts/GT Maru/GT-Maru-Bold-Oblique.otf') format('opentype'); | |
font-weight: 700; | |
font-style: italic; | |
font-display: swap; | |
} | |
/* GT Maru Black */ | |
@font-face { | |
font-family: 'GT-Maru'; | |
src: url('/fonts/GT Maru/GT-Maru-Black.otf') format('opentype'); | |
font-weight: 900; | |
font-style: normal; | |
font-display: swap; | |
} | |
@font-face { | |
font-family: 'GT-Maru'; | |
src: url('/fonts/GT Maru/GT-Maru-Black-Oblique.otf') format('opentype'); | |
font-weight: 900; | |
font-style: italic; | |
font-display: swap; | |
} | |
/* GT Maru Mono Regular */ | |
@font-face { | |
font-family: 'GT-Maru-Mono'; | |
src: url('/fonts/GT Maru/GT-Maru-Mono-Regular.otf') format('opentype'); | |
font-weight: 400; | |
font-style: normal; | |
font-display: swap; | |
} | |
@font-face { | |
font-family: 'GT-Maru-Mono'; | |
src: url('/fonts/GT Maru/GT-Maru-Mono-Regular-Oblique.otf') format('opentype'); | |
font-weight: 400; | |
font-style: italic; | |
font-display: swap; | |
} | |
/* GT Maru Mono Medium */ | |
@font-face { | |
font-family: 'GT-Maru-Mono'; | |
src: url('/fonts/GT Maru/GT-Maru-Mono-Medium.otf') format('opentype'); | |
font-weight: 500; | |
font-style: normal; | |
font-display: swap; | |
} | |
@font-face { | |
font-family: 'GT-Maru-Mono'; | |
src: url('/fonts/GT Maru/GT-Maru-Mono-Medium-Oblique.otf') format('opentype'); | |
font-weight: 500; | |
font-style: italic; | |
font-display: swap; | |
} | |
/* GT Maru Mega Midi (for special display text) */ | |
@font-face { | |
font-family: 'GT-Maru-Mega-Midi'; | |
src: url('/fonts/GT Maru/GT-Maru-Mega-Midi.otf') format('opentype'); | |
font-weight: normal; | |
font-style: normal; | |
font-display: swap; | |
} | |
/* ------------------ END FONTS.CSS ------------------ */ | |
/* ------------------ START GLOBAL.CSS ------------------ */ | |
@import "tailwindcss"; | |
:root { | |
--background: #ffffff; | |
--foreground: #171717; | |
--font-gt-maru: 'GT-Maru', sans-serif; | |
--font-gt-maru-mono: 'GT-Maru-Mono', monospace; | |
--font-gt-maru-mega: 'GT-Maru-Mega-Midi', sans-serif; | |
--notebook-bg: #f3f4f6; /* Tailwind gray-100 */ | |
--notebook-dot: #d1d5db; /* Tailwind gray-300 for dots */ | |
} | |
@theme inline { | |
--color-background: var(--background); | |
--color-foreground: var(--foreground); | |
--font-sans: var(--font-gt-maru); | |
--font-mono: var(--font-gt-maru-mono); | |
} | |
@media (prefers-color-scheme: dark) { | |
:root { | |
--background: #0a0a0a; | |
--foreground: #ededed; | |
} | |
} | |
body { | |
color: var(--foreground); | |
font-family: var(--font-gt-maru); | |
background-color: var(--notebook-bg); | |
background-image: radial-gradient(var(--notebook-dot) 1px, transparent 1px); | |
background-size: 24px 24px; | |
} | |
/* Container styling for paper effect */ | |
.paper-container { | |
background-color: rgba(255, 255, 255, 0.7); | |
border-radius: 12px; | |
box-shadow: | |
0 2px 8px rgba(0, 0, 0, 0.05), | |
0 8px 20px rgba(0, 0, 0, 0.03); | |
backdrop-filter: blur(2px); | |
} | |
/* Notebook paper dotted background - now for specific elements only */ | |
.notebook-paper-bg { | |
background-color: var(--notebook-bg); | |
background-image: radial-gradient(var(--notebook-dot) 1px, transparent 1px); | |
background-size: 24px 24px; | |
background-position: 0 0; | |
border-radius: 12px; | |
box-shadow: inset 0 0 30px rgba(0, 0, 0, 0.03); | |
} | |
/* Shadow effect for paper depth */ | |
.paper-shadow { | |
box-shadow: | |
0 1px 3px rgba(0, 0, 0, 0.05), | |
0 4px 6px rgba(0, 0, 0, 0.03), | |
inset 0 0 3px rgba(255, 255, 255, 0.8); | |
} | |
/* ------------------ END GLOBAL.CSS ------------------ */ | |
""" | |
# Create the Gradio ChatInterface with the custom CSS applied | |
demo = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS), | |
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6), | |
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9), | |
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50), | |
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2), | |
], | |
cache_examples=False, | |
type="messages", | |
fill_height=True, | |
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"), | |
stop_btn="Stop Generation", | |
multimodal=True, | |
css=custom_css # Apply our custom theme CSS | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(share=True) |