Spaces:
Sleeping
Sleeping
File size: 4,309 Bytes
a4690cb 350f8a0 a4690cb 350f8a0 a4690cb 350f8a0 a4690cb 350f8a0 a4690cb 350f8a0 a4690cb 350f8a0 a4690cb 73d2daa a4690cb 350f8a0 a4690cb 350f8a0 a4690cb 73d2daa a4690cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
from langgraph.graph import END, StateGraph
from typing import TypedDict, Any
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
)
def get_quantization_config():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
# Define the state schema
class State(TypedDict):
image: Any
voice: str
caption: str
description: str
# Build the workflow graph
def build_graph():
workflow = StateGraph(State)
# Add nodes
# workflow.add_node("caption_image", caption_image_dummy)
# workflow.add_node("describe_with_voice", describe_with_voice_dummy)
workflow.add_node("caption_image", caption_image)
workflow.add_node("describe_with_voice", describe_with_voice)
# Add edges
workflow.set_entry_point("caption_image")
workflow.add_edge("caption_image", "describe_with_voice")
workflow.add_edge("describe_with_voice", END)
# Compile the graph
return workflow.compile()
model_id = "google/gemma-3-4b-it"
# Initialize processor and model
processor = AutoProcessor.from_pretrained(model_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
# quantization_config=get_quantization_config(),
device_map="auto",
torch_dtype=torch.float16,
)
def describe_with_voice_dummy(state: State) -> State:
print("Describe")
voice = state["voice"]
state["description"] = f"Dummy description from {voice}"
return state
def caption_image_dummy(state: State) -> State:
print("Caption")
voice = state["voice"]
state["caption"] = f"Dummy caption from {voice}"
return state
def describe_with_voice(state: State) -> State:
caption = state["caption"]
voice = state["voice"]
# Voice prompt templates
voice_prompts = {
"scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.",
"forgetful wizard": "You are a forgetful and easily distracted wizard.",
"sarcastic teenager": "You are a sarcastic and disinterested teenager.",
}
messages = [
{
"role": "system",
"content": [{"type": "text", "text": voice_prompts.get(voice)}],
},
{
"role": "user",
"content": [
{"type": "text", "text": f"Describe the following:\n\n{caption}"}
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False)
generation = generation[0][input_len:]
description = processor.decode(generation, skip_special_tokens=True)
state["description"] = description
print(description)
return state
def caption_image(state: State) -> State:
# image is PIL
image = state["image"]
# Load models (in practice, do this once and cache)
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a helpful assistant that will describe images in 3-5 sentences.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Describe this image."},
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False)
generation = generation[0][input_len:]
caption = processor.decode(generation, skip_special_tokens=True)
state["caption"] = caption
print(caption)
return state
|