yaleh commited on
Commit
62fb408
·
1 Parent(s): e8de3ce

Added functions for generating acceptance criteria.

Browse files
app/gradio_meta_prompt.py CHANGED
@@ -498,6 +498,25 @@ def process_message_with_expert_llms(user_message: str, expected_output: str,
498
  recursion_limit, max_output_age, llms, prompt_template_group=prompt_template_group)
499
 
500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  class FileConfig(BaseConfig):
502
  config_file: str = 'config.yml' # default path
503
 
@@ -527,25 +546,58 @@ with gr.Blocks(title='Meta Prompt') as demo:
527
  with gr.Row():
528
  with gr.Column():
529
  user_message_input = gr.Textbox(
530
- label="User Message", show_copy_button=True)
 
 
531
  expected_output_input = gr.Textbox(
532
- label="Expected Output", show_copy_button=True)
533
- acceptance_criteria_input = gr.Textbox(
534
- label="Acceptance Criteria (Compared with Expected Output [EO])", show_copy_button=True)
535
- initial_system_message_input = gr.Textbox(
536
- label="Initial System Message", show_copy_button=True, value="")
537
- evaluate_initial_system_message_button = gr.Button(
538
- value="Evaluate", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  recursion_limit_input = gr.Number(
540
- label="Recursion Limit", value=config.recursion_limit,
541
- precision=0, minimum=1, maximum=config.recursion_limit_max, step=1)
 
 
 
 
 
542
  max_output_age = gr.Number(
543
- label="Max Output Age", value=config.max_output_age,
544
- precision=0, minimum=1, maximum=config.max_output_age_max, step=1)
 
 
 
 
 
545
  prompt_template_group = gr.Dropdown(
546
  label="Prompt Template Group",
547
  choices=list(config.prompt_templates.keys()),
548
- value=list(config.prompt_templates.keys())[0],
549
  )
550
  with gr.Row():
551
  with gr.Tabs():
@@ -658,13 +710,14 @@ with gr.Blocks(title='Meta Prompt') as demo:
658
  acceptance_criteria_input, initial_system_message_input],
659
  value='Clear All')
660
  with gr.Column():
661
- system_message_output = gr.Textbox(
662
- label="System Message", show_copy_button=True)
663
- with gr.Row():
664
- evaluate_system_message_button = gr.Button(
665
- value="Evaluate", variant="secondary")
666
- copy_to_initial_system_message_button = gr.Button(
667
- value="Copy to Initial System Message", variant="secondary")
 
668
  output_output = gr.Textbox(label="Output", show_copy_button=True)
669
  analysis_output = gr.Textbox(
670
  label="Analysis", show_copy_button=True)
@@ -693,6 +746,12 @@ with gr.Blocks(title='Meta Prompt') as demo:
693
  advanced_llm_tab.select(on_model_tab_select)
694
  expert_llm_tab.select(on_model_tab_select)
695
 
 
 
 
 
 
 
696
  evaluate_initial_system_message_button.click(
697
  evaluate_system_message,
698
  inputs=[initial_system_message_input, user_message_input,
 
498
  recursion_limit, max_output_age, llms, prompt_template_group=prompt_template_group)
499
 
500
 
501
+ def generate_acceptance_criteria(user_message, expected_output, model_name):
502
+ """
503
+ Generate acceptance criteria based on the user message and expected output.
504
+ """
505
+ prompt = f"""Given the following user message and expected output, generate appropriate acceptance criteria:
506
+
507
+ User Message: {user_message}
508
+ Expected Output: {expected_output}
509
+
510
+ Generate concise and specific acceptance criteria that can be used to evaluate the quality and relevance of the expected output in relation to the user message. The criteria should focus on key aspects such as relevance, accuracy, completeness, and clarity.
511
+
512
+ Acceptance Criteria:
513
+ """
514
+
515
+ llm = initialize_llm(model_name)
516
+ response = llm.invoke(prompt)
517
+ return response.content if hasattr(response, 'content') else ""
518
+
519
+
520
  class FileConfig(BaseConfig):
521
  config_file: str = 'config.yml' # default path
522
 
 
546
  with gr.Row():
547
  with gr.Column():
548
  user_message_input = gr.Textbox(
549
+ label="User Message",
550
+ show_copy_button=True
551
+ )
552
  expected_output_input = gr.Textbox(
553
+ label="Expected Output",
554
+ show_copy_button=True
555
+ )
556
+ with gr.Group():
557
+ with gr.Row():
558
+ acceptance_criteria_input = gr.Textbox(
559
+ label="Acceptance Criteria (Compared with Expected Output [EO])",
560
+ show_copy_button=True,
561
+ scale=4 # This makes it take up 3/4 of the row width
562
+ )
563
+ generate_acceptance_criteria_button = gr.Button(
564
+ value="Generate",
565
+ variant="secondary",
566
+ scale=1 # This makes it take up 1/4 of the row width
567
+ )
568
+ with gr.Group():
569
+ with gr.Row():
570
+ initial_system_message_input = gr.Textbox(
571
+ label="Initial System Message",
572
+ show_copy_button=True,
573
+ value="",
574
+ scale=4
575
+ )
576
+ evaluate_initial_system_message_button = gr.Button(
577
+ value="Evaluate",
578
+ variant="secondary",
579
+ scale=1
580
+ )
581
  recursion_limit_input = gr.Number(
582
+ label="Recursion Limit",
583
+ value=config.recursion_limit,
584
+ precision=0,
585
+ minimum=1,
586
+ maximum=config.recursion_limit_max,
587
+ step=1
588
+ )
589
  max_output_age = gr.Number(
590
+ label="Max Output Age",
591
+ value=config.max_output_age,
592
+ precision=0,
593
+ minimum=1,
594
+ maximum=config.max_output_age_max,
595
+ step=1
596
+ )
597
  prompt_template_group = gr.Dropdown(
598
  label="Prompt Template Group",
599
  choices=list(config.prompt_templates.keys()),
600
+ value=list(config.prompt_templates.keys())[0]
601
  )
602
  with gr.Row():
603
  with gr.Tabs():
 
710
  acceptance_criteria_input, initial_system_message_input],
711
  value='Clear All')
712
  with gr.Column():
713
+ with gr.Group():
714
+ system_message_output = gr.Textbox(
715
+ label="System Message", show_copy_button=True)
716
+ with gr.Row():
717
+ evaluate_system_message_button = gr.Button(
718
+ value="Evaluate", variant="secondary")
719
+ copy_to_initial_system_message_button = gr.Button(
720
+ value="Copy to Initial System Message", variant="secondary")
721
  output_output = gr.Textbox(label="Output", show_copy_button=True)
722
  analysis_output = gr.Textbox(
723
  label="Analysis", show_copy_button=True)
 
746
  advanced_llm_tab.select(on_model_tab_select)
747
  expert_llm_tab.select(on_model_tab_select)
748
 
749
+ generate_acceptance_criteria_button.click(
750
+ generate_acceptance_criteria,
751
+ inputs=[user_message_input, expected_output_input, simple_model_name_input],
752
+ outputs=[acceptance_criteria_input]
753
+ )
754
+
755
  evaluate_initial_system_message_button.click(
756
  evaluate_system_message,
757
  inputs=[initial_system_message_input, user_message_input,
meta_prompt/consts.py CHANGED
@@ -1,5 +1,7 @@
1
  from langchain_core.prompts import ChatPromptTemplate
2
 
 
 
3
  NODE_PROMPT_INITIAL_DEVELOPER = "prompt_initial_developer"
4
  NODE_PROMPT_DEVELOPER = "prompt_developer"
5
  NODE_PROMPT_EXECUTOR = "prompt_executor"
@@ -8,6 +10,8 @@ NODE_PROMPT_ANALYZER = "prompt_analyzer"
8
  NODE_PROMPT_SUGGESTER = "prompt_suggester"
9
 
10
  META_PROMPT_NODES = [
 
 
11
  NODE_PROMPT_INITIAL_DEVELOPER,
12
  NODE_PROMPT_DEVELOPER,
13
  NODE_PROMPT_EXECUTOR,
@@ -17,6 +21,76 @@ META_PROMPT_NODES = [
17
  ]
18
 
19
  DEFAULT_PROMPT_TEMPLATES = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  NODE_PROMPT_INITIAL_DEVELOPER: ChatPromptTemplate.from_messages([
21
  ("system", """# Expert Prompt Engineer
22
 
@@ -28,7 +102,7 @@ The user will provide you a specific example to create the GPT. You will respond
28
 
29
  ## Output
30
 
31
- Create a [name], Heres the descriptions [description]. Start with GPT Description:”
32
  """),
33
  ("human", """# User Message
34
 
@@ -56,7 +130,7 @@ The user will provide you a specific example (`User Message` and `Expected Outpu
56
 
57
  ## Output
58
 
59
- Create a [name], Heres the descriptions [description]. Start with GPT Description:”
60
  """),
61
  ("human", """# Current System Message
62
 
@@ -216,4 +290,4 @@ Provide your analysis in the following format:
216
  <|End_Analysis|>
217
  """)
218
  ])
219
- }
 
1
  from langchain_core.prompts import ChatPromptTemplate
2
 
3
+ NODE_TASK_BRIEF_DEVELOPER = "task_brief_developer"
4
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER = "acceptance_criteria_developer"
5
  NODE_PROMPT_INITIAL_DEVELOPER = "prompt_initial_developer"
6
  NODE_PROMPT_DEVELOPER = "prompt_developer"
7
  NODE_PROMPT_EXECUTOR = "prompt_executor"
 
10
  NODE_PROMPT_SUGGESTER = "prompt_suggester"
11
 
12
  META_PROMPT_NODES = [
13
+ NODE_TASK_BRIEF_DEVELOPER,
14
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
15
  NODE_PROMPT_INITIAL_DEVELOPER,
16
  NODE_PROMPT_DEVELOPER,
17
  NODE_PROMPT_EXECUTOR,
 
21
  ]
22
 
23
  DEFAULT_PROMPT_TEMPLATES = {
24
+ NODE_TASK_BRIEF_DEVELOPER: ChatPromptTemplate.from_messages([
25
+ ("system", """# Task Brief Developer
26
+
27
+ You are a task brief developer. You will receive a specific example to create a task brief. You will respond directly with the brief for the task type.
28
+
29
+ ## Instructions
30
+
31
+ The user will provide you a specific example with User Message (input) and Expected Output (output) of a task type. You will respond with a brief for the task type in the following format:
32
+
33
+ ```
34
+ # Task Description
35
+
36
+ [Task description]
37
+ ```
38
+
39
+ """),
40
+ ("human", """# User Message
41
+
42
+ {user_message}
43
+
44
+ # Expected Output
45
+
46
+ {expected_output}
47
+
48
+ # Task Brief
49
+
50
+ """)
51
+ ]),
52
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER: ChatPromptTemplate.from_messages([
53
+ ("system", """# Acceptance Criteria Developer
54
+
55
+ You are an acceptance criteria developer. You will receive a specific example of a task type to create acceptance criteria. You will respond directly with the acceptance criteria.
56
+
57
+ ## Instructions
58
+
59
+ The user will provide you a specific example with User Message (input) and Expected Output (output) of a task type. You will respond with acceptance criteria for the task type includes the following:
60
+
61
+ * What the output should include
62
+ * What the output should not include
63
+ * Any specific formatting or structure requirements
64
+
65
+ ## Output
66
+
67
+ Create acceptance criteria in the following format:
68
+
69
+ ```
70
+ # Acceptance Criteria
71
+
72
+ * [Criteria 1]
73
+ * [Criteria 2]
74
+ * [Criteria 3]
75
+ ```
76
+
77
+ """),
78
+ ("human", """# Task Brief
79
+
80
+ {system_message}
81
+
82
+ # User Message
83
+
84
+ {user_message}
85
+
86
+ # Expected Output
87
+
88
+ {expected_output}
89
+
90
+ # Acceptance Criteria
91
+
92
+ """)
93
+ ]),
94
  NODE_PROMPT_INITIAL_DEVELOPER: ChatPromptTemplate.from_messages([
95
  ("system", """# Expert Prompt Engineer
96
 
 
102
 
103
  ## Output
104
 
105
+ Create a [name], Here's the descriptions [description]. Start with "GPT Description:"
106
  """),
107
  ("human", """# User Message
108
 
 
130
 
131
  ## Output
132
 
133
+ Create a [name], Here's the descriptions [description]. Start with "GPT Description:"
134
  """),
135
  ("human", """# Current System Message
136
 
 
290
  <|End_Analysis|>
291
  """)
292
  ])
293
+ }
meta_prompt/meta_prompt.py CHANGED
@@ -8,6 +8,7 @@ from langchain_core.prompts import ChatPromptTemplate
8
  from langgraph.graph import StateGraph, END
9
  from langgraph.checkpoint.memory import MemorySaver
10
  from langgraph.errors import GraphRecursionError
 
11
  from pydantic import BaseModel
12
  from .consts import *
13
 
@@ -46,26 +47,30 @@ class MetaPromptGraph:
46
  """
47
  This class represents a graph for meta-prompting in a conversational AI system.
48
 
49
- It manages the state of the conversation, including the user's message, expected output,
50
- acceptance criteria, system message, output, suggestions, and analysis. The graph
51
- consists of nodes that represent different stages of the conversation, such as
52
- prompting the developer, executing the output, analyzing the output history, and
53
- suggesting new prompts. The class provides methods to create the workflow,
54
- initialize the graph, and invoke the graph with a given state.
55
 
56
- The MetaPromptGraph class is responsible for orchestrating the conversation flow
57
- and deciding the next step based on the current state of the conversation. It uses
58
- language models and prompt templates to generate responses and analyze the output.
 
 
 
 
 
 
 
 
59
  """
60
  @classmethod
61
  def get_node_names(cls):
62
  """
63
  Returns a list of node names in the meta-prompt graph.
64
 
65
- This method is used to initialize the language models and prompt templates for each node in the graph.
66
 
67
  Returns:
68
- list: A list of node names.
69
  """
70
  return META_PROMPT_NODES
71
 
@@ -79,12 +84,18 @@ class MetaPromptGraph:
79
  Initializes the MetaPromptGraph instance.
80
 
81
  Args:
82
- - llms (Union[BaseLanguageModel, Dict[str, BaseLanguageModel]], optional): The language models for the graph nodes. Defaults to {}.
83
- - prompts (Dict[str, ChatPromptTemplate], optional): The custom prompt templates for the graph nodes. Defaults to {}.
84
- - logger (Optional[logging.Logger], optional): The logger for the graph. Defaults to None.
85
- - verbose (bool, optional): Whether to set the logger level to DEBUG. Defaults to False.
86
-
87
- Initializes the logger, sets the language models and prompt templates for the graph nodes, and updates the prompt templates with custom ones if provided.
 
 
 
 
 
 
88
  """
89
  self.logger = logger or logging.getLogger(__name__)
90
  if self.logger is not None:
@@ -94,6 +105,7 @@ class MetaPromptGraph:
94
  self.logger.setLevel(logging.INFO)
95
 
96
  if isinstance(llms, BaseLanguageModel):
 
97
  self.llms: Dict[str, BaseLanguageModel] = {
98
  node: llms for node in self.get_node_names()}
99
  else:
@@ -102,7 +114,29 @@ class MetaPromptGraph:
102
  ChatPromptTemplate] = DEFAULT_PROMPT_TEMPLATES.copy()
103
  self.prompt_templates.update(prompts)
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def _create_workflow(self, including_initial_developer: bool = True) -> StateGraph:
 
 
 
 
 
 
 
 
 
106
  workflow = StateGraph(AgentState)
107
 
108
  workflow.add_node(NODE_PROMPT_DEVELOPER,
@@ -125,10 +159,12 @@ class MetaPromptGraph:
125
  "suggestions",
126
  x))
127
 
 
128
  workflow.add_edge(NODE_PROMPT_DEVELOPER, NODE_PROMPT_EXECUTOR)
129
  workflow.add_edge(NODE_PROMPT_EXECUTOR, NODE_OUTPUT_HISTORY_ANALYZER)
130
  workflow.add_edge(NODE_PROMPT_SUGGESTER, NODE_PROMPT_DEVELOPER)
131
 
 
132
  workflow.add_conditional_edges(
133
  NODE_OUTPUT_HISTORY_ANALYZER,
134
  lambda x: self._should_exit_on_max_age(x),
@@ -148,6 +184,7 @@ class MetaPromptGraph:
148
  }
149
  )
150
 
 
151
  if including_initial_developer:
152
  workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
153
  lambda x: self._prompt_node(
@@ -161,8 +198,40 @@ class MetaPromptGraph:
161
  workflow.set_entry_point(NODE_PROMPT_EXECUTOR)
162
 
163
  return workflow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  workflow = self._create_workflow(including_initial_developer=(
167
  state.system_message is None or state.system_message == ""))
168
 
@@ -195,20 +264,48 @@ class MetaPromptGraph:
195
 
196
  return state
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
199
  """
200
  Prompt a specific node with the given state and update the state with the response.
201
 
202
- This method formats messages using the prompt template associated with the node, logs the invocation and response,
203
- and updates the state with the response content.
204
 
205
  Parameters:
206
- node (str): The identifier of the node to be prompted.
207
- target_attribute (str): The attribute of the state to be updated with the response content.
208
- state (AgentState): The current state of the agent, containing necessary context for message formatting.
209
 
210
  Returns:
211
- AgentState: The updated state of the agent with the response content set to the target attribute.
212
  """
213
 
214
  logger = self.logger.getChild(node)
@@ -234,47 +331,77 @@ class MetaPromptGraph:
234
  return state
235
 
236
  def _output_history_analyzer(self, state: AgentState) -> AgentState:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  logger = self.logger.getChild(NODE_OUTPUT_HISTORY_ANALYZER)
238
 
239
  if state.best_output is None:
240
  state.best_output = state.output
241
  state.best_system_message = state.system_message
242
  state.best_output_age = 0
243
-
244
  logger.debug(
245
  "Best output initialized to the current output:\n%s", state.output)
246
-
247
  return state
248
 
249
  prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
250
  **state.model_dump())
251
 
252
  for message in prompt:
253
- logger.debug({'node': NODE_OUTPUT_HISTORY_ANALYZER, 'action': 'invoke',
254
- 'type': message.type, 'message': message.content})
 
 
 
 
255
 
256
  response = self.llms[NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
257
- logger.debug({'node': NODE_OUTPUT_HISTORY_ANALYZER, 'action': 'response',
258
- 'type': response.type, 'message': response.content})
 
 
 
 
259
 
260
  analysis = response.content
261
 
262
- if state.best_output is None or "# Output ID closer to Expected Output: B" in analysis:
 
263
  state.best_output = state.output
264
  state.best_system_message = state.system_message
265
  state.best_output_age = 0
266
-
267
  logger.debug(
268
  "Best output updated to the current output:\n%s", state.output)
269
  else:
270
  state.best_output_age += 1
271
-
272
- logger.debug("Best output age incremented to %s",
273
- state.best_output_age)
274
 
275
  return state
276
 
277
  def _prompt_analyzer(self, state: AgentState) -> AgentState:
 
 
 
 
 
 
 
 
 
 
 
 
278
  logger = self.logger.getChild(NODE_PROMPT_ANALYZER)
279
  prompt = self.prompt_templates[NODE_PROMPT_ANALYZER].format_messages(
280
  **state.model_dump())
@@ -295,6 +422,15 @@ class MetaPromptGraph:
295
  return state
296
 
297
  def _should_exit_on_max_age(self, state: AgentState) -> str:
 
 
 
 
 
 
 
 
 
298
  if state.max_output_age <= 0:
299
  # always continue if max age is 0
300
  return "continue"
@@ -309,4 +445,13 @@ class MetaPromptGraph:
309
  return "continue"
310
 
311
  def _should_exit_on_acceptable_output(self, state: AgentState) -> str:
 
 
 
 
 
 
 
 
 
312
  return "continue" if not state.accepted else END
 
8
  from langgraph.graph import StateGraph, END
9
  from langgraph.checkpoint.memory import MemorySaver
10
  from langgraph.errors import GraphRecursionError
11
+ from langchain_core.runnables.base import RunnableLike
12
  from pydantic import BaseModel
13
  from .consts import *
14
 
 
47
  """
48
  This class represents a graph for meta-prompting in a conversational AI system.
49
 
50
+ It manages the state of the conversation, including the user's message, expected
51
+ output, acceptance criteria, system message, output, suggestions, and analysis.
 
 
 
 
52
 
53
+ The graph consists of nodes that represent different stages of the conversation,
54
+ such as prompting the developer, executing the output, analyzing the output
55
+ history, and suggesting new prompts.
56
+
57
+ The class provides methods to create the workflow, initialize the graph, and
58
+ invoke the graph with a given state.
59
+
60
+ The MetaPromptGraph class is responsible for orchestrating the conversation
61
+ flow and deciding the next step based on the current state of the
62
+ conversation. It uses language models and prompt templates to generate
63
+ responses and analyze the output.
64
  """
65
  @classmethod
66
  def get_node_names(cls):
67
  """
68
  Returns a list of node names in the meta-prompt graph.
69
 
70
+ This method initializes language models and prompt templates for each node.
71
 
72
  Returns:
73
+ list: List of node names.
74
  """
75
  return META_PROMPT_NODES
76
 
 
84
  Initializes the MetaPromptGraph instance.
85
 
86
  Args:
87
+ - llms (Union[BaseLanguageModel, Dict[str, BaseLanguageModel]],
88
+ optional): The language models for the graph nodes. Defaults to {}.
89
+ - prompts (Dict[str, ChatPromptTemplate], optional): The custom
90
+ prompt templates for the graph nodes. Defaults to {}.
91
+ - logger (Optional[logging.Logger], optional): The logger for
92
+ the graph. Defaults to None.
93
+ - verbose (bool, optional): Whether to set the logger level to
94
+ DEBUG. Defaults to False.
95
+
96
+ Initializes the logger, sets the language models and prompt
97
+ templates for the graph nodes, and updates the prompt templates
98
+ with custom ones if provided.
99
  """
100
  self.logger = logger or logging.getLogger(__name__)
101
  if self.logger is not None:
 
105
  self.logger.setLevel(logging.INFO)
106
 
107
  if isinstance(llms, BaseLanguageModel):
108
+ # if llms is a single language model, wrap it in a dictionary
109
  self.llms: Dict[str, BaseLanguageModel] = {
110
  node: llms for node in self.get_node_names()}
111
  else:
 
114
  ChatPromptTemplate] = DEFAULT_PROMPT_TEMPLATES.copy()
115
  self.prompt_templates.update(prompts)
116
 
117
+
118
+ def _create_acceptance_criteria_workflow(self) -> StateGraph:
119
+ workflow = StateGraph(AgentState)
120
+ workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
121
+ lambda x: self._prompt_node(
122
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
123
+ "acceptance_criteria",
124
+ x))
125
+ workflow.add_edge(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, END)
126
+ workflow.set_entry_point(NODE_ACCEPTANCE_CRITERIA_DEVELOPER)
127
+ return workflow
128
+
129
+
130
  def _create_workflow(self, including_initial_developer: bool = True) -> StateGraph:
131
+ """Create a workflow state graph.
132
+
133
+ Args:
134
+ including_initial_developer: Flag indicating whether to include the
135
+ initial developer node in the workflow.
136
+
137
+ Returns:
138
+ StateGraph: A state graph representing the workflow.
139
+ """
140
  workflow = StateGraph(AgentState)
141
 
142
  workflow.add_node(NODE_PROMPT_DEVELOPER,
 
159
  "suggestions",
160
  x))
161
 
162
+ # Connect nodes
163
  workflow.add_edge(NODE_PROMPT_DEVELOPER, NODE_PROMPT_EXECUTOR)
164
  workflow.add_edge(NODE_PROMPT_EXECUTOR, NODE_OUTPUT_HISTORY_ANALYZER)
165
  workflow.add_edge(NODE_PROMPT_SUGGESTER, NODE_PROMPT_DEVELOPER)
166
 
167
+ # Add conditional edges
168
  workflow.add_conditional_edges(
169
  NODE_OUTPUT_HISTORY_ANALYZER,
170
  lambda x: self._should_exit_on_max_age(x),
 
184
  }
185
  )
186
 
187
+ # Set entry point based on including_initial_developer flag
188
  if including_initial_developer:
189
  workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
190
  lambda x: self._prompt_node(
 
198
  workflow.set_entry_point(NODE_PROMPT_EXECUTOR)
199
 
200
  return workflow
201
+
202
+ def run_acceptance_criteria_graph(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
203
+ self.logger.debug("Creating acceptance criteria workflow")
204
+ workflow = self._create_acceptance_criteria_workflow()
205
+ self.logger.debug("Compiling workflow with memory saver")
206
+ memory = MemorySaver()
207
+ graph = workflow.compile(checkpointer=memory)
208
+ self.logger.debug("Configuring graph with recursion limit %s", recursion_limit)
209
+ config = {"configurable": {"thread_id": "1"},
210
+ "recursion_limit": recursion_limit}
211
+ self.logger.debug("Invoking graph with state: %s", pprint.pformat(state))
212
+ output_state = graph.invoke(state, config)
213
+ self.logger.debug("Output state: %s", pprint.pformat(output_state))
214
+ return output_state
215
+
216
 
217
+ def run_meta_prompt_graph(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
218
+ """
219
+ Invoke the meta-prompt workflow with the given state and recursion limit.
220
+
221
+ This method creates a workflow based on the presence of an initial system
222
+ message, compiles the workflow with a memory saver, and invokes the graph
223
+ with the given state. If a recursion limit is reached, it returns the best
224
+ state found so far.
225
+
226
+ Parameters:
227
+ state (AgentState): The current state of the agent, containing
228
+ necessary context for message formatting.
229
+ recursion_limit (int): The maximum number of recursive calls
230
+ allowed. Defaults to 25.
231
+
232
+ Returns:
233
+ AgentState: The output state of the agent after invoking the workflow.
234
+ """
235
  workflow = self._create_workflow(including_initial_developer=(
236
  state.system_message is None or state.system_message == ""))
237
 
 
264
 
265
  return state
266
 
267
+
268
+ def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
269
+ return self.run_meta_prompt_graph(state, recursion_limit)
270
+
271
+
272
+ def _optional_action(
273
+ self, target_attribute: str,
274
+ action: Optional[RunnableLike],
275
+ state: AgentState
276
+ ) -> AgentState:
277
+ """
278
+ Optionally invokes an action if the target attribute is not set or empty.
279
+
280
+ Args:
281
+ node (str): Node identifier.
282
+ target_attribute (str): State attribute to be updated.
283
+ action (Optional[RunnableLike]): Action to be invoked. Defaults to None.
284
+ state (AgentState): Current agent state.
285
+
286
+ Returns:
287
+ AgentState: Updated state.
288
+ """
289
+ if not getattr(state, target_attribute, None) or getattr(state, target_attribute) == "":
290
+ if action:
291
+ state = action(state)
292
+ return state
293
+
294
+
295
  def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
296
  """
297
  Prompt a specific node with the given state and update the state with the response.
298
 
299
+ This method formats messages using the prompt template associated with the node,
300
+ logs the invocation and response, and updates the state with the response content.
301
 
302
  Parameters:
303
+ node (str): Node identifier to be prompted.
304
+ target_attribute (str): State attribute to be updated with response content.
305
+ state (AgentState): Current agent state with necessary context for message formatting.
306
 
307
  Returns:
308
+ AgentState: Updated state with response content set to the target attribute.
309
  """
310
 
311
  logger = self.logger.getChild(node)
 
331
  return state
332
 
333
  def _output_history_analyzer(self, state: AgentState) -> AgentState:
334
+ """
335
+ Analyzes the output history and updates the best output and its age.
336
+
337
+ This method checks if the best output is initialized, formats the prompt for
338
+ the output history analyzer, invokes the language model, and updates the best
339
+ output and its age based on the response.
340
+
341
+ Parameters:
342
+ state (AgentState): Current state of the agent with necessary context
343
+ for message formatting.
344
+
345
+ Returns:
346
+ AgentState: Updated state with the best output and its age.
347
+ """
348
  logger = self.logger.getChild(NODE_OUTPUT_HISTORY_ANALYZER)
349
 
350
  if state.best_output is None:
351
  state.best_output = state.output
352
  state.best_system_message = state.system_message
353
  state.best_output_age = 0
 
354
  logger.debug(
355
  "Best output initialized to the current output:\n%s", state.output)
 
356
  return state
357
 
358
  prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
359
  **state.model_dump())
360
 
361
  for message in prompt:
362
+ logger.debug({
363
+ 'node': NODE_OUTPUT_HISTORY_ANALYZER,
364
+ 'action': 'invoke',
365
+ 'type': message.type,
366
+ 'message': message.content
367
+ })
368
 
369
  response = self.llms[NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
370
+ logger.debug({
371
+ 'node': NODE_OUTPUT_HISTORY_ANALYZER,
372
+ 'action': 'response',
373
+ 'type': response.type,
374
+ 'message': response.content
375
+ })
376
 
377
  analysis = response.content
378
 
379
+ if state.best_output is None or (
380
+ "# Output ID closer to Expected Output: B" in analysis):
381
  state.best_output = state.output
382
  state.best_system_message = state.system_message
383
  state.best_output_age = 0
 
384
  logger.debug(
385
  "Best output updated to the current output:\n%s", state.output)
386
  else:
387
  state.best_output_age += 1
388
+ logger.debug("Best output age incremented to %s", state.best_output_age)
 
 
389
 
390
  return state
391
 
392
  def _prompt_analyzer(self, state: AgentState) -> AgentState:
393
+ """
394
+ Analyzes the prompt and updates the state with the analysis and
395
+ acceptance status.
396
+
397
+ Args:
398
+ state (AgentState): The current state of the agent, containing
399
+ necessary context for message formatting.
400
+
401
+ Returns:
402
+ AgentState: The updated state of the agent with the analysis
403
+ and acceptance status.
404
+ """
405
  logger = self.logger.getChild(NODE_PROMPT_ANALYZER)
406
  prompt = self.prompt_templates[NODE_PROMPT_ANALYZER].format_messages(
407
  **state.model_dump())
 
422
  return state
423
 
424
  def _should_exit_on_max_age(self, state: AgentState) -> str:
425
+ """
426
+ Determines whether to exit the workflow based on the maximum output age.
427
+
428
+ Args:
429
+ state (AgentState): The current state of the agent.
430
+
431
+ Returns:
432
+ str: The decision to continue, rerun, or end the workflow.
433
+ """
434
  if state.max_output_age <= 0:
435
  # always continue if max age is 0
436
  return "continue"
 
445
  return "continue"
446
 
447
  def _should_exit_on_acceptable_output(self, state: AgentState) -> str:
448
+ """
449
+ Determines whether to exit the workflow based on the acceptance status of the output.
450
+
451
+ Args:
452
+ state (AgentState): The current state of the agent.
453
+
454
+ Returns:
455
+ str: The decision to continue or end the workflow.
456
+ """
457
  return "continue" if not state.accepted else END
tests/meta_prompt_graph_test.py CHANGED
@@ -8,6 +8,8 @@ from langchain_openai import ChatOpenAI
8
 
9
  # Assuming the necessary imports are made for the classes and functions used in meta_prompt_graph.py
10
  from meta_prompt import *
 
 
11
 
12
  class TestMetaPromptGraph(unittest.TestCase):
13
  def setUp(self):
@@ -274,6 +276,69 @@ class TestMetaPromptGraph(unittest.TestCase):
274
 
275
  pprint.pp(output_state["best_output"])
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  if __name__ == '__main__':
279
  unittest.main()
 
8
 
9
  # Assuming the necessary imports are made for the classes and functions used in meta_prompt_graph.py
10
  from meta_prompt import *
11
+ from meta_prompt.consts import NODE_ACCEPTANCE_CRITERIA_DEVELOPER
12
+ from langgraph.graph import StateGraph, END
13
 
14
  class TestMetaPromptGraph(unittest.TestCase):
15
  def setUp(self):
 
276
 
277
  pprint.pp(output_state["best_output"])
278
 
279
+ def test_create_acceptance_criteria_workflow(self):
280
+ """
281
+ Test the _create_acceptance_criteria_workflow method of MetaPromptGraph.
282
+
283
+ This test case verifies that the workflow created by the _create_acceptance_criteria_workflow method
284
+ contains the correct node and edge.
285
+ """
286
+
287
+ llms = {
288
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER: ChatOpenAI(model_name="deepseek/deepseek-chat")
289
+ }
290
+ meta_prompt_graph = MetaPromptGraph(llms=llms)
291
+ workflow = meta_prompt_graph._create_acceptance_criteria_workflow()
292
+
293
+ # Check if the workflow contains the correct node
294
+ self.assertIn(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, workflow.nodes)
295
+
296
+ # Check if the workflow contains the correct edge
297
+ self.assertIn((NODE_ACCEPTANCE_CRITERIA_DEVELOPER, END), workflow.edges)
298
+
299
+ # compile the workflow
300
+ graph = workflow.compile()
301
+ print(graph)
302
+
303
+ # invoke the workflow
304
+ state = AgentState(
305
+ user_message="How do I reverse a list in Python?",
306
+ expected_output="The output should use the `reverse()` method.",
307
+ # system_message="Create acceptance criteria for the task of reversing a list in Python."
308
+ )
309
+ output_state = graph.invoke(state)
310
+
311
+ # check if the output state contains the acceptance criteria
312
+ self.assertIsNotNone(output_state['acceptance_criteria'])
313
+
314
+ # check if the acceptance criteria includes string '`reverse()`'
315
+ self.assertIn('`reverse()`', output_state['acceptance_criteria'])
316
+
317
+ pprint.pp(output_state["acceptance_criteria"])
318
+
319
+ def test_run_acceptance_criteria_graph(self):
320
+ """
321
+ Test the run_acceptance_criteria_graph method of MetaPromptGraph.
322
+
323
+ This test case verifies that the run_acceptance_criteria_graph method returns a state with acceptance criteria.
324
+ """
325
+ llms = {
326
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER: MagicMock(
327
+ invoke=lambda prompt: MagicMock(content="Acceptance criteria: ..."))
328
+ }
329
+ meta_prompt_graph = MetaPromptGraph(llms=llms)
330
+ state = AgentState(
331
+ user_message="How do I reverse a list in Python?",
332
+ expected_output="The output should use the `reverse()` method.",
333
+ )
334
+ output_state = meta_prompt_graph.run_acceptance_criteria_graph(state)
335
+
336
+ # Check if the output state contains the acceptance criteria
337
+ self.assertIsNotNone(output_state['acceptance_criteria'])
338
+
339
+ # Check if the acceptance criteria includes the expected content
340
+ self.assertIn("Acceptance criteria: ...", output_state['acceptance_criteria'])
341
+
342
 
343
  if __name__ == '__main__':
344
  unittest.main()