Spaces:
Running
Running
Updated unit tests.
Browse files- meta_prompt/meta_prompt.py +36 -88
- tests/meta_prompt_graph_test.py +16 -9
meta_prompt/meta_prompt.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import logging
|
2 |
-
import operator
|
3 |
import pprint
|
4 |
from langchain_core.language_models import BaseLanguageModel
|
5 |
from langchain_core.prompts import ChatPromptTemplate
|
@@ -19,21 +18,7 @@ def last_non_empty(a, b):
|
|
19 |
# return the last non-none value
|
20 |
return next((s for s in (b, a) if s), None)
|
21 |
|
22 |
-
|
23 |
-
return b
|
24 |
-
|
25 |
-
class InitialAgentState(TypedDict):
|
26 |
-
"""
|
27 |
-
Represents the state of an agent in a conversation.
|
28 |
-
"""
|
29 |
-
max_output_age: int
|
30 |
-
user_message: Optional[str]
|
31 |
-
expected_output: Optional[str]
|
32 |
-
acceptance_criteria: Annotated[Optional[str], last_non_empty]
|
33 |
-
system_message: Annotated[Optional[str], last_non_empty]
|
34 |
-
|
35 |
-
|
36 |
-
class AgentState(BaseModel):
|
37 |
"""
|
38 |
Represents the state of an agent in a conversation.
|
39 |
|
@@ -51,18 +36,18 @@ class AgentState(BaseModel):
|
|
51 |
- best_system_message (str, optional): The best system message.
|
52 |
- best_output_age (int): The age of the best output.
|
53 |
"""
|
54 |
-
max_output_age: int
|
55 |
-
user_message: Optional[str]
|
56 |
-
expected_output: Optional[str]
|
57 |
-
acceptance_criteria: Annotated[Optional[str], last_non_empty]
|
58 |
-
system_message: Annotated[Optional[str], last_non_empty]
|
59 |
-
output: Optional[str]
|
60 |
-
suggestions: Optional[str]
|
61 |
-
accepted: bool
|
62 |
-
analysis: Optional[str]
|
63 |
-
best_output: Optional[str]
|
64 |
-
best_system_message: Optional[str]
|
65 |
-
best_output_age: int
|
66 |
|
67 |
class MetaPromptGraph:
|
68 |
"""
|
@@ -220,19 +205,6 @@ class MetaPromptGraph:
|
|
220 |
}
|
221 |
)
|
222 |
|
223 |
-
# # Set entry point based on including_initial_developer flag
|
224 |
-
# if including_initial_developer:
|
225 |
-
# workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
|
226 |
-
# lambda x: self._prompt_node(
|
227 |
-
# NODE_PROMPT_INITIAL_DEVELOPER,
|
228 |
-
# "system_message",
|
229 |
-
# x))
|
230 |
-
# workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER,
|
231 |
-
# NODE_PROMPT_EXECUTOR)
|
232 |
-
# workflow.set_entry_point(NODE_PROMPT_INITIAL_DEVELOPER)
|
233 |
-
# else:
|
234 |
-
# workflow.set_entry_point(NODE_PROMPT_EXECUTOR)
|
235 |
-
|
236 |
workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
|
237 |
lambda x: self._optional_action(
|
238 |
"system_message",
|
@@ -240,7 +212,7 @@ class MetaPromptGraph:
|
|
240 |
NODE_PROMPT_INITIAL_DEVELOPER,
|
241 |
"system_message",
|
242 |
x),
|
243 |
-
|
244 |
workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
245 |
lambda x: self._optional_action(
|
246 |
"acceptance_criteria",
|
@@ -248,15 +220,13 @@ class MetaPromptGraph:
|
|
248 |
NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
249 |
"acceptance_criteria",
|
250 |
x),
|
251 |
-
|
252 |
-
# workflow.add_node(START)
|
253 |
|
254 |
workflow.add_edge(START, NODE_PROMPT_INITIAL_DEVELOPER)
|
255 |
workflow.add_edge(START, NODE_ACCEPTANCE_CRITERIA_DEVELOPER)
|
256 |
|
257 |
workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER, NODE_PROMPT_EXECUTOR)
|
258 |
workflow.add_edge(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, NODE_PROMPT_EXECUTOR)
|
259 |
-
# workflow.set_entry_point(START)
|
260 |
|
261 |
return workflow
|
262 |
|
@@ -342,8 +312,8 @@ class MetaPromptGraph:
|
|
342 |
def _optional_action(
|
343 |
self, target_attribute: str,
|
344 |
action: RunnableLike,
|
345 |
-
state:
|
346 |
-
) ->
|
347 |
"""
|
348 |
Optionally invokes an action if the target attribute is not set or empty.
|
349 |
|
@@ -370,8 +340,8 @@ class MetaPromptGraph:
|
|
370 |
|
371 |
def _prompt_node(
|
372 |
self, node: str, target_attribute: str,
|
373 |
-
state:
|
374 |
-
) ->
|
375 |
"""
|
376 |
Prompt a specific node with the given state and update the state with the response.
|
377 |
|
@@ -410,14 +380,6 @@ class MetaPromptGraph:
|
|
410 |
'message': response.content
|
411 |
})
|
412 |
|
413 |
-
# if isinstance(state, dict):
|
414 |
-
# # state[target_attribute] = response.content
|
415 |
-
# # Create a dict with the target key only
|
416 |
-
# state = {target_attribute: response.content}
|
417 |
-
# else:
|
418 |
-
# setattr(state, target_attribute, response.content)
|
419 |
-
# return state
|
420 |
-
|
421 |
return {target_attribute: response.content}
|
422 |
|
423 |
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
@@ -437,16 +399,16 @@ class MetaPromptGraph:
|
|
437 |
"""
|
438 |
logger = self.logger.getChild(NODE_OUTPUT_HISTORY_ANALYZER)
|
439 |
|
440 |
-
if state
|
441 |
-
state
|
442 |
-
state
|
443 |
-
state
|
444 |
logger.debug(
|
445 |
-
"Best output initialized to the current output:\n%s", state
|
446 |
return state
|
447 |
|
448 |
prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
|
449 |
-
**state
|
450 |
|
451 |
for message in prompt:
|
452 |
logger.debug({
|
@@ -466,29 +428,22 @@ class MetaPromptGraph:
|
|
466 |
|
467 |
analysis = response.content
|
468 |
|
469 |
-
if (state
|
470 |
"# Output ID closer to Expected Output: B" in analysis or
|
471 |
(self.aggressive_exploration and
|
472 |
"# Output ID closer to Expected Output: A" not in analysis)):
|
473 |
-
# state.best_output = state.output
|
474 |
-
# state.best_system_message = state.system_message
|
475 |
-
# state.best_output_age = 0
|
476 |
result_dict = {
|
477 |
-
"best_output": state
|
478 |
-
"best_system_message": state
|
479 |
"best_output_age": 0
|
480 |
}
|
481 |
logger.debug("Best output updated to the current output:\n%s",
|
482 |
result_dict["best_output"])
|
483 |
else:
|
484 |
-
# state.best_output_age += 1
|
485 |
-
# # rollback output and system message
|
486 |
-
# state.output = state.best_output
|
487 |
-
# state.system_message = state.best_system_message
|
488 |
result_dict = {
|
489 |
-
"output": state
|
490 |
-
"system_message": state
|
491 |
-
"best_output_age": state
|
492 |
}
|
493 |
logger.debug("Best output age incremented to %s",
|
494 |
result_dict["best_output_age"])
|
@@ -510,7 +465,7 @@ class MetaPromptGraph:
|
|
510 |
"""
|
511 |
logger = self.logger.getChild(NODE_PROMPT_ANALYZER)
|
512 |
prompt = self.prompt_templates[NODE_PROMPT_ANALYZER].format_messages(
|
513 |
-
**state
|
514 |
|
515 |
for message in prompt:
|
516 |
logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'invoke',
|
@@ -520,13 +475,6 @@ class MetaPromptGraph:
|
|
520 |
logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
|
521 |
'type': response.type, 'message': response.content})
|
522 |
|
523 |
-
# state.analysis = response.content
|
524 |
-
# state.accepted = "Accept: Yes" in response.content
|
525 |
-
|
526 |
-
# logger.debug("Accepted: %s", state.accepted)
|
527 |
-
|
528 |
-
# return state
|
529 |
-
|
530 |
result_dict = {
|
531 |
"analysis": response.content,
|
532 |
"accepted": "Accept: Yes" in response.content
|
@@ -545,14 +493,14 @@ class MetaPromptGraph:
|
|
545 |
Returns:
|
546 |
str: The decision to continue, rerun, or end the workflow.
|
547 |
"""
|
548 |
-
if state
|
549 |
# always continue if max age is 0
|
550 |
return "continue"
|
551 |
|
552 |
-
if state
|
553 |
return END
|
554 |
|
555 |
-
if state
|
556 |
# skip prompt_analyzer and prompt_suggester, goto prompt_developer
|
557 |
return "rerun"
|
558 |
|
@@ -568,4 +516,4 @@ class MetaPromptGraph:
|
|
568 |
Returns:
|
569 |
str: The decision to continue or end the workflow.
|
570 |
"""
|
571 |
-
return "continue" if not state
|
|
|
1 |
import logging
|
|
|
2 |
import pprint
|
3 |
from langchain_core.language_models import BaseLanguageModel
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
18 |
# return the last non-none value
|
19 |
return next((s for s in (b, a) if s), None)
|
20 |
|
21 |
+
class AgentState(TypedDict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
"""
|
23 |
Represents the state of an agent in a conversation.
|
24 |
|
|
|
36 |
- best_system_message (str, optional): The best system message.
|
37 |
- best_output_age (int): The age of the best output.
|
38 |
"""
|
39 |
+
max_output_age: Optional[int]
|
40 |
+
user_message: Optional[str]
|
41 |
+
expected_output: Optional[str]
|
42 |
+
acceptance_criteria: Annotated[Optional[str], last_non_empty]
|
43 |
+
system_message: Annotated[Optional[str], last_non_empty]
|
44 |
+
output: Optional[str]
|
45 |
+
suggestions: Optional[str]
|
46 |
+
accepted: Optional[bool]
|
47 |
+
analysis: Optional[str]
|
48 |
+
best_output: Optional[str]
|
49 |
+
best_system_message: Optional[str]
|
50 |
+
best_output_age: Optional[int]
|
51 |
|
52 |
class MetaPromptGraph:
|
53 |
"""
|
|
|
205 |
}
|
206 |
)
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
|
209 |
lambda x: self._optional_action(
|
210 |
"system_message",
|
|
|
212 |
NODE_PROMPT_INITIAL_DEVELOPER,
|
213 |
"system_message",
|
214 |
x),
|
215 |
+
x))
|
216 |
workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
217 |
lambda x: self._optional_action(
|
218 |
"acceptance_criteria",
|
|
|
220 |
NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
221 |
"acceptance_criteria",
|
222 |
x),
|
223 |
+
x))
|
|
|
224 |
|
225 |
workflow.add_edge(START, NODE_PROMPT_INITIAL_DEVELOPER)
|
226 |
workflow.add_edge(START, NODE_ACCEPTANCE_CRITERIA_DEVELOPER)
|
227 |
|
228 |
workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER, NODE_PROMPT_EXECUTOR)
|
229 |
workflow.add_edge(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, NODE_PROMPT_EXECUTOR)
|
|
|
230 |
|
231 |
return workflow
|
232 |
|
|
|
312 |
def _optional_action(
|
313 |
self, target_attribute: str,
|
314 |
action: RunnableLike,
|
315 |
+
state: AgentState
|
316 |
+
) -> AgentState:
|
317 |
"""
|
318 |
Optionally invokes an action if the target attribute is not set or empty.
|
319 |
|
|
|
340 |
|
341 |
def _prompt_node(
|
342 |
self, node: str, target_attribute: str,
|
343 |
+
state: AgentState
|
344 |
+
) -> AgentState:
|
345 |
"""
|
346 |
Prompt a specific node with the given state and update the state with the response.
|
347 |
|
|
|
380 |
'message': response.content
|
381 |
})
|
382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
return {target_attribute: response.content}
|
384 |
|
385 |
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
|
|
399 |
"""
|
400 |
logger = self.logger.getChild(NODE_OUTPUT_HISTORY_ANALYZER)
|
401 |
|
402 |
+
if state["best_output"] is None:
|
403 |
+
state["best_output"] = state["output"]
|
404 |
+
state["best_system_message"] = state["system_message"]
|
405 |
+
state["best_output_age"] = 0
|
406 |
logger.debug(
|
407 |
+
"Best output initialized to the current output:\n%s", state["output"])
|
408 |
return state
|
409 |
|
410 |
prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
|
411 |
+
**state)
|
412 |
|
413 |
for message in prompt:
|
414 |
logger.debug({
|
|
|
428 |
|
429 |
analysis = response.content
|
430 |
|
431 |
+
if (state["best_output"] is None or
|
432 |
"# Output ID closer to Expected Output: B" in analysis or
|
433 |
(self.aggressive_exploration and
|
434 |
"# Output ID closer to Expected Output: A" not in analysis)):
|
|
|
|
|
|
|
435 |
result_dict = {
|
436 |
+
"best_output": state["output"],
|
437 |
+
"best_system_message": state["system_message"],
|
438 |
"best_output_age": 0
|
439 |
}
|
440 |
logger.debug("Best output updated to the current output:\n%s",
|
441 |
result_dict["best_output"])
|
442 |
else:
|
|
|
|
|
|
|
|
|
443 |
result_dict = {
|
444 |
+
"output": state["best_output"],
|
445 |
+
"system_message": state["best_system_message"],
|
446 |
+
"best_output_age": state["best_output_age"] + 1
|
447 |
}
|
448 |
logger.debug("Best output age incremented to %s",
|
449 |
result_dict["best_output_age"])
|
|
|
465 |
"""
|
466 |
logger = self.logger.getChild(NODE_PROMPT_ANALYZER)
|
467 |
prompt = self.prompt_templates[NODE_PROMPT_ANALYZER].format_messages(
|
468 |
+
**state)
|
469 |
|
470 |
for message in prompt:
|
471 |
logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'invoke',
|
|
|
475 |
logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
|
476 |
'type': response.type, 'message': response.content})
|
477 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
result_dict = {
|
479 |
"analysis": response.content,
|
480 |
"accepted": "Accept: Yes" in response.content
|
|
|
493 |
Returns:
|
494 |
str: The decision to continue, rerun, or end the workflow.
|
495 |
"""
|
496 |
+
if state["max_output_age"] <= 0:
|
497 |
# always continue if max age is 0
|
498 |
return "continue"
|
499 |
|
500 |
+
if state["best_output_age"] >= state["max_output_age"]:
|
501 |
return END
|
502 |
|
503 |
+
if state["best_output_age"] > 0:
|
504 |
# skip prompt_analyzer and prompt_suggester, goto prompt_developer
|
505 |
return "rerun"
|
506 |
|
|
|
516 |
Returns:
|
517 |
str: The decision to continue or end the workflow.
|
518 |
"""
|
519 |
+
return "continue" if not state["accepted"] else END
|
tests/meta_prompt_graph_test.py
CHANGED
@@ -39,7 +39,7 @@ class TestMetaPromptGraph(unittest.TestCase):
|
|
39 |
)
|
40 |
|
41 |
assert (
|
42 |
-
updated_state
|
43 |
), "The output attribute should be updated with the mocked response content"
|
44 |
|
45 |
|
@@ -78,13 +78,13 @@ class TestMetaPromptGraph(unittest.TestCase):
|
|
78 |
updated_state = meta_prompt_graph._output_history_analyzer(state)
|
79 |
|
80 |
assert (
|
81 |
-
updated_state
|
82 |
), "Best output should be updated to the current output."
|
83 |
assert (
|
84 |
-
updated_state
|
85 |
), "Best system message should be updated to the current system message."
|
86 |
assert (
|
87 |
-
updated_state
|
88 |
), "Best output age should be reset to 0."
|
89 |
|
90 |
|
@@ -104,10 +104,13 @@ class TestMetaPromptGraph(unittest.TestCase):
|
|
104 |
}
|
105 |
meta_prompt_graph = MetaPromptGraph(llms=llms)
|
106 |
state = AgentState(
|
107 |
-
output="Test output", expected_output="Expected output"
|
|
|
|
|
|
|
108 |
)
|
109 |
updated_state = meta_prompt_graph._prompt_analyzer(state)
|
110 |
-
assert updated_state
|
111 |
|
112 |
|
113 |
def test_get_node_names(self):
|
@@ -138,7 +141,8 @@ class TestMetaPromptGraph(unittest.TestCase):
|
|
138 |
user_message="How do I reverse a list in Python?",
|
139 |
expected_output="Use the `[::-1]` slicing technique or the "
|
140 |
"`list.reverse()` method.",
|
141 |
-
acceptance_criteria="Similar in meaning, text length and style."
|
|
|
142 |
)
|
143 |
output_state = meta_prompt_graph(input_state, recursion_limit=25)
|
144 |
|
@@ -190,6 +194,7 @@ class TestMetaPromptGraph(unittest.TestCase):
|
|
190 |
|
191 |
meta_prompt_graph = MetaPromptGraph(llms=llms)
|
192 |
input_state = AgentState(
|
|
|
193 |
user_message="How do I reverse a list in Python?",
|
194 |
expected_output="Use the `[::-1]` slicing technique or the "
|
195 |
"`list.reverse()` method.",
|
@@ -239,7 +244,8 @@ class TestMetaPromptGraph(unittest.TestCase):
|
|
239 |
input_state = AgentState(
|
240 |
user_message="How do I reverse a list in Python?",
|
241 |
expected_output="The output should use the `reverse()` method.",
|
242 |
-
acceptance_criteria="The output should be correct and efficient."
|
|
|
243 |
)
|
244 |
|
245 |
output_state = meta_prompt_graph(input_state)
|
@@ -277,7 +283,8 @@ class TestMetaPromptGraph(unittest.TestCase):
|
|
277 |
input_state = AgentState(
|
278 |
user_message="How do I reverse a list in Python?",
|
279 |
expected_output="The output should use the `reverse()` method.",
|
280 |
-
acceptance_criteria="The output should be correct and efficient."
|
|
|
281 |
)
|
282 |
|
283 |
output_state = meta_prompt_graph(input_state)
|
|
|
39 |
)
|
40 |
|
41 |
assert (
|
42 |
+
updated_state['output'] == "Mocked response content"
|
43 |
), "The output attribute should be updated with the mocked response content"
|
44 |
|
45 |
|
|
|
78 |
updated_state = meta_prompt_graph._output_history_analyzer(state)
|
79 |
|
80 |
assert (
|
81 |
+
updated_state['best_output'] == state['output']
|
82 |
), "Best output should be updated to the current output."
|
83 |
assert (
|
84 |
+
updated_state['best_system_message'] == state['system_message']
|
85 |
), "Best system message should be updated to the current system message."
|
86 |
assert (
|
87 |
+
updated_state['best_output_age'] == 0
|
88 |
), "Best output age should be reset to 0."
|
89 |
|
90 |
|
|
|
104 |
}
|
105 |
meta_prompt_graph = MetaPromptGraph(llms=llms)
|
106 |
state = AgentState(
|
107 |
+
output="Test output", expected_output="Expected output",
|
108 |
+
acceptance_criteria="Acceptance criteria: ...",
|
109 |
+
system_message="System message: ...",
|
110 |
+
max_output_age=2
|
111 |
)
|
112 |
updated_state = meta_prompt_graph._prompt_analyzer(state)
|
113 |
+
assert updated_state['accepted'] is True
|
114 |
|
115 |
|
116 |
def test_get_node_names(self):
|
|
|
141 |
user_message="How do I reverse a list in Python?",
|
142 |
expected_output="Use the `[::-1]` slicing technique or the "
|
143 |
"`list.reverse()` method.",
|
144 |
+
acceptance_criteria="Similar in meaning, text length and style.",
|
145 |
+
max_output_age=2
|
146 |
)
|
147 |
output_state = meta_prompt_graph(input_state, recursion_limit=25)
|
148 |
|
|
|
194 |
|
195 |
meta_prompt_graph = MetaPromptGraph(llms=llms)
|
196 |
input_state = AgentState(
|
197 |
+
max_output_age=2,
|
198 |
user_message="How do I reverse a list in Python?",
|
199 |
expected_output="Use the `[::-1]` slicing technique or the "
|
200 |
"`list.reverse()` method.",
|
|
|
244 |
input_state = AgentState(
|
245 |
user_message="How do I reverse a list in Python?",
|
246 |
expected_output="The output should use the `reverse()` method.",
|
247 |
+
acceptance_criteria="The output should be correct and efficient.",
|
248 |
+
max_output_age=2
|
249 |
)
|
250 |
|
251 |
output_state = meta_prompt_graph(input_state)
|
|
|
283 |
input_state = AgentState(
|
284 |
user_message="How do I reverse a list in Python?",
|
285 |
expected_output="The output should use the `reverse()` method.",
|
286 |
+
acceptance_criteria="The output should be correct and efficient.",
|
287 |
+
max_output_age=2
|
288 |
)
|
289 |
|
290 |
output_state = meta_prompt_graph(input_state)
|