yaleh commited on
Commit
d535943
·
1 Parent(s): ec72229

Update Python doc strs.

Browse files
Files changed (1) hide show
  1. meta_prompt/meta_prompt.py +210 -154
meta_prompt/meta_prompt.py CHANGED
@@ -23,18 +23,18 @@ class AgentState(TypedDict):
23
  Represents the state of an agent in a conversation.
24
 
25
  Attributes:
26
- - max_output_age (int): The maximum age of the output.
27
- - user_message (str, optional): The user's message.
28
- - expected_output (str, optional): The expected output.
29
- - acceptance_criteria (str, optional): The acceptance criteria.
30
- - system_message (str, optional): The system message.
31
- - output (str, optional): The output.
32
- - suggestions (str, optional): The suggestions.
33
- - accepted (bool): Whether the output is accepted.
34
- - analysis (str, optional): The analysis.
35
- - best_output (str, optional): The best output.
36
- - best_system_message (str, optional): The best system message.
37
- - best_output_age (int): The age of the best output.
38
  """
39
  max_output_age: Optional[int]
40
  user_message: Optional[str]
@@ -80,41 +80,34 @@ class MetaPromptGraph:
80
  """
81
  return META_PROMPT_NODES
82
 
83
- def __init__(self,
84
- llms: Union[BaseLanguageModel,
85
- Dict[str, BaseLanguageModel]] = {},
86
- prompts: Dict[str, ChatPromptTemplate] = {},
87
- aggressive_exploration: bool = False,
88
- logger: Optional[logging.Logger] = None,
89
- verbose=False):
 
90
  """
91
  Initializes the MetaPromptGraph instance.
92
 
93
  Args:
94
- - llms (Union[BaseLanguageModel, Dict[str, BaseLanguageModel]],
95
- optional): The language models for the graph nodes. Defaults to {}.
96
- - prompts (Dict[str, ChatPromptTemplate], optional): The custom
97
- prompt templates for the graph nodes. Defaults to {}.
98
- - logger (Optional[logging.Logger], optional): The logger for
99
- the graph. Defaults to None.
100
- - verbose (bool, optional): Whether to set the logger level to
101
- DEBUG. Defaults to False.
102
-
103
- Initializes the logger, sets the language models and prompt
104
- templates for the graph nodes, and updates the prompt templates
105
- with custom ones if provided.
106
  """
107
  self.logger = logger or logging.getLogger(__name__)
108
  if self.logger is not None:
109
- if verbose:
110
- self.logger.setLevel(logging.DEBUG)
111
- else:
112
- self.logger.setLevel(logging.INFO)
113
 
114
  if isinstance(llms, BaseLanguageModel):
115
- # if llms is a single language model, wrap it in a dictionary
116
- self.llms: Dict[str, BaseLanguageModel] = {
117
- node: llms for node in self.get_node_names()}
118
  else:
119
  self.llms: Dict[str, BaseLanguageModel] = llms
120
  self.prompt_templates: Dict[str,
@@ -125,60 +118,82 @@ class MetaPromptGraph:
125
 
126
 
127
  def _create_acceptance_criteria_workflow(self) -> StateGraph:
 
 
 
 
 
 
128
  workflow = StateGraph(AgentState)
129
- workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
130
- lambda x: self._prompt_node(
131
- NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
132
- "acceptance_criteria",
133
- x))
 
 
 
134
  workflow.add_edge(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, END)
135
  workflow.set_entry_point(NODE_ACCEPTANCE_CRITERIA_DEVELOPER)
136
  return workflow
137
 
138
 
139
  def _create_prompt_initial_developer_workflow(self) -> StateGraph:
 
 
 
 
 
 
140
  workflow = StateGraph(AgentState)
141
- workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
142
- lambda x: self._prompt_node(
143
- NODE_PROMPT_INITIAL_DEVELOPER,
144
- "system_message",
145
- x))
 
 
 
146
  workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER, END)
147
  workflow.set_entry_point(NODE_PROMPT_INITIAL_DEVELOPER)
148
  return workflow
149
 
150
 
151
  def _create_workflow(self) -> StateGraph:
152
- """Create a workflow state graph.
153
-
154
- Args:
155
- including_initial_developer: Flag indicating whether to include the
156
- initial developer node in the workflow.
157
 
158
  Returns:
159
  StateGraph: A state graph representing the workflow.
160
  """
 
161
  workflow = StateGraph(AgentState)
162
 
163
- workflow.add_node(NODE_PROMPT_DEVELOPER,
164
- lambda x: self._prompt_node(
165
- NODE_PROMPT_DEVELOPER,
166
- "system_message",
167
- x))
168
- workflow.add_node(NODE_PROMPT_EXECUTOR,
169
- lambda x: self._prompt_node(
170
- NODE_PROMPT_EXECUTOR,
171
- "output",
172
- x))
173
- workflow.add_node(NODE_OUTPUT_HISTORY_ANALYZER,
174
- lambda x: self._output_history_analyzer(x))
175
- workflow.add_node(NODE_PROMPT_ANALYZER,
176
- lambda x: self._prompt_analyzer(x))
177
- workflow.add_node(NODE_PROMPT_SUGGESTER,
178
- lambda x: self._prompt_node(
179
- NODE_PROMPT_SUGGESTER,
180
- "suggestions",
181
- x))
 
 
 
 
 
 
182
 
183
  # Connect nodes
184
  workflow.add_edge(NODE_PROMPT_DEVELOPER, NODE_PROMPT_EXECUTOR)
@@ -195,7 +210,6 @@ class MetaPromptGraph:
195
  END: END
196
  }
197
  )
198
-
199
  workflow.add_conditional_edges(
200
  NODE_PROMPT_ANALYZER,
201
  lambda x: self._should_exit_on_acceptable_output(x),
@@ -205,26 +219,33 @@ class MetaPromptGraph:
205
  }
206
  )
207
 
208
- workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
209
- lambda x: self._optional_action(
210
- "system_message",
211
- lambda x: self._prompt_node(
212
- NODE_PROMPT_INITIAL_DEVELOPER,
213
- "system_message",
214
- x),
215
- x))
216
- workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
217
- lambda x: self._optional_action(
218
- "acceptance_criteria",
219
- lambda x: self._prompt_node(
220
- NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
221
- "acceptance_criteria",
222
- x),
223
- x))
 
 
 
 
 
 
 
224
 
 
225
  workflow.add_edge(START, NODE_PROMPT_INITIAL_DEVELOPER)
226
  workflow.add_edge(START, NODE_ACCEPTANCE_CRITERIA_DEVELOPER)
227
-
228
  workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER, NODE_PROMPT_EXECUTOR)
229
  workflow.add_edge(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, NODE_PROMPT_EXECUTOR)
230
 
@@ -232,6 +253,14 @@ class MetaPromptGraph:
232
 
233
 
234
  def run_acceptance_criteria_graph(self, state: AgentState) -> AgentState:
 
 
 
 
 
 
 
 
235
  self.logger.debug("Creating acceptance criteria workflow")
236
  workflow = self._create_acceptance_criteria_workflow()
237
  memory = MemorySaver()
@@ -241,9 +270,17 @@ class MetaPromptGraph:
241
  output_state = graph.invoke(state, config)
242
  self.logger.debug("Output state: %s", pprint.pformat(output_state))
243
  return output_state
244
-
245
 
246
  def run_prompt_initial_developer_graph(self, state: AgentState) -> AgentState:
 
 
 
 
 
 
 
 
247
  self.logger.debug("Creating prompt initial developer workflow")
248
  workflow = self._create_prompt_initial_developer_workflow()
249
  memory = MemorySaver()
@@ -255,16 +292,18 @@ class MetaPromptGraph:
255
  return output_state
256
 
257
 
258
- def run_meta_prompt_graph(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
 
 
259
  """
260
  Invoke the meta-prompt workflow with the given state and recursion limit.
261
 
262
  This method creates a workflow based on the presence of an initial system
263
  message, compiles the workflow with a memory saver, and invokes the graph
264
- with the given state. If a recursion limit is reached, it returns the best
265
- state found so far.
266
 
267
- Parameters:
268
  state (AgentState): The current state of the agent, containing
269
  necessary context for message formatting.
270
  recursion_limit (int): The maximum number of recursive calls
@@ -274,51 +313,52 @@ class MetaPromptGraph:
274
  AgentState: The output state of the agent after invoking the workflow.
275
  """
276
  workflow = self._create_workflow()
277
-
278
  memory = MemorySaver()
279
  graph = workflow.compile(checkpointer=memory)
280
- config = {"configurable": {"thread_id": "1"},
281
- "recursion_limit": recursion_limit}
 
 
282
 
283
  try:
284
- self.logger.debug("Invoking graph with state: %s",
285
- pprint.pformat(state))
286
-
287
  output_state = graph.invoke(state, config)
288
-
289
  self.logger.debug("Output state: %s", pprint.pformat(output_state))
290
-
291
  return output_state
292
  except GraphRecursionError as e:
293
- self.logger.info(
294
- "Recursion limit reached. Returning the best state found so far.")
295
  checkpoint_states = graph.get_state(config)
296
 
297
- # if the length of states is bigger than 0, print the best system message and output
298
- if len(checkpoint_states) > 0:
299
  output_state = checkpoint_states[0]
300
  return output_state
301
  else:
302
- self.logger.info(
303
- "No checkpoint states found. Returning the input state.")
304
 
305
- return state
306
 
 
 
 
 
307
 
308
- def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
 
 
 
 
 
 
309
  return self.run_meta_prompt_graph(state, recursion_limit)
310
 
311
 
312
  def _optional_action(
313
- self, target_attribute: str,
314
- action: RunnableLike,
315
- state: AgentState
316
  ) -> AgentState:
317
  """
318
  Optionally invokes an action if the target attribute is not set or empty.
319
 
320
  Args:
321
- node (str): Node identifier.
322
  target_attribute (str): State attribute to be updated.
323
  action (RunnableLike): Action to be invoked. Defaults to None.
324
  state (AgentState): Current agent state.
@@ -327,9 +367,11 @@ class MetaPromptGraph:
327
  AgentState: Updated state.
328
  """
329
  result = {
330
- target_attribute: state.get(target_attribute, "")
331
- if isinstance(state, dict)
332
- else getattr(state, target_attribute, "")
 
 
333
  }
334
 
335
  if action is not None and not result[target_attribute]:
@@ -339,16 +381,14 @@ class MetaPromptGraph:
339
 
340
 
341
  def _prompt_node(
342
- self, node: str, target_attribute: str,
343
- state: AgentState
344
  ) -> AgentState:
345
- """
346
- Prompt a specific node with the given state and update the state with the response.
347
 
348
  This method formats messages using the prompt template associated with the node,
349
  logs the invocation and response, and updates the state with the response content.
350
 
351
- Parameters:
352
  node (str): Node identifier to be prompted.
353
  target_attribute (str): State attribute to be updated with response content.
354
  state (AgentState): Current agent state with necessary context for message formatting.
@@ -365,32 +405,37 @@ class MetaPromptGraph:
365
  )
366
 
367
  for message in formatted_messages:
368
- logger.debug({
369
- 'node': node,
370
- 'action': 'invoke',
371
- 'type': message.type,
372
- 'message': message.content
373
- })
 
 
374
 
375
  response = self.llms[node].invoke(formatted_messages)
376
- logger.debug({
377
- 'node': node,
378
- 'action': 'response',
379
- 'type': response.type,
380
- 'message': response.content
381
- })
 
 
382
 
383
  return {target_attribute: response.content}
384
 
 
385
  def _output_history_analyzer(self, state: AgentState) -> AgentState:
386
  """
387
  Analyzes the output history and updates the best output and its age.
388
 
389
  This method checks if the best output is initialized, formats the prompt for
390
- the output history analyzer, invokes the language model, and updates the best
391
- output and its age based on the response.
392
 
393
- Parameters:
394
  state (AgentState): Current state of the agent with necessary context
395
  for message formatting.
396
 
@@ -403,8 +448,8 @@ class MetaPromptGraph:
403
  state["best_output"] = state["output"]
404
  state["best_system_message"] = state["system_message"]
405
  state["best_output_age"] = 0
406
- logger.debug(
407
- "Best output initialized to the current output:\n%s", state["output"])
408
  return state
409
 
410
  prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
@@ -429,9 +474,9 @@ class MetaPromptGraph:
429
  analysis = response.content
430
 
431
  if (state["best_output"] is None or
432
- "# Output ID closer to Expected Output: B" in analysis or
433
- (self.aggressive_exploration and
434
- "# Output ID closer to Expected Output: A" not in analysis)):
435
  result_dict = {
436
  "best_output": state["output"],
437
  "best_system_message": state["system_message"],
@@ -450,6 +495,7 @@ class MetaPromptGraph:
450
 
451
  return result_dict
452
 
 
453
  def _prompt_analyzer(self, state: AgentState) -> AgentState:
454
  """
455
  Analyzes the prompt and updates the state with the analysis and
@@ -468,12 +514,20 @@ class MetaPromptGraph:
468
  **state)
469
 
470
  for message in prompt:
471
- logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'invoke',
472
- 'type': message.type, 'message': message.content})
 
 
 
 
473
 
474
  response = self.llms[NODE_PROMPT_ANALYZER].invoke(prompt)
475
- logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
476
- 'type': response.type, 'message': response.content})
 
 
 
 
477
 
478
  result_dict = {
479
  "analysis": response.content,
@@ -483,6 +537,7 @@ class MetaPromptGraph:
483
 
484
  return result_dict
485
 
 
486
  def _should_exit_on_max_age(self, state: AgentState) -> str:
487
  """
488
  Determines whether to exit the workflow based on the maximum output age.
@@ -494,21 +549,22 @@ class MetaPromptGraph:
494
  str: The decision to continue, rerun, or end the workflow.
495
  """
496
  if state["max_output_age"] <= 0:
497
- # always continue if max age is 0
498
- return "continue"
499
-
500
  if state["best_output_age"] >= state["max_output_age"]:
501
  return END
502
-
503
  if state["best_output_age"] > 0:
504
  # skip prompt_analyzer and prompt_suggester, goto prompt_developer
505
- return "rerun"
506
-
507
  return "continue"
508
 
 
509
  def _should_exit_on_acceptable_output(self, state: AgentState) -> str:
510
  """
511
- Determines whether to exit the workflow based on the acceptance status of the output.
 
512
 
513
  Args:
514
  state (AgentState): The current state of the agent.
 
23
  Represents the state of an agent in a conversation.
24
 
25
  Attributes:
26
+ max_output_age (int): The maximum age of the output.
27
+ user_message (str, optional): The user's message.
28
+ expected_output (str, optional): The expected output.
29
+ acceptance_criteria (str, optional): The acceptance criteria.
30
+ system_message (str, optional): The system message.
31
+ output (str, optional): The output.
32
+ suggestions (str, optional): The suggestions.
33
+ accepted (bool, optional): Whether the output is accepted.
34
+ analysis (str, optional): The analysis.
35
+ best_output (str, optional): The best output.
36
+ best_system_message (str, optional): The best system message.
37
+ best_output_age (int, optional): The age of the best output.
38
  """
39
  max_output_age: Optional[int]
40
  user_message: Optional[str]
 
80
  """
81
  return META_PROMPT_NODES
82
 
83
+ def __init__(
84
+ self,
85
+ llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]] = {},
86
+ prompts: Dict[str, ChatPromptTemplate] = {},
87
+ aggressive_exploration: bool = False,
88
+ logger: Optional[logging.Logger] = None,
89
+ verbose: bool = False,
90
+ ):
91
  """
92
  Initializes the MetaPromptGraph instance.
93
 
94
  Args:
95
+ llms: The language models for the graph nodes.
96
+ prompts: The custom prompt templates for the graph nodes.
97
+ aggressive_exploration: Whether to use aggressive exploration.
98
+ logger: The logger for the graph.
99
+ verbose: Whether to set the logger level to DEBUG.
100
+
101
+ Initializes the logger, sets the language models and prompt templates
102
+ for the graph nodes, and updates the prompt templates with custom ones
103
+ if provided.
 
 
 
104
  """
105
  self.logger = logger or logging.getLogger(__name__)
106
  if self.logger is not None:
107
+ self.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
 
 
 
108
 
109
  if isinstance(llms, BaseLanguageModel):
110
+ self.llms = {node: llms for node in self.get_node_names()}
 
 
111
  else:
112
  self.llms: Dict[str, BaseLanguageModel] = llms
113
  self.prompt_templates: Dict[str,
 
118
 
119
 
120
  def _create_acceptance_criteria_workflow(self) -> StateGraph:
121
+ """
122
+ Create a workflow state graph for acceptance criteria.
123
+
124
+ Returns:
125
+ StateGraph: A state graph representing the workflow.
126
+ """
127
  workflow = StateGraph(AgentState)
128
+ workflow.add_node(
129
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
130
+ lambda x: self._prompt_node(
131
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
132
+ "acceptance_criteria",
133
+ x
134
+ )
135
+ )
136
  workflow.add_edge(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, END)
137
  workflow.set_entry_point(NODE_ACCEPTANCE_CRITERIA_DEVELOPER)
138
  return workflow
139
 
140
 
141
  def _create_prompt_initial_developer_workflow(self) -> StateGraph:
142
+ """
143
+ Create a workflow state graph for the initial developer prompt.
144
+
145
+ Returns:
146
+ StateGraph: A state graph representing the workflow.
147
+ """
148
  workflow = StateGraph(AgentState)
149
+ workflow.add_node(
150
+ NODE_PROMPT_INITIAL_DEVELOPER,
151
+ lambda x: self._prompt_node(
152
+ NODE_PROMPT_INITIAL_DEVELOPER,
153
+ "system_message",
154
+ x
155
+ )
156
+ )
157
  workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER, END)
158
  workflow.set_entry_point(NODE_PROMPT_INITIAL_DEVELOPER)
159
  return workflow
160
 
161
 
162
  def _create_workflow(self) -> StateGraph:
163
+ """
164
+ Create a workflow state graph for the meta-prompt.
 
 
 
165
 
166
  Returns:
167
  StateGraph: A state graph representing the workflow.
168
  """
169
+
170
  workflow = StateGraph(AgentState)
171
 
172
+ # Add nodes
173
+ workflow.add_node(
174
+ NODE_PROMPT_DEVELOPER,
175
+ lambda x: self._prompt_node(
176
+ NODE_PROMPT_DEVELOPER, "system_message", x
177
+ )
178
+ )
179
+ workflow.add_node(
180
+ NODE_PROMPT_EXECUTOR,
181
+ lambda x: self._prompt_node(NODE_PROMPT_EXECUTOR, "output", x)
182
+ )
183
+ workflow.add_node(
184
+ NODE_OUTPUT_HISTORY_ANALYZER,
185
+ lambda x: self._output_history_analyzer(x)
186
+ )
187
+ workflow.add_node(
188
+ NODE_PROMPT_ANALYZER,
189
+ lambda x: self._prompt_analyzer(x)
190
+ )
191
+ workflow.add_node(
192
+ NODE_PROMPT_SUGGESTER,
193
+ lambda x: self._prompt_node(
194
+ NODE_PROMPT_SUGGESTER, "suggestions", x
195
+ )
196
+ )
197
 
198
  # Connect nodes
199
  workflow.add_edge(NODE_PROMPT_DEVELOPER, NODE_PROMPT_EXECUTOR)
 
210
  END: END
211
  }
212
  )
 
213
  workflow.add_conditional_edges(
214
  NODE_PROMPT_ANALYZER,
215
  lambda x: self._should_exit_on_acceptable_output(x),
 
219
  }
220
  )
221
 
222
+ # Add optional nodes
223
+ workflow.add_node(
224
+ NODE_PROMPT_INITIAL_DEVELOPER,
225
+ lambda x: self._optional_action(
226
+ "system_message",
227
+ lambda x: self._prompt_node(
228
+ NODE_PROMPT_INITIAL_DEVELOPER, "system_message", x
229
+ ),
230
+ x
231
+ )
232
+ )
233
+ workflow.add_node(
234
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
235
+ lambda x: self._optional_action(
236
+ "acceptance_criteria",
237
+ lambda x: self._prompt_node(
238
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
239
+ "acceptance_criteria",
240
+ x
241
+ ),
242
+ x
243
+ )
244
+ )
245
 
246
+ # Add edges to optional nodes
247
  workflow.add_edge(START, NODE_PROMPT_INITIAL_DEVELOPER)
248
  workflow.add_edge(START, NODE_ACCEPTANCE_CRITERIA_DEVELOPER)
 
249
  workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER, NODE_PROMPT_EXECUTOR)
250
  workflow.add_edge(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, NODE_PROMPT_EXECUTOR)
251
 
 
253
 
254
 
255
  def run_acceptance_criteria_graph(self, state: AgentState) -> AgentState:
256
+ """Run the acceptance criteria graph with the given state.
257
+
258
+ Args:
259
+ state (AgentState): The current state of the agent.
260
+
261
+ Returns:
262
+ AgentState: The output state of the agent after invoking the graph.
263
+ """
264
  self.logger.debug("Creating acceptance criteria workflow")
265
  workflow = self._create_acceptance_criteria_workflow()
266
  memory = MemorySaver()
 
270
  output_state = graph.invoke(state, config)
271
  self.logger.debug("Output state: %s", pprint.pformat(output_state))
272
  return output_state
273
+
274
 
275
  def run_prompt_initial_developer_graph(self, state: AgentState) -> AgentState:
276
+ """Run the prompt initial developer graph with the given state.
277
+
278
+ Args:
279
+ state (AgentState): The current state of the agent.
280
+
281
+ Returns:
282
+ AgentState: The output state of the agent after invoking the graph.
283
+ """
284
  self.logger.debug("Creating prompt initial developer workflow")
285
  workflow = self._create_prompt_initial_developer_workflow()
286
  memory = MemorySaver()
 
292
  return output_state
293
 
294
 
295
+ def run_meta_prompt_graph(
296
+ self, state: AgentState, recursion_limit: int = 25
297
+ ) -> AgentState:
298
  """
299
  Invoke the meta-prompt workflow with the given state and recursion limit.
300
 
301
  This method creates a workflow based on the presence of an initial system
302
  message, compiles the workflow with a memory saver, and invokes the graph
303
+ with the given state. If a recursion limit is reached, it returns the
304
+ best state found so far.
305
 
306
+ Args:
307
  state (AgentState): The current state of the agent, containing
308
  necessary context for message formatting.
309
  recursion_limit (int): The maximum number of recursive calls
 
313
  AgentState: The output state of the agent after invoking the workflow.
314
  """
315
  workflow = self._create_workflow()
 
316
  memory = MemorySaver()
317
  graph = workflow.compile(checkpointer=memory)
318
+ config = {
319
+ "configurable": {"thread_id": "1"},
320
+ "recursion_limit": recursion_limit,
321
+ }
322
 
323
  try:
324
+ self.logger.debug("Invoking graph with state: %s", pprint.pformat(state))
 
 
325
  output_state = graph.invoke(state, config)
 
326
  self.logger.debug("Output state: %s", pprint.pformat(output_state))
 
327
  return output_state
328
  except GraphRecursionError as e:
329
+ self.logger.info("Recursion limit reached. Returning the best state found so far.")
 
330
  checkpoint_states = graph.get_state(config)
331
 
332
+ if checkpoint_states:
 
333
  output_state = checkpoint_states[0]
334
  return output_state
335
  else:
336
+ self.logger.info("No checkpoint states found. Returning the input state.")
337
+ return state
338
 
 
339
 
340
+ def __call__(
341
+ self, state: AgentState, recursion_limit: int = 25
342
+ ) -> AgentState:
343
+ """Invoke the meta-prompt workflow with the given state and recursion limit.
344
 
345
+ Args:
346
+ state (AgentState): The current state of the agent.
347
+ recursion_limit (int): The maximum number of recursive calls allowed.
348
+
349
+ Returns:
350
+ AgentState: The output state of the agent after invoking the workflow.
351
+ """
352
  return self.run_meta_prompt_graph(state, recursion_limit)
353
 
354
 
355
  def _optional_action(
356
+ self, target_attribute: str, action: RunnableLike, state: AgentState
 
 
357
  ) -> AgentState:
358
  """
359
  Optionally invokes an action if the target attribute is not set or empty.
360
 
361
  Args:
 
362
  target_attribute (str): State attribute to be updated.
363
  action (RunnableLike): Action to be invoked. Defaults to None.
364
  state (AgentState): Current agent state.
 
367
  AgentState: Updated state.
368
  """
369
  result = {
370
+ target_attribute: (
371
+ state.get(target_attribute, "")
372
+ if isinstance(state, dict)
373
+ else getattr(state, target_attribute, "")
374
+ )
375
  }
376
 
377
  if action is not None and not result[target_attribute]:
 
381
 
382
 
383
  def _prompt_node(
384
+ self, node: str, target_attribute: str, state: AgentState
 
385
  ) -> AgentState:
386
+ """Prompt a specific node with the given state and update the state with the response.
 
387
 
388
  This method formats messages using the prompt template associated with the node,
389
  logs the invocation and response, and updates the state with the response content.
390
 
391
+ Args:
392
  node (str): Node identifier to be prompted.
393
  target_attribute (str): State attribute to be updated with response content.
394
  state (AgentState): Current agent state with necessary context for message formatting.
 
405
  )
406
 
407
  for message in formatted_messages:
408
+ logger.debug(
409
+ {
410
+ 'node': node,
411
+ 'action': 'invoke',
412
+ 'type': message.type,
413
+ 'message': message.content
414
+ }
415
+ )
416
 
417
  response = self.llms[node].invoke(formatted_messages)
418
+ logger.debug(
419
+ {
420
+ 'node': node,
421
+ 'action': 'response',
422
+ 'type': response.type,
423
+ 'message': response.content
424
+ }
425
+ )
426
 
427
  return {target_attribute: response.content}
428
 
429
+
430
  def _output_history_analyzer(self, state: AgentState) -> AgentState:
431
  """
432
  Analyzes the output history and updates the best output and its age.
433
 
434
  This method checks if the best output is initialized, formats the prompt for
435
+ the output history analyzer, invokes the language model, and updates the
436
+ best output and its age based on the response.
437
 
438
+ Args:
439
  state (AgentState): Current state of the agent with necessary context
440
  for message formatting.
441
 
 
448
  state["best_output"] = state["output"]
449
  state["best_system_message"] = state["system_message"]
450
  state["best_output_age"] = 0
451
+ logger.debug("Best output initialized to the current output:\n%s",
452
+ state["output"])
453
  return state
454
 
455
  prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
 
474
  analysis = response.content
475
 
476
  if (state["best_output"] is None or
477
+ "# Output ID closer to Expected Output: B" in analysis or
478
+ (self.aggressive_exploration and
479
+ "# Output ID closer to Expected Output: A" not in analysis)):
480
  result_dict = {
481
  "best_output": state["output"],
482
  "best_system_message": state["system_message"],
 
495
 
496
  return result_dict
497
 
498
+
499
  def _prompt_analyzer(self, state: AgentState) -> AgentState:
500
  """
501
  Analyzes the prompt and updates the state with the analysis and
 
514
  **state)
515
 
516
  for message in prompt:
517
+ logger.debug({
518
+ 'node': NODE_PROMPT_ANALYZER,
519
+ 'action': 'invoke',
520
+ 'type': message.type,
521
+ 'message': message.content
522
+ })
523
 
524
  response = self.llms[NODE_PROMPT_ANALYZER].invoke(prompt)
525
+ logger.debug({
526
+ 'node': NODE_PROMPT_ANALYZER,
527
+ 'action': 'response',
528
+ 'type': response.type,
529
+ 'message': response.content
530
+ })
531
 
532
  result_dict = {
533
  "analysis": response.content,
 
537
 
538
  return result_dict
539
 
540
+
541
  def _should_exit_on_max_age(self, state: AgentState) -> str:
542
  """
543
  Determines whether to exit the workflow based on the maximum output age.
 
549
  str: The decision to continue, rerun, or end the workflow.
550
  """
551
  if state["max_output_age"] <= 0:
552
+ return "continue" # always continue if max age is 0
553
+
 
554
  if state["best_output_age"] >= state["max_output_age"]:
555
  return END
556
+
557
  if state["best_output_age"] > 0:
558
  # skip prompt_analyzer and prompt_suggester, goto prompt_developer
559
+ return "rerun"
560
+
561
  return "continue"
562
 
563
+
564
  def _should_exit_on_acceptable_output(self, state: AgentState) -> str:
565
  """
566
+ Determines whether to exit the workflow based on the acceptance status of
567
+ the output.
568
 
569
  Args:
570
  state (AgentState): The current state of the agent.