#!/usr/bin/env python from collections.abc import Iterator from threading import Thread import gradio as gr import spaces import torch from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer model_id = "google/gemma-3-12b-it" processor = AutoProcessor.from_pretrained(model_id, padding_side="left") model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" ) def process_new_user_message(message: dict) -> list[dict]: return [{"type": "text", "text": message["text"]}, *[{"type": "image", "url": path} for path in message["files"]]] def process_history(history: list[dict]) -> list[dict]: messages = [] current_user_content: list[dict] = [] for item in history: if item["role"] == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]}) else: content = item["content"] if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) else: current_user_content.append({"type": "image", "url": content[0]}) return messages @spaces.GPU(duration=120) def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]: messages = [] if system_prompt: messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) messages.extend(process_history(history)) messages.append({"role": "user", "content": process_new_user_message(message)}) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(device=model.device, dtype=torch.bfloat16) streamer = TextIteratorStreamer(processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() output = "" for delta in streamer: output += delta yield output examples = [ [ { "text": "caption this image", "files": ["assets/sample-images/01.png"], } ], [ { "text": "What's the sign says?", "files": ["assets/sample-images/02.png"], } ], [ { "text": "Compare and contrast the two images.", "files": ["assets/sample-images/03.png"], } ], [ { "text": "List all the objects in the image and their colors.", "files": ["assets/sample-images/04.png"], } ], [ { "text": "Describe the atmosphere of the scene.", "files": ["assets/sample-images/05.png"], } ], [ { "text": "Write a poem inspired by the visual elements of the images.", "files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"], } ], [ { "text": "Compose a short musical piece inspired by the visual elements of the images.", "files": [ "assets/sample-images/07-1.png", "assets/sample-images/07-2.png", "assets/sample-images/07-3.png", "assets/sample-images/07-4.png", ], } ], [ { "text": "Write a short story about what might have happened in this house.", "files": ["assets/sample-images/08.png"], } ], [ { "text": "Create a short story based on the sequence of images.", "files": [ "assets/sample-images/09-1.png", "assets/sample-images/09-2.png", "assets/sample-images/09-3.png", "assets/sample-images/09-4.png", "assets/sample-images/09-5.png", ], } ], [ { "text": "Describe the creatures that would live in this world.", "files": ["assets/sample-images/10.png"], } ], [ { "text": "Read text in the image.", "files": ["assets/additional-examples/1.png"], } ], [ { "text": "When is this ticket dated and how much did it cost?", "files": ["assets/additional-examples/2.png"], } ], [ { "text": "Read the text in the image into markdown.", "files": ["assets/additional-examples/3.png"], } ], [ { "text": "Evaluate this integral.", "files": ["assets/additional-examples/4.png"], } ], ] demo = gr.ChatInterface( fn=run, type="messages", textbox=gr.MultimodalTextbox(file_types=["image"], file_count="multiple"), multimodal=True, additional_inputs=[ gr.Textbox(label="System Prompt", value="You are a helpful assistant."), gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=500), ], stop_btn=False, title="Gemma 3 12B it", description="", examples=examples, run_examples_on_click=False, cache_examples=False, css_paths="style.css", delete_cache=(1800, 1800), ) if __name__ == "__main__": demo.launch()