File size: 4,889 Bytes
598dcfa
 
a4690cb
 
598dcfa
 
a4690cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5b9453
a4690cb
598dcfa
a4690cb
 
 
 
 
 
 
 
 
 
 
598dcfa
 
a4690cb
 
 
 
 
 
 
 
 
 
 
 
 
 
0160c44
d7d7a75
a4690cb
 
b5b9453
a4690cb
b5b9453
 
a4690cb
 
 
 
 
 
598dcfa
 
a4690cb
68fe4b2
a4690cb
350f8a0
 
598dcfa
350f8a0
a4690cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5b9453
a4690cb
 
 
 
68fe4b2
b5b9453
a4690cb
b5b9453
 
a4690cb
 
598dcfa
b5b9453
 
 
 
 
 
 
 
 
 
 
 
 
 
a4690cb
 
598dcfa
a4690cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350f8a0
a4690cb
 
 
73d2daa
a4690cb
b5b9453
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import operator
from helpers import image_to_base64
import torch
from langgraph.graph import END, StateGraph
from langgraph.types import Send
from typing import Annotated, TypedDict, Any

from transformers import (
    AutoProcessor,
    BitsAndBytesConfig,
    Gemma3ForConditionalGeneration,
)


def get_quantization_config():
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )


# Define the state schema
class State(TypedDict):
    image: Any
    voices: list
    caption: str
    descriptions: Annotated[list, operator.add]


# Build the workflow graph
def build_graph():
    workflow = StateGraph(State)

    workflow.add_node("caption_image", caption_image)
    workflow.add_node("describe_with_voice", describe_with_voice)

    # Add edges
    workflow.set_entry_point("caption_image")

    workflow.add_conditional_edges("caption_image", map_describe, ["describe_with_voice"])
    workflow.add_edge("describe_with_voice", END)

    # Compile the graph
    return workflow.compile()


model_id = "google/gemma-3-4b-it"

# Initialize processor and model
processor = AutoProcessor.from_pretrained(model_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id,
    # quantization_config=get_quantization_config(),
    device_map="auto",
    torch_dtype=torch.bfloat16,
).eval()


def describe_with_voice(state: State):
    caption = state["caption"]
    # select one by default shakespeare
    voice = state.get("voice", state.get("voices", ["shakespearian"])[0])

    # Voice prompt templates
    voice_prompts = {
        "scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.",
        "forgetful wizard": "You are a forgetful and easily distracted wizard.",
        "sarcastic teenager": "You are a sarcastic and disinterested teenager.",
        "private investigator": "You are a Victorian-age detective. Suave and intellectual.",
        "shakespearian": "Talk like one of Shakespeare's characters. ",
    }
    system_prompt = voice_prompts.get(voice, "You are a pirate.") + " Output 5-10 sentences. Utilize markdown for dramatic text formatting."
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_prompt}],
        },
        {
            "role": "user",
            "content": [
                {"type": "text", "text": f"Describe the following:\n\n{caption}"}
            ],
        },
    ]
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=torch.bfloat16)
    input_len = inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=1000, do_sample=True, temperature=0.9)
        generation = generation[0][input_len:]

    description = processor.decode(generation, skip_special_tokens=True)

    formatted_description = f"## {voice.title()}\n\n{description}"
    print(formatted_description)

    # note that the return value is a list
    return {"descriptions": [formatted_description]}


def map_describe(state: State) -> list:
    # Create a Send object for each selected voice
    selected_voices = state["voices"]
    
    # Generate description tasks for each selected voice
    send_objects = []
    for voice in selected_voices:
        send_objects.append(
            Send("describe_with_voice", {"caption": state["caption"], "voice": voice})
        )
    
    return send_objects


def caption_image(state: State):
    # image is PIL
    image = state["image"]
    image = image_to_base64(image)

    # Load models (in practice, do this once and cache)
    messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": "You are a helpful assistant that will describe images in 3-5 sentences.",
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": "Describe this image."},
            ],
        },
    ]
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=torch.bfloat16)
    input_len = inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False)
        generation = generation[0][input_len:]

    caption = processor.decode(generation, skip_special_tokens=True)
    print(caption)

    return {"caption" : caption}