import torch from langgraph.graph import END, StateGraph from typing import TypedDict, Any from transformers import ( AutoProcessor, BitsAndBytesConfig, Gemma3ForConditionalGeneration, ) def get_quantization_config(): return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) # Define the state schema class State(TypedDict): image: Any voice: str caption: str description: str # Build the workflow graph def build_graph(): workflow = StateGraph(State) # Add nodes # workflow.add_node("caption_image", caption_image_dummy) # workflow.add_node("describe_with_voice", describe_with_voice_dummy) workflow.add_node("caption_image", caption_image) workflow.add_node("describe_with_voice", describe_with_voice) # Add edges workflow.set_entry_point("caption_image") workflow.add_edge("caption_image", "describe_with_voice") workflow.add_edge("describe_with_voice", END) # Compile the graph return workflow.compile() model_id = "google/gemma-3-4b-it" # Initialize processor and model processor = AutoProcessor.from_pretrained(model_id) model = Gemma3ForConditionalGeneration.from_pretrained( model_id, # quantization_config=get_quantization_config(), device_map="auto", torch_dtype=torch.float16, ) def describe_with_voice_dummy(state: State) -> State: print("Describe") voice = state["voice"] state["description"] = f"Dummy description from {voice}" return state def caption_image_dummy(state: State) -> State: print("Caption") voice = state["voice"] state["caption"] = f"Dummy caption from {voice}" return state def describe_with_voice(state: State) -> State: caption = state["caption"] voice = state["voice"] # Voice prompt templates voice_prompts = { "scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.", "forgetful wizard": "You are a forgetful and easily distracted wizard.", "sarcastic teenager": "You are a sarcastic and disinterested teenager.", } messages = [ { "role": "system", "content": [{"type": "text", "text": voice_prompts.get(voice)}], }, { "role": "user", "content": [ {"type": "text", "text": f"Describe the following:\n\n{caption}"} ], }, ] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False) generation = generation[0][input_len:] description = processor.decode(generation, skip_special_tokens=True) state["description"] = description print(description) return state def caption_image(state: State) -> State: # image is PIL image = state["image"] # Load models (in practice, do this once and cache) messages = [ { "role": "system", "content": [ { "type": "text", "text": "You are a helpful assistant that will describe images in 3-5 sentences.", } ], }, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}, ], }, ] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False) generation = generation[0][input_len:] caption = processor.decode(generation, skip_special_tokens=True) state["caption"] = caption print(caption) return state