Spaces:
Sleeping
Sleeping
import torch | |
from langgraph.graph import END, StateGraph | |
from typing import TypedDict, Any | |
from transformers import ( | |
AutoProcessor, | |
BitsAndBytesConfig, | |
Gemma3ForConditionalGeneration, | |
) | |
def get_quantization_config(): | |
return BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
# Define the state schema | |
class State(TypedDict): | |
image: Any | |
voice: str | |
caption: str | |
description: str | |
# Build the workflow graph | |
def build_graph(): | |
workflow = StateGraph(State) | |
# Add nodes | |
workflow.add_node("caption_image", caption_image) | |
workflow.add_node("describe_with_voice", describe_with_voice) | |
# Add edges | |
workflow.set_entry_point("caption_image") | |
workflow.add_edge("caption_image", "describe_with_voice") | |
workflow.add_edge("describe_with_voice", END) | |
# Compile the graph | |
return workflow.compile() | |
model_id = "google/gemma-3-4b-it" | |
# Initialize processor and model | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = Gemma3ForConditionalGeneration.from_pretrained( | |
model_id, | |
# quantization_config=get_quantization_config(), | |
device_map="auto", | |
torch_dtype=torch.float16, | |
) | |
def describe_with_voice(state: State) -> State: | |
state["description"] = "Dummy description" | |
return state | |
def caption_image(state: State) -> State: | |
state["caption"] = "Dummy caption" | |
def describe_with_voice2(state: State) -> State: | |
caption = state["caption"] | |
voice = state["voice"] | |
# Voice prompt templates | |
voice_prompts = { | |
"scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.", | |
"forgetful wizard": "You are a forgetful and easily distracted wizard.", | |
"sarcastic teenager": "You are a sarcastic and disinterested teenager.", | |
} | |
messages = [ | |
{"role": "system", "content": [voice_prompts.get(voice)]}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": f"Describe the following:\n\n{caption}"} | |
], | |
}, | |
] | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(model.device, dtype=torch.bfloat16) | |
input_len = inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False) | |
generation = generation[0][input_len:] | |
description = processor.decode(generation, skip_special_tokens=True) | |
state["description"] = description | |
return state | |
def caption_image2(state: State) -> State: | |
# image is PIL | |
image = state["image"] | |
# Load models (in practice, do this once and cache) | |
messages = [ | |
{ | |
"role": "system", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "You are a helpful assistant that will describe images in 3-5 sentences.", | |
} | |
], | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": "Describe this image."}, | |
], | |
}, | |
] | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(model.device, dtype=torch.bfloat16) | |
input_len = inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False) | |
generation = generation[0][input_len:] | |
caption = processor.decode(generation, skip_special_tokens=True) | |
state["caption"] = caption | |
return state | |