yaleh commited on
Commit
ec72229
·
1 Parent(s): b090732

Updated unit tests.

Browse files
meta_prompt/meta_prompt.py CHANGED
@@ -1,5 +1,4 @@
1
  import logging
2
- import operator
3
  import pprint
4
  from langchain_core.language_models import BaseLanguageModel
5
  from langchain_core.prompts import ChatPromptTemplate
@@ -19,21 +18,7 @@ def last_non_empty(a, b):
19
  # return the last non-none value
20
  return next((s for s in (b, a) if s), None)
21
 
22
- def assign(a, b):
23
- return b
24
-
25
- class InitialAgentState(TypedDict):
26
- """
27
- Represents the state of an agent in a conversation.
28
- """
29
- max_output_age: int
30
- user_message: Optional[str]
31
- expected_output: Optional[str]
32
- acceptance_criteria: Annotated[Optional[str], last_non_empty]
33
- system_message: Annotated[Optional[str], last_non_empty]
34
-
35
-
36
- class AgentState(BaseModel):
37
  """
38
  Represents the state of an agent in a conversation.
39
 
@@ -51,18 +36,18 @@ class AgentState(BaseModel):
51
  - best_system_message (str, optional): The best system message.
52
  - best_output_age (int): The age of the best output.
53
  """
54
- max_output_age: int = 0
55
- user_message: Optional[str] = None
56
- expected_output: Optional[str] = None
57
- acceptance_criteria: Annotated[Optional[str], last_non_empty] = None
58
- system_message: Annotated[Optional[str], last_non_empty] = None
59
- output: Optional[str] = None
60
- suggestions: Optional[str] = None
61
- accepted: bool = False
62
- analysis: Optional[str] = None
63
- best_output: Optional[str] = None
64
- best_system_message: Optional[str] = None
65
- best_output_age: int = 0
66
 
67
  class MetaPromptGraph:
68
  """
@@ -220,19 +205,6 @@ class MetaPromptGraph:
220
  }
221
  )
222
 
223
- # # Set entry point based on including_initial_developer flag
224
- # if including_initial_developer:
225
- # workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
226
- # lambda x: self._prompt_node(
227
- # NODE_PROMPT_INITIAL_DEVELOPER,
228
- # "system_message",
229
- # x))
230
- # workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER,
231
- # NODE_PROMPT_EXECUTOR)
232
- # workflow.set_entry_point(NODE_PROMPT_INITIAL_DEVELOPER)
233
- # else:
234
- # workflow.set_entry_point(NODE_PROMPT_EXECUTOR)
235
-
236
  workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
237
  lambda x: self._optional_action(
238
  "system_message",
@@ -240,7 +212,7 @@ class MetaPromptGraph:
240
  NODE_PROMPT_INITIAL_DEVELOPER,
241
  "system_message",
242
  x),
243
- InitialAgentState(**(x.model_dump()))))
244
  workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
245
  lambda x: self._optional_action(
246
  "acceptance_criteria",
@@ -248,15 +220,13 @@ class MetaPromptGraph:
248
  NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
249
  "acceptance_criteria",
250
  x),
251
- InitialAgentState(**(x.model_dump()))))
252
- # workflow.add_node(START)
253
 
254
  workflow.add_edge(START, NODE_PROMPT_INITIAL_DEVELOPER)
255
  workflow.add_edge(START, NODE_ACCEPTANCE_CRITERIA_DEVELOPER)
256
 
257
  workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER, NODE_PROMPT_EXECUTOR)
258
  workflow.add_edge(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, NODE_PROMPT_EXECUTOR)
259
- # workflow.set_entry_point(START)
260
 
261
  return workflow
262
 
@@ -342,8 +312,8 @@ class MetaPromptGraph:
342
  def _optional_action(
343
  self, target_attribute: str,
344
  action: RunnableLike,
345
- state: Union[AgentState, InitialAgentState]
346
- ) -> Union[AgentState, InitialAgentState]:
347
  """
348
  Optionally invokes an action if the target attribute is not set or empty.
349
 
@@ -370,8 +340,8 @@ class MetaPromptGraph:
370
 
371
  def _prompt_node(
372
  self, node: str, target_attribute: str,
373
- state: Union[AgentState, InitialAgentState]
374
- ) -> Union[AgentState, InitialAgentState]:
375
  """
376
  Prompt a specific node with the given state and update the state with the response.
377
 
@@ -410,14 +380,6 @@ class MetaPromptGraph:
410
  'message': response.content
411
  })
412
 
413
- # if isinstance(state, dict):
414
- # # state[target_attribute] = response.content
415
- # # Create a dict with the target key only
416
- # state = {target_attribute: response.content}
417
- # else:
418
- # setattr(state, target_attribute, response.content)
419
- # return state
420
-
421
  return {target_attribute: response.content}
422
 
423
  def _output_history_analyzer(self, state: AgentState) -> AgentState:
@@ -437,16 +399,16 @@ class MetaPromptGraph:
437
  """
438
  logger = self.logger.getChild(NODE_OUTPUT_HISTORY_ANALYZER)
439
 
440
- if state.best_output is None:
441
- state.best_output = state.output
442
- state.best_system_message = state.system_message
443
- state.best_output_age = 0
444
  logger.debug(
445
- "Best output initialized to the current output:\n%s", state.output)
446
  return state
447
 
448
  prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
449
- **state.model_dump())
450
 
451
  for message in prompt:
452
  logger.debug({
@@ -466,29 +428,22 @@ class MetaPromptGraph:
466
 
467
  analysis = response.content
468
 
469
- if (state.best_output is None or
470
  "# Output ID closer to Expected Output: B" in analysis or
471
  (self.aggressive_exploration and
472
  "# Output ID closer to Expected Output: A" not in analysis)):
473
- # state.best_output = state.output
474
- # state.best_system_message = state.system_message
475
- # state.best_output_age = 0
476
  result_dict = {
477
- "best_output": state.output,
478
- "best_system_message": state.system_message,
479
  "best_output_age": 0
480
  }
481
  logger.debug("Best output updated to the current output:\n%s",
482
  result_dict["best_output"])
483
  else:
484
- # state.best_output_age += 1
485
- # # rollback output and system message
486
- # state.output = state.best_output
487
- # state.system_message = state.best_system_message
488
  result_dict = {
489
- "output": state.best_output,
490
- "system_message": state.best_system_message,
491
- "best_output_age": state.best_output_age + 1
492
  }
493
  logger.debug("Best output age incremented to %s",
494
  result_dict["best_output_age"])
@@ -510,7 +465,7 @@ class MetaPromptGraph:
510
  """
511
  logger = self.logger.getChild(NODE_PROMPT_ANALYZER)
512
  prompt = self.prompt_templates[NODE_PROMPT_ANALYZER].format_messages(
513
- **state.model_dump())
514
 
515
  for message in prompt:
516
  logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'invoke',
@@ -520,13 +475,6 @@ class MetaPromptGraph:
520
  logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
521
  'type': response.type, 'message': response.content})
522
 
523
- # state.analysis = response.content
524
- # state.accepted = "Accept: Yes" in response.content
525
-
526
- # logger.debug("Accepted: %s", state.accepted)
527
-
528
- # return state
529
-
530
  result_dict = {
531
  "analysis": response.content,
532
  "accepted": "Accept: Yes" in response.content
@@ -545,14 +493,14 @@ class MetaPromptGraph:
545
  Returns:
546
  str: The decision to continue, rerun, or end the workflow.
547
  """
548
- if state.max_output_age <= 0:
549
  # always continue if max age is 0
550
  return "continue"
551
 
552
- if state.best_output_age >= state.max_output_age:
553
  return END
554
 
555
- if state.best_output_age > 0:
556
  # skip prompt_analyzer and prompt_suggester, goto prompt_developer
557
  return "rerun"
558
 
@@ -568,4 +516,4 @@ class MetaPromptGraph:
568
  Returns:
569
  str: The decision to continue or end the workflow.
570
  """
571
- return "continue" if not state.accepted else END
 
1
  import logging
 
2
  import pprint
3
  from langchain_core.language_models import BaseLanguageModel
4
  from langchain_core.prompts import ChatPromptTemplate
 
18
  # return the last non-none value
19
  return next((s for s in (b, a) if s), None)
20
 
21
+ class AgentState(TypedDict):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """
23
  Represents the state of an agent in a conversation.
24
 
 
36
  - best_system_message (str, optional): The best system message.
37
  - best_output_age (int): The age of the best output.
38
  """
39
+ max_output_age: Optional[int]
40
+ user_message: Optional[str]
41
+ expected_output: Optional[str]
42
+ acceptance_criteria: Annotated[Optional[str], last_non_empty]
43
+ system_message: Annotated[Optional[str], last_non_empty]
44
+ output: Optional[str]
45
+ suggestions: Optional[str]
46
+ accepted: Optional[bool]
47
+ analysis: Optional[str]
48
+ best_output: Optional[str]
49
+ best_system_message: Optional[str]
50
+ best_output_age: Optional[int]
51
 
52
  class MetaPromptGraph:
53
  """
 
205
  }
206
  )
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
209
  lambda x: self._optional_action(
210
  "system_message",
 
212
  NODE_PROMPT_INITIAL_DEVELOPER,
213
  "system_message",
214
  x),
215
+ x))
216
  workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
217
  lambda x: self._optional_action(
218
  "acceptance_criteria",
 
220
  NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
221
  "acceptance_criteria",
222
  x),
223
+ x))
 
224
 
225
  workflow.add_edge(START, NODE_PROMPT_INITIAL_DEVELOPER)
226
  workflow.add_edge(START, NODE_ACCEPTANCE_CRITERIA_DEVELOPER)
227
 
228
  workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER, NODE_PROMPT_EXECUTOR)
229
  workflow.add_edge(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, NODE_PROMPT_EXECUTOR)
 
230
 
231
  return workflow
232
 
 
312
  def _optional_action(
313
  self, target_attribute: str,
314
  action: RunnableLike,
315
+ state: AgentState
316
+ ) -> AgentState:
317
  """
318
  Optionally invokes an action if the target attribute is not set or empty.
319
 
 
340
 
341
  def _prompt_node(
342
  self, node: str, target_attribute: str,
343
+ state: AgentState
344
+ ) -> AgentState:
345
  """
346
  Prompt a specific node with the given state and update the state with the response.
347
 
 
380
  'message': response.content
381
  })
382
 
 
 
 
 
 
 
 
 
383
  return {target_attribute: response.content}
384
 
385
  def _output_history_analyzer(self, state: AgentState) -> AgentState:
 
399
  """
400
  logger = self.logger.getChild(NODE_OUTPUT_HISTORY_ANALYZER)
401
 
402
+ if state["best_output"] is None:
403
+ state["best_output"] = state["output"]
404
+ state["best_system_message"] = state["system_message"]
405
+ state["best_output_age"] = 0
406
  logger.debug(
407
+ "Best output initialized to the current output:\n%s", state["output"])
408
  return state
409
 
410
  prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
411
+ **state)
412
 
413
  for message in prompt:
414
  logger.debug({
 
428
 
429
  analysis = response.content
430
 
431
+ if (state["best_output"] is None or
432
  "# Output ID closer to Expected Output: B" in analysis or
433
  (self.aggressive_exploration and
434
  "# Output ID closer to Expected Output: A" not in analysis)):
 
 
 
435
  result_dict = {
436
+ "best_output": state["output"],
437
+ "best_system_message": state["system_message"],
438
  "best_output_age": 0
439
  }
440
  logger.debug("Best output updated to the current output:\n%s",
441
  result_dict["best_output"])
442
  else:
 
 
 
 
443
  result_dict = {
444
+ "output": state["best_output"],
445
+ "system_message": state["best_system_message"],
446
+ "best_output_age": state["best_output_age"] + 1
447
  }
448
  logger.debug("Best output age incremented to %s",
449
  result_dict["best_output_age"])
 
465
  """
466
  logger = self.logger.getChild(NODE_PROMPT_ANALYZER)
467
  prompt = self.prompt_templates[NODE_PROMPT_ANALYZER].format_messages(
468
+ **state)
469
 
470
  for message in prompt:
471
  logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'invoke',
 
475
  logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
476
  'type': response.type, 'message': response.content})
477
 
 
 
 
 
 
 
 
478
  result_dict = {
479
  "analysis": response.content,
480
  "accepted": "Accept: Yes" in response.content
 
493
  Returns:
494
  str: The decision to continue, rerun, or end the workflow.
495
  """
496
+ if state["max_output_age"] <= 0:
497
  # always continue if max age is 0
498
  return "continue"
499
 
500
+ if state["best_output_age"] >= state["max_output_age"]:
501
  return END
502
 
503
+ if state["best_output_age"] > 0:
504
  # skip prompt_analyzer and prompt_suggester, goto prompt_developer
505
  return "rerun"
506
 
 
516
  Returns:
517
  str: The decision to continue or end the workflow.
518
  """
519
+ return "continue" if not state["accepted"] else END
tests/meta_prompt_graph_test.py CHANGED
@@ -39,7 +39,7 @@ class TestMetaPromptGraph(unittest.TestCase):
39
  )
40
 
41
  assert (
42
- updated_state.output == "Mocked response content"
43
  ), "The output attribute should be updated with the mocked response content"
44
 
45
 
@@ -78,13 +78,13 @@ class TestMetaPromptGraph(unittest.TestCase):
78
  updated_state = meta_prompt_graph._output_history_analyzer(state)
79
 
80
  assert (
81
- updated_state.best_output == state.output
82
  ), "Best output should be updated to the current output."
83
  assert (
84
- updated_state.best_system_message == state.system_message
85
  ), "Best system message should be updated to the current system message."
86
  assert (
87
- updated_state.best_output_age == 0
88
  ), "Best output age should be reset to 0."
89
 
90
 
@@ -104,10 +104,13 @@ class TestMetaPromptGraph(unittest.TestCase):
104
  }
105
  meta_prompt_graph = MetaPromptGraph(llms=llms)
106
  state = AgentState(
107
- output="Test output", expected_output="Expected output"
 
 
 
108
  )
109
  updated_state = meta_prompt_graph._prompt_analyzer(state)
110
- assert updated_state.accepted is True
111
 
112
 
113
  def test_get_node_names(self):
@@ -138,7 +141,8 @@ class TestMetaPromptGraph(unittest.TestCase):
138
  user_message="How do I reverse a list in Python?",
139
  expected_output="Use the `[::-1]` slicing technique or the "
140
  "`list.reverse()` method.",
141
- acceptance_criteria="Similar in meaning, text length and style."
 
142
  )
143
  output_state = meta_prompt_graph(input_state, recursion_limit=25)
144
 
@@ -190,6 +194,7 @@ class TestMetaPromptGraph(unittest.TestCase):
190
 
191
  meta_prompt_graph = MetaPromptGraph(llms=llms)
192
  input_state = AgentState(
 
193
  user_message="How do I reverse a list in Python?",
194
  expected_output="Use the `[::-1]` slicing technique or the "
195
  "`list.reverse()` method.",
@@ -239,7 +244,8 @@ class TestMetaPromptGraph(unittest.TestCase):
239
  input_state = AgentState(
240
  user_message="How do I reverse a list in Python?",
241
  expected_output="The output should use the `reverse()` method.",
242
- acceptance_criteria="The output should be correct and efficient."
 
243
  )
244
 
245
  output_state = meta_prompt_graph(input_state)
@@ -277,7 +283,8 @@ class TestMetaPromptGraph(unittest.TestCase):
277
  input_state = AgentState(
278
  user_message="How do I reverse a list in Python?",
279
  expected_output="The output should use the `reverse()` method.",
280
- acceptance_criteria="The output should be correct and efficient."
 
281
  )
282
 
283
  output_state = meta_prompt_graph(input_state)
 
39
  )
40
 
41
  assert (
42
+ updated_state['output'] == "Mocked response content"
43
  ), "The output attribute should be updated with the mocked response content"
44
 
45
 
 
78
  updated_state = meta_prompt_graph._output_history_analyzer(state)
79
 
80
  assert (
81
+ updated_state['best_output'] == state['output']
82
  ), "Best output should be updated to the current output."
83
  assert (
84
+ updated_state['best_system_message'] == state['system_message']
85
  ), "Best system message should be updated to the current system message."
86
  assert (
87
+ updated_state['best_output_age'] == 0
88
  ), "Best output age should be reset to 0."
89
 
90
 
 
104
  }
105
  meta_prompt_graph = MetaPromptGraph(llms=llms)
106
  state = AgentState(
107
+ output="Test output", expected_output="Expected output",
108
+ acceptance_criteria="Acceptance criteria: ...",
109
+ system_message="System message: ...",
110
+ max_output_age=2
111
  )
112
  updated_state = meta_prompt_graph._prompt_analyzer(state)
113
+ assert updated_state['accepted'] is True
114
 
115
 
116
  def test_get_node_names(self):
 
141
  user_message="How do I reverse a list in Python?",
142
  expected_output="Use the `[::-1]` slicing technique or the "
143
  "`list.reverse()` method.",
144
+ acceptance_criteria="Similar in meaning, text length and style.",
145
+ max_output_age=2
146
  )
147
  output_state = meta_prompt_graph(input_state, recursion_limit=25)
148
 
 
194
 
195
  meta_prompt_graph = MetaPromptGraph(llms=llms)
196
  input_state = AgentState(
197
+ max_output_age=2,
198
  user_message="How do I reverse a list in Python?",
199
  expected_output="Use the `[::-1]` slicing technique or the "
200
  "`list.reverse()` method.",
 
244
  input_state = AgentState(
245
  user_message="How do I reverse a list in Python?",
246
  expected_output="The output should use the `reverse()` method.",
247
+ acceptance_criteria="The output should be correct and efficient.",
248
+ max_output_age=2
249
  )
250
 
251
  output_state = meta_prompt_graph(input_state)
 
283
  input_state = AgentState(
284
  user_message="How do I reverse a list in Python?",
285
  expected_output="The output should use the `reverse()` method.",
286
+ acceptance_criteria="The output should be correct and efficient.",
287
+ max_output_age=2
288
  )
289
 
290
  output_state = meta_prompt_graph(input_state)