Spaces:
Running
Running
Create a lib of meta_prompt_graph and the unit test script.
Browse files- meta_prompt_graph.py +427 -0
- meta_prompt_graph_test.py +110 -0
meta_prompt_graph.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
import pprint
|
3 |
+
import logging
|
4 |
+
from typing import Dict, Any, Callable, List, Union, Optional
|
5 |
+
from langchain_core.language_models import BaseLanguageModel
|
6 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
7 |
+
from langchain_core.prompts import ChatPromptTemplate
|
8 |
+
from langgraph.graph import StateGraph, END
|
9 |
+
from langgraph.checkpoint.memory import MemorySaver
|
10 |
+
from langgraph.errors import GraphRecursionError
|
11 |
+
from pydantic import BaseModel
|
12 |
+
|
13 |
+
class AgentState(BaseModel):
|
14 |
+
max_output_age: int = 0
|
15 |
+
user_message: Optional[str] = None
|
16 |
+
expected_output: Optional[str] = None
|
17 |
+
acceptance_criteria: Optional[str] = None
|
18 |
+
system_message: Optional[str] = None
|
19 |
+
output: Optional[str] = None
|
20 |
+
suggestions: Optional[str] = None
|
21 |
+
accepted: bool = False
|
22 |
+
analysis: Optional[str] = None
|
23 |
+
best_output: Optional[str] = None
|
24 |
+
best_system_message: Optional[str] = None
|
25 |
+
best_output_age: int = 0
|
26 |
+
|
27 |
+
class MetaPromptGraph:
|
28 |
+
NODE_PROMPT_INITIAL_DEVELOPER = "prompt_initial_developer"
|
29 |
+
NODE_PROMPT_DEVELOPER = "prompt_developer"
|
30 |
+
NODE_PROMPT_EXECUTOR = "prompt_executor"
|
31 |
+
NODE_OUTPUT_HISTORY_ANALYZER = "output_history_analyzer"
|
32 |
+
NODE_PROMPT_ANALYZER = "prompt_analyzer"
|
33 |
+
NODE_PROMPT_SUGGESTER = "prompt_suggester"
|
34 |
+
|
35 |
+
DEFAULT_PROMPT_TEMPLATES = {
|
36 |
+
NODE_PROMPT_INITIAL_DEVELOPER: ChatPromptTemplate.from_messages([
|
37 |
+
("system", """# Expert Prompt Engineer
|
38 |
+
|
39 |
+
You are an expert prompt engineer tasked with creating system messages for AI
|
40 |
+
assistants.
|
41 |
+
|
42 |
+
## Instructions
|
43 |
+
|
44 |
+
1. Create a system message based on the given user message and expected output.
|
45 |
+
2. Ensure the system message can handle similar user messages.
|
46 |
+
3. Output only the system message, without any additional content.
|
47 |
+
4. Expected Output text should not appear in System Message as an example. But
|
48 |
+
it's OK to use some similar text as an example instead.
|
49 |
+
5. Format the system message well, with no more than 80 characters per line
|
50 |
+
(except for raw text).
|
51 |
+
|
52 |
+
## Output
|
53 |
+
|
54 |
+
Provide only the system message, adhering to the above guidelines.
|
55 |
+
"""),
|
56 |
+
("human", """# User Message
|
57 |
+
|
58 |
+
{user_message}
|
59 |
+
|
60 |
+
# Expected Output
|
61 |
+
|
62 |
+
{expected_output}
|
63 |
+
""")
|
64 |
+
]),
|
65 |
+
NODE_PROMPT_DEVELOPER: ChatPromptTemplate.from_messages([
|
66 |
+
("system", """# Expert Prompt Engineer
|
67 |
+
|
68 |
+
You are an expert prompt engineer tasked with updating system messages for AI
|
69 |
+
assistants. You Update System Message according to Suggestions, to improve
|
70 |
+
Output and match Expected Output more closely.
|
71 |
+
|
72 |
+
## Instructions
|
73 |
+
|
74 |
+
1. Update the system message based on the given Suggestion, User Message, and
|
75 |
+
Expected Output.
|
76 |
+
2. Ensure the updated system message can handle similar user messages.
|
77 |
+
3. Modify only the content mentioned in the Suggestion. Do not change the
|
78 |
+
parts that are not related to the Suggestion.
|
79 |
+
4. Output only the updated system message, without any additional content.
|
80 |
+
5. Expected Output text should not appear in System Message as an example. But
|
81 |
+
it's OK to use some similar text as an example instead.
|
82 |
+
* Remove the Expected Output text or text highly similar to Expected Output
|
83 |
+
from System Message, if it's present.
|
84 |
+
6. Format the system message well, with no more than 80 characters per line
|
85 |
+
(except for raw text).
|
86 |
+
|
87 |
+
## Output
|
88 |
+
|
89 |
+
Provide only the updated System Message, adhering to the above guidelines.
|
90 |
+
"""),
|
91 |
+
("human", """# Current system message
|
92 |
+
|
93 |
+
{system_message}
|
94 |
+
|
95 |
+
# User Message
|
96 |
+
|
97 |
+
{user_message}
|
98 |
+
|
99 |
+
# Expected Output
|
100 |
+
|
101 |
+
{expected_output}
|
102 |
+
|
103 |
+
# Suggestions
|
104 |
+
|
105 |
+
{suggestions}
|
106 |
+
""")
|
107 |
+
]),
|
108 |
+
NODE_PROMPT_EXECUTOR: ChatPromptTemplate.from_messages([
|
109 |
+
("system", "{system_message}"),
|
110 |
+
("human", "{user_message}")
|
111 |
+
]),
|
112 |
+
NODE_OUTPUT_HISTORY_ANALYZER: ChatPromptTemplate.from_messages([
|
113 |
+
("system", """You are a text comparing program. You read the Acceptance Criteria, compare the
|
114 |
+
compare the exptected output with two different outputs, and decide which one is
|
115 |
+
more similar to the expected output.
|
116 |
+
|
117 |
+
You output the following analysis according to the Acceptance Criteria:
|
118 |
+
|
119 |
+
* Your analysis in a Markdown list.
|
120 |
+
* The ID of the output that is more similar to the Expected Output as Preferred
|
121 |
+
Output ID, with the following format:
|
122 |
+
|
123 |
+
```
|
124 |
+
# Analysis
|
125 |
+
|
126 |
+
...
|
127 |
+
|
128 |
+
# Preferred Output ID: [ID]
|
129 |
+
```
|
130 |
+
|
131 |
+
If both outputs are equally similar to the expected output, output the following:
|
132 |
+
|
133 |
+
```
|
134 |
+
# Analysis
|
135 |
+
|
136 |
+
...
|
137 |
+
|
138 |
+
# Draw
|
139 |
+
```
|
140 |
+
"""),
|
141 |
+
("human", """
|
142 |
+
# Output ID: A
|
143 |
+
|
144 |
+
```
|
145 |
+
{best_output}
|
146 |
+
```
|
147 |
+
|
148 |
+
# Output ID: B
|
149 |
+
|
150 |
+
```
|
151 |
+
{output}
|
152 |
+
```
|
153 |
+
|
154 |
+
# Acceptance Criteria
|
155 |
+
|
156 |
+
{acceptance_criteria}
|
157 |
+
|
158 |
+
# Expected Output
|
159 |
+
|
160 |
+
```
|
161 |
+
{expected_output}
|
162 |
+
```
|
163 |
+
""")
|
164 |
+
]),
|
165 |
+
NODE_PROMPT_ANALYZER: ChatPromptTemplate.from_messages([
|
166 |
+
("system", """
|
167 |
+
You are a text comparing program. You compare the following output texts and provide a
|
168 |
+
detailed analysis according to `Acceptance Criteria`. Then you decide whether `Actual Output`
|
169 |
+
is acceptable.
|
170 |
+
|
171 |
+
# Expected Output
|
172 |
+
|
173 |
+
```
|
174 |
+
{expected_output}
|
175 |
+
```
|
176 |
+
|
177 |
+
# Actual Output
|
178 |
+
|
179 |
+
```
|
180 |
+
{output}
|
181 |
+
```
|
182 |
+
|
183 |
+
----
|
184 |
+
|
185 |
+
Provide your analysis in the following format:
|
186 |
+
|
187 |
+
```
|
188 |
+
- Acceptable Differences: [List acceptable differences succinctly]
|
189 |
+
- Unacceptable Differences: [List unacceptable differences succinctly]
|
190 |
+
- Accept: [Yes/No]
|
191 |
+
```
|
192 |
+
|
193 |
+
* Compare Expected Output and Actual Output with the guidance of Accept Criteria.
|
194 |
+
* Only set 'Accept' to 'Yes', if Accept Criteria are all met. Otherwise, set 'Accept' to 'No'.
|
195 |
+
* List only the acceptable differences according to Accept Criteria in 'acceptable Differences' section.
|
196 |
+
* List only the unacceptable differences according to Accept Criteria in 'Unacceptable Differences' section.
|
197 |
+
|
198 |
+
# Acceptance Criteria
|
199 |
+
|
200 |
+
```
|
201 |
+
{acceptance_criteria}
|
202 |
+
```
|
203 |
+
""")
|
204 |
+
]),
|
205 |
+
NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
|
206 |
+
("system", """
|
207 |
+
Read the following inputs and outputs of an LLM prompt, and also analysis about them.
|
208 |
+
Then suggest how to improve System Prompt.
|
209 |
+
|
210 |
+
System Prompt:
|
211 |
+
```
|
212 |
+
{system_message}
|
213 |
+
```
|
214 |
+
User Message:
|
215 |
+
```
|
216 |
+
{user_message}
|
217 |
+
```
|
218 |
+
Expected Output:
|
219 |
+
```
|
220 |
+
{expected_output}
|
221 |
+
```
|
222 |
+
Actual Output:
|
223 |
+
```
|
224 |
+
{output}
|
225 |
+
```
|
226 |
+
|
227 |
+
Acceptance Criteria:
|
228 |
+
```
|
229 |
+
{acceptance_criteria}
|
230 |
+
```
|
231 |
+
|
232 |
+
Analysis:
|
233 |
+
```
|
234 |
+
{analysis}
|
235 |
+
```
|
236 |
+
|
237 |
+
* The goal is to improve the System Prompt to match the Expected Output better.
|
238 |
+
* Ignore all Acceptable Differences and focus on Unacceptable Differences.
|
239 |
+
* Suggest formal changes first, then semantic changes.
|
240 |
+
* Provide your suggestions in a Markdown list, nothing else. Output only the
|
241 |
+
suggestions related with Unacceptable Differences.
|
242 |
+
* Use `... should ...` to clearly state the desired output.
|
243 |
+
* Figue out the contexts of the System Message that conflict with the suggestions,
|
244 |
+
and suggest modification or deletion.
|
245 |
+
* Expected Output text should not appear in System Message as an example. But
|
246 |
+
it's OK to use some similar text as an example instead.
|
247 |
+
* Ask to remove the Expected Output text or text highly similar to Expected Output
|
248 |
+
from System Message, if it's present.
|
249 |
+
* Provide format examples or detected format name, if System Message does not.
|
250 |
+
* Specify the detected format name (e.g. XML, JSON, etc.) of Expected Output, if
|
251 |
+
System Message does not mention it.
|
252 |
+
""")
|
253 |
+
])
|
254 |
+
}
|
255 |
+
|
256 |
+
@classmethod
|
257 |
+
def get_node_names(cls):
|
258 |
+
return [
|
259 |
+
cls.NODE_PROMPT_INITIAL_DEVELOPER,
|
260 |
+
cls.NODE_PROMPT_DEVELOPER,
|
261 |
+
cls.NODE_PROMPT_EXECUTOR,
|
262 |
+
cls.NODE_OUTPUT_HISTORY_ANALYZER,
|
263 |
+
cls.NODE_PROMPT_ANALYZER,
|
264 |
+
cls.NODE_PROMPT_SUGGESTER
|
265 |
+
]
|
266 |
+
|
267 |
+
def __init__(self,
|
268 |
+
llms: Dict[str, BaseLanguageModel] = {},
|
269 |
+
prompts: Dict[str, ChatPromptTemplate] = {},
|
270 |
+
verbose = False):
|
271 |
+
self.logger = logging.getLogger(__name__)
|
272 |
+
if verbose:
|
273 |
+
self.logger.setLevel(logging.DEBUG)
|
274 |
+
else:
|
275 |
+
self.logger.setLevel(logging.INFO)
|
276 |
+
|
277 |
+
self.llms: Dict[str, BaseLanguageModel] = llms
|
278 |
+
self.prompt_templates: Dict[str, ChatPromptTemplate] = self.DEFAULT_PROMPT_TEMPLATES.copy()
|
279 |
+
self.prompt_templates.update(prompts)
|
280 |
+
|
281 |
+
# create workflow
|
282 |
+
self.workflow = StateGraph(AgentState)
|
283 |
+
|
284 |
+
self.workflow.add_node(self.NODE_PROMPT_INITIAL_DEVELOPER,
|
285 |
+
lambda x: self._prompt_node(
|
286 |
+
self.NODE_PROMPT_INITIAL_DEVELOPER,
|
287 |
+
"system_message",
|
288 |
+
x))
|
289 |
+
self.workflow.add_node(self.NODE_PROMPT_DEVELOPER,
|
290 |
+
lambda x: self._prompt_node(
|
291 |
+
self.NODE_PROMPT_DEVELOPER,
|
292 |
+
"system_message",
|
293 |
+
x))
|
294 |
+
self.workflow.add_node(self.NODE_PROMPT_EXECUTOR,
|
295 |
+
lambda x: self._prompt_node(
|
296 |
+
self.NODE_PROMPT_EXECUTOR,
|
297 |
+
"output",
|
298 |
+
x))
|
299 |
+
self.workflow.add_node(self.NODE_OUTPUT_HISTORY_ANALYZER,
|
300 |
+
lambda x: self._output_history_analyzer(x))
|
301 |
+
self.workflow.add_node(self.NODE_PROMPT_ANALYZER,
|
302 |
+
lambda x: self._prompt_analyzer(x))
|
303 |
+
self.workflow.add_node(self.NODE_PROMPT_SUGGESTER,
|
304 |
+
lambda x: self._prompt_node(
|
305 |
+
self.NODE_PROMPT_SUGGESTER,
|
306 |
+
"suggestions",
|
307 |
+
x))
|
308 |
+
|
309 |
+
self.workflow.set_entry_point(self.NODE_PROMPT_INITIAL_DEVELOPER)
|
310 |
+
|
311 |
+
self.workflow.add_edge(self.NODE_PROMPT_INITIAL_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
|
312 |
+
self.workflow.add_edge(self.NODE_PROMPT_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
|
313 |
+
self.workflow.add_edge(self.NODE_PROMPT_EXECUTOR, self.NODE_OUTPUT_HISTORY_ANALYZER)
|
314 |
+
self.workflow.add_edge(self.NODE_PROMPT_SUGGESTER, self.NODE_PROMPT_DEVELOPER)
|
315 |
+
|
316 |
+
self.workflow.add_conditional_edges(
|
317 |
+
self.NODE_OUTPUT_HISTORY_ANALYZER,
|
318 |
+
lambda x: self._should_exit_on_max_age(x),
|
319 |
+
{
|
320 |
+
"continue": self.NODE_PROMPT_ANALYZER,
|
321 |
+
"rerun": self.NODE_PROMPT_SUGGESTER,
|
322 |
+
END: END
|
323 |
+
}
|
324 |
+
)
|
325 |
+
|
326 |
+
self.workflow.add_conditional_edges(
|
327 |
+
self.NODE_PROMPT_ANALYZER,
|
328 |
+
lambda x: self._should_exit_on_acceptable_output(x),
|
329 |
+
{
|
330 |
+
"continue": self.NODE_PROMPT_SUGGESTER,
|
331 |
+
END: END
|
332 |
+
}
|
333 |
+
)
|
334 |
+
|
335 |
+
def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
|
336 |
+
memory = MemorySaver()
|
337 |
+
graph = self.workflow.compile(checkpointer=memory)
|
338 |
+
config = {"configurable": {"thread_id": "1"}, "recursion_limit": recursion_limit}
|
339 |
+
|
340 |
+
try:
|
341 |
+
self.logger.debug("Invoking graph with state: %s", pprint.pformat(state))
|
342 |
+
|
343 |
+
output_state = graph.invoke(state, config)
|
344 |
+
|
345 |
+
self.logger.debug("Output state: %s", pprint.pformat(output_state))
|
346 |
+
|
347 |
+
return output_state
|
348 |
+
except GraphRecursionError as e:
|
349 |
+
self.logger.info("Recursion limit reached. Returning the best state found so far.")
|
350 |
+
checkpoint_states = graph.get_state(config)
|
351 |
+
|
352 |
+
# if the length of states is bigger than 0, print the best system message and output
|
353 |
+
if len(checkpoint_states) > 0:
|
354 |
+
output_state = checkpoint_states[0]
|
355 |
+
return output_state
|
356 |
+
else:
|
357 |
+
self.logger.info("No checkpoint states found. Returning the input state.")
|
358 |
+
|
359 |
+
return state
|
360 |
+
|
361 |
+
def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
|
362 |
+
prompt = self.prompt_templates[node].format_messages(**state.model_dump())
|
363 |
+
|
364 |
+
self.logger.debug("Invoking %s with prompt: %s", node, pprint.pformat(prompt))
|
365 |
+
response = self.llms[node].invoke(self.prompt_templates[node].format_messages(**state.model_dump()))
|
366 |
+
self.logger.debug("Response: %s", pprint.pformat(response.content))
|
367 |
+
|
368 |
+
setattr(state, target_attribute, response.content)
|
369 |
+
return state
|
370 |
+
|
371 |
+
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
372 |
+
if state.best_output is None:
|
373 |
+
state.best_output = state.output
|
374 |
+
state.best_system_message = state.system_message
|
375 |
+
state.best_output_age = 0
|
376 |
+
|
377 |
+
self.logger.debug("Best output initialized to the current output: \n %s", state.output)
|
378 |
+
|
379 |
+
return state
|
380 |
+
|
381 |
+
prompt = self.prompt_templates[self.NODE_OUTPUT_HISTORY_ANALYZER].format_messages(**state.model_dump())
|
382 |
+
|
383 |
+
self.logger.debug("Invoking %s with prompt: %s",
|
384 |
+
self.NODE_OUTPUT_HISTORY_ANALYZER,
|
385 |
+
pprint.pformat(prompt))
|
386 |
+
response = self.llms[self.NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
|
387 |
+
self.logger.debug("Response: %s", pprint.pformat(response.content))
|
388 |
+
|
389 |
+
analysis = response.content
|
390 |
+
|
391 |
+
if state.best_output is None or "# Preferred Output ID: B" in analysis:
|
392 |
+
state.best_output = state.output
|
393 |
+
state.best_system_message = state.system_message
|
394 |
+
state.best_output_age = 0
|
395 |
+
|
396 |
+
self.logger.debug("Best output updated to the current output: \n %s", state.output)
|
397 |
+
else:
|
398 |
+
state.best_output_age += 1
|
399 |
+
|
400 |
+
self.logger.debug("Best output age incremented to %s", state.best_output_age)
|
401 |
+
|
402 |
+
return state
|
403 |
+
|
404 |
+
def _prompt_analyzer(self, state: AgentState) -> AgentState:
|
405 |
+
prompt = self.prompt_templates[self.NODE_PROMPT_ANALYZER].format_messages(**state.model_dump())
|
406 |
+
|
407 |
+
self.logger.debug("Invoking %s with prompt: %s",
|
408 |
+
self.NODE_PROMPT_ANALYZER,
|
409 |
+
pprint.pformat(prompt))
|
410 |
+
response = self.llms[self.NODE_PROMPT_ANALYZER].invoke(prompt)
|
411 |
+
self.logger.debug("Response: %s", pprint.pformat(response.content))
|
412 |
+
|
413 |
+
state.analysis = response.content
|
414 |
+
state.accepted = "Accept: Yes" in response.content
|
415 |
+
|
416 |
+
self.logger.debug("Accepted: %s", state.accepted)
|
417 |
+
|
418 |
+
return state
|
419 |
+
|
420 |
+
def _should_exit_on_max_age(self, state: AgentState) -> str:
|
421 |
+
if state.max_output_age <= 0 or state.best_output_age < state.max_output_age:
|
422 |
+
return "continue"
|
423 |
+
else:
|
424 |
+
return "rerun"
|
425 |
+
|
426 |
+
def _should_exit_on_acceptable_output(self, state: AgentState) -> str:
|
427 |
+
return "continue" if not state.accepted else END
|
meta_prompt_graph_test.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import pprint
|
3 |
+
import logging
|
4 |
+
from unittest.mock import MagicMock
|
5 |
+
from unittest.mock import patch
|
6 |
+
|
7 |
+
# Assuming the necessary imports are made for the classes and functions used in meta_prompt_graph.py
|
8 |
+
from meta_prompt_graph import MetaPromptGraph, AgentState
|
9 |
+
|
10 |
+
from langchain_openai import ChatOpenAI
|
11 |
+
|
12 |
+
class TestMetaPromptGraph(unittest.TestCase):
|
13 |
+
def setUp(self):
|
14 |
+
# Mocking the BaseLanguageModel and ChatPromptTemplate for testing
|
15 |
+
logging.basicConfig(level=logging.DEBUG)
|
16 |
+
|
17 |
+
def test_prompt_node(self):
|
18 |
+
llms = {
|
19 |
+
MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER: MagicMock(
|
20 |
+
invoke=MagicMock(return_value=MagicMock(content="Mocked response content"))
|
21 |
+
)
|
22 |
+
}
|
23 |
+
|
24 |
+
# Create an instance of MetaPromptGraph with the mocked language model and template
|
25 |
+
graph = MetaPromptGraph(llms=llms)
|
26 |
+
|
27 |
+
# Create a mock AgentState
|
28 |
+
state = AgentState(user_message="Test message", expected_output="Expected output")
|
29 |
+
|
30 |
+
# Invoke the _prompt_node method with the mock node, target attribute, and state
|
31 |
+
updated_state = graph._prompt_node(
|
32 |
+
MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER, "output", state
|
33 |
+
)
|
34 |
+
|
35 |
+
# Assertions
|
36 |
+
assert updated_state.output == "Mocked response content", \
|
37 |
+
"The output attribute should be updated with the mocked response content"
|
38 |
+
|
39 |
+
def test_output_history_analyzer(self):
|
40 |
+
# Setup
|
41 |
+
llms = {
|
42 |
+
"output_history_analyzer": MagicMock(invoke=lambda prompt: MagicMock(content="""# Analysis
|
43 |
+
|
44 |
+
This analysis compares two outputs to the expected output based on specific criteria.
|
45 |
+
|
46 |
+
# Preferred Output ID: B"""))
|
47 |
+
}
|
48 |
+
prompts = {}
|
49 |
+
meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompts)
|
50 |
+
state = AgentState(
|
51 |
+
user_message="How do I reverse a list in Python?",
|
52 |
+
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
|
53 |
+
output="To reverse a list in Python, you can use the `[::-1]` slicing.",
|
54 |
+
system_message="To reverse a list, use slicing or the reverse method.",
|
55 |
+
best_output="To reverse a list in Python, use the `reverse()` method.",
|
56 |
+
best_system_message="To reverse a list, use the `reverse()` method.",
|
57 |
+
acceptance_criteria="The output should correctly describe how to reverse a list in Python."
|
58 |
+
)
|
59 |
+
|
60 |
+
# Invoke the output history analyzer node
|
61 |
+
updated_state = meta_prompt_graph._output_history_analyzer(state)
|
62 |
+
|
63 |
+
# Assertions
|
64 |
+
assert updated_state.best_output == state.output, \
|
65 |
+
"Best output should be updated to the current output."
|
66 |
+
assert updated_state.best_system_message == state.system_message, \
|
67 |
+
"Best system message should be updated to the current system message."
|
68 |
+
assert updated_state.best_output_age == 0, \
|
69 |
+
"Best output age should be reset to 0."
|
70 |
+
|
71 |
+
def test_prompt_analyzer_accept(self):
|
72 |
+
llms = {
|
73 |
+
MetaPromptGraph.NODE_PROMPT_ANALYZER: MagicMock(
|
74 |
+
invoke=lambda prompt: MagicMock(content="Accept: Yes"))
|
75 |
+
}
|
76 |
+
meta_prompt_graph = MetaPromptGraph(llms)
|
77 |
+
state = AgentState(output="Test output", expected_output="Expected output")
|
78 |
+
updated_state = meta_prompt_graph._prompt_analyzer(state)
|
79 |
+
assert updated_state.accepted == True
|
80 |
+
|
81 |
+
def test_workflow_execution(self):
|
82 |
+
# MODEL_NAME = "google/gemma-2-9b-it"
|
83 |
+
MODEL_NAME = "anthropic/claude-3.5-sonnet:haiku"
|
84 |
+
llm = ChatOpenAI(model_name=MODEL_NAME)
|
85 |
+
|
86 |
+
node_names = MetaPromptGraph.get_node_names()
|
87 |
+
llms = {
|
88 |
+
}
|
89 |
+
for node_name in node_names:
|
90 |
+
llms[node_name] = llm
|
91 |
+
|
92 |
+
meta_prompt_graph = MetaPromptGraph(llms=llms, verbose=True)
|
93 |
+
input_state = AgentState(
|
94 |
+
user_message="How do I reverse a list in Python?",
|
95 |
+
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
|
96 |
+
acceptance_criteria="Similar in meaning, text length and style."
|
97 |
+
)
|
98 |
+
output_state = meta_prompt_graph(input_state, recursion_limit=25)
|
99 |
+
|
100 |
+
pprint.pp(output_state)
|
101 |
+
# if output_state has key 'best_system_message', print it
|
102 |
+
assert 'best_system_message' in output_state, \
|
103 |
+
"The output state should contain the key 'best_system_message'"
|
104 |
+
assert output_state['best_system_message'] is not None, \
|
105 |
+
"The best system message should not be None"
|
106 |
+
if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
|
107 |
+
print(output_state['best_system_message'])
|
108 |
+
|
109 |
+
if __name__ == '__main__':
|
110 |
+
unittest.main()
|