Spaces:
Sleeping
Sleeping
import operator | |
from helpers import image_to_base64 | |
import torch | |
from langgraph.graph import END, StateGraph | |
from langgraph.types import Send | |
from typing import Annotated, TypedDict, Any | |
from transformers import ( | |
AutoProcessor, | |
BitsAndBytesConfig, | |
Gemma3ForConditionalGeneration, | |
) | |
def get_quantization_config(): | |
return BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
# Define the state schema | |
class State(TypedDict): | |
image: Any | |
voice: str | |
caption: str | |
descriptions: Annotated[list, operator.add] | |
# Build the workflow graph | |
def build_graph(): | |
workflow = StateGraph(State) | |
workflow.add_node("caption_image", caption_image) | |
workflow.add_node("describe_with_voice", describe_with_voice) | |
# Add edges | |
workflow.set_entry_point("caption_image") | |
workflow.add_conditional_edges("caption_image", map_describe, ["describe_with_voice"]) | |
# workflow.add_edge("caption_image", "describe_with_voice") | |
workflow.add_edge("describe_with_voice", END) | |
# Compile the graph | |
return workflow.compile() | |
model_id = "google/gemma-3-4b-it" | |
# Initialize processor and model | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = Gemma3ForConditionalGeneration.from_pretrained( | |
model_id, | |
# quantization_config=get_quantization_config(), | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
).eval() | |
def describe_with_voice_dummy(state: State) -> State: | |
print("Describe") | |
voice = state["voice"] | |
state["description"] = f"Dummy description from {voice}" | |
return state | |
def caption_image_dummy(state: State) -> State: | |
print("Caption") | |
voice = state["voice"] | |
state["caption"] = f"Dummy caption from {voice}" | |
return state | |
def describe_with_voice(state: State) -> State: | |
caption = state["caption"] | |
voice = state["voice"] | |
# Voice prompt templates | |
voice_prompts = { | |
"scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.", | |
"forgetful wizard": "You are a forgetful and easily distracted wizard.", | |
"sarcastic teenager": "You are a sarcastic and disinterested teenager.", | |
"private investigator": "You are a Victorian-age detective. Suave and intellectual.", | |
"shakespearian": "Talk like one of Shakespeare's characters. ", | |
} | |
system_prompt = voice_prompts.get(voice, "You are a pirate.") + " Output 5-10 sentences." | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": system_prompt}], | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": f"Describe the following:\n\n{caption}"} | |
], | |
}, | |
] | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(model.device, dtype=torch.bfloat16) | |
input_len = inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**inputs, max_new_tokens=1000, do_sample=True, temperature=0.7) | |
generation = generation[0][input_len:] | |
description = processor.decode(generation, skip_special_tokens=True) | |
# note that the return value is a list | |
state["description"] = [description] | |
print(description) | |
return state | |
def map_describe(state: State) -> list: | |
# return list of `Send ` objects (3) | |
return [Send("describe_with_voice", {"caption" : state["caption"], "voice": state["voice"]})] * 3 | |
def caption_image(state: State) -> State: | |
# image is PIL | |
image = state["image"] | |
image = image_to_base64(image) | |
# Load models (in practice, do this once and cache) | |
messages = [ | |
{ | |
"role": "system", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "You are a helpful assistant that will describe images in 3-5 sentences.", | |
} | |
], | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": "Describe this image."}, | |
], | |
}, | |
] | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(model.device, dtype=torch.bfloat16) | |
input_len = inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False) | |
generation = generation[0][input_len:] | |
caption = processor.decode(generation, skip_special_tokens=True) | |
state["caption"] = caption | |
print(caption) | |
return state | |