import operator from helpers import image_to_base64 import torch from langgraph.graph import END, StateGraph from langgraph.types import Send from typing import Annotated, TypedDict, Any from transformers import ( AutoProcessor, BitsAndBytesConfig, Gemma3ForConditionalGeneration, ) def get_quantization_config(): return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) # Define the state schema class State(TypedDict): image: Any voices: list caption: str descriptions: Annotated[list, operator.add] # Build the workflow graph def build_graph(): workflow = StateGraph(State) workflow.add_node("caption_image", caption_image) workflow.add_node("describe_with_voice", describe_with_voice) # Add edges workflow.set_entry_point("caption_image") workflow.add_conditional_edges("caption_image", map_describe, ["describe_with_voice"]) workflow.add_edge("describe_with_voice", END) # Compile the graph return workflow.compile() model_id = "google/gemma-3-4b-it" # Initialize processor and model processor = AutoProcessor.from_pretrained(model_id) model = Gemma3ForConditionalGeneration.from_pretrained( model_id, # quantization_config=get_quantization_config(), device_map="auto", torch_dtype=torch.bfloat16, ).eval() def describe_with_voice(state: State): caption = state["caption"] # select one by default shakespeare voice = state.get("voice", state.get("voices", ["shakespearian"])[0]) # Voice prompt templates voice_prompts = { "scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.", "forgetful wizard": "You are a forgetful and easily distracted wizard.", "sarcastic teenager": "You are a sarcastic and disinterested teenager.", "private investigator": "You are a Victorian-age detective. Suave and intellectual.", "shakespearian": "Talk like one of Shakespeare's characters. ", } system_prompt = voice_prompts.get(voice, "You are a pirate.") + " Output 5-10 sentences. Utilize markdown for dramatic text formatting." messages = [ { "role": "system", "content": [{"type": "text", "text": system_prompt}], }, { "role": "user", "content": [ {"type": "text", "text": f"Describe the following:\n\n{caption}"} ], }, ] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=1000, do_sample=True, temperature=0.9) generation = generation[0][input_len:] description = processor.decode(generation, skip_special_tokens=True) formatted_description = f"## {voice.title()}\n\n{description}" print(formatted_description) # note that the return value is a list return {"descriptions": [formatted_description]} def map_describe(state: State) -> list: # Create a Send object for each selected voice selected_voices = state["voices"] # Generate description tasks for each selected voice send_objects = [] for voice in selected_voices: send_objects.append( Send("describe_with_voice", {"caption": state["caption"], "voice": voice}) ) return send_objects def caption_image(state: State): # image is PIL image = state["image"] image = image_to_base64(image) # Load models (in practice, do this once and cache) messages = [ { "role": "system", "content": [ { "type": "text", "text": "You are a helpful assistant that will describe images in 3-5 sentences.", } ], }, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}, ], }, ] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False) generation = generation[0][input_len:] caption = processor.decode(generation, skip_special_tokens=True) print(caption) return {"caption" : caption}