Spaces:
Running
Running
Dict state works.
Browse files- meta_prompt/meta_prompt.py +93 -44
- tests/meta_prompt_graph_test.py +2 -1
meta_prompt/meta_prompt.py
CHANGED
@@ -8,7 +8,7 @@ from langgraph.errors import GraphRecursionError
|
|
8 |
from langgraph.graph import StateGraph, START, END
|
9 |
from langchain_core.runnables.base import RunnableLike
|
10 |
from pydantic import BaseModel
|
11 |
-
from typing import Annotated, Dict, Optional, Union
|
12 |
from .consts import *
|
13 |
|
14 |
def first_non_empty(a, b):
|
@@ -22,6 +22,17 @@ def last_non_empty(a, b):
|
|
22 |
def assign(a, b):
|
23 |
return b
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
class AgentState(BaseModel):
|
26 |
"""
|
27 |
Represents the state of an agent in a conversation.
|
@@ -40,18 +51,18 @@ class AgentState(BaseModel):
|
|
40 |
- best_system_message (str, optional): The best system message.
|
41 |
- best_output_age (int): The age of the best output.
|
42 |
"""
|
43 |
-
max_output_age:
|
44 |
-
user_message:
|
45 |
-
expected_output:
|
46 |
acceptance_criteria: Annotated[Optional[str], last_non_empty] = None
|
47 |
system_message: Annotated[Optional[str], last_non_empty] = None
|
48 |
-
output:
|
49 |
-
suggestions:
|
50 |
-
accepted:
|
51 |
-
analysis:
|
52 |
-
best_output:
|
53 |
-
best_system_message:
|
54 |
-
best_output_age:
|
55 |
|
56 |
class MetaPromptGraph:
|
57 |
"""
|
@@ -229,7 +240,7 @@ class MetaPromptGraph:
|
|
229 |
NODE_PROMPT_INITIAL_DEVELOPER,
|
230 |
"system_message",
|
231 |
x),
|
232 |
-
x))
|
233 |
workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
234 |
lambda x: self._optional_action(
|
235 |
"acceptance_criteria",
|
@@ -237,7 +248,7 @@ class MetaPromptGraph:
|
|
237 |
NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
238 |
"acceptance_criteria",
|
239 |
x),
|
240 |
-
x))
|
241 |
# workflow.add_node(START)
|
242 |
|
243 |
workflow.add_edge(START, NODE_PROMPT_INITIAL_DEVELOPER)
|
@@ -331,8 +342,8 @@ class MetaPromptGraph:
|
|
331 |
def _optional_action(
|
332 |
self, target_attribute: str,
|
333 |
action: RunnableLike,
|
334 |
-
state: AgentState
|
335 |
-
) -> AgentState:
|
336 |
"""
|
337 |
Optionally invokes an action if the target attribute is not set or empty.
|
338 |
|
@@ -345,13 +356,22 @@ class MetaPromptGraph:
|
|
345 |
Returns:
|
346 |
AgentState: Updated state.
|
347 |
"""
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
|
353 |
|
354 |
-
def _prompt_node(
|
|
|
|
|
|
|
355 |
"""
|
356 |
Prompt a specific node with the given state and update the state with the response.
|
357 |
|
@@ -368,26 +388,37 @@ class MetaPromptGraph:
|
|
368 |
"""
|
369 |
|
370 |
logger = self.logger.getChild(node)
|
371 |
-
formatted_messages =
|
372 |
-
|
|
|
|
|
|
|
|
|
373 |
for message in formatted_messages:
|
374 |
logger.debug({
|
375 |
-
'node': node,
|
376 |
'action': 'invoke',
|
377 |
-
'type': message.type,
|
378 |
'message': message.content
|
379 |
})
|
380 |
-
|
381 |
response = self.llms[node].invoke(formatted_messages)
|
382 |
logger.debug({
|
383 |
-
'node': node,
|
384 |
'action': 'response',
|
385 |
-
'type': response.type,
|
386 |
'message': response.content
|
387 |
})
|
388 |
-
|
389 |
-
|
390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
|
392 |
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
393 |
"""
|
@@ -439,20 +470,30 @@ class MetaPromptGraph:
|
|
439 |
"# Output ID closer to Expected Output: B" in analysis or
|
440 |
(self.aggressive_exploration and
|
441 |
"# Output ID closer to Expected Output: A" not in analysis)):
|
442 |
-
state.best_output = state.output
|
443 |
-
state.best_system_message = state.system_message
|
444 |
-
state.best_output_age = 0
|
|
|
|
|
|
|
|
|
|
|
445 |
logger.debug("Best output updated to the current output:\n%s",
|
446 |
-
|
447 |
else:
|
448 |
-
state.best_output_age += 1
|
449 |
-
# rollback output and system message
|
450 |
-
state.output = state.best_output
|
451 |
-
state.system_message = state.best_system_message
|
|
|
|
|
|
|
|
|
|
|
452 |
logger.debug("Best output age incremented to %s",
|
453 |
-
|
454 |
|
455 |
-
return
|
456 |
|
457 |
def _prompt_analyzer(self, state: AgentState) -> AgentState:
|
458 |
"""
|
@@ -479,12 +520,20 @@ class MetaPromptGraph:
|
|
479 |
logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
|
480 |
'type': response.type, 'message': response.content})
|
481 |
|
482 |
-
state.analysis = response.content
|
483 |
-
state.accepted = "Accept: Yes" in response.content
|
484 |
|
485 |
-
logger.debug("Accepted: %s", state.accepted)
|
486 |
|
487 |
-
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
|
489 |
def _should_exit_on_max_age(self, state: AgentState) -> str:
|
490 |
"""
|
|
|
8 |
from langgraph.graph import StateGraph, START, END
|
9 |
from langchain_core.runnables.base import RunnableLike
|
10 |
from pydantic import BaseModel
|
11 |
+
from typing import Annotated, Dict, Optional, Union, TypedDict
|
12 |
from .consts import *
|
13 |
|
14 |
def first_non_empty(a, b):
|
|
|
22 |
def assign(a, b):
|
23 |
return b
|
24 |
|
25 |
+
class InitialAgentState(TypedDict):
|
26 |
+
"""
|
27 |
+
Represents the state of an agent in a conversation.
|
28 |
+
"""
|
29 |
+
max_output_age: int
|
30 |
+
user_message: Optional[str]
|
31 |
+
expected_output: Optional[str]
|
32 |
+
acceptance_criteria: Annotated[Optional[str], last_non_empty]
|
33 |
+
system_message: Annotated[Optional[str], last_non_empty]
|
34 |
+
|
35 |
+
|
36 |
class AgentState(BaseModel):
|
37 |
"""
|
38 |
Represents the state of an agent in a conversation.
|
|
|
51 |
- best_system_message (str, optional): The best system message.
|
52 |
- best_output_age (int): The age of the best output.
|
53 |
"""
|
54 |
+
max_output_age: int = 0
|
55 |
+
user_message: Optional[str] = None
|
56 |
+
expected_output: Optional[str] = None
|
57 |
acceptance_criteria: Annotated[Optional[str], last_non_empty] = None
|
58 |
system_message: Annotated[Optional[str], last_non_empty] = None
|
59 |
+
output: Optional[str] = None
|
60 |
+
suggestions: Optional[str] = None
|
61 |
+
accepted: bool = False
|
62 |
+
analysis: Optional[str] = None
|
63 |
+
best_output: Optional[str] = None
|
64 |
+
best_system_message: Optional[str] = None
|
65 |
+
best_output_age: int = 0
|
66 |
|
67 |
class MetaPromptGraph:
|
68 |
"""
|
|
|
240 |
NODE_PROMPT_INITIAL_DEVELOPER,
|
241 |
"system_message",
|
242 |
x),
|
243 |
+
InitialAgentState(**(x.model_dump()))))
|
244 |
workflow.add_node(NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
245 |
lambda x: self._optional_action(
|
246 |
"acceptance_criteria",
|
|
|
248 |
NODE_ACCEPTANCE_CRITERIA_DEVELOPER,
|
249 |
"acceptance_criteria",
|
250 |
x),
|
251 |
+
InitialAgentState(**(x.model_dump()))))
|
252 |
# workflow.add_node(START)
|
253 |
|
254 |
workflow.add_edge(START, NODE_PROMPT_INITIAL_DEVELOPER)
|
|
|
342 |
def _optional_action(
|
343 |
self, target_attribute: str,
|
344 |
action: RunnableLike,
|
345 |
+
state: Union[AgentState, InitialAgentState]
|
346 |
+
) -> Union[AgentState, InitialAgentState]:
|
347 |
"""
|
348 |
Optionally invokes an action if the target attribute is not set or empty.
|
349 |
|
|
|
356 |
Returns:
|
357 |
AgentState: Updated state.
|
358 |
"""
|
359 |
+
result = {
|
360 |
+
target_attribute: state.get(target_attribute, "")
|
361 |
+
if isinstance(state, dict)
|
362 |
+
else getattr(state, target_attribute, "")
|
363 |
+
}
|
364 |
+
|
365 |
+
if action is not None and not result[target_attribute]:
|
366 |
+
result = action(state)
|
367 |
+
|
368 |
+
return result
|
369 |
|
370 |
|
371 |
+
def _prompt_node(
|
372 |
+
self, node: str, target_attribute: str,
|
373 |
+
state: Union[AgentState, InitialAgentState]
|
374 |
+
) -> Union[AgentState, InitialAgentState]:
|
375 |
"""
|
376 |
Prompt a specific node with the given state and update the state with the response.
|
377 |
|
|
|
388 |
"""
|
389 |
|
390 |
logger = self.logger.getChild(node)
|
391 |
+
formatted_messages = (
|
392 |
+
self.prompt_templates[node].format_messages(
|
393 |
+
**(state.model_dump() if isinstance(state, BaseModel) else state)
|
394 |
+
)
|
395 |
+
)
|
396 |
+
|
397 |
for message in formatted_messages:
|
398 |
logger.debug({
|
399 |
+
'node': node,
|
400 |
'action': 'invoke',
|
401 |
+
'type': message.type,
|
402 |
'message': message.content
|
403 |
})
|
404 |
+
|
405 |
response = self.llms[node].invoke(formatted_messages)
|
406 |
logger.debug({
|
407 |
+
'node': node,
|
408 |
'action': 'response',
|
409 |
+
'type': response.type,
|
410 |
'message': response.content
|
411 |
})
|
412 |
+
|
413 |
+
# if isinstance(state, dict):
|
414 |
+
# # state[target_attribute] = response.content
|
415 |
+
# # Create a dict with the target key only
|
416 |
+
# state = {target_attribute: response.content}
|
417 |
+
# else:
|
418 |
+
# setattr(state, target_attribute, response.content)
|
419 |
+
# return state
|
420 |
+
|
421 |
+
return {target_attribute: response.content}
|
422 |
|
423 |
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
424 |
"""
|
|
|
470 |
"# Output ID closer to Expected Output: B" in analysis or
|
471 |
(self.aggressive_exploration and
|
472 |
"# Output ID closer to Expected Output: A" not in analysis)):
|
473 |
+
# state.best_output = state.output
|
474 |
+
# state.best_system_message = state.system_message
|
475 |
+
# state.best_output_age = 0
|
476 |
+
result_dict = {
|
477 |
+
"best_output": state.output,
|
478 |
+
"best_system_message": state.system_message,
|
479 |
+
"best_output_age": 0
|
480 |
+
}
|
481 |
logger.debug("Best output updated to the current output:\n%s",
|
482 |
+
result_dict["best_output"])
|
483 |
else:
|
484 |
+
# state.best_output_age += 1
|
485 |
+
# # rollback output and system message
|
486 |
+
# state.output = state.best_output
|
487 |
+
# state.system_message = state.best_system_message
|
488 |
+
result_dict = {
|
489 |
+
"output": state.best_output,
|
490 |
+
"system_message": state.best_system_message,
|
491 |
+
"best_output_age": state.best_output_age + 1
|
492 |
+
}
|
493 |
logger.debug("Best output age incremented to %s",
|
494 |
+
result_dict["best_output_age"])
|
495 |
|
496 |
+
return result_dict
|
497 |
|
498 |
def _prompt_analyzer(self, state: AgentState) -> AgentState:
|
499 |
"""
|
|
|
520 |
logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
|
521 |
'type': response.type, 'message': response.content})
|
522 |
|
523 |
+
# state.analysis = response.content
|
524 |
+
# state.accepted = "Accept: Yes" in response.content
|
525 |
|
526 |
+
# logger.debug("Accepted: %s", state.accepted)
|
527 |
|
528 |
+
# return state
|
529 |
+
|
530 |
+
result_dict = {
|
531 |
+
"analysis": response.content,
|
532 |
+
"accepted": "Accept: Yes" in response.content
|
533 |
+
}
|
534 |
+
logger.debug("Accepted: %s", result_dict["accepted"])
|
535 |
+
|
536 |
+
return result_dict
|
537 |
|
538 |
def _should_exit_on_max_age(self, state: AgentState) -> str:
|
539 |
"""
|
tests/meta_prompt_graph_test.py
CHANGED
@@ -180,6 +180,7 @@ class TestMetaPromptGraph(unittest.TestCase):
|
|
180 |
|
181 |
llms = {
|
182 |
NODE_PROMPT_INITIAL_DEVELOPER: optimizer_llm,
|
|
|
183 |
NODE_PROMPT_DEVELOPER: optimizer_llm,
|
184 |
NODE_PROMPT_EXECUTOR: executor_llm,
|
185 |
NODE_OUTPUT_HISTORY_ANALYZER: optimizer_llm,
|
@@ -192,7 +193,7 @@ class TestMetaPromptGraph(unittest.TestCase):
|
|
192 |
user_message="How do I reverse a list in Python?",
|
193 |
expected_output="Use the `[::-1]` slicing technique or the "
|
194 |
"`list.reverse()` method.",
|
195 |
-
acceptance_criteria="Similar in meaning, text length and style."
|
196 |
)
|
197 |
output_state = meta_prompt_graph(input_state, recursion_limit=25)
|
198 |
|
|
|
180 |
|
181 |
llms = {
|
182 |
NODE_PROMPT_INITIAL_DEVELOPER: optimizer_llm,
|
183 |
+
NODE_ACCEPTANCE_CRITERIA_DEVELOPER: optimizer_llm,
|
184 |
NODE_PROMPT_DEVELOPER: optimizer_llm,
|
185 |
NODE_PROMPT_EXECUTOR: executor_llm,
|
186 |
NODE_OUTPUT_HISTORY_ANALYZER: optimizer_llm,
|
|
|
193 |
user_message="How do I reverse a list in Python?",
|
194 |
expected_output="Use the `[::-1]` slicing technique or the "
|
195 |
"`list.reverse()` method.",
|
196 |
+
# acceptance_criteria="Similar in meaning, text length and style."
|
197 |
)
|
198 |
output_state = meta_prompt_graph(input_state, recursion_limit=25)
|
199 |
|