import gradio as gr from transformers import ( Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer, AutoModelForImageTextToText, Gemma3ForConditionalGeneration # new Gemma3 model import ) from transformers.image_utils import load_image from threading import Thread import time import torch import spaces from PIL import Image import requests from io import BytesIO # Helper function to return a progress bar HTML snippet. def progress_bar_html(label: str) -> str: return f'''
{label}
''' ### Load Models & Processors ### # Qwen2VL OCR model (default) QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # or alternate version qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True) qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( QV_MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16 ).to("cuda").eval() # Aya-Vision model (trigger with @aya-vision) AYA_MODEL_ID = "CohereForAI/aya-vision-8b" aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID) aya_model = AutoModelForImageTextToText.from_pretrained( AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16 ) # Gemma3-4b model (trigger with @gemma3-4b) GEMMA3_MODEL_ID = "google/gemma-3-4b-it" gemma3_model = Gemma3ForConditionalGeneration.from_pretrained( GEMMA3_MODEL_ID, device_map="auto" ).eval() gemma3_processor = AutoProcessor.from_pretrained(GEMMA3_MODEL_ID) @spaces.GPU def model_inference(input_dict, history): text = input_dict["text"].strip() files = input_dict.get("files", []) # Branch: Aya-Vision (trigger with @aya-vision) if text.lower().startswith("@aya-vision"): text_prompt = text[len("@aya-vision"):].strip() if not files: yield "Error: Please provide an image for the @aya-vision feature." return image = load_image(files[0]) yield progress_bar_html("Processing with Aya-Vision-8b") messages = [{ "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": text_prompt}, ], }] inputs = aya_processor.apply_chat_template( messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(aya_model.device) streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, temperature=0.3 ) thread = Thread(target=aya_model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer return # Branch: Gemma3-4b (trigger with @gemma3-4b) if text.lower().startswith("@gemma3-4b"): text_prompt = text[len("@gemma3-4b"):].strip() if not files: yield "Error: Please provide an image for the @gemma3-4b feature." return image = load_image(files[0]) yield progress_bar_html("Processing with Gemma3-4b") messages = [ { "role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}] }, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": text_prompt} ] } ] inputs = gemma3_processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(gemma3_model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] streamer = TextIteratorStreamer(gemma3_processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512, do_sample=False) thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer return # Default Branch: Qwen2-VL OCR (for text query with optional images) if len(files) > 1: images = [load_image(image) for image in files] elif len(files) == 1: images = [load_image(files[0])] else: images = [] if text == "" and not images: yield "Error: Please input a query and optionally image(s)." return if text == "" and images: yield "Error: Please input a text query along with the image(s)." return messages = [{ "role": "user", "content": [ *[{"type": "image", "image": image} for image in images], {"type": "text", "text": text}, ], }] prompt = qwen_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = qwen_processor( text=[prompt], images=images if images else None, return_tensors="pt", padding=True, ).to("cuda") streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs) thread.start() buffer = "" yield progress_bar_html("Processing with Qwen2VL OCR") for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer # Examples for quick testing. examples = [ [{"text": "@gemma3-4b Summarize the letter", "files": ["examples/1.png"]}], [{"text": "@gemma3-4b Extract JSON from the image", "files": ["example_images/document.jpg"]}], [{"text": "@gemma3-4b Describe the photo", "files": ["examples/3.png"]}], [{"text": "@aya-vision Summarize the full image in detail", "files": ["examples/2.jpg"]}], [{"text": "@aya-vision Describe this image.", "files": ["example_images/campeones.jpg"]}], [{"text": "@aya-vision What is this UI about?", "files": ["example_images/s2w_example.png"]}], [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}], [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}], [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}], [{"text": "@aya-vision Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}], ] # Gradio ChatInterface with a multimodal textbox. demo = gr.ChatInterface( fn=model_inference, description=( "# **Multimodal OCR & Vision Features**\n\n" "Use the following commands to select a model:\n" "- `@aya-vision` for Aya-Vision-8b\n" "- `@gemma3-4b` for Gemma3-4b\n\n" "Default processing is done with Qwen2VL OCR." ), examples=examples, textbox=gr.MultimodalTextbox( label="Query Input", file_types=["image"], file_count="multiple", placeholder="Enter your text query and attach images if needed. Use @aya-vision or @gemma3-4b to choose a feature." ), stop_btn="Stop Generation", multimodal=True, cache_examples=False, ) demo.launch(debug=True)