Dylan commited on
Commit
b5b9453
·
1 Parent(s): 7d14b9f

some formatting

Browse files
Files changed (2) hide show
  1. agents.py +24 -31
  2. app.py +17 -11
agents.py CHANGED
@@ -24,7 +24,7 @@ def get_quantization_config():
24
  # Define the state schema
25
  class State(TypedDict):
26
  image: Any
27
- voice: str
28
  caption: str
29
  descriptions: Annotated[list, operator.add]
30
 
@@ -40,7 +40,6 @@ def build_graph():
40
  workflow.set_entry_point("caption_image")
41
 
42
  workflow.add_conditional_edges("caption_image", map_describe, ["describe_with_voice"])
43
- # workflow.add_edge("caption_image", "describe_with_voice")
44
  workflow.add_edge("describe_with_voice", END)
45
 
46
  # Compile the graph
@@ -59,23 +58,10 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
59
  ).eval()
60
 
61
 
62
- def describe_with_voice_dummy(state: State) -> State:
63
- print("Describe")
64
- voice = state["voice"]
65
- state["description"] = f"Dummy description from {voice}"
66
- return state
67
-
68
-
69
- def caption_image_dummy(state: State) -> State:
70
- print("Caption")
71
- voice = state["voice"]
72
- state["caption"] = f"Dummy caption from {voice}"
73
- return state
74
-
75
-
76
- def describe_with_voice(state: State) -> State:
77
  caption = state["caption"]
78
- voice = state["voice"]
 
79
 
80
  # Voice prompt templates
81
  voice_prompts = {
@@ -108,24 +94,33 @@ def describe_with_voice(state: State) -> State:
108
  input_len = inputs["input_ids"].shape[-1]
109
 
110
  with torch.inference_mode():
111
- generation = model.generate(**inputs, max_new_tokens=1000, do_sample=True, temperature=0.7)
112
  generation = generation[0][input_len:]
113
 
114
  description = processor.decode(generation, skip_special_tokens=True)
115
 
116
- # note that the return value is a list
117
- state["description"] = [description]
118
- print(description)
119
 
120
- return state
 
121
 
122
 
123
  def map_describe(state: State) -> list:
124
- # return list of `Send ` objects (3)
125
- return [Send("describe_with_voice", {"caption" : state["caption"], "voice": state["voice"]})] * 3
126
-
127
-
128
- def caption_image(state: State) -> State:
 
 
 
 
 
 
 
 
 
129
  # image is PIL
130
  image = state["image"]
131
  image = image_to_base64(image)
@@ -163,8 +158,6 @@ def caption_image(state: State) -> State:
163
  generation = generation[0][input_len:]
164
 
165
  caption = processor.decode(generation, skip_special_tokens=True)
166
-
167
- state["caption"] = caption
168
  print(caption)
169
 
170
- return state
 
24
  # Define the state schema
25
  class State(TypedDict):
26
  image: Any
27
+ voices: list
28
  caption: str
29
  descriptions: Annotated[list, operator.add]
30
 
 
40
  workflow.set_entry_point("caption_image")
41
 
42
  workflow.add_conditional_edges("caption_image", map_describe, ["describe_with_voice"])
 
43
  workflow.add_edge("describe_with_voice", END)
44
 
45
  # Compile the graph
 
58
  ).eval()
59
 
60
 
61
+ def describe_with_voice(state: State):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  caption = state["caption"]
63
+ # select one by default shakespeare
64
+ voice = state.get("voice", state.get("voices", ["shakespearian"])[0])
65
 
66
  # Voice prompt templates
67
  voice_prompts = {
 
94
  input_len = inputs["input_ids"].shape[-1]
95
 
96
  with torch.inference_mode():
97
+ generation = model.generate(**inputs, max_new_tokens=1000, do_sample=True, temperature=0.9)
98
  generation = generation[0][input_len:]
99
 
100
  description = processor.decode(generation, skip_special_tokens=True)
101
 
102
+ formatted_description = f"#{voice.title()}\n{description}"
103
+ print(formatted_description)
 
104
 
105
+ # note that the return value is a list
106
+ return {"descriptions": [formatted_description]}
107
 
108
 
109
  def map_describe(state: State) -> list:
110
+ # Create a Send object for each selected voice
111
+ selected_voices = state["voices"]
112
+
113
+ # Generate description tasks for each selected voice
114
+ send_objects = []
115
+ for voice in selected_voices:
116
+ send_objects.append(
117
+ Send("describe_with_voice", {"caption": state["caption"], "voice": voice})
118
+ )
119
+
120
+ return send_objects
121
+
122
+
123
+ def caption_image(state: State):
124
  # image is PIL
125
  image = state["image"]
126
  image = image_to_base64(image)
 
158
  generation = generation[0][input_len:]
159
 
160
  caption = processor.decode(generation, skip_special_tokens=True)
 
 
161
  print(caption)
162
 
163
+ return {"caption" : caption}
app.py CHANGED
@@ -8,9 +8,12 @@ graph = build_graph()
8
 
9
 
10
  @spaces.GPU(duration=60)
11
- def process_and_display(image, voice):
 
 
 
12
  # Initialize state
13
- state = {"image": image, "voice": voice, "caption": "", "description": ""}
14
 
15
  # Run the graph
16
  result = graph.invoke(state, {"max_concurrency" : 1})
@@ -26,11 +29,13 @@ def create_interface():
26
  with gr.Blocks() as demo:
27
  gr.Markdown("# Image Description with Voice Personas")
28
  gr.Markdown("""
29
- This app takes an image and generates a description using a selected voice persona.
30
 
31
  1. Upload an image
32
- 2. Select a voice persona from the dropdown
33
  3. Click "Generate Description" to see the results
 
 
34
  """)
35
 
36
  with gr.Row():
@@ -39,19 +44,20 @@ def create_interface():
39
  voice_dropdown = gr.Dropdown(
40
  choices=[
41
  "scurvy-ridden pirate",
42
- "forgetful wizard",
43
- "sarcastic teenager",
44
  "private investigator",
 
 
45
  "shakespearian"
46
  ],
47
- label="Select a Voice",
48
- value="scurvy-ridden pirate",
 
49
  )
50
  submit_button = gr.Button("Generate Description")
51
 
52
  with gr.Column():
53
- caption_output = gr.Textbox(label="Image Caption")
54
- description_output = gr.Textbox(label="Voice Description")
55
 
56
  submit_button.click(
57
  fn=process_and_display,
@@ -66,4 +72,4 @@ def create_interface():
66
  demo = create_interface()
67
 
68
  if __name__ == "__main__":
69
- demo.launch()
 
8
 
9
 
10
  @spaces.GPU(duration=60)
11
+ def process_and_display(image, voices):
12
+ if not voices: # If no voices selected
13
+ return "Please select at least one voice persona.", "No voice personas selected."
14
+
15
  # Initialize state
16
+ state = {"image": image, "voices": voices, "caption": "", "descriptions": []}
17
 
18
  # Run the graph
19
  result = graph.invoke(state, {"max_concurrency" : 1})
 
29
  with gr.Blocks() as demo:
30
  gr.Markdown("# Image Description with Voice Personas")
31
  gr.Markdown("""
32
+ This app takes an image and generates descriptions using selected voice personas.
33
 
34
  1. Upload an image
35
+ 2. Select voice personas from the multi-select dropdown
36
  3. Click "Generate Description" to see the results
37
+
38
+ The descriptions will be generated in parallel for all selected voices.
39
  """)
40
 
41
  with gr.Row():
 
44
  voice_dropdown = gr.Dropdown(
45
  choices=[
46
  "scurvy-ridden pirate",
 
 
47
  "private investigator",
48
+ "sarcastic teenager",
49
+ "forgetful wizard",
50
  "shakespearian"
51
  ],
52
+ label="Select Voice Personas (max 2 recommended)",
53
+ multiselect=True,
54
+ value=["scurvy-ridden pirate", "private investigator"]
55
  )
56
  submit_button = gr.Button("Generate Description")
57
 
58
  with gr.Column():
59
+ caption_output = gr.Textbox(label="Image Caption", lines=4)
60
+ description_output = gr.Textbox(label="Voice Descriptions", lines=10)
61
 
62
  submit_button.click(
63
  fn=process_and_display,
 
72
  demo = create_interface()
73
 
74
  if __name__ == "__main__":
75
+ demo.launch()