Spaces:
Running
Running
Added functions for generating acceptance criteria.
Browse files- app/gradio_meta_prompt.py +79 -20
- meta_prompt/consts.py +77 -3
- meta_prompt/meta_prompt.py +180 -35
- tests/meta_prompt_graph_test.py +65 -0
app/gradio_meta_prompt.py
CHANGED
@@ -498,6 +498,25 @@ def process_message_with_expert_llms(user_message: str, expected_output: str,
|
|
498 |
recursion_limit, max_output_age, llms, prompt_template_group=prompt_template_group)
|
499 |
|
500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
class FileConfig(BaseConfig):
|
502 |
config_file: str = 'config.yml' # default path
|
503 |
|
@@ -527,25 +546,58 @@ with gr.Blocks(title='Meta Prompt') as demo:
|
|
527 |
with gr.Row():
|
528 |
with gr.Column():
|
529 |
user_message_input = gr.Textbox(
|
530 |
-
label="User Message",
|
|
|
|
|
531 |
expected_output_input = gr.Textbox(
|
532 |
-
label="Expected Output",
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
recursion_limit_input = gr.Number(
|
540 |
-
label="Recursion Limit",
|
541 |
-
|
|
|
|
|
|
|
|
|
|
|
542 |
max_output_age = gr.Number(
|
543 |
-
label="Max Output Age",
|
544 |
-
|
|
|
|
|
|
|
|
|
|
|
545 |
prompt_template_group = gr.Dropdown(
|
546 |
label="Prompt Template Group",
|
547 |
choices=list(config.prompt_templates.keys()),
|
548 |
-
value=list(config.prompt_templates.keys())[0]
|
549 |
)
|
550 |
with gr.Row():
|
551 |
with gr.Tabs():
|
@@ -658,13 +710,14 @@ with gr.Blocks(title='Meta Prompt') as demo:
|
|
658 |
acceptance_criteria_input, initial_system_message_input],
|
659 |
value='Clear All')
|
660 |
with gr.Column():
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
|
|
668 |
output_output = gr.Textbox(label="Output", show_copy_button=True)
|
669 |
analysis_output = gr.Textbox(
|
670 |
label="Analysis", show_copy_button=True)
|
@@ -693,6 +746,12 @@ with gr.Blocks(title='Meta Prompt') as demo:
|
|
693 |
advanced_llm_tab.select(on_model_tab_select)
|
694 |
expert_llm_tab.select(on_model_tab_select)
|
695 |
|
|
|
|
|
|
|
|
|
|
|
|
|
696 |
evaluate_initial_system_message_button.click(
|
697 |
evaluate_system_message,
|
698 |
inputs=[initial_system_message_input, user_message_input,
|
|
|
498 |
recursion_limit, max_output_age, llms, prompt_template_group=prompt_template_group)
|
499 |
|
500 |
|
501 |
+
def generate_acceptance_criteria(user_message, expected_output, model_name):
|
502 |
+
"""
|
503 |
+
Generate acceptance criteria based on the user message and expected output.
|
504 |
+
"""
|
505 |
+
prompt = f"""Given the following user message and expected output, generate appropriate acceptance criteria:
|
506 |
+
|
507 |
+
User Message: {user_message}
|
508 |
+
Expected Output: {expected_output}
|
509 |
+
|
510 |
+
Generate concise and specific acceptance criteria that can be used to evaluate the quality and relevance of the expected output in relation to the user message. The criteria should focus on key aspects such as relevance, accuracy, completeness, and clarity.
|
511 |
+
|
512 |
+
Acceptance Criteria:
|
513 |
+
"""
|
514 |
+
|
515 |
+
llm = initialize_llm(model_name)
|
516 |
+
response = llm.invoke(prompt)
|
517 |
+
return response.content if hasattr(response, 'content') else ""
|
518 |
+
|
519 |
+
|
520 |
class FileConfig(BaseConfig):
|
521 |
config_file: str = 'config.yml' # default path
|
522 |
|
|
|
546 |
with gr.Row():
|
547 |
with gr.Column():
|
548 |
user_message_input = gr.Textbox(
|
549 |
+
label="User Message",
|
550 |
+
show_copy_button=True
|
551 |
+
)
|
552 |
expected_output_input = gr.Textbox(
|
553 |
+
label="Expected Output",
|
554 |
+
show_copy_button=True
|
555 |
+
)
|
556 |
+
with gr.Group():
|
557 |
+
with gr.Row():
|
558 |
+
acceptance_criteria_input = gr.Textbox(
|
559 |
+
label="Acceptance Criteria (Compared with Expected Output [EO])",
|
560 |
+
show_copy_button=True,
|
561 |
+
scale=4 # This makes it take up 3/4 of the row width
|
562 |
+
)
|
563 |
+
generate_acceptance_criteria_button = gr.Button(
|
564 |
+
value="Generate",
|
565 |
+
variant="secondary",
|
566 |
+
scale=1 # This makes it take up 1/4 of the row width
|
567 |
+
)
|
568 |
+
with gr.Group():
|
569 |
+
with gr.Row():
|
570 |
+
initial_system_message_input = gr.Textbox(
|
571 |
+
label="Initial System Message",
|
572 |
+
show_copy_button=True,
|
573 |
+
value="",
|
574 |
+
scale=4
|
575 |
+
)
|
576 |
+
evaluate_initial_system_message_button = gr.Button(
|
577 |
+
value="Evaluate",
|
578 |
+
variant="secondary",
|
579 |
+
scale=1
|
580 |
+
)
|
581 |
recursion_limit_input = gr.Number(
|
582 |
+
label="Recursion Limit",
|
583 |
+
value=config.recursion_limit,
|
584 |
+
precision=0,
|
585 |
+
minimum=1,
|
586 |
+
maximum=config.recursion_limit_max,
|
587 |
+
step=1
|
588 |
+
)
|
589 |
max_output_age = gr.Number(
|
590 |
+
label="Max Output Age",
|
591 |
+
value=config.max_output_age,
|
592 |
+
precision=0,
|
593 |
+
minimum=1,
|
594 |
+
maximum=config.max_output_age_max,
|
595 |
+
step=1
|
596 |
+
)
|
597 |
prompt_template_group = gr.Dropdown(
|
598 |
label="Prompt Template Group",
|
599 |
choices=list(config.prompt_templates.keys()),
|
600 |
+
value=list(config.prompt_templates.keys())[0]
|
601 |
)
|
602 |
with gr.Row():
|
603 |
with gr.Tabs():
|
|
|
710 |
acceptance_criteria_input, initial_system_message_input],
|
711 |
value='Clear All')
|
712 |
with gr.Column():
|
713 |
+
with gr.Group():
|
714 |
+
system_message_output = gr.Textbox(
|
715 |
+
label="System Message", show_copy_button=True)
|
716 |
+
with gr.Row():
|
717 |
+
evaluate_system_message_button = gr.Button(
|
718 |
+
value="Evaluate", variant="secondary")
|
719 |
+
copy_to_initial_system_message_button = gr.Button(
|
720 |
+
value="Copy to Initial System Message", variant="secondary")
|
721 |
output_output = gr.Textbox(label="Output", show_copy_button=True)
|
722 |
analysis_output = gr.Textbox(
|
723 |
label="Analysis", show_copy_button=True)
|
|
|
746 |
advanced_llm_tab.select(on_model_tab_select)
|
747 |
expert_llm_tab.select(on_model_tab_select)
|
748 |
|
749 |
+
generate_acceptance_criteria_button.click(
|
750 |
+
generate_acceptance_criteria,
|
751 |
+
inputs=[user_message_input, expected_output_input, simple_model_name_input],
|
752 |
+
outputs=[acceptance_criteria_input]
|
753 |
+
)
|
754 |
+
|
755 |
evaluate_initial_system_message_button.click(
|
756 |
evaluate_system_message,
|
757 |
inputs=[initial_system_message_input, user_message_input,
|
meta_prompt/consts.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
|
|
|
|
|
3 |
NODE_PROMPT_INITIAL_DEVELOPER = "prompt_initial_developer"
|
4 |
NODE_PROMPT_DEVELOPER = "prompt_developer"
|
5 |
NODE_PROMPT_EXECUTOR = "prompt_executor"
|
@@ -8,6 +10,8 @@ NODE_PROMPT_ANALYZER = "prompt_analyzer"
|
|
8 |
NODE_PROMPT_SUGGESTER = "prompt_suggester"
|
9 |
|
10 |
META_PROMPT_NODES = [
|
|
|
|
|
11 |
NODE_PROMPT_INITIAL_DEVELOPER,
|
12 |
NODE_PROMPT_DEVELOPER,
|
13 |
NODE_PROMPT_EXECUTOR,
|
@@ -17,6 +21,76 @@ META_PROMPT_NODES = [
|
|
17 |
]
|
18 |
|
19 |
DEFAULT_PROMPT_TEMPLATES = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
NODE_PROMPT_INITIAL_DEVELOPER: ChatPromptTemplate.from_messages([
|
21 |
("system", """# Expert Prompt Engineer
|
22 |
|
@@ -28,7 +102,7 @@ The user will provide you a specific example to create the GPT. You will respond
|
|
28 |
|
29 |
## Output
|
30 |
|
31 |
-
Create a [name], Here
|
32 |
"""),
|
33 |
("human", """# User Message
|
34 |
|
@@ -56,7 +130,7 @@ The user will provide you a specific example (`User Message` and `Expected Outpu
|
|
56 |
|
57 |
## Output
|
58 |
|
59 |
-
Create a [name], Here
|
60 |
"""),
|
61 |
("human", """# Current System Message
|
62 |
|
@@ -216,4 +290,4 @@ Provide your analysis in the following format:
|
|
216 |
<|End_Analysis|>
|
217 |
""")
|
218 |
])
|
219 |
-
}
|
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
|
3 |
+
NODE_TASK_BRIEF_DEVELOPER = "task_brief_developer"
|
4 |
+
NODE_ACCEPTANCE_CRITERIA_DEVELOPER = "acceptance_criteria_developer"
|
5 |
NODE_PROMPT_INITIAL_DEVELOPER = "prompt_initial_developer"
|
6 |
NODE_PROMPT_DEVELOPER = "prompt_developer"
|
7 |
NODE_PROMPT_EXECUTOR = "prompt_executor"
|
|
|
10 |
NODE_PROMPT_SUGGESTER = "prompt_suggester"
|
11 |
|
12 |
META_PROMPT_NODES = [
|
13 |
+
NODE_TASK_BRIEF_DEVELOPER,
|
14 |
+
NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
15 |
NODE_PROMPT_INITIAL_DEVELOPER,
|
16 |
NODE_PROMPT_DEVELOPER,
|
17 |
NODE_PROMPT_EXECUTOR,
|
|
|
21 |
]
|
22 |
|
23 |
DEFAULT_PROMPT_TEMPLATES = {
|
24 |
+
NODE_TASK_BRIEF_DEVELOPER: ChatPromptTemplate.from_messages([
|
25 |
+
("system", """# Task Brief Developer
|
26 |
+
|
27 |
+
You are a task brief developer. You will receive a specific example to create a task brief. You will respond directly with the brief for the task type.
|
28 |
+
|
29 |
+
## Instructions
|
30 |
+
|
31 |
+
The user will provide you a specific example with User Message (input) and Expected Output (output) of a task type. You will respond with a brief for the task type in the following format:
|
32 |
+
|
33 |
+
```
|
34 |
+
# Task Description
|
35 |
+
|
36 |
+
[Task description]
|
37 |
+
```
|
38 |
+
|
39 |
+
"""),
|
40 |
+
("human", """# User Message
|
41 |
+
|
42 |
+
{user_message}
|
43 |
+
|
44 |
+
# Expected Output
|
45 |
+
|
46 |
+
{expected_output}
|
47 |
+
|
48 |
+
# Task Brief
|
49 |
+
|
50 |
+
""")
|
51 |
+
]),
|
52 |
+
NODE_ACCEPTANCE_CRITERIA_DEVELOPER: ChatPromptTemplate.from_messages([
|
53 |
+
("system", """# Acceptance Criteria Developer
|
54 |
+
|
55 |
+
You are an acceptance criteria developer. You will receive a specific example of a task type to create acceptance criteria. You will respond directly with the acceptance criteria.
|
56 |
+
|
57 |
+
## Instructions
|
58 |
+
|
59 |
+
The user will provide you a specific example with User Message (input) and Expected Output (output) of a task type. You will respond with acceptance criteria for the task type includes the following:
|
60 |
+
|
61 |
+
* What the output should include
|
62 |
+
* What the output should not include
|
63 |
+
* Any specific formatting or structure requirements
|
64 |
+
|
65 |
+
## Output
|
66 |
+
|
67 |
+
Create acceptance criteria in the following format:
|
68 |
+
|
69 |
+
```
|
70 |
+
# Acceptance Criteria
|
71 |
+
|
72 |
+
* [Criteria 1]
|
73 |
+
* [Criteria 2]
|
74 |
+
* [Criteria 3]
|
75 |
+
```
|
76 |
+
|
77 |
+
"""),
|
78 |
+
("human", """# Task Brief
|
79 |
+
|
80 |
+
{system_message}
|
81 |
+
|
82 |
+
# User Message
|
83 |
+
|
84 |
+
{user_message}
|
85 |
+
|
86 |
+
# Expected Output
|
87 |
+
|
88 |
+
{expected_output}
|
89 |
+
|
90 |
+
# Acceptance Criteria
|
91 |
+
|
92 |
+
""")
|
93 |
+
]),
|
94 |
NODE_PROMPT_INITIAL_DEVELOPER: ChatPromptTemplate.from_messages([
|
95 |
("system", """# Expert Prompt Engineer
|
96 |
|
|
|
102 |
|
103 |
## Output
|
104 |
|
105 |
+
Create a [name], Here's the descriptions [description]. Start with "GPT Description:"
|
106 |
"""),
|
107 |
("human", """# User Message
|
108 |
|
|
|
130 |
|
131 |
## Output
|
132 |
|
133 |
+
Create a [name], Here's the descriptions [description]. Start with "GPT Description:"
|
134 |
"""),
|
135 |
("human", """# Current System Message
|
136 |
|
|
|
290 |
<|End_Analysis|>
|
291 |
""")
|
292 |
])
|
293 |
+
}
|
meta_prompt/meta_prompt.py
CHANGED
@@ -8,6 +8,7 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
8 |
from langgraph.graph import StateGraph, END
|
9 |
from langgraph.checkpoint.memory import MemorySaver
|
10 |
from langgraph.errors import GraphRecursionError
|
|
|
11 |
from pydantic import BaseModel
|
12 |
from .consts import *
|
13 |
|
@@ -46,26 +47,30 @@ class MetaPromptGraph:
|
|
46 |
"""
|
47 |
This class represents a graph for meta-prompting in a conversational AI system.
|
48 |
|
49 |
-
It manages the state of the conversation, including the user's message, expected
|
50 |
-
acceptance criteria, system message, output, suggestions, and analysis.
|
51 |
-
consists of nodes that represent different stages of the conversation, such as
|
52 |
-
prompting the developer, executing the output, analyzing the output history, and
|
53 |
-
suggesting new prompts. The class provides methods to create the workflow,
|
54 |
-
initialize the graph, and invoke the graph with a given state.
|
55 |
|
56 |
-
The
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
"""
|
60 |
@classmethod
|
61 |
def get_node_names(cls):
|
62 |
"""
|
63 |
Returns a list of node names in the meta-prompt graph.
|
64 |
|
65 |
-
This method
|
66 |
|
67 |
Returns:
|
68 |
-
list:
|
69 |
"""
|
70 |
return META_PROMPT_NODES
|
71 |
|
@@ -79,12 +84,18 @@ class MetaPromptGraph:
|
|
79 |
Initializes the MetaPromptGraph instance.
|
80 |
|
81 |
Args:
|
82 |
-
- llms (Union[BaseLanguageModel, Dict[str, BaseLanguageModel]],
|
83 |
-
|
84 |
-
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
"""
|
89 |
self.logger = logger or logging.getLogger(__name__)
|
90 |
if self.logger is not None:
|
@@ -94,6 +105,7 @@ class MetaPromptGraph:
|
|
94 |
self.logger.setLevel(logging.INFO)
|
95 |
|
96 |
if isinstance(llms, BaseLanguageModel):
|
|
|
97 |
self.llms: Dict[str, BaseLanguageModel] = {
|
98 |
node: llms for node in self.get_node_names()}
|
99 |
else:
|
@@ -102,7 +114,29 @@ class MetaPromptGraph:
|
|
102 |
ChatPromptTemplate] = DEFAULT_PROMPT_TEMPLATES.copy()
|
103 |
self.prompt_templates.update(prompts)
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def _create_workflow(self, including_initial_developer: bool = True) -> StateGraph:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
workflow = StateGraph(AgentState)
|
107 |
|
108 |
workflow.add_node(NODE_PROMPT_DEVELOPER,
|
@@ -125,10 +159,12 @@ class MetaPromptGraph:
|
|
125 |
"suggestions",
|
126 |
x))
|
127 |
|
|
|
128 |
workflow.add_edge(NODE_PROMPT_DEVELOPER, NODE_PROMPT_EXECUTOR)
|
129 |
workflow.add_edge(NODE_PROMPT_EXECUTOR, NODE_OUTPUT_HISTORY_ANALYZER)
|
130 |
workflow.add_edge(NODE_PROMPT_SUGGESTER, NODE_PROMPT_DEVELOPER)
|
131 |
|
|
|
132 |
workflow.add_conditional_edges(
|
133 |
NODE_OUTPUT_HISTORY_ANALYZER,
|
134 |
lambda x: self._should_exit_on_max_age(x),
|
@@ -148,6 +184,7 @@ class MetaPromptGraph:
|
|
148 |
}
|
149 |
)
|
150 |
|
|
|
151 |
if including_initial_developer:
|
152 |
workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
|
153 |
lambda x: self._prompt_node(
|
@@ -161,8 +198,40 @@ class MetaPromptGraph:
|
|
161 |
workflow.set_entry_point(NODE_PROMPT_EXECUTOR)
|
162 |
|
163 |
return workflow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
workflow = self._create_workflow(including_initial_developer=(
|
167 |
state.system_message is None or state.system_message == ""))
|
168 |
|
@@ -195,20 +264,48 @@ class MetaPromptGraph:
|
|
195 |
|
196 |
return state
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
|
199 |
"""
|
200 |
Prompt a specific node with the given state and update the state with the response.
|
201 |
|
202 |
-
This method formats messages using the prompt template associated with the node,
|
203 |
-
and updates the state with the response content.
|
204 |
|
205 |
Parameters:
|
206 |
-
node (str):
|
207 |
-
target_attribute (str):
|
208 |
-
state (AgentState):
|
209 |
|
210 |
Returns:
|
211 |
-
AgentState:
|
212 |
"""
|
213 |
|
214 |
logger = self.logger.getChild(node)
|
@@ -234,47 +331,77 @@ class MetaPromptGraph:
|
|
234 |
return state
|
235 |
|
236 |
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
logger = self.logger.getChild(NODE_OUTPUT_HISTORY_ANALYZER)
|
238 |
|
239 |
if state.best_output is None:
|
240 |
state.best_output = state.output
|
241 |
state.best_system_message = state.system_message
|
242 |
state.best_output_age = 0
|
243 |
-
|
244 |
logger.debug(
|
245 |
"Best output initialized to the current output:\n%s", state.output)
|
246 |
-
|
247 |
return state
|
248 |
|
249 |
prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
|
250 |
**state.model_dump())
|
251 |
|
252 |
for message in prompt:
|
253 |
-
logger.debug({
|
254 |
-
|
|
|
|
|
|
|
|
|
255 |
|
256 |
response = self.llms[NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
|
257 |
-
logger.debug({
|
258 |
-
|
|
|
|
|
|
|
|
|
259 |
|
260 |
analysis = response.content
|
261 |
|
262 |
-
if state.best_output is None or
|
|
|
263 |
state.best_output = state.output
|
264 |
state.best_system_message = state.system_message
|
265 |
state.best_output_age = 0
|
266 |
-
|
267 |
logger.debug(
|
268 |
"Best output updated to the current output:\n%s", state.output)
|
269 |
else:
|
270 |
state.best_output_age += 1
|
271 |
-
|
272 |
-
logger.debug("Best output age incremented to %s",
|
273 |
-
state.best_output_age)
|
274 |
|
275 |
return state
|
276 |
|
277 |
def _prompt_analyzer(self, state: AgentState) -> AgentState:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
logger = self.logger.getChild(NODE_PROMPT_ANALYZER)
|
279 |
prompt = self.prompt_templates[NODE_PROMPT_ANALYZER].format_messages(
|
280 |
**state.model_dump())
|
@@ -295,6 +422,15 @@ class MetaPromptGraph:
|
|
295 |
return state
|
296 |
|
297 |
def _should_exit_on_max_age(self, state: AgentState) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
if state.max_output_age <= 0:
|
299 |
# always continue if max age is 0
|
300 |
return "continue"
|
@@ -309,4 +445,13 @@ class MetaPromptGraph:
|
|
309 |
return "continue"
|
310 |
|
311 |
def _should_exit_on_acceptable_output(self, state: AgentState) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
return "continue" if not state.accepted else END
|
|
|
8 |
from langgraph.graph import StateGraph, END
|
9 |
from langgraph.checkpoint.memory import MemorySaver
|
10 |
from langgraph.errors import GraphRecursionError
|
11 |
+
from langchain_core.runnables.base import RunnableLike
|
12 |
from pydantic import BaseModel
|
13 |
from .consts import *
|
14 |
|
|
|
47 |
"""
|
48 |
This class represents a graph for meta-prompting in a conversational AI system.
|
49 |
|
50 |
+
It manages the state of the conversation, including the user's message, expected
|
51 |
+
output, acceptance criteria, system message, output, suggestions, and analysis.
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
The graph consists of nodes that represent different stages of the conversation,
|
54 |
+
such as prompting the developer, executing the output, analyzing the output
|
55 |
+
history, and suggesting new prompts.
|
56 |
+
|
57 |
+
The class provides methods to create the workflow, initialize the graph, and
|
58 |
+
invoke the graph with a given state.
|
59 |
+
|
60 |
+
The MetaPromptGraph class is responsible for orchestrating the conversation
|
61 |
+
flow and deciding the next step based on the current state of the
|
62 |
+
conversation. It uses language models and prompt templates to generate
|
63 |
+
responses and analyze the output.
|
64 |
"""
|
65 |
@classmethod
|
66 |
def get_node_names(cls):
|
67 |
"""
|
68 |
Returns a list of node names in the meta-prompt graph.
|
69 |
|
70 |
+
This method initializes language models and prompt templates for each node.
|
71 |
|
72 |
Returns:
|
73 |
+
list: List of node names.
|
74 |
"""
|
75 |
return META_PROMPT_NODES
|
76 |
|
|
|
84 |
Initializes the MetaPromptGraph instance.
|
85 |
|
86 |
Args:
|
87 |
+
- llms (Union[BaseLanguageModel, Dict[str, BaseLanguageModel]],
|
88 |
+
optional): The language models for the graph nodes. Defaults to {}.
|
89 |
+
- prompts (Dict[str, ChatPromptTemplate], optional): The custom
|
90 |
+
prompt templates for the graph nodes. Defaults to {}.
|
91 |
+
- logger (Optional[logging.Logger], optional): The logger for
|
92 |
+
the graph. Defaults to None.
|
93 |
+
- verbose (bool, optional): Whether to set the logger level to
|
94 |
+
DEBUG. Defaults to False.
|
95 |
+
|
96 |
+
Initializes the logger, sets the language models and prompt
|
97 |
+
templates for the graph nodes, and updates the prompt templates
|
98 |
+
with custom ones if provided.
|
99 |
"""
|
100 |
self.logger = logger or logging.getLogger(__name__)
|
101 |
if self.logger is not None:
|
|
|
105 |
self.logger.setLevel(logging.INFO)
|
106 |
|
107 |
if isinstance(llms, BaseLanguageModel):
|
108 |
+
# if llms is a single language model, wrap it in a dictionary
|
109 |
self.llms: Dict[str, BaseLanguageModel] = {
|
110 |
node: llms for node in self.get_node_names()}
|
111 |
else:
|
|
|
114 |
ChatPromptTemplate] = DEFAULT_PROMPT_TEMPLATES.copy()
|
115 |
self.prompt_templates.update(prompts)
|
116 |
|
117 |
+
|
118 |
+
def _create_acceptance_criteria_workflow(self) -> StateGraph:
|
119 |
+
workflow = StateGraph(AgentState)
|
120 |
+
workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
121 |
+
lambda x: self._prompt_node(
|
122 |
+
NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
123 |
+
"acceptance_criteria",
|
124 |
+
x))
|
125 |
+
workflow.add_edge(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, END)
|
126 |
+
workflow.set_entry_point(NODE_ACCEPTANCE_CRITERIA_DEVELOPER)
|
127 |
+
return workflow
|
128 |
+
|
129 |
+
|
130 |
def _create_workflow(self, including_initial_developer: bool = True) -> StateGraph:
|
131 |
+
"""Create a workflow state graph.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
including_initial_developer: Flag indicating whether to include the
|
135 |
+
initial developer node in the workflow.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
StateGraph: A state graph representing the workflow.
|
139 |
+
"""
|
140 |
workflow = StateGraph(AgentState)
|
141 |
|
142 |
workflow.add_node(NODE_PROMPT_DEVELOPER,
|
|
|
159 |
"suggestions",
|
160 |
x))
|
161 |
|
162 |
+
# Connect nodes
|
163 |
workflow.add_edge(NODE_PROMPT_DEVELOPER, NODE_PROMPT_EXECUTOR)
|
164 |
workflow.add_edge(NODE_PROMPT_EXECUTOR, NODE_OUTPUT_HISTORY_ANALYZER)
|
165 |
workflow.add_edge(NODE_PROMPT_SUGGESTER, NODE_PROMPT_DEVELOPER)
|
166 |
|
167 |
+
# Add conditional edges
|
168 |
workflow.add_conditional_edges(
|
169 |
NODE_OUTPUT_HISTORY_ANALYZER,
|
170 |
lambda x: self._should_exit_on_max_age(x),
|
|
|
184 |
}
|
185 |
)
|
186 |
|
187 |
+
# Set entry point based on including_initial_developer flag
|
188 |
if including_initial_developer:
|
189 |
workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
|
190 |
lambda x: self._prompt_node(
|
|
|
198 |
workflow.set_entry_point(NODE_PROMPT_EXECUTOR)
|
199 |
|
200 |
return workflow
|
201 |
+
|
202 |
+
def run_acceptance_criteria_graph(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
|
203 |
+
self.logger.debug("Creating acceptance criteria workflow")
|
204 |
+
workflow = self._create_acceptance_criteria_workflow()
|
205 |
+
self.logger.debug("Compiling workflow with memory saver")
|
206 |
+
memory = MemorySaver()
|
207 |
+
graph = workflow.compile(checkpointer=memory)
|
208 |
+
self.logger.debug("Configuring graph with recursion limit %s", recursion_limit)
|
209 |
+
config = {"configurable": {"thread_id": "1"},
|
210 |
+
"recursion_limit": recursion_limit}
|
211 |
+
self.logger.debug("Invoking graph with state: %s", pprint.pformat(state))
|
212 |
+
output_state = graph.invoke(state, config)
|
213 |
+
self.logger.debug("Output state: %s", pprint.pformat(output_state))
|
214 |
+
return output_state
|
215 |
+
|
216 |
|
217 |
+
def run_meta_prompt_graph(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
|
218 |
+
"""
|
219 |
+
Invoke the meta-prompt workflow with the given state and recursion limit.
|
220 |
+
|
221 |
+
This method creates a workflow based on the presence of an initial system
|
222 |
+
message, compiles the workflow with a memory saver, and invokes the graph
|
223 |
+
with the given state. If a recursion limit is reached, it returns the best
|
224 |
+
state found so far.
|
225 |
+
|
226 |
+
Parameters:
|
227 |
+
state (AgentState): The current state of the agent, containing
|
228 |
+
necessary context for message formatting.
|
229 |
+
recursion_limit (int): The maximum number of recursive calls
|
230 |
+
allowed. Defaults to 25.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
AgentState: The output state of the agent after invoking the workflow.
|
234 |
+
"""
|
235 |
workflow = self._create_workflow(including_initial_developer=(
|
236 |
state.system_message is None or state.system_message == ""))
|
237 |
|
|
|
264 |
|
265 |
return state
|
266 |
|
267 |
+
|
268 |
+
def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
|
269 |
+
return self.run_meta_prompt_graph(state, recursion_limit)
|
270 |
+
|
271 |
+
|
272 |
+
def _optional_action(
|
273 |
+
self, target_attribute: str,
|
274 |
+
action: Optional[RunnableLike],
|
275 |
+
state: AgentState
|
276 |
+
) -> AgentState:
|
277 |
+
"""
|
278 |
+
Optionally invokes an action if the target attribute is not set or empty.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
node (str): Node identifier.
|
282 |
+
target_attribute (str): State attribute to be updated.
|
283 |
+
action (Optional[RunnableLike]): Action to be invoked. Defaults to None.
|
284 |
+
state (AgentState): Current agent state.
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
AgentState: Updated state.
|
288 |
+
"""
|
289 |
+
if not getattr(state, target_attribute, None) or getattr(state, target_attribute) == "":
|
290 |
+
if action:
|
291 |
+
state = action(state)
|
292 |
+
return state
|
293 |
+
|
294 |
+
|
295 |
def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
|
296 |
"""
|
297 |
Prompt a specific node with the given state and update the state with the response.
|
298 |
|
299 |
+
This method formats messages using the prompt template associated with the node,
|
300 |
+
logs the invocation and response, and updates the state with the response content.
|
301 |
|
302 |
Parameters:
|
303 |
+
node (str): Node identifier to be prompted.
|
304 |
+
target_attribute (str): State attribute to be updated with response content.
|
305 |
+
state (AgentState): Current agent state with necessary context for message formatting.
|
306 |
|
307 |
Returns:
|
308 |
+
AgentState: Updated state with response content set to the target attribute.
|
309 |
"""
|
310 |
|
311 |
logger = self.logger.getChild(node)
|
|
|
331 |
return state
|
332 |
|
333 |
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
334 |
+
"""
|
335 |
+
Analyzes the output history and updates the best output and its age.
|
336 |
+
|
337 |
+
This method checks if the best output is initialized, formats the prompt for
|
338 |
+
the output history analyzer, invokes the language model, and updates the best
|
339 |
+
output and its age based on the response.
|
340 |
+
|
341 |
+
Parameters:
|
342 |
+
state (AgentState): Current state of the agent with necessary context
|
343 |
+
for message formatting.
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
AgentState: Updated state with the best output and its age.
|
347 |
+
"""
|
348 |
logger = self.logger.getChild(NODE_OUTPUT_HISTORY_ANALYZER)
|
349 |
|
350 |
if state.best_output is None:
|
351 |
state.best_output = state.output
|
352 |
state.best_system_message = state.system_message
|
353 |
state.best_output_age = 0
|
|
|
354 |
logger.debug(
|
355 |
"Best output initialized to the current output:\n%s", state.output)
|
|
|
356 |
return state
|
357 |
|
358 |
prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
|
359 |
**state.model_dump())
|
360 |
|
361 |
for message in prompt:
|
362 |
+
logger.debug({
|
363 |
+
'node': NODE_OUTPUT_HISTORY_ANALYZER,
|
364 |
+
'action': 'invoke',
|
365 |
+
'type': message.type,
|
366 |
+
'message': message.content
|
367 |
+
})
|
368 |
|
369 |
response = self.llms[NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
|
370 |
+
logger.debug({
|
371 |
+
'node': NODE_OUTPUT_HISTORY_ANALYZER,
|
372 |
+
'action': 'response',
|
373 |
+
'type': response.type,
|
374 |
+
'message': response.content
|
375 |
+
})
|
376 |
|
377 |
analysis = response.content
|
378 |
|
379 |
+
if state.best_output is None or (
|
380 |
+
"# Output ID closer to Expected Output: B" in analysis):
|
381 |
state.best_output = state.output
|
382 |
state.best_system_message = state.system_message
|
383 |
state.best_output_age = 0
|
|
|
384 |
logger.debug(
|
385 |
"Best output updated to the current output:\n%s", state.output)
|
386 |
else:
|
387 |
state.best_output_age += 1
|
388 |
+
logger.debug("Best output age incremented to %s", state.best_output_age)
|
|
|
|
|
389 |
|
390 |
return state
|
391 |
|
392 |
def _prompt_analyzer(self, state: AgentState) -> AgentState:
|
393 |
+
"""
|
394 |
+
Analyzes the prompt and updates the state with the analysis and
|
395 |
+
acceptance status.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
state (AgentState): The current state of the agent, containing
|
399 |
+
necessary context for message formatting.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
AgentState: The updated state of the agent with the analysis
|
403 |
+
and acceptance status.
|
404 |
+
"""
|
405 |
logger = self.logger.getChild(NODE_PROMPT_ANALYZER)
|
406 |
prompt = self.prompt_templates[NODE_PROMPT_ANALYZER].format_messages(
|
407 |
**state.model_dump())
|
|
|
422 |
return state
|
423 |
|
424 |
def _should_exit_on_max_age(self, state: AgentState) -> str:
|
425 |
+
"""
|
426 |
+
Determines whether to exit the workflow based on the maximum output age.
|
427 |
+
|
428 |
+
Args:
|
429 |
+
state (AgentState): The current state of the agent.
|
430 |
+
|
431 |
+
Returns:
|
432 |
+
str: The decision to continue, rerun, or end the workflow.
|
433 |
+
"""
|
434 |
if state.max_output_age <= 0:
|
435 |
# always continue if max age is 0
|
436 |
return "continue"
|
|
|
445 |
return "continue"
|
446 |
|
447 |
def _should_exit_on_acceptable_output(self, state: AgentState) -> str:
|
448 |
+
"""
|
449 |
+
Determines whether to exit the workflow based on the acceptance status of the output.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
state (AgentState): The current state of the agent.
|
453 |
+
|
454 |
+
Returns:
|
455 |
+
str: The decision to continue or end the workflow.
|
456 |
+
"""
|
457 |
return "continue" if not state.accepted else END
|
tests/meta_prompt_graph_test.py
CHANGED
@@ -8,6 +8,8 @@ from langchain_openai import ChatOpenAI
|
|
8 |
|
9 |
# Assuming the necessary imports are made for the classes and functions used in meta_prompt_graph.py
|
10 |
from meta_prompt import *
|
|
|
|
|
11 |
|
12 |
class TestMetaPromptGraph(unittest.TestCase):
|
13 |
def setUp(self):
|
@@ -274,6 +276,69 @@ class TestMetaPromptGraph(unittest.TestCase):
|
|
274 |
|
275 |
pprint.pp(output_state["best_output"])
|
276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
if __name__ == '__main__':
|
279 |
unittest.main()
|
|
|
8 |
|
9 |
# Assuming the necessary imports are made for the classes and functions used in meta_prompt_graph.py
|
10 |
from meta_prompt import *
|
11 |
+
from meta_prompt.consts import NODE_ACCEPTANCE_CRITERIA_DEVELOPER
|
12 |
+
from langgraph.graph import StateGraph, END
|
13 |
|
14 |
class TestMetaPromptGraph(unittest.TestCase):
|
15 |
def setUp(self):
|
|
|
276 |
|
277 |
pprint.pp(output_state["best_output"])
|
278 |
|
279 |
+
def test_create_acceptance_criteria_workflow(self):
|
280 |
+
"""
|
281 |
+
Test the _create_acceptance_criteria_workflow method of MetaPromptGraph.
|
282 |
+
|
283 |
+
This test case verifies that the workflow created by the _create_acceptance_criteria_workflow method
|
284 |
+
contains the correct node and edge.
|
285 |
+
"""
|
286 |
+
|
287 |
+
llms = {
|
288 |
+
NODE_ACCEPTANCE_CRITERIA_DEVELOPER: ChatOpenAI(model_name="deepseek/deepseek-chat")
|
289 |
+
}
|
290 |
+
meta_prompt_graph = MetaPromptGraph(llms=llms)
|
291 |
+
workflow = meta_prompt_graph._create_acceptance_criteria_workflow()
|
292 |
+
|
293 |
+
# Check if the workflow contains the correct node
|
294 |
+
self.assertIn(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, workflow.nodes)
|
295 |
+
|
296 |
+
# Check if the workflow contains the correct edge
|
297 |
+
self.assertIn((NODE_ACCEPTANCE_CRITERIA_DEVELOPER, END), workflow.edges)
|
298 |
+
|
299 |
+
# compile the workflow
|
300 |
+
graph = workflow.compile()
|
301 |
+
print(graph)
|
302 |
+
|
303 |
+
# invoke the workflow
|
304 |
+
state = AgentState(
|
305 |
+
user_message="How do I reverse a list in Python?",
|
306 |
+
expected_output="The output should use the `reverse()` method.",
|
307 |
+
# system_message="Create acceptance criteria for the task of reversing a list in Python."
|
308 |
+
)
|
309 |
+
output_state = graph.invoke(state)
|
310 |
+
|
311 |
+
# check if the output state contains the acceptance criteria
|
312 |
+
self.assertIsNotNone(output_state['acceptance_criteria'])
|
313 |
+
|
314 |
+
# check if the acceptance criteria includes string '`reverse()`'
|
315 |
+
self.assertIn('`reverse()`', output_state['acceptance_criteria'])
|
316 |
+
|
317 |
+
pprint.pp(output_state["acceptance_criteria"])
|
318 |
+
|
319 |
+
def test_run_acceptance_criteria_graph(self):
|
320 |
+
"""
|
321 |
+
Test the run_acceptance_criteria_graph method of MetaPromptGraph.
|
322 |
+
|
323 |
+
This test case verifies that the run_acceptance_criteria_graph method returns a state with acceptance criteria.
|
324 |
+
"""
|
325 |
+
llms = {
|
326 |
+
NODE_ACCEPTANCE_CRITERIA_DEVELOPER: MagicMock(
|
327 |
+
invoke=lambda prompt: MagicMock(content="Acceptance criteria: ..."))
|
328 |
+
}
|
329 |
+
meta_prompt_graph = MetaPromptGraph(llms=llms)
|
330 |
+
state = AgentState(
|
331 |
+
user_message="How do I reverse a list in Python?",
|
332 |
+
expected_output="The output should use the `reverse()` method.",
|
333 |
+
)
|
334 |
+
output_state = meta_prompt_graph.run_acceptance_criteria_graph(state)
|
335 |
+
|
336 |
+
# Check if the output state contains the acceptance criteria
|
337 |
+
self.assertIsNotNone(output_state['acceptance_criteria'])
|
338 |
+
|
339 |
+
# Check if the acceptance criteria includes the expected content
|
340 |
+
self.assertIn("Acceptance criteria: ...", output_state['acceptance_criteria'])
|
341 |
+
|
342 |
|
343 |
if __name__ == '__main__':
|
344 |
unittest.main()
|