Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,889 Bytes
598dcfa a4690cb 598dcfa a4690cb b5b9453 a4690cb 598dcfa a4690cb 598dcfa a4690cb 0160c44 d7d7a75 a4690cb b5b9453 a4690cb b5b9453 a4690cb 598dcfa a4690cb 68fe4b2 a4690cb 350f8a0 598dcfa 350f8a0 a4690cb b5b9453 a4690cb 68fe4b2 b5b9453 a4690cb b5b9453 a4690cb 598dcfa b5b9453 a4690cb 598dcfa a4690cb 350f8a0 a4690cb 73d2daa a4690cb b5b9453 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import operator
from helpers import image_to_base64
import torch
from langgraph.graph import END, StateGraph
from langgraph.types import Send
from typing import Annotated, TypedDict, Any
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
)
def get_quantization_config():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
# Define the state schema
class State(TypedDict):
image: Any
voices: list
caption: str
descriptions: Annotated[list, operator.add]
# Build the workflow graph
def build_graph():
workflow = StateGraph(State)
workflow.add_node("caption_image", caption_image)
workflow.add_node("describe_with_voice", describe_with_voice)
# Add edges
workflow.set_entry_point("caption_image")
workflow.add_conditional_edges("caption_image", map_describe, ["describe_with_voice"])
workflow.add_edge("describe_with_voice", END)
# Compile the graph
return workflow.compile()
model_id = "google/gemma-3-4b-it"
# Initialize processor and model
processor = AutoProcessor.from_pretrained(model_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
# quantization_config=get_quantization_config(),
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()
def describe_with_voice(state: State):
caption = state["caption"]
# select one by default shakespeare
voice = state.get("voice", state.get("voices", ["shakespearian"])[0])
# Voice prompt templates
voice_prompts = {
"scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.",
"forgetful wizard": "You are a forgetful and easily distracted wizard.",
"sarcastic teenager": "You are a sarcastic and disinterested teenager.",
"private investigator": "You are a Victorian-age detective. Suave and intellectual.",
"shakespearian": "Talk like one of Shakespeare's characters. ",
}
system_prompt = voice_prompts.get(voice, "You are a pirate.") + " Output 5-10 sentences. Utilize markdown for dramatic text formatting."
messages = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
},
{
"role": "user",
"content": [
{"type": "text", "text": f"Describe the following:\n\n{caption}"}
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=1000, do_sample=True, temperature=0.9)
generation = generation[0][input_len:]
description = processor.decode(generation, skip_special_tokens=True)
formatted_description = f"## {voice.title()}\n\n{description}"
print(formatted_description)
# note that the return value is a list
return {"descriptions": [formatted_description]}
def map_describe(state: State) -> list:
# Create a Send object for each selected voice
selected_voices = state["voices"]
# Generate description tasks for each selected voice
send_objects = []
for voice in selected_voices:
send_objects.append(
Send("describe_with_voice", {"caption": state["caption"], "voice": voice})
)
return send_objects
def caption_image(state: State):
# image is PIL
image = state["image"]
image = image_to_base64(image)
# Load models (in practice, do this once and cache)
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a helpful assistant that will describe images in 3-5 sentences.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Describe this image."},
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False)
generation = generation[0][input_len:]
caption = processor.decode(generation, skip_special_tokens=True)
print(caption)
return {"caption" : caption}
|