Spaces:
Sleeping
Sleeping
import operator | |
from helpers import image_to_base64 | |
import torch | |
from langgraph.graph import END, StateGraph | |
from langgraph.types import Send | |
from typing import Annotated, TypedDict, Any | |
from transformers import ( | |
AutoProcessor, | |
BitsAndBytesConfig, | |
Gemma3ForConditionalGeneration, | |
) | |
def get_quantization_config(): | |
return BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
# Define the state schema | |
class State(TypedDict): | |
image: Any | |
voices: list | |
caption: str | |
descriptions: Annotated[list, operator.add] | |
# Build the workflow graph | |
def build_graph(): | |
workflow = StateGraph(State) | |
workflow.add_node("caption_image", caption_image) | |
workflow.add_node("describe_with_voice", describe_with_voice) | |
# Add edges | |
workflow.set_entry_point("caption_image") | |
workflow.add_conditional_edges("caption_image", map_describe, ["describe_with_voice"]) | |
workflow.add_edge("describe_with_voice", END) | |
# Compile the graph | |
return workflow.compile() | |
model_id = "google/gemma-3-4b-it" | |
# Initialize processor and model | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = Gemma3ForConditionalGeneration.from_pretrained( | |
model_id, | |
# quantization_config=get_quantization_config(), | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
).eval() | |
def describe_with_voice(state: State): | |
caption = state["caption"] | |
# select one by default shakespeare | |
voice = state.get("voice", state.get("voices", ["shakespearian"])[0]) | |
# Voice prompt templates | |
voice_prompts = { | |
"scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.", | |
"forgetful wizard": "You are a forgetful and easily distracted wizard.", | |
"sarcastic teenager": "You are a sarcastic and disinterested teenager.", | |
"private investigator": "You are a Victorian-age detective. Suave and intellectual.", | |
"shakespearian": "Talk like one of Shakespeare's characters. ", | |
} | |
system_prompt = voice_prompts.get(voice, "You are a pirate.") + " Output 5-10 sentences. Utilize markdown for dramatic text formatting." | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": system_prompt}], | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": f"Describe the following:\n\n{caption}"} | |
], | |
}, | |
] | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(model.device, dtype=torch.bfloat16) | |
input_len = inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**inputs, max_new_tokens=1000, do_sample=True, temperature=0.9) | |
generation = generation[0][input_len:] | |
description = processor.decode(generation, skip_special_tokens=True) | |
formatted_description = f"## {voice.title()}\n\n{description}" | |
print(formatted_description) | |
# note that the return value is a list | |
return {"descriptions": [formatted_description]} | |
def map_describe(state: State) -> list: | |
# Create a Send object for each selected voice | |
selected_voices = state["voices"] | |
# Generate description tasks for each selected voice | |
send_objects = [] | |
for voice in selected_voices: | |
send_objects.append( | |
Send("describe_with_voice", {"caption": state["caption"], "voice": voice}) | |
) | |
return send_objects | |
def caption_image(state: State): | |
# image is PIL | |
image = state["image"] | |
image = image_to_base64(image) | |
# Load models (in practice, do this once and cache) | |
messages = [ | |
{ | |
"role": "system", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "You are a helpful assistant that will describe images in 3-5 sentences.", | |
} | |
], | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": "Describe this image."}, | |
], | |
}, | |
] | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(model.device, dtype=torch.bfloat16) | |
input_len = inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False) | |
generation = generation[0][input_len:] | |
caption = processor.decode(generation, skip_special_tokens=True) | |
print(caption) | |
return {"caption" : caption} | |