fun-image-caption / agents.py
Dylan
added description agents -- dummy
a4690cb
raw
history blame
3.92 kB
import torch
from langgraph.graph import END, StateGraph
from typing import TypedDict, Any
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
)
def get_quantization_config():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
# Define the state schema
class State(TypedDict):
image: Any
voice: str
caption: str
description: str
# Build the workflow graph
def build_graph():
workflow = StateGraph(State)
# Add nodes
workflow.add_node("caption_image", caption_image)
workflow.add_node("describe_with_voice", describe_with_voice)
# Add edges
workflow.set_entry_point("caption_image")
workflow.add_edge("caption_image", "describe_with_voice")
workflow.add_edge("describe_with_voice", END)
# Compile the graph
return workflow.compile()
model_id = "google/gemma-3-4b-it"
# Initialize processor and model
processor = AutoProcessor.from_pretrained(model_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
# quantization_config=get_quantization_config(),
device_map="auto",
torch_dtype=torch.float16,
)
def describe_with_voice(state: State) -> State:
state["description"] = "Dummy description"
return state
def caption_image(state: State) -> State:
state["caption"] = "Dummy caption"
def describe_with_voice2(state: State) -> State:
caption = state["caption"]
voice = state["voice"]
# Voice prompt templates
voice_prompts = {
"scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.",
"forgetful wizard": "You are a forgetful and easily distracted wizard.",
"sarcastic teenager": "You are a sarcastic and disinterested teenager.",
}
messages = [
{"role": "system", "content": [voice_prompts.get(voice)]},
{
"role": "user",
"content": [
{"type": "text", "text": f"Describe the following:\n\n{caption}"}
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
description = processor.decode(generation, skip_special_tokens=True)
state["description"] = description
return state
def caption_image2(state: State) -> State:
# image is PIL
image = state["image"]
# Load models (in practice, do this once and cache)
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a helpful assistant that will describe images in 3-5 sentences.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Describe this image."},
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
caption = processor.decode(generation, skip_special_tokens=True)
state["caption"] = caption
return state