File size: 3,915 Bytes
a4690cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from langgraph.graph import END, StateGraph
from typing import TypedDict, Any

from transformers import (
    AutoProcessor,
    BitsAndBytesConfig,
    Gemma3ForConditionalGeneration,
)


def get_quantization_config():
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )


# Define the state schema
class State(TypedDict):
    image: Any
    voice: str
    caption: str
    description: str


# Build the workflow graph
def build_graph():
    workflow = StateGraph(State)

    # Add nodes
    workflow.add_node("caption_image", caption_image)
    workflow.add_node("describe_with_voice", describe_with_voice)

    # Add edges
    workflow.set_entry_point("caption_image")
    workflow.add_edge("caption_image", "describe_with_voice")
    workflow.add_edge("describe_with_voice", END)

    # Compile the graph
    return workflow.compile()


model_id = "google/gemma-3-4b-it"

# Initialize processor and model
processor = AutoProcessor.from_pretrained(model_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id,
    # quantization_config=get_quantization_config(),
    device_map="auto",
    torch_dtype=torch.float16,
)


def describe_with_voice(state: State) -> State:
    state["description"] = "Dummy description"
    return state


def caption_image(state: State) -> State:
    state["caption"] = "Dummy caption"


def describe_with_voice2(state: State) -> State:
    caption = state["caption"]
    voice = state["voice"]

    # Voice prompt templates
    voice_prompts = {
        "scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.",
        "forgetful wizard": "You are a forgetful and easily distracted wizard.",
        "sarcastic teenager": "You are a sarcastic and disinterested teenager.",
    }
    messages = [
        {"role": "system", "content": [voice_prompts.get(voice)]},
        {
            "role": "user",
            "content": [
                {"type": "text", "text": f"Describe the following:\n\n{caption}"}
            ],
        },
    ]
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=torch.bfloat16)
    input_len = inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
        generation = generation[0][input_len:]

    description = processor.decode(generation, skip_special_tokens=True)

    state["description"] = description

    return state


def caption_image2(state: State) -> State:
    # image is PIL
    image = state["image"]

    # Load models (in practice, do this once and cache)
    messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": "You are a helpful assistant that will describe images in 3-5 sentences.",
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": "Describe this image."},
            ],
        },
    ]
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=torch.bfloat16)
    input_len = inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
        generation = generation[0][input_len:]

    caption = processor.decode(generation, skip_special_tokens=True)

    state["caption"] = caption

    return state