prithivMLmods's picture
Update app.py
5cf35b5 verified
raw
history blame
14.2 kB
import os
import random
import uuid
import json
import time
import asyncio
from threading import Thread
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
import cv2
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
Qwen2VLForConditionalGeneration,
AutoProcessor,
)
from transformers.image_utils import load_image
# Constants for text generation
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load text-only model and tokenizer
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16
).to("cuda").eval()
def clean_chat_history(chat_history):
"""
Filter out any chat entries whose "content" is not a string.
This helps prevent errors when concatenating previous messages.
"""
cleaned = []
for msg in chat_history:
if isinstance(msg, dict) and isinstance(msg.get("content"), str):
cleaned.append(msg)
return cleaned
def downsample_video(video_path):
"""
Downsamples the video to 10 evenly spaced frames.
Each frame is returned as a PIL image along with its timestamp.
"""
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = vidcap.get(cv2.CAP_PROP_FPS)
frames = []
# Sample 10 evenly spaced frames.
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
for i in frame_indices:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
vidcap.release()
return frames
def progress_bar_html(label: str) -> str:
"""
Returns an HTML snippet for a thin progress bar with a label.
The progress bar is styled as a dark red animated bar.
"""
return f'''
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
'''
@spaces.GPU(duration=60, enable_queue=True)
def generate(input_dict: dict, chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2):
"""
Generates chatbot responses with support for multimodal input and video processing.
Special command:
- "@video-infer": triggers video processing using Qwen2VL.
"""
text = input_dict["text"]
files = input_dict.get("files", [])
lower_text = text.strip().lower()
# Branch for video processing with Qwen2VL.
if lower_text.startswith("@video-infer"):
prompt = text[len("@video-infer"):].strip()
if files:
# Assume the first file is a video.
video_path = files[0]
frames = downsample_video(video_path)
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "text": prompt}]}
]
# Append each frame with its timestamp.
for frame in frames:
image, timestamp = frame
image_path = f"video_frame_{uuid.uuid4().hex}.png"
image.save(image_path)
messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
messages[1]["content"].append({"type": "image", "url": image_path})
else:
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "text": prompt}]}
]
inputs = processor.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
).to("cuda")
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
yield progress_bar_html("Processing video with Qwen2VL")
for new_text in streamer:
buffer += new_text
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer
return
# Normal text or multimodal conversation processing.
if files:
if len(files) > 1:
images = [load_image(image) for image in files]
elif len(files) == 1:
images = [load_image(files[0])]
else:
images = []
messages = [{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
]
}]
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda")
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
yield progress_bar_html("Thinking...")
for new_text in streamer:
buffer += new_text
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer
else:
conversation = clean_chat_history(chat_history)
conversation.append({"role": "user", "content": text})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
}
t = Thread(target=model.generate, kwargs=generation_kwargs)
t.start()
outputs = []
yield progress_bar_html("Processing...")
for new_text in streamer:
outputs.append(new_text)
yield "".join(outputs)
final_response = "".join(outputs)
yield final_response
# Custom CSS for the theme (combining font and global styles)
custom_css = """
/* ------------------ START FONTS.CSS ------------------ */
/* GT Maru Regular */
@font-face {
font-family: 'GT-Maru';
src: url('/fonts/GT Maru/GT-Maru-Regular.otf') format('opentype');
font-weight: 400;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'GT-Maru';
src: url('/fonts/GT Maru/GT-Maru-Regular-Oblique.otf') format('opentype');
font-weight: 400;
font-style: italic;
font-display: swap;
}
/* GT Maru Light */
@font-face {
font-family: 'GT-Maru';
src: url('/fonts/GT Maru/GT-Maru-Light.otf') format('opentype');
font-weight: 300;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'GT-Maru';
src: url('/fonts/GT Maru/GT-Maru-Light-Oblique.otf') format('opentype');
font-weight: 300;
font-style: italic;
font-display: swap;
}
/* GT Maru Medium */
@font-face {
font-family: 'GT-Maru';
src: url('/fonts/GT Maru/GT-Maru-Medium.otf') format('opentype');
font-weight: 500;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'GT-Maru';
src: url('/fonts/GT Maru/GT-Maru-Medium-Oblique.otf') format('opentype');
font-weight: 500;
font-style: italic;
font-display: swap;
}
/* GT Maru Bold */
@font-face {
font-family: 'GT-Maru';
src: url('/fonts/GT Maru/GT-Maru-Bold.otf') format('opentype');
font-weight: 700;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'GT-Maru';
src: url('/fonts/GT Maru/GT-Maru-Bold-Oblique.otf') format('opentype');
font-weight: 700;
font-style: italic;
font-display: swap;
}
/* GT Maru Black */
@font-face {
font-family: 'GT-Maru';
src: url('/fonts/GT Maru/GT-Maru-Black.otf') format('opentype');
font-weight: 900;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'GT-Maru';
src: url('/fonts/GT Maru/GT-Maru-Black-Oblique.otf') format('opentype');
font-weight: 900;
font-style: italic;
font-display: swap;
}
/* GT Maru Mono Regular */
@font-face {
font-family: 'GT-Maru-Mono';
src: url('/fonts/GT Maru/GT-Maru-Mono-Regular.otf') format('opentype');
font-weight: 400;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'GT-Maru-Mono';
src: url('/fonts/GT Maru/GT-Maru-Mono-Regular-Oblique.otf') format('opentype');
font-weight: 400;
font-style: italic;
font-display: swap;
}
/* GT Maru Mono Medium */
@font-face {
font-family: 'GT-Maru-Mono';
src: url('/fonts/GT Maru/GT-Maru-Mono-Medium.otf') format('opentype');
font-weight: 500;
font-style: normal;
font-display: swap;
}
@font-face {
font-family: 'GT-Maru-Mono';
src: url('/fonts/GT Maru/GT-Maru-Mono-Medium-Oblique.otf') format('opentype');
font-weight: 500;
font-style: italic;
font-display: swap;
}
/* GT Maru Mega Midi (for special display text) */
@font-face {
font-family: 'GT-Maru-Mega-Midi';
src: url('/fonts/GT Maru/GT-Maru-Mega-Midi.otf') format('opentype');
font-weight: normal;
font-style: normal;
font-display: swap;
}
/* ------------------ END FONTS.CSS ------------------ */
/* ------------------ START GLOBAL.CSS ------------------ */
@import "tailwindcss";
:root {
--background: #ffffff;
--foreground: #171717;
--font-gt-maru: 'GT-Maru', sans-serif;
--font-gt-maru-mono: 'GT-Maru-Mono', monospace;
--font-gt-maru-mega: 'GT-Maru-Mega-Midi', sans-serif;
--notebook-bg: #f3f4f6; /* Tailwind gray-100 */
--notebook-dot: #d1d5db; /* Tailwind gray-300 for dots */
}
@theme inline {
--color-background: var(--background);
--color-foreground: var(--foreground);
--font-sans: var(--font-gt-maru);
--font-mono: var(--font-gt-maru-mono);
}
@media (prefers-color-scheme: dark) {
:root {
--background: #0a0a0a;
--foreground: #ededed;
}
}
body {
color: var(--foreground);
font-family: var(--font-gt-maru);
background-color: var(--notebook-bg);
background-image: radial-gradient(var(--notebook-dot) 1px, transparent 1px);
background-size: 24px 24px;
}
/* Container styling for paper effect */
.paper-container {
background-color: rgba(255, 255, 255, 0.7);
border-radius: 12px;
box-shadow:
0 2px 8px rgba(0, 0, 0, 0.05),
0 8px 20px rgba(0, 0, 0, 0.03);
backdrop-filter: blur(2px);
}
/* Notebook paper dotted background - now for specific elements only */
.notebook-paper-bg {
background-color: var(--notebook-bg);
background-image: radial-gradient(var(--notebook-dot) 1px, transparent 1px);
background-size: 24px 24px;
background-position: 0 0;
border-radius: 12px;
box-shadow: inset 0 0 30px rgba(0, 0, 0, 0.03);
}
/* Shadow effect for paper depth */
.paper-shadow {
box-shadow:
0 1px 3px rgba(0, 0, 0, 0.05),
0 4px 6px rgba(0, 0, 0, 0.03),
inset 0 0 3px rgba(255, 255, 255, 0.8);
}
/* ------------------ END GLOBAL.CSS ------------------ */
"""
# Create the Gradio ChatInterface with the custom CSS applied
demo = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
],
cache_examples=False,
type="messages",
fill_height=True,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
stop_btn="Stop Generation",
multimodal=True,
css=custom_css # Apply our custom theme CSS
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)