Dylan commited on
Commit
a4690cb
·
1 Parent(s): 98efca2

added description agents -- dummy

Browse files
Files changed (4) hide show
  1. agents.py +146 -0
  2. app.backup.py +9 -0
  3. app.py +59 -4
  4. helpers.py +9 -0
agents.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from langgraph.graph import END, StateGraph
3
+ from typing import TypedDict, Any
4
+
5
+ from transformers import (
6
+ AutoProcessor,
7
+ BitsAndBytesConfig,
8
+ Gemma3ForConditionalGeneration,
9
+ )
10
+
11
+
12
+ def get_quantization_config():
13
+ return BitsAndBytesConfig(
14
+ load_in_4bit=True,
15
+ bnb_4bit_compute_dtype=torch.float16,
16
+ bnb_4bit_quant_type="nf4",
17
+ bnb_4bit_use_double_quant=True,
18
+ )
19
+
20
+
21
+ # Define the state schema
22
+ class State(TypedDict):
23
+ image: Any
24
+ voice: str
25
+ caption: str
26
+ description: str
27
+
28
+
29
+ # Build the workflow graph
30
+ def build_graph():
31
+ workflow = StateGraph(State)
32
+
33
+ # Add nodes
34
+ workflow.add_node("caption_image", caption_image)
35
+ workflow.add_node("describe_with_voice", describe_with_voice)
36
+
37
+ # Add edges
38
+ workflow.set_entry_point("caption_image")
39
+ workflow.add_edge("caption_image", "describe_with_voice")
40
+ workflow.add_edge("describe_with_voice", END)
41
+
42
+ # Compile the graph
43
+ return workflow.compile()
44
+
45
+
46
+ model_id = "google/gemma-3-4b-it"
47
+
48
+ # Initialize processor and model
49
+ processor = AutoProcessor.from_pretrained(model_id)
50
+ model = Gemma3ForConditionalGeneration.from_pretrained(
51
+ model_id,
52
+ # quantization_config=get_quantization_config(),
53
+ device_map="auto",
54
+ torch_dtype=torch.float16,
55
+ )
56
+
57
+
58
+ def describe_with_voice(state: State) -> State:
59
+ state["description"] = "Dummy description"
60
+ return state
61
+
62
+
63
+ def caption_image(state: State) -> State:
64
+ state["caption"] = "Dummy caption"
65
+
66
+
67
+ def describe_with_voice2(state: State) -> State:
68
+ caption = state["caption"]
69
+ voice = state["voice"]
70
+
71
+ # Voice prompt templates
72
+ voice_prompts = {
73
+ "scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.",
74
+ "forgetful wizard": "You are a forgetful and easily distracted wizard.",
75
+ "sarcastic teenager": "You are a sarcastic and disinterested teenager.",
76
+ }
77
+ messages = [
78
+ {"role": "system", "content": [voice_prompts.get(voice)]},
79
+ {
80
+ "role": "user",
81
+ "content": [
82
+ {"type": "text", "text": f"Describe the following:\n\n{caption}"}
83
+ ],
84
+ },
85
+ ]
86
+ inputs = processor.apply_chat_template(
87
+ messages,
88
+ add_generation_prompt=True,
89
+ tokenize=True,
90
+ return_dict=True,
91
+ return_tensors="pt",
92
+ ).to(model.device, dtype=torch.bfloat16)
93
+ input_len = inputs["input_ids"].shape[-1]
94
+
95
+ with torch.inference_mode():
96
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
97
+ generation = generation[0][input_len:]
98
+
99
+ description = processor.decode(generation, skip_special_tokens=True)
100
+
101
+ state["description"] = description
102
+
103
+ return state
104
+
105
+
106
+ def caption_image2(state: State) -> State:
107
+ # image is PIL
108
+ image = state["image"]
109
+
110
+ # Load models (in practice, do this once and cache)
111
+ messages = [
112
+ {
113
+ "role": "system",
114
+ "content": [
115
+ {
116
+ "type": "text",
117
+ "text": "You are a helpful assistant that will describe images in 3-5 sentences.",
118
+ }
119
+ ],
120
+ },
121
+ {
122
+ "role": "user",
123
+ "content": [
124
+ {"type": "image", "image": image},
125
+ {"type": "text", "text": "Describe this image."},
126
+ ],
127
+ },
128
+ ]
129
+ inputs = processor.apply_chat_template(
130
+ messages,
131
+ add_generation_prompt=True,
132
+ tokenize=True,
133
+ return_dict=True,
134
+ return_tensors="pt",
135
+ ).to(model.device, dtype=torch.bfloat16)
136
+ input_len = inputs["input_ids"].shape[-1]
137
+
138
+ with torch.inference_mode():
139
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
140
+ generation = generation[0][input_len:]
141
+
142
+ caption = processor.decode(generation, skip_special_tokens=True)
143
+
144
+ state["caption"] = caption
145
+
146
+ return state
app.backup.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def greet(name):
5
+ return "Hello " + name + "!!"
6
+
7
+
8
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
9
+ demo.launch()
app.py CHANGED
@@ -1,7 +1,62 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ from agents import build_graph
 
4
 
5
+ # Initialize the graph
6
+ graph = build_graph()
7
+
8
+
9
+ def process_and_display(image, voice):
10
+ # Initialize state
11
+ state = {"image": image, "voice": voice, "caption": "", "description": ""}
12
+
13
+ # Run the graph
14
+ result = graph.invoke(state)
15
+
16
+ # Return the caption and description
17
+ return result["caption"], result["description"]
18
+
19
+
20
+ def create_interface():
21
+ with gr.Blocks() as demo:
22
+ gr.Markdown("# Image Description with Voice Personas")
23
+ gr.Markdown("""
24
+ This app takes an image and generates a description using a selected voice persona.
25
+
26
+ 1. Upload an image
27
+ 2. Select a voice persona from the dropdown
28
+ 3. Click "Generate Description" to see the results
29
+ """)
30
+
31
+ with gr.Row():
32
+ with gr.Column():
33
+ image_input = gr.Image(type="pil", label="Upload an Image")
34
+ voice_dropdown = gr.Dropdown(
35
+ choices=[
36
+ "scurvy-ridden pirate",
37
+ "forgetful wizard",
38
+ "sarcastic teenager",
39
+ ],
40
+ label="Select a Voice",
41
+ value="scurvy-ridden pirate",
42
+ )
43
+ submit_button = gr.Button("Generate Description")
44
+
45
+ with gr.Column():
46
+ caption_output = gr.Textbox(label="Image Caption")
47
+ description_output = gr.Textbox(label="Voice Description")
48
+
49
+ submit_button.click(
50
+ fn=process_and_display,
51
+ inputs=[image_input, voice_dropdown],
52
+ outputs=[caption_output, description_output],
53
+ )
54
+
55
+ return demo
56
+
57
+
58
+ # Launch the app
59
+ demo = create_interface()
60
+
61
+ if __name__ == "__main__":
62
+ demo.launch()
helpers.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+
4
+
5
+ def image_to_base64(image):
6
+ """Convert PIL Image to base64 encoded string"""
7
+ img_byte_arr = io.BytesIO()
8
+ image.save(img_byte_arr, format="JPEG")
9
+ return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")