Dylan commited on
Commit
598dcfa
·
1 Parent(s): 0160c44

added parallel map to call model multiple times

Browse files
Files changed (2) hide show
  1. agents.py +20 -12
  2. app.py +6 -1
agents.py CHANGED
@@ -1,6 +1,9 @@
 
 
1
  import torch
2
  from langgraph.graph import END, StateGraph
3
- from typing import TypedDict, Any
 
4
 
5
  from transformers import (
6
  AutoProcessor,
@@ -23,23 +26,21 @@ class State(TypedDict):
23
  image: Any
24
  voice: str
25
  caption: str
26
- description: str
27
 
28
 
29
  # Build the workflow graph
30
  def build_graph():
31
  workflow = StateGraph(State)
32
 
33
- # Add nodes
34
- # workflow.add_node("caption_image", caption_image_dummy)
35
- # workflow.add_node("describe_with_voice", describe_with_voice_dummy)
36
-
37
  workflow.add_node("caption_image", caption_image)
38
  workflow.add_node("describe_with_voice", describe_with_voice)
39
 
40
  # Add edges
41
  workflow.set_entry_point("caption_image")
42
- workflow.add_edge("caption_image", "describe_with_voice")
 
 
43
  workflow.add_edge("describe_with_voice", END)
44
 
45
  # Compile the graph
@@ -76,18 +77,19 @@ def describe_with_voice(state: State) -> State:
76
  caption = state["caption"]
77
  voice = state["voice"]
78
 
79
- caption = "A golden retriever that seems to be smiling straight to the camera"
80
-
81
  # Voice prompt templates
82
  voice_prompts = {
83
  "scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.",
84
  "forgetful wizard": "You are a forgetful and easily distracted wizard.",
85
  "sarcastic teenager": "You are a sarcastic and disinterested teenager.",
 
 
86
  }
 
87
  messages = [
88
  {
89
  "role": "system",
90
- "content": [{"type": "text", "text": voice_prompts.get(voice)}],
91
  },
92
  {
93
  "role": "user",
@@ -111,16 +113,22 @@ def describe_with_voice(state: State) -> State:
111
 
112
  description = processor.decode(generation, skip_special_tokens=True)
113
 
114
- state["description"] = description
 
115
  print(description)
116
 
117
  return state
118
 
119
 
 
 
 
 
 
120
  def caption_image(state: State) -> State:
121
  # image is PIL
122
  image = state["image"]
123
- image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
124
 
125
  # Load models (in practice, do this once and cache)
126
  messages = [
 
1
+ import operator
2
+ from helpers import image_to_base64
3
  import torch
4
  from langgraph.graph import END, StateGraph
5
+ from langgraph.types import Send
6
+ from typing import Annotated, TypedDict, Any
7
 
8
  from transformers import (
9
  AutoProcessor,
 
26
  image: Any
27
  voice: str
28
  caption: str
29
+ descriptions: Annotated[list, operator.add]
30
 
31
 
32
  # Build the workflow graph
33
  def build_graph():
34
  workflow = StateGraph(State)
35
 
 
 
 
 
36
  workflow.add_node("caption_image", caption_image)
37
  workflow.add_node("describe_with_voice", describe_with_voice)
38
 
39
  # Add edges
40
  workflow.set_entry_point("caption_image")
41
+
42
+ workflow.add_conditional_edges("caption_image", map_describe, ["describe_with_voice"])
43
+ # workflow.add_edge("caption_image", "describe_with_voice")
44
  workflow.add_edge("describe_with_voice", END)
45
 
46
  # Compile the graph
 
77
  caption = state["caption"]
78
  voice = state["voice"]
79
 
 
 
80
  # Voice prompt templates
81
  voice_prompts = {
82
  "scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.",
83
  "forgetful wizard": "You are a forgetful and easily distracted wizard.",
84
  "sarcastic teenager": "You are a sarcastic and disinterested teenager.",
85
+ "private investigator": "You are a Victorian-age detective. Suave and intellectual.",
86
+ "shakespearian": "Talk like one of Shakespeare's characters. ",
87
  }
88
+ system_prompt = voice_prompts.get(voice, "You are a pirate.") + " Output 5-10 sentences."
89
  messages = [
90
  {
91
  "role": "system",
92
+ "content": [{"type": "text", "text": system_prompt}],
93
  },
94
  {
95
  "role": "user",
 
113
 
114
  description = processor.decode(generation, skip_special_tokens=True)
115
 
116
+ # note that the return value is a list
117
+ state["description"] = [description]
118
  print(description)
119
 
120
  return state
121
 
122
 
123
+ def map_describe(state: State) -> list:
124
+ # return list of `Send ` objects (3)
125
+ return [Send("describe_with_voice", {"caption" : state["caption"], "voice": state["voice"]})] * 3
126
+
127
+
128
  def caption_image(state: State) -> State:
129
  # image is PIL
130
  image = state["image"]
131
+ image = image_to_base64(image)
132
 
133
  # Load models (in practice, do this once and cache)
134
  messages = [
app.py CHANGED
@@ -15,8 +15,11 @@ def process_and_display(image, voice):
15
  # Run the graph
16
  result = graph.invoke(state)
17
 
 
 
 
18
  # Return the caption and description
19
- return result["caption"], result["description"]
20
 
21
 
22
  def create_interface():
@@ -38,6 +41,8 @@ def create_interface():
38
  "scurvy-ridden pirate",
39
  "forgetful wizard",
40
  "sarcastic teenager",
 
 
41
  ],
42
  label="Select a Voice",
43
  value="scurvy-ridden pirate",
 
15
  # Run the graph
16
  result = graph.invoke(state)
17
 
18
+ descriptions:list[str] = result["descriptions"]
19
+ description = "\n---\n".join(descriptions)
20
+
21
  # Return the caption and description
22
+ return result["caption"], description
23
 
24
 
25
  def create_interface():
 
41
  "scurvy-ridden pirate",
42
  "forgetful wizard",
43
  "sarcastic teenager",
44
+ "private investigator",
45
+ "shakespearian"
46
  ],
47
  label="Select a Voice",
48
  value="scurvy-ridden pirate",