yaleh commited on
Commit
b090732
·
1 Parent(s): 504903f

Dict state works.

Browse files
meta_prompt/meta_prompt.py CHANGED
@@ -8,7 +8,7 @@ from langgraph.errors import GraphRecursionError
8
  from langgraph.graph import StateGraph, START, END
9
  from langchain_core.runnables.base import RunnableLike
10
  from pydantic import BaseModel
11
- from typing import Annotated, Dict, Optional, Union
12
  from .consts import *
13
 
14
  def first_non_empty(a, b):
@@ -22,6 +22,17 @@ def last_non_empty(a, b):
22
  def assign(a, b):
23
  return b
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  class AgentState(BaseModel):
26
  """
27
  Represents the state of an agent in a conversation.
@@ -40,18 +51,18 @@ class AgentState(BaseModel):
40
  - best_system_message (str, optional): The best system message.
41
  - best_output_age (int): The age of the best output.
42
  """
43
- max_output_age: Annotated[int, assign] = 0
44
- user_message: Annotated[Optional[str], assign] = None
45
- expected_output: Annotated[Optional[str], assign] = None
46
  acceptance_criteria: Annotated[Optional[str], last_non_empty] = None
47
  system_message: Annotated[Optional[str], last_non_empty] = None
48
- output: Annotated[Optional[str], assign] = None
49
- suggestions: Annotated[Optional[str], assign] = None
50
- accepted: Annotated[bool, assign] = False
51
- analysis: Annotated[Optional[str], assign] = None
52
- best_output: Annotated[Optional[str], assign] = None
53
- best_system_message: Annotated[Optional[str], assign] = None
54
- best_output_age: Annotated[int, assign] = 0
55
 
56
  class MetaPromptGraph:
57
  """
@@ -229,7 +240,7 @@ class MetaPromptGraph:
229
  NODE_PROMPT_INITIAL_DEVELOPER,
230
  "system_message",
231
  x),
232
- x))
233
  workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
234
  lambda x: self._optional_action(
235
  "acceptance_criteria",
@@ -237,7 +248,7 @@ class MetaPromptGraph:
237
  NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
238
  "acceptance_criteria",
239
  x),
240
- x))
241
  # workflow.add_node(START)
242
 
243
  workflow.add_edge(START, NODE_PROMPT_INITIAL_DEVELOPER)
@@ -331,8 +342,8 @@ class MetaPromptGraph:
331
  def _optional_action(
332
  self, target_attribute: str,
333
  action: RunnableLike,
334
- state: AgentState
335
- ) -> AgentState:
336
  """
337
  Optionally invokes an action if the target attribute is not set or empty.
338
 
@@ -345,13 +356,22 @@ class MetaPromptGraph:
345
  Returns:
346
  AgentState: Updated state.
347
  """
348
- if not getattr(state, target_attribute, None) or getattr(state, target_attribute) == "":
349
- if action:
350
- state = action(state)
351
- return state
 
 
 
 
 
 
352
 
353
 
354
- def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
 
 
 
355
  """
356
  Prompt a specific node with the given state and update the state with the response.
357
 
@@ -368,26 +388,37 @@ class MetaPromptGraph:
368
  """
369
 
370
  logger = self.logger.getChild(node)
371
- formatted_messages = self.prompt_templates[node].format_messages(**state.model_dump())
372
-
 
 
 
 
373
  for message in formatted_messages:
374
  logger.debug({
375
- 'node': node,
376
  'action': 'invoke',
377
- 'type': message.type,
378
  'message': message.content
379
  })
380
-
381
  response = self.llms[node].invoke(formatted_messages)
382
  logger.debug({
383
- 'node': node,
384
  'action': 'response',
385
- 'type': response.type,
386
  'message': response.content
387
  })
388
-
389
- setattr(state, target_attribute, response.content)
390
- return state
 
 
 
 
 
 
 
391
 
392
  def _output_history_analyzer(self, state: AgentState) -> AgentState:
393
  """
@@ -439,20 +470,30 @@ class MetaPromptGraph:
439
  "# Output ID closer to Expected Output: B" in analysis or
440
  (self.aggressive_exploration and
441
  "# Output ID closer to Expected Output: A" not in analysis)):
442
- state.best_output = state.output
443
- state.best_system_message = state.system_message
444
- state.best_output_age = 0
 
 
 
 
 
445
  logger.debug("Best output updated to the current output:\n%s",
446
- state.output)
447
  else:
448
- state.best_output_age += 1
449
- # rollback output and system message
450
- state.output = state.best_output
451
- state.system_message = state.best_system_message
 
 
 
 
 
452
  logger.debug("Best output age incremented to %s",
453
- state.best_output_age)
454
 
455
- return state
456
 
457
  def _prompt_analyzer(self, state: AgentState) -> AgentState:
458
  """
@@ -479,12 +520,20 @@ class MetaPromptGraph:
479
  logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
480
  'type': response.type, 'message': response.content})
481
 
482
- state.analysis = response.content
483
- state.accepted = "Accept: Yes" in response.content
484
 
485
- logger.debug("Accepted: %s", state.accepted)
486
 
487
- return state
 
 
 
 
 
 
 
 
488
 
489
  def _should_exit_on_max_age(self, state: AgentState) -> str:
490
  """
 
8
  from langgraph.graph import StateGraph, START, END
9
  from langchain_core.runnables.base import RunnableLike
10
  from pydantic import BaseModel
11
+ from typing import Annotated, Dict, Optional, Union, TypedDict
12
  from .consts import *
13
 
14
  def first_non_empty(a, b):
 
22
  def assign(a, b):
23
  return b
24
 
25
+ class InitialAgentState(TypedDict):
26
+ """
27
+ Represents the state of an agent in a conversation.
28
+ """
29
+ max_output_age: int
30
+ user_message: Optional[str]
31
+ expected_output: Optional[str]
32
+ acceptance_criteria: Annotated[Optional[str], last_non_empty]
33
+ system_message: Annotated[Optional[str], last_non_empty]
34
+
35
+
36
  class AgentState(BaseModel):
37
  """
38
  Represents the state of an agent in a conversation.
 
51
  - best_system_message (str, optional): The best system message.
52
  - best_output_age (int): The age of the best output.
53
  """
54
+ max_output_age: int = 0
55
+ user_message: Optional[str] = None
56
+ expected_output: Optional[str] = None
57
  acceptance_criteria: Annotated[Optional[str], last_non_empty] = None
58
  system_message: Annotated[Optional[str], last_non_empty] = None
59
+ output: Optional[str] = None
60
+ suggestions: Optional[str] = None
61
+ accepted: bool = False
62
+ analysis: Optional[str] = None
63
+ best_output: Optional[str] = None
64
+ best_system_message: Optional[str] = None
65
+ best_output_age: int = 0
66
 
67
  class MetaPromptGraph:
68
  """
 
240
  NODE_PROMPT_INITIAL_DEVELOPER,
241
  "system_message",
242
  x),
243
+ InitialAgentState(**(x.model_dump()))))
244
  workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
245
  lambda x: self._optional_action(
246
  "acceptance_criteria",
 
248
  NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
249
  "acceptance_criteria",
250
  x),
251
+ InitialAgentState(**(x.model_dump()))))
252
  # workflow.add_node(START)
253
 
254
  workflow.add_edge(START, NODE_PROMPT_INITIAL_DEVELOPER)
 
342
  def _optional_action(
343
  self, target_attribute: str,
344
  action: RunnableLike,
345
+ state: Union[AgentState, InitialAgentState]
346
+ ) -> Union[AgentState, InitialAgentState]:
347
  """
348
  Optionally invokes an action if the target attribute is not set or empty.
349
 
 
356
  Returns:
357
  AgentState: Updated state.
358
  """
359
+ result = {
360
+ target_attribute: state.get(target_attribute, "")
361
+ if isinstance(state, dict)
362
+ else getattr(state, target_attribute, "")
363
+ }
364
+
365
+ if action is not None and not result[target_attribute]:
366
+ result = action(state)
367
+
368
+ return result
369
 
370
 
371
+ def _prompt_node(
372
+ self, node: str, target_attribute: str,
373
+ state: Union[AgentState, InitialAgentState]
374
+ ) -> Union[AgentState, InitialAgentState]:
375
  """
376
  Prompt a specific node with the given state and update the state with the response.
377
 
 
388
  """
389
 
390
  logger = self.logger.getChild(node)
391
+ formatted_messages = (
392
+ self.prompt_templates[node].format_messages(
393
+ **(state.model_dump() if isinstance(state, BaseModel) else state)
394
+ )
395
+ )
396
+
397
  for message in formatted_messages:
398
  logger.debug({
399
+ 'node': node,
400
  'action': 'invoke',
401
+ 'type': message.type,
402
  'message': message.content
403
  })
404
+
405
  response = self.llms[node].invoke(formatted_messages)
406
  logger.debug({
407
+ 'node': node,
408
  'action': 'response',
409
+ 'type': response.type,
410
  'message': response.content
411
  })
412
+
413
+ # if isinstance(state, dict):
414
+ # # state[target_attribute] = response.content
415
+ # # Create a dict with the target key only
416
+ # state = {target_attribute: response.content}
417
+ # else:
418
+ # setattr(state, target_attribute, response.content)
419
+ # return state
420
+
421
+ return {target_attribute: response.content}
422
 
423
  def _output_history_analyzer(self, state: AgentState) -> AgentState:
424
  """
 
470
  "# Output ID closer to Expected Output: B" in analysis or
471
  (self.aggressive_exploration and
472
  "# Output ID closer to Expected Output: A" not in analysis)):
473
+ # state.best_output = state.output
474
+ # state.best_system_message = state.system_message
475
+ # state.best_output_age = 0
476
+ result_dict = {
477
+ "best_output": state.output,
478
+ "best_system_message": state.system_message,
479
+ "best_output_age": 0
480
+ }
481
  logger.debug("Best output updated to the current output:\n%s",
482
+ result_dict["best_output"])
483
  else:
484
+ # state.best_output_age += 1
485
+ # # rollback output and system message
486
+ # state.output = state.best_output
487
+ # state.system_message = state.best_system_message
488
+ result_dict = {
489
+ "output": state.best_output,
490
+ "system_message": state.best_system_message,
491
+ "best_output_age": state.best_output_age + 1
492
+ }
493
  logger.debug("Best output age incremented to %s",
494
+ result_dict["best_output_age"])
495
 
496
+ return result_dict
497
 
498
  def _prompt_analyzer(self, state: AgentState) -> AgentState:
499
  """
 
520
  logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
521
  'type': response.type, 'message': response.content})
522
 
523
+ # state.analysis = response.content
524
+ # state.accepted = "Accept: Yes" in response.content
525
 
526
+ # logger.debug("Accepted: %s", state.accepted)
527
 
528
+ # return state
529
+
530
+ result_dict = {
531
+ "analysis": response.content,
532
+ "accepted": "Accept: Yes" in response.content
533
+ }
534
+ logger.debug("Accepted: %s", result_dict["accepted"])
535
+
536
+ return result_dict
537
 
538
  def _should_exit_on_max_age(self, state: AgentState) -> str:
539
  """
tests/meta_prompt_graph_test.py CHANGED
@@ -180,6 +180,7 @@ class TestMetaPromptGraph(unittest.TestCase):
180
 
181
  llms = {
182
  NODE_PROMPT_INITIAL_DEVELOPER: optimizer_llm,
 
183
  NODE_PROMPT_DEVELOPER: optimizer_llm,
184
  NODE_PROMPT_EXECUTOR: executor_llm,
185
  NODE_OUTPUT_HISTORY_ANALYZER: optimizer_llm,
@@ -192,7 +193,7 @@ class TestMetaPromptGraph(unittest.TestCase):
192
  user_message="How do I reverse a list in Python?",
193
  expected_output="Use the `[::-1]` slicing technique or the "
194
  "`list.reverse()` method.",
195
- acceptance_criteria="Similar in meaning, text length and style."
196
  )
197
  output_state = meta_prompt_graph(input_state, recursion_limit=25)
198
 
 
180
 
181
  llms = {
182
  NODE_PROMPT_INITIAL_DEVELOPER: optimizer_llm,
183
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER: optimizer_llm,
184
  NODE_PROMPT_DEVELOPER: optimizer_llm,
185
  NODE_PROMPT_EXECUTOR: executor_llm,
186
  NODE_OUTPUT_HISTORY_ANALYZER: optimizer_llm,
 
193
  user_message="How do I reverse a list in Python?",
194
  expected_output="Use the `[::-1]` slicing technique or the "
195
  "`list.reverse()` method.",
196
+ # acceptance_criteria="Similar in meaning, text length and style."
197
  )
198
  output_state = meta_prompt_graph(input_state, recursion_limit=25)
199