Spaces:
Running
Running
Gradio demo works with Confz now.
Browse files- .gitignore +1 -0
- config.yml +37 -0
- demo/config.py +16 -0
- demo/examples/log.csv +4 -2
- demo/gradio_meta_prompt.py +65 -20
- requirements.txt +16 -5
- src/meta_prompt/meta_prompt.py +67 -55
.gitignore
CHANGED
@@ -2,3 +2,4 @@
|
|
2 |
.vscode
|
3 |
__pycache__
|
4 |
.env
|
|
|
|
2 |
.vscode
|
3 |
__pycache__
|
4 |
.env
|
5 |
+
config.yml.debug
|
config.yml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
llms:
|
2 |
+
anthropic/claude-3-haiku:
|
3 |
+
type: ChatOpenAI
|
4 |
+
temperature: 0.1
|
5 |
+
model_name: "anthropic/claude-3-haiku:beta"
|
6 |
+
openai_api_key: ""
|
7 |
+
openai_api_base: "https://openrouter.ai/api/v1"
|
8 |
+
max_tokens: 8192
|
9 |
+
verbose: true
|
10 |
+
anthropic/claude-3-sonnet:
|
11 |
+
type: ChatOpenAI
|
12 |
+
temperature: 0.1
|
13 |
+
model_name: "anthropic/claude-3-sonnet:beta"
|
14 |
+
openai_api_key: ""
|
15 |
+
openai_api_base: "https://openrouter.ai/api/v1"
|
16 |
+
max_tokens: 8192
|
17 |
+
verbose: true
|
18 |
+
anthropic/deepseek-chat:
|
19 |
+
type: ChatOpenAI
|
20 |
+
temperature: 0.1
|
21 |
+
model_name: "deepseek/deepseek-chat"
|
22 |
+
openai_api_key: ""
|
23 |
+
openai_api_base: "https://openrouter.ai/api/v1"
|
24 |
+
max_tokens: 8192
|
25 |
+
verbose: true
|
26 |
+
groq/llama3-70b-8192:
|
27 |
+
type: ChatOpenAI
|
28 |
+
temperature: 0.1
|
29 |
+
model_name: "llama3-70b-8192"
|
30 |
+
openai_api_key: ""
|
31 |
+
openai_api_base: "https://api.groq.com/openai/v1"
|
32 |
+
max_tokens: 8192
|
33 |
+
verbose: true
|
34 |
+
|
35 |
+
examples_path: "demo/examples"
|
36 |
+
server_name: 0.0.0.0
|
37 |
+
server_port: 7870
|
demo/config.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# config.py
|
2 |
+
from confz import BaseConfig
|
3 |
+
from pydantic import BaseModel, Extra
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
class LLMConfig(BaseModel):
|
7 |
+
type: str
|
8 |
+
|
9 |
+
class Config:
|
10 |
+
extra = Extra.allow
|
11 |
+
|
12 |
+
class MetaPromptConfig(BaseConfig):
|
13 |
+
llms: Optional[dict[str, LLMConfig]]
|
14 |
+
examples_path: Optional[str]
|
15 |
+
server_name: Optional[str] = '127.0.0.1'
|
16 |
+
server_port: Optional[int] = 7878
|
demo/examples/log.csv
CHANGED
@@ -188,10 +188,12 @@ What is the meaning of life?,"[
|
|
188 |
* Data types and formats of all JSON fields
|
189 |
* Top layer sections
|
190 |
* Acceptable differences:
|
191 |
-
*
|
|
|
192 |
* Extra or missing spaces
|
193 |
* Extra or missing line breaks at the beginning or end of the output
|
194 |
-
|
|
|
195 |
"<?php
|
196 |
$username = $_POST['username'];
|
197 |
$password = $_POST['password'];
|
|
|
188 |
* Data types and formats of all JSON fields
|
189 |
* Top layer sections
|
190 |
* Acceptable differences:
|
191 |
+
* Different personas or prompts
|
192 |
+
* Different numbers of personas
|
193 |
* Extra or missing spaces
|
194 |
* Extra or missing line breaks at the beginning or end of the output
|
195 |
+
* Unacceptable:
|
196 |
+
* Showing the personas in Expected Output in System Message"
|
197 |
"<?php
|
198 |
$username = $_POST['username'];
|
199 |
$password = $_POST['password'];
|
demo/gradio_meta_prompt.py
CHANGED
@@ -1,31 +1,40 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
from meta_prompt import MetaPromptGraph, AgentState
|
3 |
from langchain_openai import ChatOpenAI
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
def process_message(user_message, expected_output, acceptance_criteria,
|
|
|
16 |
# Create the input state
|
17 |
input_state = AgentState(
|
18 |
user_message=user_message,
|
19 |
expected_output=expected_output,
|
20 |
-
acceptance_criteria=acceptance_criteria
|
|
|
21 |
)
|
22 |
|
23 |
# Get the output state from MetaPromptGraph
|
|
|
|
|
|
|
|
|
24 |
output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
|
25 |
|
26 |
# Validate the output state
|
27 |
system_message = ''
|
28 |
output = ''
|
|
|
29 |
|
30 |
if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
|
31 |
system_message = output_state['best_system_message']
|
@@ -37,22 +46,58 @@ def process_message(user_message, expected_output, acceptance_criteria, recursio
|
|
37 |
else:
|
38 |
output = "Error: The output state does not contain a valid 'best_output'"
|
39 |
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
# Create the Gradio interface
|
43 |
iface = gr.Interface(
|
44 |
fn=process_message,
|
45 |
inputs=[
|
46 |
-
gr.Textbox(label="User Message"),
|
47 |
-
gr.Textbox(label="Expected Output"),
|
48 |
-
gr.Textbox(label="Acceptance Criteria"),
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
],
|
51 |
-
|
52 |
title="MetaPromptGraph Chat Interface",
|
53 |
description="A chat interface for MetaPromptGraph to process user inputs and generate system messages.",
|
54 |
-
examples=
|
55 |
)
|
56 |
|
57 |
# Launch the Gradio app
|
58 |
-
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from confz import BaseConfig, CLArgSource, EnvSource, FileSource
|
3 |
from meta_prompt import MetaPromptGraph, AgentState
|
4 |
from langchain_openai import ChatOpenAI
|
5 |
+
from config import MetaPromptConfig
|
6 |
|
7 |
+
class LLMModelFactory:
|
8 |
+
def __init__(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def create(self, model_type: str, **kwargs):
|
12 |
+
model_class = globals()[model_type]
|
13 |
+
return model_class(**kwargs)
|
14 |
+
|
15 |
+
llm_model_factory = LLMModelFactory()
|
16 |
+
|
17 |
+
def process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
18 |
+
recursion_limit: int, model_name: str):
|
19 |
# Create the input state
|
20 |
input_state = AgentState(
|
21 |
user_message=user_message,
|
22 |
expected_output=expected_output,
|
23 |
+
acceptance_criteria=acceptance_criteria,
|
24 |
+
system_message=initial_system_message
|
25 |
)
|
26 |
|
27 |
# Get the output state from MetaPromptGraph
|
28 |
+
type = config.llms[model_name].type
|
29 |
+
args = config.llms[model_name].model_dump(exclude={'type'})
|
30 |
+
llm = llm_model_factory.create(type, **args)
|
31 |
+
meta_prompt_graph = MetaPromptGraph(llms=llm)
|
32 |
output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
|
33 |
|
34 |
# Validate the output state
|
35 |
system_message = ''
|
36 |
output = ''
|
37 |
+
analysis = ''
|
38 |
|
39 |
if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
|
40 |
system_message = output_state['best_system_message']
|
|
|
46 |
else:
|
47 |
output = "Error: The output state does not contain a valid 'best_output'"
|
48 |
|
49 |
+
if 'analysis' in output_state and output_state['analysis'] is not None:
|
50 |
+
analysis = output_state['analysis']
|
51 |
+
else:
|
52 |
+
analysis = "Error: The output state does not contain a valid 'analysis'"
|
53 |
+
|
54 |
+
return system_message, output, analysis
|
55 |
+
|
56 |
+
class FileConfig(BaseConfig):
|
57 |
+
config_file: str = 'config.yml' # default path
|
58 |
+
|
59 |
+
pre_config_sources = [
|
60 |
+
EnvSource(prefix='METAPROMPT_', allow_all=True),
|
61 |
+
CLArgSource()
|
62 |
+
]
|
63 |
+
pre_config = FileConfig(config_sources=pre_config_sources)
|
64 |
+
|
65 |
+
config_sources = [
|
66 |
+
FileSource(file=pre_config.config_file, optional=True),
|
67 |
+
EnvSource(prefix='METAPROMPT_', allow_all=True),
|
68 |
+
CLArgSource()
|
69 |
+
]
|
70 |
+
|
71 |
+
config = MetaPromptConfig(config_sources=config_sources)
|
72 |
|
73 |
# Create the Gradio interface
|
74 |
iface = gr.Interface(
|
75 |
fn=process_message,
|
76 |
inputs=[
|
77 |
+
gr.Textbox(label="User Message", show_copy_button=True),
|
78 |
+
gr.Textbox(label="Expected Output", show_copy_button=True),
|
79 |
+
gr.Textbox(label="Acceptance Criteria", show_copy_button=True),
|
80 |
+
],
|
81 |
+
outputs=[
|
82 |
+
gr.Textbox(label="System Message", show_copy_button=True),
|
83 |
+
gr.Textbox(label="Output", show_copy_button=True),
|
84 |
+
gr.Textbox(label="Analysis", show_copy_button=True)
|
85 |
+
],
|
86 |
+
additional_inputs=[
|
87 |
+
gr.Textbox(label="Initial System Message", show_copy_button=True, value=""),
|
88 |
+
gr.Number(label="Recursion Limit", value=25,
|
89 |
+
precision=0, minimum=1, maximum=100, step=1),
|
90 |
+
gr.Dropdown(
|
91 |
+
label="Model Name",
|
92 |
+
choices=config.llms.keys(),
|
93 |
+
value=list(config.llms.keys())[0],
|
94 |
+
)
|
95 |
],
|
96 |
+
# stop_btn = gr.Button("Stop", variant="stop", visible=True),
|
97 |
title="MetaPromptGraph Chat Interface",
|
98 |
description="A chat interface for MetaPromptGraph to process user inputs and generate system messages.",
|
99 |
+
examples=config.examples_path
|
100 |
)
|
101 |
|
102 |
# Launch the Gradio app
|
103 |
+
iface.launch(server_name=config.server_name, server_port=config.server_port)
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
aiofiles==23.2.1
|
2 |
-
aiohttp==3.
|
3 |
aiosignal==1.3.1
|
4 |
altair==5.1.1
|
5 |
annotated-types==0.5.0
|
@@ -11,6 +11,7 @@ certifi==2023.7.22
|
|
11 |
charset-normalizer==3.2.0
|
12 |
click==8.1.7
|
13 |
comm==0.2.2
|
|
|
14 |
contourpy==1.1.1
|
15 |
cycler==0.11.0
|
16 |
dataclasses-json==0.6.0
|
@@ -25,16 +26,17 @@ filelock==3.12.4
|
|
25 |
fonttools==4.42.1
|
26 |
frozenlist==1.4.0
|
27 |
fsspec==2023.9.2
|
28 |
-
gradio==
|
29 |
-
gradio_client==0.
|
30 |
greenlet==2.0.2
|
31 |
h11==0.14.0
|
32 |
httpcore==0.18.0
|
33 |
httpx==0.25.0
|
34 |
-
huggingface-hub==0.
|
35 |
idna==3.4
|
36 |
importlib-resources==6.1.0
|
37 |
ipykernel==6.29.4
|
|
|
38 |
jedi==0.19.1
|
39 |
Jinja2==3.1.2
|
40 |
joblib==1.3.2
|
@@ -51,10 +53,12 @@ langchain-openai==0.1.13
|
|
51 |
langchain-text-splitters==0.2.2
|
52 |
langgraph==0.1.4
|
53 |
langsmith==0.1.82
|
|
|
54 |
MarkupSafe==2.1.3
|
55 |
marshmallow==3.20.1
|
56 |
matplotlib==3.8.0
|
57 |
matplotlib-inline==0.1.7
|
|
|
58 |
multidict==6.0.4
|
59 |
mypy-extensions==1.0.0
|
60 |
nest-asyncio==1.6.0
|
@@ -78,17 +82,21 @@ pydub==0.25.1
|
|
78 |
Pygments==2.18.0
|
79 |
pyparsing==3.1.1
|
80 |
python-dateutil==2.9.0.post0
|
81 |
-
python-
|
|
|
82 |
pytz==2023.3.post1
|
83 |
PyYAML==6.0.1
|
84 |
pyzmq==26.0.3
|
85 |
referencing==0.30.2
|
86 |
regex==2024.5.15
|
87 |
requests==2.31.0
|
|
|
88 |
rpds-py==0.10.3
|
|
|
89 |
scikit-learn==1.3.1
|
90 |
scipy==1.11.3
|
91 |
semantic-version==2.10.0
|
|
|
92 |
six==1.16.0
|
93 |
sniffio==1.3.0
|
94 |
SQLAlchemy==2.0.21
|
@@ -97,10 +105,13 @@ starlette==0.27.0
|
|
97 |
tenacity==8.2.3
|
98 |
threadpoolctl==3.2.0
|
99 |
tiktoken==0.7.0
|
|
|
|
|
100 |
toolz==0.12.0
|
101 |
tornado==6.4.1
|
102 |
tqdm==4.66.1
|
103 |
traitlets==5.14.3
|
|
|
104 |
typing-inspect==0.9.0
|
105 |
typing_extensions==4.12.2
|
106 |
tzdata==2023.3
|
|
|
1 |
aiofiles==23.2.1
|
2 |
+
aiohttp==3.9.5
|
3 |
aiosignal==1.3.1
|
4 |
altair==5.1.1
|
5 |
annotated-types==0.5.0
|
|
|
11 |
charset-normalizer==3.2.0
|
12 |
click==8.1.7
|
13 |
comm==0.2.2
|
14 |
+
confz==2.0.1
|
15 |
contourpy==1.1.1
|
16 |
cycler==0.11.0
|
17 |
dataclasses-json==0.6.0
|
|
|
26 |
fonttools==4.42.1
|
27 |
frozenlist==1.4.0
|
28 |
fsspec==2023.9.2
|
29 |
+
gradio==4.37.2
|
30 |
+
gradio_client==1.0.2
|
31 |
greenlet==2.0.2
|
32 |
h11==0.14.0
|
33 |
httpcore==0.18.0
|
34 |
httpx==0.25.0
|
35 |
+
huggingface-hub==0.23.4
|
36 |
idna==3.4
|
37 |
importlib-resources==6.1.0
|
38 |
ipykernel==6.29.4
|
39 |
+
ipython==8.26.0
|
40 |
jedi==0.19.1
|
41 |
Jinja2==3.1.2
|
42 |
joblib==1.3.2
|
|
|
53 |
langchain-text-splitters==0.2.2
|
54 |
langgraph==0.1.4
|
55 |
langsmith==0.1.82
|
56 |
+
markdown-it-py==3.0.0
|
57 |
MarkupSafe==2.1.3
|
58 |
marshmallow==3.20.1
|
59 |
matplotlib==3.8.0
|
60 |
matplotlib-inline==0.1.7
|
61 |
+
mdurl==0.1.2
|
62 |
multidict==6.0.4
|
63 |
mypy-extensions==1.0.0
|
64 |
nest-asyncio==1.6.0
|
|
|
82 |
Pygments==2.18.0
|
83 |
pyparsing==3.1.1
|
84 |
python-dateutil==2.9.0.post0
|
85 |
+
python-dotenv==1.0.1
|
86 |
+
python-multipart==0.0.9
|
87 |
pytz==2023.3.post1
|
88 |
PyYAML==6.0.1
|
89 |
pyzmq==26.0.3
|
90 |
referencing==0.30.2
|
91 |
regex==2024.5.15
|
92 |
requests==2.31.0
|
93 |
+
rich==13.7.1
|
94 |
rpds-py==0.10.3
|
95 |
+
ruff==0.5.0
|
96 |
scikit-learn==1.3.1
|
97 |
scipy==1.11.3
|
98 |
semantic-version==2.10.0
|
99 |
+
shellingham==1.5.4
|
100 |
six==1.16.0
|
101 |
sniffio==1.3.0
|
102 |
SQLAlchemy==2.0.21
|
|
|
105 |
tenacity==8.2.3
|
106 |
threadpoolctl==3.2.0
|
107 |
tiktoken==0.7.0
|
108 |
+
toml==0.10.2
|
109 |
+
tomlkit==0.12.0
|
110 |
toolz==0.12.0
|
111 |
tornado==6.4.1
|
112 |
tqdm==4.66.1
|
113 |
traitlets==5.14.3
|
114 |
+
typer==0.12.3
|
115 |
typing-inspect==0.9.0
|
116 |
typing_extensions==4.12.2
|
117 |
tzdata==2023.3
|
src/meta_prompt/meta_prompt.py
CHANGED
@@ -164,23 +164,9 @@ If both outputs are equally similar to the expected output, output the following
|
|
164 |
]),
|
165 |
NODE_PROMPT_ANALYZER: ChatPromptTemplate.from_messages([
|
166 |
("system", """
|
167 |
-
You are a text comparing program. You compare the following output texts
|
168 |
-
|
169 |
-
is acceptable.
|
170 |
-
|
171 |
-
# Expected Output
|
172 |
-
|
173 |
-
```
|
174 |
-
{expected_output}
|
175 |
-
```
|
176 |
-
|
177 |
-
# Actual Output
|
178 |
-
|
179 |
-
```
|
180 |
-
{output}
|
181 |
-
```
|
182 |
-
|
183 |
-
----
|
184 |
|
185 |
Provide your analysis in the following format:
|
186 |
|
@@ -200,6 +186,25 @@ Provide your analysis in the following format:
|
|
200 |
```
|
201 |
{acceptance_criteria}
|
202 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
""")
|
204 |
]),
|
205 |
NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
|
@@ -281,42 +286,34 @@ Analysis:
|
|
281 |
self.prompt_templates: Dict[str, ChatPromptTemplate] = self.DEFAULT_PROMPT_TEMPLATES.copy()
|
282 |
self.prompt_templates.update(prompts)
|
283 |
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
self.workflow.set_entry_point(self.NODE_PROMPT_INITIAL_DEVELOPER)
|
313 |
-
|
314 |
-
self.workflow.add_edge(self.NODE_PROMPT_INITIAL_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
|
315 |
-
self.workflow.add_edge(self.NODE_PROMPT_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
|
316 |
-
self.workflow.add_edge(self.NODE_PROMPT_EXECUTOR, self.NODE_OUTPUT_HISTORY_ANALYZER)
|
317 |
-
self.workflow.add_edge(self.NODE_PROMPT_SUGGESTER, self.NODE_PROMPT_DEVELOPER)
|
318 |
-
|
319 |
-
self.workflow.add_conditional_edges(
|
320 |
self.NODE_OUTPUT_HISTORY_ANALYZER,
|
321 |
lambda x: self._should_exit_on_max_age(x),
|
322 |
{
|
@@ -326,7 +323,7 @@ Analysis:
|
|
326 |
}
|
327 |
)
|
328 |
|
329 |
-
|
330 |
self.NODE_PROMPT_ANALYZER,
|
331 |
lambda x: self._should_exit_on_acceptable_output(x),
|
332 |
{
|
@@ -335,9 +332,24 @@ Analysis:
|
|
335 |
}
|
336 |
)
|
337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
|
|
|
|
|
339 |
memory = MemorySaver()
|
340 |
-
graph =
|
341 |
config = {"configurable": {"thread_id": "1"}, "recursion_limit": recursion_limit}
|
342 |
|
343 |
try:
|
|
|
164 |
]),
|
165 |
NODE_PROMPT_ANALYZER: ChatPromptTemplate.from_messages([
|
166 |
("system", """
|
167 |
+
You are a text comparing program. You compare the following output texts,
|
168 |
+
analysis the System Message and provide a detailed analysis according to
|
169 |
+
`Acceptance Criteria`. Then you decide whether `Actual Output` is acceptable.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
Provide your analysis in the following format:
|
172 |
|
|
|
186 |
```
|
187 |
{acceptance_criteria}
|
188 |
```
|
189 |
+
"""),
|
190 |
+
("human", """
|
191 |
+
# System Message
|
192 |
+
|
193 |
+
```
|
194 |
+
{system_message}
|
195 |
+
```
|
196 |
+
|
197 |
+
# Expected Output
|
198 |
+
|
199 |
+
```
|
200 |
+
{expected_output}
|
201 |
+
```
|
202 |
+
|
203 |
+
# Actual Output
|
204 |
+
|
205 |
+
```
|
206 |
+
{output}
|
207 |
+
```
|
208 |
""")
|
209 |
]),
|
210 |
NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
|
|
|
286 |
self.prompt_templates: Dict[str, ChatPromptTemplate] = self.DEFAULT_PROMPT_TEMPLATES.copy()
|
287 |
self.prompt_templates.update(prompts)
|
288 |
|
289 |
+
def _create_workflow(self, including_initial_developer: bool = True) -> StateGraph:
|
290 |
+
workflow = StateGraph(AgentState)
|
291 |
+
|
292 |
+
workflow.add_node(self.NODE_PROMPT_DEVELOPER,
|
293 |
+
lambda x: self._prompt_node(
|
294 |
+
self.NODE_PROMPT_DEVELOPER,
|
295 |
+
"system_message",
|
296 |
+
x))
|
297 |
+
workflow.add_node(self.NODE_PROMPT_EXECUTOR,
|
298 |
+
lambda x: self._prompt_node(
|
299 |
+
self.NODE_PROMPT_EXECUTOR,
|
300 |
+
"output",
|
301 |
+
x))
|
302 |
+
workflow.add_node(self.NODE_OUTPUT_HISTORY_ANALYZER,
|
303 |
+
lambda x: self._output_history_analyzer(x))
|
304 |
+
workflow.add_node(self.NODE_PROMPT_ANALYZER,
|
305 |
+
lambda x: self._prompt_analyzer(x))
|
306 |
+
workflow.add_node(self.NODE_PROMPT_SUGGESTER,
|
307 |
+
lambda x: self._prompt_node(
|
308 |
+
self.NODE_PROMPT_SUGGESTER,
|
309 |
+
"suggestions",
|
310 |
+
x))
|
311 |
+
|
312 |
+
workflow.add_edge(self.NODE_PROMPT_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
|
313 |
+
workflow.add_edge(self.NODE_PROMPT_EXECUTOR, self.NODE_OUTPUT_HISTORY_ANALYZER)
|
314 |
+
workflow.add_edge(self.NODE_PROMPT_SUGGESTER, self.NODE_PROMPT_DEVELOPER)
|
315 |
+
|
316 |
+
workflow.add_conditional_edges(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
self.NODE_OUTPUT_HISTORY_ANALYZER,
|
318 |
lambda x: self._should_exit_on_max_age(x),
|
319 |
{
|
|
|
323 |
}
|
324 |
)
|
325 |
|
326 |
+
workflow.add_conditional_edges(
|
327 |
self.NODE_PROMPT_ANALYZER,
|
328 |
lambda x: self._should_exit_on_acceptable_output(x),
|
329 |
{
|
|
|
332 |
}
|
333 |
)
|
334 |
|
335 |
+
if including_initial_developer:
|
336 |
+
workflow.add_node(self.NODE_PROMPT_INITIAL_DEVELOPER,
|
337 |
+
lambda x: self._prompt_node(
|
338 |
+
self.NODE_PROMPT_INITIAL_DEVELOPER,
|
339 |
+
"system_message",
|
340 |
+
x))
|
341 |
+
workflow.add_edge(self.NODE_PROMPT_INITIAL_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
|
342 |
+
workflow.set_entry_point(self.NODE_PROMPT_INITIAL_DEVELOPER)
|
343 |
+
else:
|
344 |
+
workflow.set_entry_point(self.NODE_PROMPT_EXECUTOR)
|
345 |
+
|
346 |
+
return workflow
|
347 |
+
|
348 |
def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
|
349 |
+
workflow = self._create_workflow(including_initial_developer=(state.system_message is None or state.system_message == ""))
|
350 |
+
|
351 |
memory = MemorySaver()
|
352 |
+
graph = workflow.compile(checkpointer=memory)
|
353 |
config = {"configurable": {"thread_id": "1"}, "recursion_limit": recursion_limit}
|
354 |
|
355 |
try:
|