import os import random import uuid import json import time import asyncio from threading import Thread import gradio as gr import spaces import torch import numpy as np from PIL import Image import cv2 import edge_tts from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, Qwen2VLForConditionalGeneration, AutoProcessor, ) from transformers.image_utils import load_image # Constants for text generation MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load text-only model and tokenizer (Pocket Llama) model_id = "prithivMLmods/Pocket-Llama2-3.2-3B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) model.eval() # Load multimodal processor and model (Callisto OCR3) MODEL_ID = "prithivMLmods/Callisto-OCR3-2B-Instruct" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model_m = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16 ).to("cuda").eval() # Edge TTS voices mapping for new tags. TTS_VOICE_MAP = { "@jennyneural": "en-US-JennyNeural", "@guyneural": "en-US-GuyNeural", "@palomaneural": "es-US-PalomaNeural", "@alonsoneural": "es-US-AlonsoNeural", "@madhurneural": "hi-IN-MadhurNeural" } async def text_to_speech(text: str, voice: str, output_file="output.mp3"): """ Convert text to speech using Edge TTS and save as MP3. """ communicate = edge_tts.Communicate(text, voice) await communicate.save(output_file) return output_file def clean_chat_history(chat_history): """ Filter out any chat entries whose "content" is not a string. This helps prevent errors when concatenating previous messages. """ cleaned = [] for msg in chat_history: if isinstance(msg, dict) and isinstance(msg.get("content"), str): cleaned.append(msg) return cleaned def downsample_video(video_path): """ Downsamples the video to 10 evenly spaced frames. Each frame is returned as a PIL image along with its timestamp. """ vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] # Sample 10 evenly spaced frames. frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames def progress_bar_html(label: str) -> str: """ Returns an HTML snippet for a thin progress bar with a label. The progress bar is styled as a light cyan animated bar. """ return f'''
{label}
''' @spaces.GPU def generate(input_dict: dict, chat_history: list[dict], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """ Generates chatbot responses with support for multimodal input, video processing, and Edge TTS when using the new tags @JennyNeural or @GuyNeural. Special command: - "@video-infer": triggers video processing using Callisto OCR3. """ text = input_dict["text"] files = input_dict.get("files", []) lower_text = text.strip().lower() # Check for TTS tag in the prompt. tts_voice = None for tag, voice in TTS_VOICE_MAP.items(): if lower_text.startswith(tag): tts_voice = voice text = text[len(tag):].strip() # Remove the tag from the prompt. break # Branch for video processing with Callisto OCR3. if lower_text.startswith("@video-infer"): prompt = text[len("@video-infer"):].strip() if not tts_voice else text if files: # Assume the first file is a video. video_path = files[0] frames = downsample_video(video_path) messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, {"role": "user", "content": [{"type": "text", "text": prompt}]} ] # Append each frame with its timestamp. for frame in frames: image, timestamp = frame image_path = f"video_frame_{uuid.uuid4().hex}.png" image.save(image_path) messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"}) messages[1]["content"].append({"type": "image", "url": image_path}) else: messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, {"role": "user", "content": [{"type": "text", "text": prompt}]} ] # Enable truncation to avoid token/feature mismatch. inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH ).to("cuda") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model_m.generate, kwargs=generation_kwargs) thread.start() buffer = "" yield progress_bar_html("Processing video with Callisto OCR3") for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer return # Multimodal processing when files are provided. if files: if len(files) > 1: images = [load_image(image) for image in files] elif len(files) == 1: images = [load_image(files[0])] else: images = [] messages = [{ "role": "user", "content": [ *[{"type": "image", "image": image} for image in images], {"type": "text", "text": text}, ] }] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Enable truncation explicitly here as well. inputs = processor( text=[prompt_full], images=images, return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH ).to("cuda") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} thread = Thread(target=model_m.generate, kwargs=generation_kwargs) thread.start() buffer = "" yield progress_bar_html("Processing image with Callisto OCR3") for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer else: # Normal text conversation processing with Pocket Llama. conversation = clean_chat_history(chat_history) conversation.append({"role": "user", "content": text}) input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "top_k": top_k, "temperature": temperature, "num_beams": 1, "repetition_penalty": repetition_penalty, } t = Thread(target=model.generate, kwargs=generation_kwargs) t.start() outputs = [] yield progress_bar_html("Processing With Pocket Llama 3B") for new_text in streamer: outputs.append(new_text) yield "".join(outputs) final_response = "".join(outputs) yield final_response # If a TTS voice was specified, convert the final response to speech. if tts_voice: output_file = asyncio.run(text_to_speech(final_response, tts_voice)) yield gr.Audio(output_file, autoplay=True) # Create the Gradio ChatInterface with the custom CSS applied demo = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS), gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6), gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9), gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50), gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2), ], examples=[ ["Write the code that converts temperatures between Celsius and Fahrenheit in short"], [{"text": "Create a short story based on the image.", "files": ["examples/1.jpg"]}], ["@JennyNeural Who was Nikola Tesla and what were his contributions?"], [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}], [{"text": "@video-infer Describe the Ad", "files": ["examples/coca.mp4"]}], ["@GuyNeural Explain how rainbows are formed."], ["@PalomaNeural What is the water cycle?"], ["@AlonsoNeural Who was Pablo Picasso and why is he famous?"], ["@MadhurNeural What are the key principles of Ayurveda?"] ], cache_examples=False, description="# **Pocket Llama**", type="messages", fill_height=True, textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True, ) if __name__ == "__main__": demo.queue(max_size=20).launch(share=True)