import os import random import uuid import json import time import asyncio from threading import Thread import gradio as gr import spaces import torch import numpy as np from PIL import Image import edge_tts from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForImageTextToText, ) from transformers.image_utils import load_image from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler # Application description and CSS DESCRIPTION = """ # QwQ Edge 💬 """ css = ''' h1 { text-align: center; display: block; } #duplicate-button { margin: auto; color: #fff; background: #1565c0; border-radius: 100vh; } ''' MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ------------------------- # Load Text-only Model # ------------------------- model_id = "prithivMLmods/FastThink-0.5B-Tiny" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) model.eval() # ------------------------- # TTS Settings # ------------------------- TTS_VOICES = [ "en-US-JennyNeural", # @tts1 "en-US-GuyNeural", # @tts2 ] async def text_to_speech(text: str, voice: str, output_file="output.mp3"): """Convert text to speech using Edge TTS and save as MP3""" communicate = edge_tts.Communicate(text, voice) await communicate.save(output_file) return output_file def clean_chat_history(chat_history): """ Filter out any chat entries whose "content" is not a string. This helps prevent errors when concatenating previous messages. """ cleaned = [] for msg in chat_history: if isinstance(msg, dict) and isinstance(msg.get("content"), str): cleaned.append(msg) return cleaned # ------------------------- # Load Multimodal Model (Qwen2-VL) # ------------------------- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model_m = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16 ).to("cuda").eval() # ------------------------- # Load Aya-Vision Model (New Feature) # ------------------------- AYA_MODEL_ID = "CohereForAI/aya-vision-8b" aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID) aya_model = AutoModelForImageTextToText.from_pretrained( AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16 ) aya_tokenizer = AutoTokenizer.from_pretrained(AYA_MODEL_ID) # ------------------------- # Stable Diffusion XL Settings & Pipeline # ------------------------- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096")) USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation sd_pipe = StableDiffusionXLPipeline.from_pretrained( MODEL_ID_SD, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_safetensors=True, add_watermarker=False, ).to(device) sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config) if torch.cuda.is_available(): sd_pipe.text_encoder = sd_pipe.text_encoder.half() if USE_TORCH_COMPILE: sd_pipe.compile() if ENABLE_CPU_OFFLOAD: sd_pipe.enable_model_cpu_offload() MAX_SEED = np.iinfo(np.int32).max def save_image(img: Image.Image) -> str: """Save a PIL image with a unique filename and return the path.""" unique_name = str(uuid.uuid4()) + ".png" img.save(unique_name) return unique_name def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed @spaces.GPU(duration=60, enable_queue=True) def generate_image_fn( prompt: str, negative_prompt: str = "", use_negative_prompt: bool = False, seed: int = 1, width: int = 1024, height: int = 1024, guidance_scale: float = 3, num_inference_steps: int = 25, randomize_seed: bool = False, use_resolution_binning: bool = True, num_images: int = 1, progress=gr.Progress(track_tqdm=True), ): """Generate images using the SDXL pipeline.""" seed = int(randomize_seed_fn(seed, randomize_seed)) generator = torch.Generator(device=device).manual_seed(seed) options = { "prompt": [prompt] * num_images, "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None, "width": width, "height": height, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "generator": generator, "output_type": "pil", } if use_resolution_binning: options["use_resolution_binning"] = True images = [] # Process in batches for i in range(0, num_images, BATCH_SIZE): batch_options = options.copy() batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE] if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None: batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE] if device.type == "cuda": with torch.autocast("cuda", dtype=torch.float16): outputs = sd_pipe(**batch_options) else: outputs = sd_pipe(**batch_options) images.extend(outputs.images) image_paths = [save_image(img) for img in images] return image_paths, seed @spaces.GPU def generate( input_dict: dict, chat_history: list[dict], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ): """ Generates chatbot responses with support for multimodal input, TTS, and image generation. Special commands: - "@tts1" or "@tts2": triggers text-to-speech. - "@image": triggers image generation using the SDXL pipeline. - "@aya-vision": triggers image-text-to-text generation using the Aya-Vision model. """ text = input_dict["text"] files = input_dict.get("files", []) # ------------------------- # Aya-Vision Feature # ------------------------- if text.strip().lower().startswith("@aya-vision"): prompt = text[len("@aya-vision"):].strip() if files: if len(files) > 1: images = [load_image(file) for file in files] elif len(files) == 1: images = [load_image(files[0])] messages = [{ "role": "user", "content": [ *[{"type": "image", "image": image} for image in images], {"type": "text", "text": prompt}, ] }] else: messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] yield "Processing with Aya-Vision..." inputs = aya_processor.apply_chat_template( messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(aya_model.device) # Remove deprecated parameter if present to avoid conflicts. inputs.pop("num_logits_to_keep", None) gen_tokens = aya_model.generate( **inputs, max_new_tokens=300, do_sample=True, temperature=0.3, ) gen_text = aya_tokenizer.decode(gen_tokens[0], skip_special_tokens=True) yield gen_text return # Exit early after processing with Aya-Vision # ------------------------- # Image Generation Feature (@image) # ------------------------- if text.strip().lower().startswith("@image"): prompt = text[len("@image"):].strip() yield "Generating image..." image_paths, used_seed = generate_image_fn( prompt=prompt, negative_prompt="", use_negative_prompt=False, seed=1, width=1024, height=1024, guidance_scale=3, num_inference_steps=25, randomize_seed=True, use_resolution_binning=True, num_images=1, ) yield gr.Image(image_paths[0]) return # Exit early # ------------------------- # TTS Feature (@tts1 or @tts2) # ------------------------- tts_prefix = "@tts" is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3)) voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None) if is_tts and voice_index: voice = TTS_VOICES[voice_index - 1] text = text.replace(f"{tts_prefix}{voice_index}", "").strip() # Clear previous chat history for a fresh TTS request. conversation = [{"role": "user", "content": text}] else: voice = None # Remove any stray @tts tags and build the conversation history. text = text.replace(tts_prefix, "").strip() conversation = clean_chat_history(chat_history) conversation.append({"role": "user", "content": text}) # ------------------------- # Multimodal Input (with files) using Qwen2-VL # ------------------------- if files: if len(files) > 1: images = [load_image(image) for image in files] elif len(files) == 1: images = [load_image(files[0])] else: images = [] messages = [{ "role": "user", "content": [ *[{"type": "image", "image": image} for image in images], {"type": "text", "text": text}, ] }] prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} thread = Thread(target=model_m.generate, kwargs=generation_kwargs) thread.start() buffer = "" yield "Thinking..." for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer else: # ------------------------- # Text-only Generation # ------------------------- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "top_k": top_k, "temperature": temperature, "num_beams": 1, "repetition_penalty": repetition_penalty, } t = Thread(target=model.generate, kwargs=generation_kwargs) t.start() outputs = [] for new_text in streamer: outputs.append(new_text) yield "".join(outputs) final_response = "".join(outputs) yield final_response if is_tts and voice: output_file = asyncio.run(text_to_speech(final_response, voice)) yield gr.Audio(output_file, autoplay=True) demo = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS), gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6), gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9), gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50), gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2), ], examples=[ [{"text": "@aya-vision Extract JSON from the image", "files": ["examples/document.jpg"]}], [{"text": "@aya-vision Summarize the letter", "files": ["examples/1.png"]}], ["@tts1 Who is Nikola Tesla, and why did he die?"], ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"], ["Write a Python function to check if a number is prime."], ["@tts2 What causes rainbows to form?"], ], cache_examples=False, type="messages", description=DESCRIPTION, css=css, fill_height=True, textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True, ) if __name__ == "__main__": # To create a public link, set share=True in launch(). demo.queue(max_size=20).launch(share=True)