Spaces:
Running
Running
Refactored code.
Browse files- app/gradio_meta_prompt.py +111 -56
- meta_prompt/__init__.py +10 -1
- meta_prompt/consts.py +220 -0
- meta_prompt/meta_prompt.py +78 -301
app/gradio_meta_prompt.py
CHANGED
@@ -3,7 +3,6 @@ import io
|
|
3 |
import json
|
4 |
import logging
|
5 |
from pathlib import Path
|
6 |
-
import pprint
|
7 |
from typing import Any, Dict, Union
|
8 |
import gradio as gr
|
9 |
from gradio import CSVLogger, Button, utils
|
@@ -13,8 +12,10 @@ from confz import BaseConfig, CLArgSource, EnvSource, FileSource
|
|
13 |
from app.config import MetaPromptConfig
|
14 |
from langchain_core.language_models import BaseLanguageModel
|
15 |
from langchain_openai import ChatOpenAI
|
16 |
-
from meta_prompt import
|
17 |
from pythonjsonlogger import jsonlogger
|
|
|
|
|
18 |
|
19 |
class SimplifiedCSVLogger(CSVLogger):
|
20 |
"""
|
@@ -65,15 +66,19 @@ class SimplifiedCSVLogger(CSVLogger):
|
|
65 |
line_count = len(list(csv.reader(csvfile))) - 1
|
66 |
return line_count
|
67 |
|
|
|
68 |
class LLMModelFactory:
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def create(self, model_type: str, **kwargs):
|
73 |
model_class = globals()[model_type]
|
74 |
return model_class(**kwargs)
|
75 |
-
|
76 |
-
llm_model_factory = LLMModelFactory()
|
77 |
|
78 |
def chat_log_2_chatbot_list(chat_log: str):
|
79 |
chatbot_list = []
|
@@ -95,8 +100,10 @@ def chat_log_2_chatbot_list(chat_log: str):
|
|
95 |
print(line)
|
96 |
return chatbot_list
|
97 |
|
98 |
-
|
99 |
-
|
|
|
|
|
100 |
llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]]):
|
101 |
# Create the input state
|
102 |
input_state = AgentState(
|
@@ -106,7 +113,7 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
|
|
106 |
system_message=initial_system_message,
|
107 |
max_output_age=max_output_age
|
108 |
)
|
109 |
-
|
110 |
# Get the output state from MetaPromptGraph
|
111 |
log_stream = io.StringIO()
|
112 |
log_handler = None
|
@@ -114,10 +121,12 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
|
|
114 |
if config.verbose:
|
115 |
log_handler = logging.StreamHandler(log_stream)
|
116 |
logger = logging.getLogger(MetaPromptGraph.__name__)
|
117 |
-
log_handler.setFormatter(jsonlogger.JsonFormatter(
|
|
|
118 |
logger.addHandler(log_handler)
|
119 |
|
120 |
-
meta_prompt_graph = MetaPromptGraph(
|
|
|
121 |
output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
|
122 |
|
123 |
if config.verbose:
|
@@ -125,7 +134,7 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
|
|
125 |
log_output = log_stream.getvalue()
|
126 |
else:
|
127 |
log_output = None
|
128 |
-
|
129 |
# Validate the output state
|
130 |
system_message = ''
|
131 |
output = ''
|
@@ -146,7 +155,8 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
|
|
146 |
else:
|
147 |
analysis = "Error: The output state does not contain a valid 'analysis'"
|
148 |
|
149 |
-
return system_message, output, analysis,
|
|
|
150 |
|
151 |
|
152 |
def process_message_with_single_llm(user_message, expected_output, acceptance_criteria, initial_system_message,
|
@@ -155,31 +165,33 @@ def process_message_with_single_llm(user_message, expected_output, acceptance_cr
|
|
155 |
# Get the output state from MetaPromptGraph
|
156 |
type = config.llms[model_name].type
|
157 |
args = config.llms[model_name].model_dump(exclude={'type'})
|
158 |
-
llm =
|
159 |
|
160 |
return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
161 |
recursion_limit, max_output_age, llm)
|
162 |
|
|
|
163 |
def process_message_with_2_llms(user_message, expected_output, acceptance_criteria, initial_system_message,
|
164 |
-
|
165 |
-
|
166 |
# Get the output state from MetaPromptGraph
|
167 |
-
optimizer_model =
|
168 |
**config.llms[optimizer_model_name].model_dump(exclude={'type'}))
|
169 |
-
executor_model =
|
170 |
**config.llms[executor_model_name].model_dump(exclude={'type'}))
|
171 |
llms = {
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
}
|
179 |
|
180 |
return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
181 |
recursion_limit, max_output_age, llms)
|
182 |
|
|
|
183 |
class FileConfig(BaseConfig):
|
184 |
config_file: str = 'config.yml' # default path
|
185 |
|
@@ -216,10 +228,12 @@ with gr.Blocks(title='Meta Prompt') as demo:
|
|
216 |
label="Acceptance Criteria", show_copy_button=True)
|
217 |
initial_system_message_input = gr.Textbox(
|
218 |
label="Initial System Message", show_copy_button=True, value="")
|
219 |
-
recursion_limit_input = gr.Number(
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
223 |
with gr.Row():
|
224 |
with gr.Tab('Simple'):
|
225 |
model_name_input = gr.Dropdown(
|
@@ -229,7 +243,8 @@ with gr.Blocks(title='Meta Prompt') as demo:
|
|
229 |
)
|
230 |
# Connect the inputs and outputs to the function
|
231 |
with gr.Row():
|
232 |
-
submit_button = gr.Button(
|
|
|
233 |
clear_button = gr.ClearButton(
|
234 |
[user_message_input, expected_output_input,
|
235 |
acceptance_criteria_input, initial_system_message_input],
|
@@ -247,45 +262,85 @@ with gr.Blocks(title='Meta Prompt') as demo:
|
|
247 |
)
|
248 |
# Connect the inputs and outputs to the function
|
249 |
with gr.Row():
|
250 |
-
multiple_submit_button = gr.Button(
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
with gr.Column():
|
257 |
-
system_message_output = gr.Textbox(
|
258 |
-
label="System Message", show_copy_button=True)
|
259 |
output_output = gr.Textbox(label="Output", show_copy_button=True)
|
260 |
-
analysis_output = gr.Textbox(
|
261 |
-
label="Analysis", show_copy_button=True)
|
262 |
flag_button = gr.Button(value="Flag", variant="secondary", visible=config.allow_flagging)
|
263 |
with gr.Accordion("Details", open=False, visible=config.verbose):
|
264 |
-
logs_chatbot = gr.Chatbot(
|
265 |
-
|
|
|
|
|
266 |
clear_logs_button = gr.ClearButton([logs_chatbot], value='Clear Logs')
|
267 |
|
268 |
clear_button.add([system_message_output, output_output,
|
269 |
analysis_output, logs_chatbot])
|
270 |
multiple_clear_button.add([system_message_output, output_output,
|
271 |
-
|
272 |
-
|
273 |
-
submit_button.click(
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
# Load examples
|
285 |
examples = config.examples_path
|
286 |
-
gr.Examples(examples, inputs=[
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
# Configure flagging
|
291 |
if config.allow_flagging:
|
|
|
3 |
import json
|
4 |
import logging
|
5 |
from pathlib import Path
|
|
|
6 |
from typing import Any, Dict, Union
|
7 |
import gradio as gr
|
8 |
from gradio import CSVLogger, Button, utils
|
|
|
12 |
from app.config import MetaPromptConfig
|
13 |
from langchain_core.language_models import BaseLanguageModel
|
14 |
from langchain_openai import ChatOpenAI
|
15 |
+
from meta_prompt import *
|
16 |
from pythonjsonlogger import jsonlogger
|
17 |
+
import pprint
|
18 |
+
|
19 |
|
20 |
class SimplifiedCSVLogger(CSVLogger):
|
21 |
"""
|
|
|
66 |
line_count = len(list(csv.reader(csvfile))) - 1
|
67 |
return line_count
|
68 |
|
69 |
+
|
70 |
class LLMModelFactory:
|
71 |
+
_instance = None
|
72 |
+
|
73 |
+
def __new__(cls):
|
74 |
+
if not cls._instance:
|
75 |
+
cls._instance = super(LLMModelFactory, cls).__new__(cls)
|
76 |
+
return cls._instance
|
77 |
|
78 |
def create(self, model_type: str, **kwargs):
|
79 |
model_class = globals()[model_type]
|
80 |
return model_class(**kwargs)
|
81 |
+
|
|
|
82 |
|
83 |
def chat_log_2_chatbot_list(chat_log: str):
|
84 |
chatbot_list = []
|
|
|
100 |
print(line)
|
101 |
return chatbot_list
|
102 |
|
103 |
+
|
104 |
+
def process_message(user_message, expected_output, acceptance_criteria,
|
105 |
+
initial_system_message, recursion_limit: int,
|
106 |
+
max_output_age: int,
|
107 |
llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]]):
|
108 |
# Create the input state
|
109 |
input_state = AgentState(
|
|
|
113 |
system_message=initial_system_message,
|
114 |
max_output_age=max_output_age
|
115 |
)
|
116 |
+
|
117 |
# Get the output state from MetaPromptGraph
|
118 |
log_stream = io.StringIO()
|
119 |
log_handler = None
|
|
|
121 |
if config.verbose:
|
122 |
log_handler = logging.StreamHandler(log_stream)
|
123 |
logger = logging.getLogger(MetaPromptGraph.__name__)
|
124 |
+
log_handler.setFormatter(jsonlogger.JsonFormatter(
|
125 |
+
'%(asctime)s %(name)s %(levelname)s %(message)s'))
|
126 |
logger.addHandler(log_handler)
|
127 |
|
128 |
+
meta_prompt_graph = MetaPromptGraph(
|
129 |
+
llms=llms, verbose=config.verbose, logger=logger)
|
130 |
output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
|
131 |
|
132 |
if config.verbose:
|
|
|
134 |
log_output = log_stream.getvalue()
|
135 |
else:
|
136 |
log_output = None
|
137 |
+
|
138 |
# Validate the output state
|
139 |
system_message = ''
|
140 |
output = ''
|
|
|
155 |
else:
|
156 |
analysis = "Error: The output state does not contain a valid 'analysis'"
|
157 |
|
158 |
+
return (system_message, output, analysis,
|
159 |
+
chat_log_2_chatbot_list(log_output))
|
160 |
|
161 |
|
162 |
def process_message_with_single_llm(user_message, expected_output, acceptance_criteria, initial_system_message,
|
|
|
165 |
# Get the output state from MetaPromptGraph
|
166 |
type = config.llms[model_name].type
|
167 |
args = config.llms[model_name].model_dump(exclude={'type'})
|
168 |
+
llm = LLMModelFactory().create(type, **args)
|
169 |
|
170 |
return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
171 |
recursion_limit, max_output_age, llm)
|
172 |
|
173 |
+
|
174 |
def process_message_with_2_llms(user_message, expected_output, acceptance_criteria, initial_system_message,
|
175 |
+
recursion_limit: int, max_output_age: int,
|
176 |
+
optimizer_model_name: str, executor_model_name: str,):
|
177 |
# Get the output state from MetaPromptGraph
|
178 |
+
optimizer_model = LLMModelFactory().create(config.llms[optimizer_model_name].type,
|
179 |
**config.llms[optimizer_model_name].model_dump(exclude={'type'}))
|
180 |
+
executor_model = LLMModelFactory().create(config.llms[executor_model_name].type,
|
181 |
**config.llms[executor_model_name].model_dump(exclude={'type'}))
|
182 |
llms = {
|
183 |
+
NODE_PROMPT_INITIAL_DEVELOPER: optimizer_model,
|
184 |
+
NODE_PROMPT_DEVELOPER: optimizer_model,
|
185 |
+
NODE_PROMPT_EXECUTOR: executor_model,
|
186 |
+
NODE_OUTPUT_HISTORY_ANALYZER: optimizer_model,
|
187 |
+
NODE_PROMPT_ANALYZER: optimizer_model,
|
188 |
+
NODE_PROMPT_SUGGESTER: optimizer_model
|
189 |
}
|
190 |
|
191 |
return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
192 |
recursion_limit, max_output_age, llms)
|
193 |
|
194 |
+
|
195 |
class FileConfig(BaseConfig):
|
196 |
config_file: str = 'config.yml' # default path
|
197 |
|
|
|
228 |
label="Acceptance Criteria", show_copy_button=True)
|
229 |
initial_system_message_input = gr.Textbox(
|
230 |
label="Initial System Message", show_copy_button=True, value="")
|
231 |
+
recursion_limit_input = gr.Number(
|
232 |
+
label="Recursion Limit", value=config.recursion_limit,
|
233 |
+
precision=0, minimum=1, maximum=config.recursion_limit_max, step=1)
|
234 |
+
max_output_age = gr.Number(
|
235 |
+
label="Max Output Age", value=config.max_output_age,
|
236 |
+
precision=0, minimum=1, maximum=config.max_output_age_max, step=1)
|
237 |
with gr.Row():
|
238 |
with gr.Tab('Simple'):
|
239 |
model_name_input = gr.Dropdown(
|
|
|
243 |
)
|
244 |
# Connect the inputs and outputs to the function
|
245 |
with gr.Row():
|
246 |
+
submit_button = gr.Button(
|
247 |
+
value="Submit", variant="primary")
|
248 |
clear_button = gr.ClearButton(
|
249 |
[user_message_input, expected_output_input,
|
250 |
acceptance_criteria_input, initial_system_message_input],
|
|
|
262 |
)
|
263 |
# Connect the inputs and outputs to the function
|
264 |
with gr.Row():
|
265 |
+
multiple_submit_button = gr.Button(
|
266 |
+
value="Submit", variant="primary")
|
267 |
+
multiple_clear_button = gr.ClearButton(
|
268 |
+
components=[user_message_input, expected_output_input,
|
269 |
+
acceptance_criteria_input, initial_system_message_input],
|
270 |
+
value='Clear All')
|
271 |
with gr.Column():
|
272 |
+
system_message_output = gr.Textbox(label="System Message", show_copy_button=True)
|
|
|
273 |
output_output = gr.Textbox(label="Output", show_copy_button=True)
|
274 |
+
analysis_output = gr.Textbox(label="Analysis", show_copy_button=True)
|
|
|
275 |
flag_button = gr.Button(value="Flag", variant="secondary", visible=config.allow_flagging)
|
276 |
with gr.Accordion("Details", open=False, visible=config.verbose):
|
277 |
+
logs_chatbot = gr.Chatbot(
|
278 |
+
label='Messages', show_copy_button=True, layout='bubble',
|
279 |
+
bubble_full_width=False, render_markdown=False
|
280 |
+
)
|
281 |
clear_logs_button = gr.ClearButton([logs_chatbot], value='Clear Logs')
|
282 |
|
283 |
clear_button.add([system_message_output, output_output,
|
284 |
analysis_output, logs_chatbot])
|
285 |
multiple_clear_button.add([system_message_output, output_output,
|
286 |
+
analysis_output, logs_chatbot])
|
287 |
+
|
288 |
+
submit_button.click(
|
289 |
+
process_message_with_single_llm,
|
290 |
+
inputs=[
|
291 |
+
user_message_input,
|
292 |
+
expected_output_input,
|
293 |
+
acceptance_criteria_input,
|
294 |
+
initial_system_message_input,
|
295 |
+
recursion_limit_input,
|
296 |
+
max_output_age,
|
297 |
+
model_name_input
|
298 |
+
],
|
299 |
+
outputs=[
|
300 |
+
system_message_output,
|
301 |
+
output_output,
|
302 |
+
analysis_output,
|
303 |
+
logs_chatbot
|
304 |
+
]
|
305 |
+
)
|
306 |
+
|
307 |
+
multiple_submit_button.click(
|
308 |
+
process_message_with_2_llms,
|
309 |
+
inputs=[
|
310 |
+
user_message_input,
|
311 |
+
expected_output_input,
|
312 |
+
acceptance_criteria_input,
|
313 |
+
initial_system_message_input,
|
314 |
+
recursion_limit_input,
|
315 |
+
max_output_age,
|
316 |
+
optimizer_model_name_input,
|
317 |
+
executor_model_name_input
|
318 |
+
],
|
319 |
+
outputs=[
|
320 |
+
system_message_output,
|
321 |
+
output_output,
|
322 |
+
analysis_output,
|
323 |
+
logs_chatbot
|
324 |
+
]
|
325 |
+
)
|
326 |
|
327 |
# Load examples
|
328 |
examples = config.examples_path
|
329 |
+
gr.Examples(examples, inputs=[
|
330 |
+
user_message_input,
|
331 |
+
expected_output_input,
|
332 |
+
acceptance_criteria_input,
|
333 |
+
initial_system_message_input,
|
334 |
+
recursion_limit_input,
|
335 |
+
model_name_input
|
336 |
+
])
|
337 |
+
|
338 |
+
flagging_inputs = [
|
339 |
+
user_message_input,
|
340 |
+
expected_output_input,
|
341 |
+
acceptance_criteria_input,
|
342 |
+
initial_system_message_input
|
343 |
+
]
|
344 |
|
345 |
# Configure flagging
|
346 |
if config.allow_flagging:
|
meta_prompt/__init__.py
CHANGED
@@ -1,4 +1,13 @@
|
|
1 |
__version__ = '0.1.0'
|
2 |
|
3 |
from .meta_prompt import AgentState, MetaPromptGraph
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
__version__ = '0.1.0'
|
2 |
|
3 |
from .meta_prompt import AgentState, MetaPromptGraph
|
4 |
+
from .consts import (
|
5 |
+
META_PROMPT_NODES,
|
6 |
+
NODE_PROMPT_INITIAL_DEVELOPER,
|
7 |
+
NODE_PROMPT_DEVELOPER,
|
8 |
+
NODE_PROMPT_EXECUTOR,
|
9 |
+
NODE_OUTPUT_HISTORY_ANALYZER,
|
10 |
+
NODE_PROMPT_ANALYZER,
|
11 |
+
NODE_PROMPT_SUGGESTER,
|
12 |
+
DEFAULT_PROMPT_TEMPLATES,
|
13 |
+
)
|
meta_prompt/consts.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import ChatPromptTemplate
|
2 |
+
|
3 |
+
NODE_PROMPT_INITIAL_DEVELOPER = "prompt_initial_developer"
|
4 |
+
NODE_PROMPT_DEVELOPER = "prompt_developer"
|
5 |
+
NODE_PROMPT_EXECUTOR = "prompt_executor"
|
6 |
+
NODE_OUTPUT_HISTORY_ANALYZER = "output_history_analyzer"
|
7 |
+
NODE_PROMPT_ANALYZER = "prompt_analyzer"
|
8 |
+
NODE_PROMPT_SUGGESTER = "prompt_suggester"
|
9 |
+
|
10 |
+
META_PROMPT_NODES = [
|
11 |
+
NODE_PROMPT_INITIAL_DEVELOPER,
|
12 |
+
NODE_PROMPT_DEVELOPER,
|
13 |
+
NODE_PROMPT_EXECUTOR,
|
14 |
+
NODE_OUTPUT_HISTORY_ANALYZER,
|
15 |
+
NODE_PROMPT_ANALYZER,
|
16 |
+
NODE_PROMPT_SUGGESTER
|
17 |
+
]
|
18 |
+
|
19 |
+
DEFAULT_PROMPT_TEMPLATES = {
|
20 |
+
NODE_PROMPT_INITIAL_DEVELOPER: ChatPromptTemplate.from_messages([
|
21 |
+
("system", """# Expert Prompt Engineer
|
22 |
+
|
23 |
+
You are an expert prompt engineer tasked with creating system messages for AI assistants.
|
24 |
+
|
25 |
+
## Instructions
|
26 |
+
|
27 |
+
1. Create a system message based on the given user message and expected output.
|
28 |
+
2. Ensure the system message can handle similar user messages.
|
29 |
+
3. Output only the system message, without any additional content.
|
30 |
+
4. Expected Output text should not appear in System Message as an example. But it's OK to use some similar text as an example instead.
|
31 |
+
5. Format the system message well, with no more than 80 characters per line (except for raw text).
|
32 |
+
|
33 |
+
## Output
|
34 |
+
|
35 |
+
Provide only the system message, adhering to the above guidelines.
|
36 |
+
"""),
|
37 |
+
("human", """# User Message
|
38 |
+
|
39 |
+
{user_message}
|
40 |
+
|
41 |
+
# Expected Output
|
42 |
+
|
43 |
+
{expected_output}
|
44 |
+
""")
|
45 |
+
]),
|
46 |
+
NODE_PROMPT_DEVELOPER: ChatPromptTemplate.from_messages([
|
47 |
+
("system", """# Expert Prompt Engineer
|
48 |
+
|
49 |
+
You are an expert prompt engineer tasked with updating system messages for AI assistants. You Update System Message according to Suggestions, to improve Output and match Expected Output more closely.
|
50 |
+
|
51 |
+
## Instructions
|
52 |
+
|
53 |
+
1. Update the system message based on the given Suggestion, User Message, and Expected Output.
|
54 |
+
2. Ensure the updated system message can handle similar user messages.
|
55 |
+
3. Modify only the content mentioned in the Suggestion. Do not change the parts that are not related to the Suggestion.
|
56 |
+
4. Output only the updated system message, without any additional content.
|
57 |
+
5. Avoiding the behavior should be explicitly requested (e.g. `Don't ...`) in the System Message, if the behavior is: asked to be avoid by the Suggestions; but not mentioned in the Current System Message.
|
58 |
+
6. Expected Output text should not appear in System Message as an example. But it's OK to use some similar text as an example instead.
|
59 |
+
* Remove the Expected Output text or text highly similar to Expected Output from System Message, if it's present.
|
60 |
+
7. Format the system message well, with no more than 80 characters per line (except for raw text).
|
61 |
+
|
62 |
+
## Output
|
63 |
+
|
64 |
+
Provide only the updated System Message, adhering to the above guidelines.
|
65 |
+
"""),
|
66 |
+
("human", """# Current system message
|
67 |
+
|
68 |
+
{system_message}
|
69 |
+
|
70 |
+
# User Message
|
71 |
+
|
72 |
+
{user_message}
|
73 |
+
|
74 |
+
# Expected Output
|
75 |
+
|
76 |
+
{expected_output}
|
77 |
+
|
78 |
+
# Suggestions
|
79 |
+
|
80 |
+
{suggestions}
|
81 |
+
""")
|
82 |
+
]),
|
83 |
+
NODE_PROMPT_EXECUTOR: ChatPromptTemplate.from_messages([
|
84 |
+
("system", "{system_message}"),
|
85 |
+
("human", "{user_message}")
|
86 |
+
]),
|
87 |
+
NODE_OUTPUT_HISTORY_ANALYZER: ChatPromptTemplate.from_messages([
|
88 |
+
("system", """You are a text comparing program. You read the Acceptance Criteria, compare the compare the exptected output with two different outputs, and decide which one is more consistent with the expected output. When comparing the outputs, ignore the differences which are acceptable or ignorable according to the Acceptance Criteria.
|
89 |
+
|
90 |
+
You output the following analysis according to the Acceptance Criteria:
|
91 |
+
|
92 |
+
* Your analysis in a Markdown list.
|
93 |
+
* The ID of the output that is more consistent with the Expected Output as Preferred Output ID, with the following format:
|
94 |
+
|
95 |
+
```
|
96 |
+
# Analysis
|
97 |
+
|
98 |
+
...
|
99 |
+
|
100 |
+
# Preferred Output ID: [ID]
|
101 |
+
```
|
102 |
+
|
103 |
+
If both outputs are equally similar to the expected output, output the following:
|
104 |
+
|
105 |
+
```
|
106 |
+
# Analysis
|
107 |
+
|
108 |
+
...
|
109 |
+
|
110 |
+
# Draw
|
111 |
+
```
|
112 |
+
"""),
|
113 |
+
("human", """
|
114 |
+
# Output ID: A
|
115 |
+
|
116 |
+
```
|
117 |
+
{best_output}
|
118 |
+
```
|
119 |
+
|
120 |
+
# Output ID: B
|
121 |
+
|
122 |
+
```
|
123 |
+
{output}
|
124 |
+
```
|
125 |
+
|
126 |
+
# Acceptance Criteria
|
127 |
+
|
128 |
+
{acceptance_criteria}
|
129 |
+
|
130 |
+
# Expected Output
|
131 |
+
|
132 |
+
```
|
133 |
+
{expected_output}
|
134 |
+
```
|
135 |
+
""")
|
136 |
+
]),
|
137 |
+
NODE_PROMPT_ANALYZER: ChatPromptTemplate.from_messages([
|
138 |
+
("system", """You are a text comparing program. You compare the following output texts, analysis the System Message and provide a detailed analysis according to `Acceptance Criteria`. Then you decide whether `Actual Output` is acceptable.
|
139 |
+
|
140 |
+
Provide your analysis in the following format:
|
141 |
+
|
142 |
+
```
|
143 |
+
- Acceptable Differences: [List acceptable differences succinctly]
|
144 |
+
- Unacceptable Differences: [List unacceptable differences succinctly]
|
145 |
+
- Accept: [Yes/No]
|
146 |
+
```
|
147 |
+
|
148 |
+
* Compare Expected Output and Actual Output with the guidance of Accept Criteria.
|
149 |
+
* Only set 'Accept' to 'Yes', if Accept Criteria are all met. Otherwise, set 'Accept' to 'No'.
|
150 |
+
* List only the acceptable differences according to Accept Criteria in 'acceptable Differences' section.
|
151 |
+
* List only the unacceptable differences according to Accept Criteria in 'Unacceptable Differences' section.
|
152 |
+
|
153 |
+
# Acceptance Criteria
|
154 |
+
|
155 |
+
```
|
156 |
+
{acceptance_criteria}
|
157 |
+
```
|
158 |
+
"""),
|
159 |
+
("human", """
|
160 |
+
# System Message
|
161 |
+
|
162 |
+
```
|
163 |
+
{system_message}
|
164 |
+
```
|
165 |
+
|
166 |
+
# Expected Output
|
167 |
+
|
168 |
+
```
|
169 |
+
{expected_output}
|
170 |
+
```
|
171 |
+
|
172 |
+
# Actual Output
|
173 |
+
|
174 |
+
```
|
175 |
+
{output}
|
176 |
+
```
|
177 |
+
""")
|
178 |
+
]),
|
179 |
+
NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
|
180 |
+
("system", """Read the following inputs and outputs of an LLM prompt, and also analysis about them. Then suggest how to improve System Message.
|
181 |
+
|
182 |
+
* The goal is to improve the System Message to match the Expected Output better.
|
183 |
+
* Ignore all Acceptable Differences and focus on Unacceptable Differences.
|
184 |
+
* Suggest formal changes first, then semantic changes.
|
185 |
+
* Provide your suggestions in a Markdown list, nothing else. Output only the suggestions related with Unacceptable Differences.
|
186 |
+
* Start every suggestion with `The System Message should ...`.
|
187 |
+
* Figue out the contexts of the System Message that conflict with the suggestions, and suggest modification or deletion.
|
188 |
+
* Avoiding the behavior should be explicitly requested (e.g. `The System Message should explicitly state that the output shoud not ...`) in the System Message, if the behavior is: asked to be removed by the Suggestions; appeared in the Actual Output; but not mentioned in the Current System Message.
|
189 |
+
* Expected Output text should not appear in System Message as an example. But it's OK to use some similar but distinct text as an example instead.
|
190 |
+
* Ask to remove the Expected Output text or text highly similar to Expected Output from System Message, if it's present.
|
191 |
+
* Provide format examples or detected format name, if System Message does not.
|
192 |
+
* Specify the detected format name (e.g. XML, JSON, etc.) of Expected Output, if System Message does not mention it.
|
193 |
+
"""),
|
194 |
+
("human", """
|
195 |
+
<|Start_System_Message|>
|
196 |
+
{system_message}
|
197 |
+
<|End_System_Message|>
|
198 |
+
|
199 |
+
<|Start_User_Message|>
|
200 |
+
{user_message}
|
201 |
+
<|End_User_Message|>
|
202 |
+
|
203 |
+
<|Start_Expected_Output|>
|
204 |
+
{expected_output}
|
205 |
+
<|End_Expected_Output|>
|
206 |
+
|
207 |
+
<|Start_Actual_Output|>
|
208 |
+
{output}
|
209 |
+
<|End_Actual_Output|>
|
210 |
+
|
211 |
+
<|Start_Acceptance Criteria|>
|
212 |
+
{acceptance_criteria}
|
213 |
+
<|End_Acceptance Criteria|>
|
214 |
+
|
215 |
+
<|Start_Analysis|>
|
216 |
+
{analysis}
|
217 |
+
<|End_Analysis|>
|
218 |
+
""")
|
219 |
+
])
|
220 |
+
}
|
meta_prompt/meta_prompt.py
CHANGED
@@ -9,6 +9,7 @@ from langgraph.graph import StateGraph, END
|
|
9 |
from langgraph.checkpoint.memory import MemorySaver
|
10 |
from langgraph.errors import GraphRecursionError
|
11 |
from pydantic import BaseModel
|
|
|
12 |
|
13 |
class AgentState(BaseModel):
|
14 |
max_output_age: int = 0
|
@@ -25,261 +26,16 @@ class AgentState(BaseModel):
|
|
25 |
best_output_age: int = 0
|
26 |
|
27 |
class MetaPromptGraph:
|
28 |
-
NODE_PROMPT_INITIAL_DEVELOPER = "prompt_initial_developer"
|
29 |
-
NODE_PROMPT_DEVELOPER = "prompt_developer"
|
30 |
-
NODE_PROMPT_EXECUTOR = "prompt_executor"
|
31 |
-
NODE_OUTPUT_HISTORY_ANALYZER = "output_history_analyzer"
|
32 |
-
NODE_PROMPT_ANALYZER = "prompt_analyzer"
|
33 |
-
NODE_PROMPT_SUGGESTER = "prompt_suggester"
|
34 |
-
|
35 |
-
DEFAULT_PROMPT_TEMPLATES = {
|
36 |
-
NODE_PROMPT_INITIAL_DEVELOPER: ChatPromptTemplate.from_messages([
|
37 |
-
("system", """# Expert Prompt Engineer
|
38 |
-
|
39 |
-
You are an expert prompt engineer tasked with creating system messages for AI
|
40 |
-
assistants.
|
41 |
-
|
42 |
-
## Instructions
|
43 |
-
|
44 |
-
1. Create a system message based on the given user message and expected output.
|
45 |
-
2. Ensure the system message can handle similar user messages.
|
46 |
-
3. Output only the system message, without any additional content.
|
47 |
-
4. Expected Output text should not appear in System Message as an example. But
|
48 |
-
it's OK to use some similar text as an example instead.
|
49 |
-
5. Format the system message well, with no more than 80 characters per line
|
50 |
-
(except for raw text).
|
51 |
-
|
52 |
-
## Output
|
53 |
-
|
54 |
-
Provide only the system message, adhering to the above guidelines.
|
55 |
-
"""),
|
56 |
-
("human", """# User Message
|
57 |
-
|
58 |
-
{user_message}
|
59 |
-
|
60 |
-
# Expected Output
|
61 |
-
|
62 |
-
{expected_output}
|
63 |
-
""")
|
64 |
-
]),
|
65 |
-
NODE_PROMPT_DEVELOPER: ChatPromptTemplate.from_messages([
|
66 |
-
("system", """# Expert Prompt Engineer
|
67 |
-
|
68 |
-
You are an expert prompt engineer tasked with updating system messages for AI
|
69 |
-
assistants. You Update System Message according to Suggestions, to improve
|
70 |
-
Output and match Expected Output more closely.
|
71 |
-
|
72 |
-
## Instructions
|
73 |
-
|
74 |
-
1. Update the system message based on the given Suggestion, User Message, and
|
75 |
-
Expected Output.
|
76 |
-
2. Ensure the updated system message can handle similar user messages.
|
77 |
-
3. Modify only the content mentioned in the Suggestion. Do not change the
|
78 |
-
parts that are not related to the Suggestion.
|
79 |
-
4. Output only the updated system message, without any additional content.
|
80 |
-
5. Avoiding the behavior should be explicitly requested (e.g. `Don't ...`) in the
|
81 |
-
System Message, if the behavior is: asked to be avoid by the Suggestions;
|
82 |
-
but not mentioned in the Current System Message.
|
83 |
-
6. Expected Output text should not appear in System Message as an example. But
|
84 |
-
it's OK to use some similar text as an example instead.
|
85 |
-
* Remove the Expected Output text or text highly similar to Expected Output
|
86 |
-
from System Message, if it's present.
|
87 |
-
7. Format the system message well, with no more than 80 characters per line
|
88 |
-
(except for raw text).
|
89 |
-
|
90 |
-
## Output
|
91 |
-
|
92 |
-
Provide only the updated System Message, adhering to the above guidelines.
|
93 |
-
"""),
|
94 |
-
("human", """# Current system message
|
95 |
-
|
96 |
-
{system_message}
|
97 |
-
|
98 |
-
# User Message
|
99 |
-
|
100 |
-
{user_message}
|
101 |
-
|
102 |
-
# Expected Output
|
103 |
-
|
104 |
-
{expected_output}
|
105 |
-
|
106 |
-
# Suggestions
|
107 |
-
|
108 |
-
{suggestions}
|
109 |
-
""")
|
110 |
-
]),
|
111 |
-
NODE_PROMPT_EXECUTOR: ChatPromptTemplate.from_messages([
|
112 |
-
("system", "{system_message}"),
|
113 |
-
("human", "{user_message}")
|
114 |
-
]),
|
115 |
-
NODE_OUTPUT_HISTORY_ANALYZER: ChatPromptTemplate.from_messages([
|
116 |
-
("system", """You are a text comparing program. You read the Acceptance Criteria, compare the
|
117 |
-
compare the exptected output with two different outputs, and decide which one is
|
118 |
-
more consistent with the expected output. When comparing the outputs, ignore the
|
119 |
-
differences which are acceptable or ignorable according to the Acceptance Criteria.
|
120 |
-
|
121 |
-
You output the following analysis according to the Acceptance Criteria:
|
122 |
-
|
123 |
-
* Your analysis in a Markdown list.
|
124 |
-
* The ID of the output that is more consistent with the Expected Output as Preferred
|
125 |
-
Output ID, with the following format:
|
126 |
-
|
127 |
-
```
|
128 |
-
# Analysis
|
129 |
-
|
130 |
-
...
|
131 |
-
|
132 |
-
# Preferred Output ID: [ID]
|
133 |
-
```
|
134 |
-
|
135 |
-
If both outputs are equally similar to the expected output, output the following:
|
136 |
-
|
137 |
-
```
|
138 |
-
# Analysis
|
139 |
-
|
140 |
-
...
|
141 |
-
|
142 |
-
# Draw
|
143 |
-
```
|
144 |
-
"""),
|
145 |
-
("human", """
|
146 |
-
# Output ID: A
|
147 |
-
|
148 |
-
```
|
149 |
-
{best_output}
|
150 |
-
```
|
151 |
-
|
152 |
-
# Output ID: B
|
153 |
-
|
154 |
-
```
|
155 |
-
{output}
|
156 |
-
```
|
157 |
-
|
158 |
-
# Acceptance Criteria
|
159 |
-
|
160 |
-
{acceptance_criteria}
|
161 |
-
|
162 |
-
# Expected Output
|
163 |
-
|
164 |
-
```
|
165 |
-
{expected_output}
|
166 |
-
```
|
167 |
-
""")
|
168 |
-
]),
|
169 |
-
NODE_PROMPT_ANALYZER: ChatPromptTemplate.from_messages([
|
170 |
-
("system", """
|
171 |
-
You are a text comparing program. You compare the following output texts,
|
172 |
-
analysis the System Message and provide a detailed analysis according to
|
173 |
-
`Acceptance Criteria`. Then you decide whether `Actual Output` is acceptable.
|
174 |
-
|
175 |
-
Provide your analysis in the following format:
|
176 |
-
|
177 |
-
```
|
178 |
-
- Acceptable Differences: [List acceptable differences succinctly]
|
179 |
-
- Unacceptable Differences: [List unacceptable differences succinctly]
|
180 |
-
- Accept: [Yes/No]
|
181 |
-
```
|
182 |
-
|
183 |
-
* Compare Expected Output and Actual Output with the guidance of Accept Criteria.
|
184 |
-
* Only set 'Accept' to 'Yes', if Accept Criteria are all met. Otherwise, set 'Accept' to 'No'.
|
185 |
-
* List only the acceptable differences according to Accept Criteria in 'acceptable Differences' section.
|
186 |
-
* List only the unacceptable differences according to Accept Criteria in 'Unacceptable Differences' section.
|
187 |
-
|
188 |
-
# Acceptance Criteria
|
189 |
-
|
190 |
-
```
|
191 |
-
{acceptance_criteria}
|
192 |
-
```
|
193 |
-
"""),
|
194 |
-
("human", """
|
195 |
-
# System Message
|
196 |
-
|
197 |
-
```
|
198 |
-
{system_message}
|
199 |
-
```
|
200 |
-
|
201 |
-
# Expected Output
|
202 |
-
|
203 |
-
```
|
204 |
-
{expected_output}
|
205 |
-
```
|
206 |
-
|
207 |
-
# Actual Output
|
208 |
-
|
209 |
-
```
|
210 |
-
{output}
|
211 |
-
```
|
212 |
-
""")
|
213 |
-
]),
|
214 |
-
NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
|
215 |
-
("system", """
|
216 |
-
Read the following inputs and outputs of an LLM prompt, and also analysis about them.
|
217 |
-
Then suggest how to improve System Message.
|
218 |
-
|
219 |
-
* The goal is to improve the System Message to match the Expected Output better.
|
220 |
-
* Ignore all Acceptable Differences and focus on Unacceptable Differences.
|
221 |
-
* Suggest formal changes first, then semantic changes.
|
222 |
-
* Provide your suggestions in a Markdown list, nothing else. Output only the
|
223 |
-
suggestions related with Unacceptable Differences.
|
224 |
-
* Start every suggestion with `The System Message should ...`.
|
225 |
-
* Figue out the contexts of the System Message that conflict with the suggestions,
|
226 |
-
and suggest modification or deletion.
|
227 |
-
* Avoiding the behavior should be explicitly requested (e.g. `The System Message
|
228 |
-
should explicitly state that the output shoud not ...`) in the System Message, if
|
229 |
-
the behavior is: asked to be removed by the Suggestions; appeared in the Actual
|
230 |
-
Output; but not mentioned in the Current System Message.
|
231 |
-
* Expected Output text should not appear in System Message as an example. But
|
232 |
-
it's OK to use some similar but distinct text as an example instead.
|
233 |
-
* Ask to remove the Expected Output text or text highly similar to Expected Output
|
234 |
-
from System Message, if it's present.
|
235 |
-
* Provide format examples or detected format name, if System Message does not.
|
236 |
-
* Specify the detected format name (e.g. XML, JSON, etc.) of Expected Output, if
|
237 |
-
System Message does not mention it.
|
238 |
-
"""),
|
239 |
-
("human", """
|
240 |
-
<|Start_System_Message|>
|
241 |
-
{system_message}
|
242 |
-
<|End_System_Message|>
|
243 |
-
|
244 |
-
<|Start_User_Message|>
|
245 |
-
{user_message}
|
246 |
-
<|End_User_Message|>
|
247 |
-
|
248 |
-
<|Start_Expected_Output|>
|
249 |
-
{expected_output}
|
250 |
-
<|End_Expected_Output|>
|
251 |
-
|
252 |
-
<|Start_Actual_Output|>
|
253 |
-
{output}
|
254 |
-
<|End_Actual_Output|>
|
255 |
-
|
256 |
-
<|Start_Acceptance Criteria|>
|
257 |
-
{acceptance_criteria}
|
258 |
-
<|End_Acceptance Criteria|>
|
259 |
-
|
260 |
-
<|Start_Analysis|>
|
261 |
-
{analysis}
|
262 |
-
<|End_Analysis|>
|
263 |
-
""")
|
264 |
-
])
|
265 |
-
}
|
266 |
-
|
267 |
@classmethod
|
268 |
def get_node_names(cls):
|
269 |
-
return
|
270 |
-
cls.NODE_PROMPT_INITIAL_DEVELOPER,
|
271 |
-
cls.NODE_PROMPT_DEVELOPER,
|
272 |
-
cls.NODE_PROMPT_EXECUTOR,
|
273 |
-
cls.NODE_OUTPUT_HISTORY_ANALYZER,
|
274 |
-
cls.NODE_PROMPT_ANALYZER,
|
275 |
-
cls.NODE_PROMPT_SUGGESTER
|
276 |
-
]
|
277 |
|
278 |
def __init__(self,
|
279 |
-
llms: Union[BaseLanguageModel,
|
|
|
280 |
prompts: Dict[str, ChatPromptTemplate] = {},
|
281 |
logger: Optional[logging.Logger] = None,
|
282 |
-
verbose
|
283 |
self.logger = logger or logging.getLogger(__name__)
|
284 |
if self.logger is not None:
|
285 |
if verbose:
|
@@ -288,80 +44,86 @@ Then suggest how to improve System Message.
|
|
288 |
self.logger.setLevel(logging.INFO)
|
289 |
|
290 |
if isinstance(llms, BaseLanguageModel):
|
291 |
-
self.llms: Dict[str, BaseLanguageModel] = {
|
|
|
292 |
else:
|
293 |
self.llms: Dict[str, BaseLanguageModel] = llms
|
294 |
-
self.prompt_templates: Dict[str,
|
|
|
295 |
self.prompt_templates.update(prompts)
|
296 |
|
297 |
def _create_workflow(self, including_initial_developer: bool = True) -> StateGraph:
|
298 |
workflow = StateGraph(AgentState)
|
299 |
-
|
300 |
-
workflow.add_node(
|
301 |
lambda x: self._prompt_node(
|
302 |
-
|
303 |
"system_message",
|
304 |
x))
|
305 |
-
workflow.add_node(
|
306 |
lambda x: self._prompt_node(
|
307 |
-
|
308 |
"output",
|
309 |
x))
|
310 |
-
workflow.add_node(
|
311 |
lambda x: self._output_history_analyzer(x))
|
312 |
-
workflow.add_node(
|
313 |
lambda x: self._prompt_analyzer(x))
|
314 |
-
workflow.add_node(
|
315 |
lambda x: self._prompt_node(
|
316 |
-
|
317 |
"suggestions",
|
318 |
x))
|
319 |
|
320 |
-
workflow.add_edge(
|
321 |
-
workflow.add_edge(
|
322 |
-
workflow.add_edge(
|
323 |
|
324 |
workflow.add_conditional_edges(
|
325 |
-
|
326 |
lambda x: self._should_exit_on_max_age(x),
|
327 |
{
|
328 |
-
"continue":
|
329 |
-
"rerun":
|
330 |
END: END
|
331 |
}
|
332 |
)
|
333 |
|
334 |
workflow.add_conditional_edges(
|
335 |
-
|
336 |
lambda x: self._should_exit_on_acceptable_output(x),
|
337 |
{
|
338 |
-
"continue":
|
339 |
END: END
|
340 |
}
|
341 |
)
|
342 |
|
343 |
if including_initial_developer:
|
344 |
-
workflow.add_node(
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
workflow.add_edge(
|
350 |
-
|
|
|
351 |
else:
|
352 |
-
workflow.set_entry_point(
|
353 |
|
354 |
return workflow
|
355 |
|
356 |
def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
|
357 |
-
workflow = self._create_workflow(including_initial_developer=(
|
|
|
358 |
|
359 |
memory = MemorySaver()
|
360 |
graph = workflow.compile(checkpointer=memory)
|
361 |
-
config = {"configurable": {"thread_id": "1"},
|
|
|
362 |
|
363 |
try:
|
364 |
-
self.logger.debug("Invoking graph with state: %s",
|
|
|
365 |
|
366 |
output_state = graph.invoke(state, config)
|
367 |
|
@@ -369,7 +131,8 @@ Then suggest how to improve System Message.
|
|
369 |
|
370 |
return output_state
|
371 |
except GraphRecursionError as e:
|
372 |
-
self.logger.info(
|
|
|
373 |
checkpoint_states = graph.get_state(config)
|
374 |
|
375 |
# if the length of states is bigger than 0, print the best system message and output
|
@@ -377,41 +140,50 @@ Then suggest how to improve System Message.
|
|
377 |
output_state = checkpoint_states[0]
|
378 |
return output_state
|
379 |
else:
|
380 |
-
self.logger.info(
|
381 |
-
|
|
|
382 |
return state
|
383 |
|
384 |
def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
|
385 |
logger = self.logger.getChild(node)
|
386 |
-
prompt = self.prompt_templates[node].format_messages(
|
|
|
387 |
|
388 |
for message in prompt:
|
389 |
-
logger.debug({'node': node, 'action': 'invoke',
|
390 |
-
|
391 |
-
|
392 |
-
|
|
|
|
|
|
|
393 |
setattr(state, target_attribute, response.content)
|
394 |
return state
|
395 |
|
396 |
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
397 |
-
logger = self.logger.getChild(
|
398 |
|
399 |
if state.best_output is None:
|
400 |
state.best_output = state.output
|
401 |
state.best_system_message = state.system_message
|
402 |
state.best_output_age = 0
|
403 |
|
404 |
-
logger.debug(
|
|
|
405 |
|
406 |
return state
|
407 |
|
408 |
-
prompt = self.prompt_templates[
|
|
|
409 |
|
410 |
for message in prompt:
|
411 |
-
logger.debug({'node':
|
|
|
412 |
|
413 |
-
response = self.llms[
|
414 |
-
logger.debug({'node':
|
|
|
415 |
|
416 |
analysis = response.content
|
417 |
|
@@ -420,23 +192,28 @@ Then suggest how to improve System Message.
|
|
420 |
state.best_system_message = state.system_message
|
421 |
state.best_output_age = 0
|
422 |
|
423 |
-
logger.debug(
|
|
|
424 |
else:
|
425 |
state.best_output_age += 1
|
426 |
|
427 |
-
logger.debug("Best output age incremented to %s",
|
|
|
428 |
|
429 |
return state
|
430 |
|
431 |
def _prompt_analyzer(self, state: AgentState) -> AgentState:
|
432 |
-
logger = self.logger.getChild(
|
433 |
-
prompt = self.prompt_templates[
|
|
|
434 |
|
435 |
for message in prompt:
|
436 |
-
logger.debug({'node':
|
|
|
437 |
|
438 |
-
response = self.llms[
|
439 |
-
logger.debug({'node':
|
|
|
440 |
|
441 |
state.analysis = response.content
|
442 |
state.accepted = "Accept: Yes" in response.content
|
@@ -446,7 +223,7 @@ Then suggest how to improve System Message.
|
|
446 |
return state
|
447 |
|
448 |
def _should_exit_on_max_age(self, state: AgentState) -> str:
|
449 |
-
if state.max_output_age <=0:
|
450 |
# always continue if max age is 0
|
451 |
return "continue"
|
452 |
|
|
|
9 |
from langgraph.checkpoint.memory import MemorySaver
|
10 |
from langgraph.errors import GraphRecursionError
|
11 |
from pydantic import BaseModel
|
12 |
+
from .consts import *
|
13 |
|
14 |
class AgentState(BaseModel):
|
15 |
max_output_age: int = 0
|
|
|
26 |
best_output_age: int = 0
|
27 |
|
28 |
class MetaPromptGraph:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
@classmethod
|
30 |
def get_node_names(cls):
|
31 |
+
return META_PROMPT_NODES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def __init__(self,
|
34 |
+
llms: Union[BaseLanguageModel,
|
35 |
+
Dict[str, BaseLanguageModel]] = {},
|
36 |
prompts: Dict[str, ChatPromptTemplate] = {},
|
37 |
logger: Optional[logging.Logger] = None,
|
38 |
+
verbose=False):
|
39 |
self.logger = logger or logging.getLogger(__name__)
|
40 |
if self.logger is not None:
|
41 |
if verbose:
|
|
|
44 |
self.logger.setLevel(logging.INFO)
|
45 |
|
46 |
if isinstance(llms, BaseLanguageModel):
|
47 |
+
self.llms: Dict[str, BaseLanguageModel] = {
|
48 |
+
node: llms for node in self.get_node_names()}
|
49 |
else:
|
50 |
self.llms: Dict[str, BaseLanguageModel] = llms
|
51 |
+
self.prompt_templates: Dict[str,
|
52 |
+
ChatPromptTemplate] = DEFAULT_PROMPT_TEMPLATES.copy()
|
53 |
self.prompt_templates.update(prompts)
|
54 |
|
55 |
def _create_workflow(self, including_initial_developer: bool = True) -> StateGraph:
|
56 |
workflow = StateGraph(AgentState)
|
57 |
+
|
58 |
+
workflow.add_node(NODE_PROMPT_DEVELOPER,
|
59 |
lambda x: self._prompt_node(
|
60 |
+
NODE_PROMPT_DEVELOPER,
|
61 |
"system_message",
|
62 |
x))
|
63 |
+
workflow.add_node(NODE_PROMPT_EXECUTOR,
|
64 |
lambda x: self._prompt_node(
|
65 |
+
NODE_PROMPT_EXECUTOR,
|
66 |
"output",
|
67 |
x))
|
68 |
+
workflow.add_node(NODE_OUTPUT_HISTORY_ANALYZER,
|
69 |
lambda x: self._output_history_analyzer(x))
|
70 |
+
workflow.add_node(NODE_PROMPT_ANALYZER,
|
71 |
lambda x: self._prompt_analyzer(x))
|
72 |
+
workflow.add_node(NODE_PROMPT_SUGGESTER,
|
73 |
lambda x: self._prompt_node(
|
74 |
+
NODE_PROMPT_SUGGESTER,
|
75 |
"suggestions",
|
76 |
x))
|
77 |
|
78 |
+
workflow.add_edge(NODE_PROMPT_DEVELOPER, NODE_PROMPT_EXECUTOR)
|
79 |
+
workflow.add_edge(NODE_PROMPT_EXECUTOR, NODE_OUTPUT_HISTORY_ANALYZER)
|
80 |
+
workflow.add_edge(NODE_PROMPT_SUGGESTER, NODE_PROMPT_DEVELOPER)
|
81 |
|
82 |
workflow.add_conditional_edges(
|
83 |
+
NODE_OUTPUT_HISTORY_ANALYZER,
|
84 |
lambda x: self._should_exit_on_max_age(x),
|
85 |
{
|
86 |
+
"continue": NODE_PROMPT_ANALYZER,
|
87 |
+
"rerun": NODE_PROMPT_SUGGESTER,
|
88 |
END: END
|
89 |
}
|
90 |
)
|
91 |
|
92 |
workflow.add_conditional_edges(
|
93 |
+
NODE_PROMPT_ANALYZER,
|
94 |
lambda x: self._should_exit_on_acceptable_output(x),
|
95 |
{
|
96 |
+
"continue": NODE_PROMPT_SUGGESTER,
|
97 |
END: END
|
98 |
}
|
99 |
)
|
100 |
|
101 |
if including_initial_developer:
|
102 |
+
workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
|
103 |
+
lambda x: self._prompt_node(
|
104 |
+
NODE_PROMPT_INITIAL_DEVELOPER,
|
105 |
+
"system_message",
|
106 |
+
x))
|
107 |
+
workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER,
|
108 |
+
NODE_PROMPT_EXECUTOR)
|
109 |
+
workflow.set_entry_point(NODE_PROMPT_INITIAL_DEVELOPER)
|
110 |
else:
|
111 |
+
workflow.set_entry_point(NODE_PROMPT_EXECUTOR)
|
112 |
|
113 |
return workflow
|
114 |
|
115 |
def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
|
116 |
+
workflow = self._create_workflow(including_initial_developer=(
|
117 |
+
state.system_message is None or state.system_message == ""))
|
118 |
|
119 |
memory = MemorySaver()
|
120 |
graph = workflow.compile(checkpointer=memory)
|
121 |
+
config = {"configurable": {"thread_id": "1"},
|
122 |
+
"recursion_limit": recursion_limit}
|
123 |
|
124 |
try:
|
125 |
+
self.logger.debug("Invoking graph with state: %s",
|
126 |
+
pprint.pformat(state))
|
127 |
|
128 |
output_state = graph.invoke(state, config)
|
129 |
|
|
|
131 |
|
132 |
return output_state
|
133 |
except GraphRecursionError as e:
|
134 |
+
self.logger.info(
|
135 |
+
"Recursion limit reached. Returning the best state found so far.")
|
136 |
checkpoint_states = graph.get_state(config)
|
137 |
|
138 |
# if the length of states is bigger than 0, print the best system message and output
|
|
|
140 |
output_state = checkpoint_states[0]
|
141 |
return output_state
|
142 |
else:
|
143 |
+
self.logger.info(
|
144 |
+
"No checkpoint states found. Returning the input state.")
|
145 |
+
|
146 |
return state
|
147 |
|
148 |
def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
|
149 |
logger = self.logger.getChild(node)
|
150 |
+
prompt = self.prompt_templates[node].format_messages(
|
151 |
+
**state.model_dump())
|
152 |
|
153 |
for message in prompt:
|
154 |
+
logger.debug({'node': node, 'action': 'invoke',
|
155 |
+
'type': message.type, 'message': message.content})
|
156 |
+
response = self.llms[node].invoke(
|
157 |
+
self.prompt_templates[node].format_messages(**state.model_dump()))
|
158 |
+
logger.debug({'node': node, 'action': 'response',
|
159 |
+
'type': response.type, 'message': response.content})
|
160 |
+
|
161 |
setattr(state, target_attribute, response.content)
|
162 |
return state
|
163 |
|
164 |
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
165 |
+
logger = self.logger.getChild(NODE_OUTPUT_HISTORY_ANALYZER)
|
166 |
|
167 |
if state.best_output is None:
|
168 |
state.best_output = state.output
|
169 |
state.best_system_message = state.system_message
|
170 |
state.best_output_age = 0
|
171 |
|
172 |
+
logger.debug(
|
173 |
+
"Best output initialized to the current output:\n%s", state.output)
|
174 |
|
175 |
return state
|
176 |
|
177 |
+
prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
|
178 |
+
**state.model_dump())
|
179 |
|
180 |
for message in prompt:
|
181 |
+
logger.debug({'node': NODE_OUTPUT_HISTORY_ANALYZER, 'action': 'invoke',
|
182 |
+
'type': message.type, 'message': message.content})
|
183 |
|
184 |
+
response = self.llms[NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
|
185 |
+
logger.debug({'node': NODE_OUTPUT_HISTORY_ANALYZER, 'action': 'response',
|
186 |
+
'type': response.type, 'message': response.content})
|
187 |
|
188 |
analysis = response.content
|
189 |
|
|
|
192 |
state.best_system_message = state.system_message
|
193 |
state.best_output_age = 0
|
194 |
|
195 |
+
logger.debug(
|
196 |
+
"Best output updated to the current output:\n%s", state.output)
|
197 |
else:
|
198 |
state.best_output_age += 1
|
199 |
|
200 |
+
logger.debug("Best output age incremented to %s",
|
201 |
+
state.best_output_age)
|
202 |
|
203 |
return state
|
204 |
|
205 |
def _prompt_analyzer(self, state: AgentState) -> AgentState:
|
206 |
+
logger = self.logger.getChild(NODE_PROMPT_ANALYZER)
|
207 |
+
prompt = self.prompt_templates[NODE_PROMPT_ANALYZER].format_messages(
|
208 |
+
**state.model_dump())
|
209 |
|
210 |
for message in prompt:
|
211 |
+
logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'invoke',
|
212 |
+
'type': message.type, 'message': message.content})
|
213 |
|
214 |
+
response = self.llms[NODE_PROMPT_ANALYZER].invoke(prompt)
|
215 |
+
logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
|
216 |
+
'type': response.type, 'message': response.content})
|
217 |
|
218 |
state.analysis = response.content
|
219 |
state.accepted = "Accept: Yes" in response.content
|
|
|
223 |
return state
|
224 |
|
225 |
def _should_exit_on_max_age(self, state: AgentState) -> str:
|
226 |
+
if state.max_output_age <= 0:
|
227 |
# always continue if max age is 0
|
228 |
return "continue"
|
229 |
|