fun-image-caption / agents.py
Dylan
markdown formatting
68fe4b2
import operator
from helpers import image_to_base64
import torch
from langgraph.graph import END, StateGraph
from langgraph.types import Send
from typing import Annotated, TypedDict, Any
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
)
def get_quantization_config():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
# Define the state schema
class State(TypedDict):
image: Any
voices: list
caption: str
descriptions: Annotated[list, operator.add]
# Build the workflow graph
def build_graph():
workflow = StateGraph(State)
workflow.add_node("caption_image", caption_image)
workflow.add_node("describe_with_voice", describe_with_voice)
# Add edges
workflow.set_entry_point("caption_image")
workflow.add_conditional_edges("caption_image", map_describe, ["describe_with_voice"])
workflow.add_edge("describe_with_voice", END)
# Compile the graph
return workflow.compile()
model_id = "google/gemma-3-4b-it"
# Initialize processor and model
processor = AutoProcessor.from_pretrained(model_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
# quantization_config=get_quantization_config(),
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()
def describe_with_voice(state: State):
caption = state["caption"]
# select one by default shakespeare
voice = state.get("voice", state.get("voices", ["shakespearian"])[0])
# Voice prompt templates
voice_prompts = {
"scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.",
"forgetful wizard": "You are a forgetful and easily distracted wizard.",
"sarcastic teenager": "You are a sarcastic and disinterested teenager.",
"private investigator": "You are a Victorian-age detective. Suave and intellectual.",
"shakespearian": "Talk like one of Shakespeare's characters. ",
}
system_prompt = voice_prompts.get(voice, "You are a pirate.") + " Output 5-10 sentences. Utilize markdown for dramatic text formatting."
messages = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
},
{
"role": "user",
"content": [
{"type": "text", "text": f"Describe the following:\n\n{caption}"}
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=1000, do_sample=True, temperature=0.9)
generation = generation[0][input_len:]
description = processor.decode(generation, skip_special_tokens=True)
formatted_description = f"## {voice.title()}\n\n{description}"
print(formatted_description)
# note that the return value is a list
return {"descriptions": [formatted_description]}
def map_describe(state: State) -> list:
# Create a Send object for each selected voice
selected_voices = state["voices"]
# Generate description tasks for each selected voice
send_objects = []
for voice in selected_voices:
send_objects.append(
Send("describe_with_voice", {"caption": state["caption"], "voice": voice})
)
return send_objects
def caption_image(state: State):
# image is PIL
image = state["image"]
image = image_to_base64(image)
# Load models (in practice, do this once and cache)
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a helpful assistant that will describe images in 3-5 sentences.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Describe this image."},
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False)
generation = generation[0][input_len:]
caption = processor.decode(generation, skip_special_tokens=True)
print(caption)
return {"caption" : caption}