gemma-3-12b-it / app.py
hysts's picture
hysts HF staff
Remove local transformers wheel and switch to git source
2d73e21
raw
history blame
5.81 kB
#!/usr/bin/env python
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
model_id = "google/gemma-3-12b-it"
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
)
def process_new_user_message(message: dict) -> list[dict]:
return [{"type": "text", "text": message["text"]}, *[{"type": "image", "url": path} for path in message["files"]]]
def process_history(history: list[dict]) -> list[dict]:
messages = []
current_user_content: list[dict] = []
for item in history:
if item["role"] == "assistant":
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
else:
content = item["content"]
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
else:
current_user_content.append({"type": "image", "url": content[0]})
return messages
@spaces.GPU(duration=120)
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
messages = []
if system_prompt:
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
messages.extend(process_history(history))
messages.append({"role": "user", "content": process_new_user_message(message)})
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
output = ""
for delta in streamer:
output += delta
yield output
examples = [
[
{
"text": "caption this image",
"files": ["assets/sample-images/01.png"],
}
],
[
{
"text": "What's the sign says?",
"files": ["assets/sample-images/02.png"],
}
],
[
{
"text": "Compare and contrast the two images.",
"files": ["assets/sample-images/03.png"],
}
],
[
{
"text": "List all the objects in the image and their colors.",
"files": ["assets/sample-images/04.png"],
}
],
[
{
"text": "Describe the atmosphere of the scene.",
"files": ["assets/sample-images/05.png"],
}
],
[
{
"text": "Write a poem inspired by the visual elements of the images.",
"files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"],
}
],
[
{
"text": "Compose a short musical piece inspired by the visual elements of the images.",
"files": [
"assets/sample-images/07-1.png",
"assets/sample-images/07-2.png",
"assets/sample-images/07-3.png",
"assets/sample-images/07-4.png",
],
}
],
[
{
"text": "Write a short story about what might have happened in this house.",
"files": ["assets/sample-images/08.png"],
}
],
[
{
"text": "Create a short story based on the sequence of images.",
"files": [
"assets/sample-images/09-1.png",
"assets/sample-images/09-2.png",
"assets/sample-images/09-3.png",
"assets/sample-images/09-4.png",
"assets/sample-images/09-5.png",
],
}
],
[
{
"text": "Describe the creatures that would live in this world.",
"files": ["assets/sample-images/10.png"],
}
],
[
{
"text": "Read text in the image.",
"files": ["assets/additional-examples/1.png"],
}
],
[
{
"text": "When is this ticket dated and how much did it cost?",
"files": ["assets/additional-examples/2.png"],
}
],
[
{
"text": "Read the text in the image into markdown.",
"files": ["assets/additional-examples/3.png"],
}
],
[
{
"text": "Evaluate this integral.",
"files": ["assets/additional-examples/4.png"],
}
],
]
demo = gr.ChatInterface(
fn=run,
type="messages",
textbox=gr.MultimodalTextbox(file_types=["image"], file_count="multiple"),
multimodal=True,
additional_inputs=[
gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=500),
],
stop_btn=False,
title="Gemma 3 12B it",
description="<img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />",
examples=examples,
run_examples_on_click=False,
cache_examples=False,
css_paths="style.css",
delete_cache=(1800, 1800),
)
if __name__ == "__main__":
demo.launch()