yaleh commited on
Commit
3b1cdbf
·
1 Parent(s): 1fcb0dd

Refactored code.

Browse files
app/gradio_meta_prompt.py CHANGED
@@ -3,7 +3,6 @@ import io
3
  import json
4
  import logging
5
  from pathlib import Path
6
- import pprint
7
  from typing import Any, Dict, Union
8
  import gradio as gr
9
  from gradio import CSVLogger, Button, utils
@@ -13,8 +12,10 @@ from confz import BaseConfig, CLArgSource, EnvSource, FileSource
13
  from app.config import MetaPromptConfig
14
  from langchain_core.language_models import BaseLanguageModel
15
  from langchain_openai import ChatOpenAI
16
- from meta_prompt import MetaPromptGraph, AgentState
17
  from pythonjsonlogger import jsonlogger
 
 
18
 
19
  class SimplifiedCSVLogger(CSVLogger):
20
  """
@@ -65,15 +66,19 @@ class SimplifiedCSVLogger(CSVLogger):
65
  line_count = len(list(csv.reader(csvfile))) - 1
66
  return line_count
67
 
 
68
  class LLMModelFactory:
69
- def __init__(self):
70
- pass
 
 
 
 
71
 
72
  def create(self, model_type: str, **kwargs):
73
  model_class = globals()[model_type]
74
  return model_class(**kwargs)
75
-
76
- llm_model_factory = LLMModelFactory()
77
 
78
  def chat_log_2_chatbot_list(chat_log: str):
79
  chatbot_list = []
@@ -95,8 +100,10 @@ def chat_log_2_chatbot_list(chat_log: str):
95
  print(line)
96
  return chatbot_list
97
 
98
- def process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
99
- recursion_limit: int, max_output_age: int,
 
 
100
  llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]]):
101
  # Create the input state
102
  input_state = AgentState(
@@ -106,7 +113,7 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
106
  system_message=initial_system_message,
107
  max_output_age=max_output_age
108
  )
109
-
110
  # Get the output state from MetaPromptGraph
111
  log_stream = io.StringIO()
112
  log_handler = None
@@ -114,10 +121,12 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
114
  if config.verbose:
115
  log_handler = logging.StreamHandler(log_stream)
116
  logger = logging.getLogger(MetaPromptGraph.__name__)
117
- log_handler.setFormatter(jsonlogger.JsonFormatter('%(asctime)s %(name)s %(levelname)s %(message)s'))
 
118
  logger.addHandler(log_handler)
119
 
120
- meta_prompt_graph = MetaPromptGraph(llms=llms, verbose=config.verbose, logger=logger)
 
121
  output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
122
 
123
  if config.verbose:
@@ -125,7 +134,7 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
125
  log_output = log_stream.getvalue()
126
  else:
127
  log_output = None
128
-
129
  # Validate the output state
130
  system_message = ''
131
  output = ''
@@ -146,7 +155,8 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
146
  else:
147
  analysis = "Error: The output state does not contain a valid 'analysis'"
148
 
149
- return system_message, output, analysis, chat_log_2_chatbot_list(log_output)
 
150
 
151
 
152
  def process_message_with_single_llm(user_message, expected_output, acceptance_criteria, initial_system_message,
@@ -155,31 +165,33 @@ def process_message_with_single_llm(user_message, expected_output, acceptance_cr
155
  # Get the output state from MetaPromptGraph
156
  type = config.llms[model_name].type
157
  args = config.llms[model_name].model_dump(exclude={'type'})
158
- llm = llm_model_factory.create(type, **args)
159
 
160
  return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
161
  recursion_limit, max_output_age, llm)
162
 
 
163
  def process_message_with_2_llms(user_message, expected_output, acceptance_criteria, initial_system_message,
164
- recursion_limit: int, max_output_age: int,
165
- optimizer_model_name: str, executor_model_name: str,):
166
  # Get the output state from MetaPromptGraph
167
- optimizer_model = llm_model_factory.create(config.llms[optimizer_model_name].type,
168
  **config.llms[optimizer_model_name].model_dump(exclude={'type'}))
169
- executor_model = llm_model_factory.create(config.llms[executor_model_name].type,
170
  **config.llms[executor_model_name].model_dump(exclude={'type'}))
171
  llms = {
172
- MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER: optimizer_model,
173
- MetaPromptGraph.NODE_PROMPT_DEVELOPER: optimizer_model,
174
- MetaPromptGraph.NODE_PROMPT_EXECUTOR: executor_model,
175
- MetaPromptGraph.NODE_OUTPUT_HISTORY_ANALYZER: optimizer_model,
176
- MetaPromptGraph.NODE_PROMPT_ANALYZER: optimizer_model,
177
- MetaPromptGraph.NODE_PROMPT_SUGGESTER: optimizer_model
178
  }
179
 
180
  return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
181
  recursion_limit, max_output_age, llms)
182
 
 
183
  class FileConfig(BaseConfig):
184
  config_file: str = 'config.yml' # default path
185
 
@@ -216,10 +228,12 @@ with gr.Blocks(title='Meta Prompt') as demo:
216
  label="Acceptance Criteria", show_copy_button=True)
217
  initial_system_message_input = gr.Textbox(
218
  label="Initial System Message", show_copy_button=True, value="")
219
- recursion_limit_input = gr.Number(label="Recursion Limit", value=config.recursion_limit,
220
- precision=0, minimum=1, maximum=config.recursion_limit_max, step=1)
221
- max_output_age = gr.Number(label="Max Output Age", value=config.max_output_age,
222
- precision=0, minimum=1, maximum=config.max_output_age_max, step=1)
 
 
223
  with gr.Row():
224
  with gr.Tab('Simple'):
225
  model_name_input = gr.Dropdown(
@@ -229,7 +243,8 @@ with gr.Blocks(title='Meta Prompt') as demo:
229
  )
230
  # Connect the inputs and outputs to the function
231
  with gr.Row():
232
- submit_button = gr.Button(value="Submit", variant="primary")
 
233
  clear_button = gr.ClearButton(
234
  [user_message_input, expected_output_input,
235
  acceptance_criteria_input, initial_system_message_input],
@@ -247,45 +262,85 @@ with gr.Blocks(title='Meta Prompt') as demo:
247
  )
248
  # Connect the inputs and outputs to the function
249
  with gr.Row():
250
- multiple_submit_button = gr.Button(value="Submit", variant="primary")
251
- multiple_clear_button = gr.ClearButton(components=[user_message_input,
252
- expected_output_input,
253
- acceptance_criteria_input,
254
- initial_system_message_input],
255
- value='Clear All')
256
  with gr.Column():
257
- system_message_output = gr.Textbox(
258
- label="System Message", show_copy_button=True)
259
  output_output = gr.Textbox(label="Output", show_copy_button=True)
260
- analysis_output = gr.Textbox(
261
- label="Analysis", show_copy_button=True)
262
  flag_button = gr.Button(value="Flag", variant="secondary", visible=config.allow_flagging)
263
  with gr.Accordion("Details", open=False, visible=config.verbose):
264
- logs_chatbot = gr.Chatbot(label='Messages', show_copy_button=True, layout='bubble',
265
- bubble_full_width=False, render_markdown=False)
 
 
266
  clear_logs_button = gr.ClearButton([logs_chatbot], value='Clear Logs')
267
 
268
  clear_button.add([system_message_output, output_output,
269
  analysis_output, logs_chatbot])
270
  multiple_clear_button.add([system_message_output, output_output,
271
- analysis_output, logs_chatbot])
272
-
273
- submit_button.click(process_message_with_single_llm,
274
- inputs=[user_message_input, expected_output_input, acceptance_criteria_input,
275
- initial_system_message_input, recursion_limit_input, max_output_age,
276
- model_name_input],
277
- outputs=[system_message_output, output_output, analysis_output, logs_chatbot])
278
- multiple_submit_button.click(process_message_with_2_llms,
279
- inputs=[user_message_input, expected_output_input, acceptance_criteria_input,
280
- initial_system_message_input, recursion_limit_input, max_output_age,
281
- optimizer_model_name_input, executor_model_name_input],
282
- outputs=[system_message_output, output_output, analysis_output, logs_chatbot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  # Load examples
285
  examples = config.examples_path
286
- gr.Examples(examples, inputs=[user_message_input, expected_output_input, acceptance_criteria_input, initial_system_message_input, recursion_limit_input, model_name_input])
287
-
288
- flagging_inputs = [user_message_input, expected_output_input, acceptance_criteria_input, initial_system_message_input]
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  # Configure flagging
291
  if config.allow_flagging:
 
3
  import json
4
  import logging
5
  from pathlib import Path
 
6
  from typing import Any, Dict, Union
7
  import gradio as gr
8
  from gradio import CSVLogger, Button, utils
 
12
  from app.config import MetaPromptConfig
13
  from langchain_core.language_models import BaseLanguageModel
14
  from langchain_openai import ChatOpenAI
15
+ from meta_prompt import *
16
  from pythonjsonlogger import jsonlogger
17
+ import pprint
18
+
19
 
20
  class SimplifiedCSVLogger(CSVLogger):
21
  """
 
66
  line_count = len(list(csv.reader(csvfile))) - 1
67
  return line_count
68
 
69
+
70
  class LLMModelFactory:
71
+ _instance = None
72
+
73
+ def __new__(cls):
74
+ if not cls._instance:
75
+ cls._instance = super(LLMModelFactory, cls).__new__(cls)
76
+ return cls._instance
77
 
78
  def create(self, model_type: str, **kwargs):
79
  model_class = globals()[model_type]
80
  return model_class(**kwargs)
81
+
 
82
 
83
  def chat_log_2_chatbot_list(chat_log: str):
84
  chatbot_list = []
 
100
  print(line)
101
  return chatbot_list
102
 
103
+
104
+ def process_message(user_message, expected_output, acceptance_criteria,
105
+ initial_system_message, recursion_limit: int,
106
+ max_output_age: int,
107
  llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]]):
108
  # Create the input state
109
  input_state = AgentState(
 
113
  system_message=initial_system_message,
114
  max_output_age=max_output_age
115
  )
116
+
117
  # Get the output state from MetaPromptGraph
118
  log_stream = io.StringIO()
119
  log_handler = None
 
121
  if config.verbose:
122
  log_handler = logging.StreamHandler(log_stream)
123
  logger = logging.getLogger(MetaPromptGraph.__name__)
124
+ log_handler.setFormatter(jsonlogger.JsonFormatter(
125
+ '%(asctime)s %(name)s %(levelname)s %(message)s'))
126
  logger.addHandler(log_handler)
127
 
128
+ meta_prompt_graph = MetaPromptGraph(
129
+ llms=llms, verbose=config.verbose, logger=logger)
130
  output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
131
 
132
  if config.verbose:
 
134
  log_output = log_stream.getvalue()
135
  else:
136
  log_output = None
137
+
138
  # Validate the output state
139
  system_message = ''
140
  output = ''
 
155
  else:
156
  analysis = "Error: The output state does not contain a valid 'analysis'"
157
 
158
+ return (system_message, output, analysis,
159
+ chat_log_2_chatbot_list(log_output))
160
 
161
 
162
  def process_message_with_single_llm(user_message, expected_output, acceptance_criteria, initial_system_message,
 
165
  # Get the output state from MetaPromptGraph
166
  type = config.llms[model_name].type
167
  args = config.llms[model_name].model_dump(exclude={'type'})
168
+ llm = LLMModelFactory().create(type, **args)
169
 
170
  return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
171
  recursion_limit, max_output_age, llm)
172
 
173
+
174
  def process_message_with_2_llms(user_message, expected_output, acceptance_criteria, initial_system_message,
175
+ recursion_limit: int, max_output_age: int,
176
+ optimizer_model_name: str, executor_model_name: str,):
177
  # Get the output state from MetaPromptGraph
178
+ optimizer_model = LLMModelFactory().create(config.llms[optimizer_model_name].type,
179
  **config.llms[optimizer_model_name].model_dump(exclude={'type'}))
180
+ executor_model = LLMModelFactory().create(config.llms[executor_model_name].type,
181
  **config.llms[executor_model_name].model_dump(exclude={'type'}))
182
  llms = {
183
+ NODE_PROMPT_INITIAL_DEVELOPER: optimizer_model,
184
+ NODE_PROMPT_DEVELOPER: optimizer_model,
185
+ NODE_PROMPT_EXECUTOR: executor_model,
186
+ NODE_OUTPUT_HISTORY_ANALYZER: optimizer_model,
187
+ NODE_PROMPT_ANALYZER: optimizer_model,
188
+ NODE_PROMPT_SUGGESTER: optimizer_model
189
  }
190
 
191
  return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
192
  recursion_limit, max_output_age, llms)
193
 
194
+
195
  class FileConfig(BaseConfig):
196
  config_file: str = 'config.yml' # default path
197
 
 
228
  label="Acceptance Criteria", show_copy_button=True)
229
  initial_system_message_input = gr.Textbox(
230
  label="Initial System Message", show_copy_button=True, value="")
231
+ recursion_limit_input = gr.Number(
232
+ label="Recursion Limit", value=config.recursion_limit,
233
+ precision=0, minimum=1, maximum=config.recursion_limit_max, step=1)
234
+ max_output_age = gr.Number(
235
+ label="Max Output Age", value=config.max_output_age,
236
+ precision=0, minimum=1, maximum=config.max_output_age_max, step=1)
237
  with gr.Row():
238
  with gr.Tab('Simple'):
239
  model_name_input = gr.Dropdown(
 
243
  )
244
  # Connect the inputs and outputs to the function
245
  with gr.Row():
246
+ submit_button = gr.Button(
247
+ value="Submit", variant="primary")
248
  clear_button = gr.ClearButton(
249
  [user_message_input, expected_output_input,
250
  acceptance_criteria_input, initial_system_message_input],
 
262
  )
263
  # Connect the inputs and outputs to the function
264
  with gr.Row():
265
+ multiple_submit_button = gr.Button(
266
+ value="Submit", variant="primary")
267
+ multiple_clear_button = gr.ClearButton(
268
+ components=[user_message_input, expected_output_input,
269
+ acceptance_criteria_input, initial_system_message_input],
270
+ value='Clear All')
271
  with gr.Column():
272
+ system_message_output = gr.Textbox(label="System Message", show_copy_button=True)
 
273
  output_output = gr.Textbox(label="Output", show_copy_button=True)
274
+ analysis_output = gr.Textbox(label="Analysis", show_copy_button=True)
 
275
  flag_button = gr.Button(value="Flag", variant="secondary", visible=config.allow_flagging)
276
  with gr.Accordion("Details", open=False, visible=config.verbose):
277
+ logs_chatbot = gr.Chatbot(
278
+ label='Messages', show_copy_button=True, layout='bubble',
279
+ bubble_full_width=False, render_markdown=False
280
+ )
281
  clear_logs_button = gr.ClearButton([logs_chatbot], value='Clear Logs')
282
 
283
  clear_button.add([system_message_output, output_output,
284
  analysis_output, logs_chatbot])
285
  multiple_clear_button.add([system_message_output, output_output,
286
+ analysis_output, logs_chatbot])
287
+
288
+ submit_button.click(
289
+ process_message_with_single_llm,
290
+ inputs=[
291
+ user_message_input,
292
+ expected_output_input,
293
+ acceptance_criteria_input,
294
+ initial_system_message_input,
295
+ recursion_limit_input,
296
+ max_output_age,
297
+ model_name_input
298
+ ],
299
+ outputs=[
300
+ system_message_output,
301
+ output_output,
302
+ analysis_output,
303
+ logs_chatbot
304
+ ]
305
+ )
306
+
307
+ multiple_submit_button.click(
308
+ process_message_with_2_llms,
309
+ inputs=[
310
+ user_message_input,
311
+ expected_output_input,
312
+ acceptance_criteria_input,
313
+ initial_system_message_input,
314
+ recursion_limit_input,
315
+ max_output_age,
316
+ optimizer_model_name_input,
317
+ executor_model_name_input
318
+ ],
319
+ outputs=[
320
+ system_message_output,
321
+ output_output,
322
+ analysis_output,
323
+ logs_chatbot
324
+ ]
325
+ )
326
 
327
  # Load examples
328
  examples = config.examples_path
329
+ gr.Examples(examples, inputs=[
330
+ user_message_input,
331
+ expected_output_input,
332
+ acceptance_criteria_input,
333
+ initial_system_message_input,
334
+ recursion_limit_input,
335
+ model_name_input
336
+ ])
337
+
338
+ flagging_inputs = [
339
+ user_message_input,
340
+ expected_output_input,
341
+ acceptance_criteria_input,
342
+ initial_system_message_input
343
+ ]
344
 
345
  # Configure flagging
346
  if config.allow_flagging:
meta_prompt/__init__.py CHANGED
@@ -1,4 +1,13 @@
1
  __version__ = '0.1.0'
2
 
3
  from .meta_prompt import AgentState, MetaPromptGraph
4
-
 
 
 
 
 
 
 
 
 
 
1
  __version__ = '0.1.0'
2
 
3
  from .meta_prompt import AgentState, MetaPromptGraph
4
+ from .consts import (
5
+ META_PROMPT_NODES,
6
+ NODE_PROMPT_INITIAL_DEVELOPER,
7
+ NODE_PROMPT_DEVELOPER,
8
+ NODE_PROMPT_EXECUTOR,
9
+ NODE_OUTPUT_HISTORY_ANALYZER,
10
+ NODE_PROMPT_ANALYZER,
11
+ NODE_PROMPT_SUGGESTER,
12
+ DEFAULT_PROMPT_TEMPLATES,
13
+ )
meta_prompt/consts.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+
3
+ NODE_PROMPT_INITIAL_DEVELOPER = "prompt_initial_developer"
4
+ NODE_PROMPT_DEVELOPER = "prompt_developer"
5
+ NODE_PROMPT_EXECUTOR = "prompt_executor"
6
+ NODE_OUTPUT_HISTORY_ANALYZER = "output_history_analyzer"
7
+ NODE_PROMPT_ANALYZER = "prompt_analyzer"
8
+ NODE_PROMPT_SUGGESTER = "prompt_suggester"
9
+
10
+ META_PROMPT_NODES = [
11
+ NODE_PROMPT_INITIAL_DEVELOPER,
12
+ NODE_PROMPT_DEVELOPER,
13
+ NODE_PROMPT_EXECUTOR,
14
+ NODE_OUTPUT_HISTORY_ANALYZER,
15
+ NODE_PROMPT_ANALYZER,
16
+ NODE_PROMPT_SUGGESTER
17
+ ]
18
+
19
+ DEFAULT_PROMPT_TEMPLATES = {
20
+ NODE_PROMPT_INITIAL_DEVELOPER: ChatPromptTemplate.from_messages([
21
+ ("system", """# Expert Prompt Engineer
22
+
23
+ You are an expert prompt engineer tasked with creating system messages for AI assistants.
24
+
25
+ ## Instructions
26
+
27
+ 1. Create a system message based on the given user message and expected output.
28
+ 2. Ensure the system message can handle similar user messages.
29
+ 3. Output only the system message, without any additional content.
30
+ 4. Expected Output text should not appear in System Message as an example. But it's OK to use some similar text as an example instead.
31
+ 5. Format the system message well, with no more than 80 characters per line (except for raw text).
32
+
33
+ ## Output
34
+
35
+ Provide only the system message, adhering to the above guidelines.
36
+ """),
37
+ ("human", """# User Message
38
+
39
+ {user_message}
40
+
41
+ # Expected Output
42
+
43
+ {expected_output}
44
+ """)
45
+ ]),
46
+ NODE_PROMPT_DEVELOPER: ChatPromptTemplate.from_messages([
47
+ ("system", """# Expert Prompt Engineer
48
+
49
+ You are an expert prompt engineer tasked with updating system messages for AI assistants. You Update System Message according to Suggestions, to improve Output and match Expected Output more closely.
50
+
51
+ ## Instructions
52
+
53
+ 1. Update the system message based on the given Suggestion, User Message, and Expected Output.
54
+ 2. Ensure the updated system message can handle similar user messages.
55
+ 3. Modify only the content mentioned in the Suggestion. Do not change the parts that are not related to the Suggestion.
56
+ 4. Output only the updated system message, without any additional content.
57
+ 5. Avoiding the behavior should be explicitly requested (e.g. `Don't ...`) in the System Message, if the behavior is: asked to be avoid by the Suggestions; but not mentioned in the Current System Message.
58
+ 6. Expected Output text should not appear in System Message as an example. But it's OK to use some similar text as an example instead.
59
+ * Remove the Expected Output text or text highly similar to Expected Output from System Message, if it's present.
60
+ 7. Format the system message well, with no more than 80 characters per line (except for raw text).
61
+
62
+ ## Output
63
+
64
+ Provide only the updated System Message, adhering to the above guidelines.
65
+ """),
66
+ ("human", """# Current system message
67
+
68
+ {system_message}
69
+
70
+ # User Message
71
+
72
+ {user_message}
73
+
74
+ # Expected Output
75
+
76
+ {expected_output}
77
+
78
+ # Suggestions
79
+
80
+ {suggestions}
81
+ """)
82
+ ]),
83
+ NODE_PROMPT_EXECUTOR: ChatPromptTemplate.from_messages([
84
+ ("system", "{system_message}"),
85
+ ("human", "{user_message}")
86
+ ]),
87
+ NODE_OUTPUT_HISTORY_ANALYZER: ChatPromptTemplate.from_messages([
88
+ ("system", """You are a text comparing program. You read the Acceptance Criteria, compare the compare the exptected output with two different outputs, and decide which one is more consistent with the expected output. When comparing the outputs, ignore the differences which are acceptable or ignorable according to the Acceptance Criteria.
89
+
90
+ You output the following analysis according to the Acceptance Criteria:
91
+
92
+ * Your analysis in a Markdown list.
93
+ * The ID of the output that is more consistent with the Expected Output as Preferred Output ID, with the following format:
94
+
95
+ ```
96
+ # Analysis
97
+
98
+ ...
99
+
100
+ # Preferred Output ID: [ID]
101
+ ```
102
+
103
+ If both outputs are equally similar to the expected output, output the following:
104
+
105
+ ```
106
+ # Analysis
107
+
108
+ ...
109
+
110
+ # Draw
111
+ ```
112
+ """),
113
+ ("human", """
114
+ # Output ID: A
115
+
116
+ ```
117
+ {best_output}
118
+ ```
119
+
120
+ # Output ID: B
121
+
122
+ ```
123
+ {output}
124
+ ```
125
+
126
+ # Acceptance Criteria
127
+
128
+ {acceptance_criteria}
129
+
130
+ # Expected Output
131
+
132
+ ```
133
+ {expected_output}
134
+ ```
135
+ """)
136
+ ]),
137
+ NODE_PROMPT_ANALYZER: ChatPromptTemplate.from_messages([
138
+ ("system", """You are a text comparing program. You compare the following output texts, analysis the System Message and provide a detailed analysis according to `Acceptance Criteria`. Then you decide whether `Actual Output` is acceptable.
139
+
140
+ Provide your analysis in the following format:
141
+
142
+ ```
143
+ - Acceptable Differences: [List acceptable differences succinctly]
144
+ - Unacceptable Differences: [List unacceptable differences succinctly]
145
+ - Accept: [Yes/No]
146
+ ```
147
+
148
+ * Compare Expected Output and Actual Output with the guidance of Accept Criteria.
149
+ * Only set 'Accept' to 'Yes', if Accept Criteria are all met. Otherwise, set 'Accept' to 'No'.
150
+ * List only the acceptable differences according to Accept Criteria in 'acceptable Differences' section.
151
+ * List only the unacceptable differences according to Accept Criteria in 'Unacceptable Differences' section.
152
+
153
+ # Acceptance Criteria
154
+
155
+ ```
156
+ {acceptance_criteria}
157
+ ```
158
+ """),
159
+ ("human", """
160
+ # System Message
161
+
162
+ ```
163
+ {system_message}
164
+ ```
165
+
166
+ # Expected Output
167
+
168
+ ```
169
+ {expected_output}
170
+ ```
171
+
172
+ # Actual Output
173
+
174
+ ```
175
+ {output}
176
+ ```
177
+ """)
178
+ ]),
179
+ NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
180
+ ("system", """Read the following inputs and outputs of an LLM prompt, and also analysis about them. Then suggest how to improve System Message.
181
+
182
+ * The goal is to improve the System Message to match the Expected Output better.
183
+ * Ignore all Acceptable Differences and focus on Unacceptable Differences.
184
+ * Suggest formal changes first, then semantic changes.
185
+ * Provide your suggestions in a Markdown list, nothing else. Output only the suggestions related with Unacceptable Differences.
186
+ * Start every suggestion with `The System Message should ...`.
187
+ * Figue out the contexts of the System Message that conflict with the suggestions, and suggest modification or deletion.
188
+ * Avoiding the behavior should be explicitly requested (e.g. `The System Message should explicitly state that the output shoud not ...`) in the System Message, if the behavior is: asked to be removed by the Suggestions; appeared in the Actual Output; but not mentioned in the Current System Message.
189
+ * Expected Output text should not appear in System Message as an example. But it's OK to use some similar but distinct text as an example instead.
190
+ * Ask to remove the Expected Output text or text highly similar to Expected Output from System Message, if it's present.
191
+ * Provide format examples or detected format name, if System Message does not.
192
+ * Specify the detected format name (e.g. XML, JSON, etc.) of Expected Output, if System Message does not mention it.
193
+ """),
194
+ ("human", """
195
+ <|Start_System_Message|>
196
+ {system_message}
197
+ <|End_System_Message|>
198
+
199
+ <|Start_User_Message|>
200
+ {user_message}
201
+ <|End_User_Message|>
202
+
203
+ <|Start_Expected_Output|>
204
+ {expected_output}
205
+ <|End_Expected_Output|>
206
+
207
+ <|Start_Actual_Output|>
208
+ {output}
209
+ <|End_Actual_Output|>
210
+
211
+ <|Start_Acceptance Criteria|>
212
+ {acceptance_criteria}
213
+ <|End_Acceptance Criteria|>
214
+
215
+ <|Start_Analysis|>
216
+ {analysis}
217
+ <|End_Analysis|>
218
+ """)
219
+ ])
220
+ }
meta_prompt/meta_prompt.py CHANGED
@@ -9,6 +9,7 @@ from langgraph.graph import StateGraph, END
9
  from langgraph.checkpoint.memory import MemorySaver
10
  from langgraph.errors import GraphRecursionError
11
  from pydantic import BaseModel
 
12
 
13
  class AgentState(BaseModel):
14
  max_output_age: int = 0
@@ -25,261 +26,16 @@ class AgentState(BaseModel):
25
  best_output_age: int = 0
26
 
27
  class MetaPromptGraph:
28
- NODE_PROMPT_INITIAL_DEVELOPER = "prompt_initial_developer"
29
- NODE_PROMPT_DEVELOPER = "prompt_developer"
30
- NODE_PROMPT_EXECUTOR = "prompt_executor"
31
- NODE_OUTPUT_HISTORY_ANALYZER = "output_history_analyzer"
32
- NODE_PROMPT_ANALYZER = "prompt_analyzer"
33
- NODE_PROMPT_SUGGESTER = "prompt_suggester"
34
-
35
- DEFAULT_PROMPT_TEMPLATES = {
36
- NODE_PROMPT_INITIAL_DEVELOPER: ChatPromptTemplate.from_messages([
37
- ("system", """# Expert Prompt Engineer
38
-
39
- You are an expert prompt engineer tasked with creating system messages for AI
40
- assistants.
41
-
42
- ## Instructions
43
-
44
- 1. Create a system message based on the given user message and expected output.
45
- 2. Ensure the system message can handle similar user messages.
46
- 3. Output only the system message, without any additional content.
47
- 4. Expected Output text should not appear in System Message as an example. But
48
- it's OK to use some similar text as an example instead.
49
- 5. Format the system message well, with no more than 80 characters per line
50
- (except for raw text).
51
-
52
- ## Output
53
-
54
- Provide only the system message, adhering to the above guidelines.
55
- """),
56
- ("human", """# User Message
57
-
58
- {user_message}
59
-
60
- # Expected Output
61
-
62
- {expected_output}
63
- """)
64
- ]),
65
- NODE_PROMPT_DEVELOPER: ChatPromptTemplate.from_messages([
66
- ("system", """# Expert Prompt Engineer
67
-
68
- You are an expert prompt engineer tasked with updating system messages for AI
69
- assistants. You Update System Message according to Suggestions, to improve
70
- Output and match Expected Output more closely.
71
-
72
- ## Instructions
73
-
74
- 1. Update the system message based on the given Suggestion, User Message, and
75
- Expected Output.
76
- 2. Ensure the updated system message can handle similar user messages.
77
- 3. Modify only the content mentioned in the Suggestion. Do not change the
78
- parts that are not related to the Suggestion.
79
- 4. Output only the updated system message, without any additional content.
80
- 5. Avoiding the behavior should be explicitly requested (e.g. `Don't ...`) in the
81
- System Message, if the behavior is: asked to be avoid by the Suggestions;
82
- but not mentioned in the Current System Message.
83
- 6. Expected Output text should not appear in System Message as an example. But
84
- it's OK to use some similar text as an example instead.
85
- * Remove the Expected Output text or text highly similar to Expected Output
86
- from System Message, if it's present.
87
- 7. Format the system message well, with no more than 80 characters per line
88
- (except for raw text).
89
-
90
- ## Output
91
-
92
- Provide only the updated System Message, adhering to the above guidelines.
93
- """),
94
- ("human", """# Current system message
95
-
96
- {system_message}
97
-
98
- # User Message
99
-
100
- {user_message}
101
-
102
- # Expected Output
103
-
104
- {expected_output}
105
-
106
- # Suggestions
107
-
108
- {suggestions}
109
- """)
110
- ]),
111
- NODE_PROMPT_EXECUTOR: ChatPromptTemplate.from_messages([
112
- ("system", "{system_message}"),
113
- ("human", "{user_message}")
114
- ]),
115
- NODE_OUTPUT_HISTORY_ANALYZER: ChatPromptTemplate.from_messages([
116
- ("system", """You are a text comparing program. You read the Acceptance Criteria, compare the
117
- compare the exptected output with two different outputs, and decide which one is
118
- more consistent with the expected output. When comparing the outputs, ignore the
119
- differences which are acceptable or ignorable according to the Acceptance Criteria.
120
-
121
- You output the following analysis according to the Acceptance Criteria:
122
-
123
- * Your analysis in a Markdown list.
124
- * The ID of the output that is more consistent with the Expected Output as Preferred
125
- Output ID, with the following format:
126
-
127
- ```
128
- # Analysis
129
-
130
- ...
131
-
132
- # Preferred Output ID: [ID]
133
- ```
134
-
135
- If both outputs are equally similar to the expected output, output the following:
136
-
137
- ```
138
- # Analysis
139
-
140
- ...
141
-
142
- # Draw
143
- ```
144
- """),
145
- ("human", """
146
- # Output ID: A
147
-
148
- ```
149
- {best_output}
150
- ```
151
-
152
- # Output ID: B
153
-
154
- ```
155
- {output}
156
- ```
157
-
158
- # Acceptance Criteria
159
-
160
- {acceptance_criteria}
161
-
162
- # Expected Output
163
-
164
- ```
165
- {expected_output}
166
- ```
167
- """)
168
- ]),
169
- NODE_PROMPT_ANALYZER: ChatPromptTemplate.from_messages([
170
- ("system", """
171
- You are a text comparing program. You compare the following output texts,
172
- analysis the System Message and provide a detailed analysis according to
173
- `Acceptance Criteria`. Then you decide whether `Actual Output` is acceptable.
174
-
175
- Provide your analysis in the following format:
176
-
177
- ```
178
- - Acceptable Differences: [List acceptable differences succinctly]
179
- - Unacceptable Differences: [List unacceptable differences succinctly]
180
- - Accept: [Yes/No]
181
- ```
182
-
183
- * Compare Expected Output and Actual Output with the guidance of Accept Criteria.
184
- * Only set 'Accept' to 'Yes', if Accept Criteria are all met. Otherwise, set 'Accept' to 'No'.
185
- * List only the acceptable differences according to Accept Criteria in 'acceptable Differences' section.
186
- * List only the unacceptable differences according to Accept Criteria in 'Unacceptable Differences' section.
187
-
188
- # Acceptance Criteria
189
-
190
- ```
191
- {acceptance_criteria}
192
- ```
193
- """),
194
- ("human", """
195
- # System Message
196
-
197
- ```
198
- {system_message}
199
- ```
200
-
201
- # Expected Output
202
-
203
- ```
204
- {expected_output}
205
- ```
206
-
207
- # Actual Output
208
-
209
- ```
210
- {output}
211
- ```
212
- """)
213
- ]),
214
- NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
215
- ("system", """
216
- Read the following inputs and outputs of an LLM prompt, and also analysis about them.
217
- Then suggest how to improve System Message.
218
-
219
- * The goal is to improve the System Message to match the Expected Output better.
220
- * Ignore all Acceptable Differences and focus on Unacceptable Differences.
221
- * Suggest formal changes first, then semantic changes.
222
- * Provide your suggestions in a Markdown list, nothing else. Output only the
223
- suggestions related with Unacceptable Differences.
224
- * Start every suggestion with `The System Message should ...`.
225
- * Figue out the contexts of the System Message that conflict with the suggestions,
226
- and suggest modification or deletion.
227
- * Avoiding the behavior should be explicitly requested (e.g. `The System Message
228
- should explicitly state that the output shoud not ...`) in the System Message, if
229
- the behavior is: asked to be removed by the Suggestions; appeared in the Actual
230
- Output; but not mentioned in the Current System Message.
231
- * Expected Output text should not appear in System Message as an example. But
232
- it's OK to use some similar but distinct text as an example instead.
233
- * Ask to remove the Expected Output text or text highly similar to Expected Output
234
- from System Message, if it's present.
235
- * Provide format examples or detected format name, if System Message does not.
236
- * Specify the detected format name (e.g. XML, JSON, etc.) of Expected Output, if
237
- System Message does not mention it.
238
- """),
239
- ("human", """
240
- <|Start_System_Message|>
241
- {system_message}
242
- <|End_System_Message|>
243
-
244
- <|Start_User_Message|>
245
- {user_message}
246
- <|End_User_Message|>
247
-
248
- <|Start_Expected_Output|>
249
- {expected_output}
250
- <|End_Expected_Output|>
251
-
252
- <|Start_Actual_Output|>
253
- {output}
254
- <|End_Actual_Output|>
255
-
256
- <|Start_Acceptance Criteria|>
257
- {acceptance_criteria}
258
- <|End_Acceptance Criteria|>
259
-
260
- <|Start_Analysis|>
261
- {analysis}
262
- <|End_Analysis|>
263
- """)
264
- ])
265
- }
266
-
267
  @classmethod
268
  def get_node_names(cls):
269
- return [
270
- cls.NODE_PROMPT_INITIAL_DEVELOPER,
271
- cls.NODE_PROMPT_DEVELOPER,
272
- cls.NODE_PROMPT_EXECUTOR,
273
- cls.NODE_OUTPUT_HISTORY_ANALYZER,
274
- cls.NODE_PROMPT_ANALYZER,
275
- cls.NODE_PROMPT_SUGGESTER
276
- ]
277
 
278
  def __init__(self,
279
- llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]] = {},
 
280
  prompts: Dict[str, ChatPromptTemplate] = {},
281
  logger: Optional[logging.Logger] = None,
282
- verbose = False):
283
  self.logger = logger or logging.getLogger(__name__)
284
  if self.logger is not None:
285
  if verbose:
@@ -288,80 +44,86 @@ Then suggest how to improve System Message.
288
  self.logger.setLevel(logging.INFO)
289
 
290
  if isinstance(llms, BaseLanguageModel):
291
- self.llms: Dict[str, BaseLanguageModel] = {node: llms for node in self.get_node_names()}
 
292
  else:
293
  self.llms: Dict[str, BaseLanguageModel] = llms
294
- self.prompt_templates: Dict[str, ChatPromptTemplate] = self.DEFAULT_PROMPT_TEMPLATES.copy()
 
295
  self.prompt_templates.update(prompts)
296
 
297
  def _create_workflow(self, including_initial_developer: bool = True) -> StateGraph:
298
  workflow = StateGraph(AgentState)
299
-
300
- workflow.add_node(self.NODE_PROMPT_DEVELOPER,
301
  lambda x: self._prompt_node(
302
- self.NODE_PROMPT_DEVELOPER,
303
  "system_message",
304
  x))
305
- workflow.add_node(self.NODE_PROMPT_EXECUTOR,
306
  lambda x: self._prompt_node(
307
- self.NODE_PROMPT_EXECUTOR,
308
  "output",
309
  x))
310
- workflow.add_node(self.NODE_OUTPUT_HISTORY_ANALYZER,
311
  lambda x: self._output_history_analyzer(x))
312
- workflow.add_node(self.NODE_PROMPT_ANALYZER,
313
  lambda x: self._prompt_analyzer(x))
314
- workflow.add_node(self.NODE_PROMPT_SUGGESTER,
315
  lambda x: self._prompt_node(
316
- self.NODE_PROMPT_SUGGESTER,
317
  "suggestions",
318
  x))
319
 
320
- workflow.add_edge(self.NODE_PROMPT_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
321
- workflow.add_edge(self.NODE_PROMPT_EXECUTOR, self.NODE_OUTPUT_HISTORY_ANALYZER)
322
- workflow.add_edge(self.NODE_PROMPT_SUGGESTER, self.NODE_PROMPT_DEVELOPER)
323
 
324
  workflow.add_conditional_edges(
325
- self.NODE_OUTPUT_HISTORY_ANALYZER,
326
  lambda x: self._should_exit_on_max_age(x),
327
  {
328
- "continue": self.NODE_PROMPT_ANALYZER,
329
- "rerun": self.NODE_PROMPT_SUGGESTER,
330
  END: END
331
  }
332
  )
333
 
334
  workflow.add_conditional_edges(
335
- self.NODE_PROMPT_ANALYZER,
336
  lambda x: self._should_exit_on_acceptable_output(x),
337
  {
338
- "continue": self.NODE_PROMPT_SUGGESTER,
339
  END: END
340
  }
341
  )
342
 
343
  if including_initial_developer:
344
- workflow.add_node(self.NODE_PROMPT_INITIAL_DEVELOPER,
345
- lambda x: self._prompt_node(
346
- self.NODE_PROMPT_INITIAL_DEVELOPER,
347
- "system_message",
348
- x))
349
- workflow.add_edge(self.NODE_PROMPT_INITIAL_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
350
- workflow.set_entry_point(self.NODE_PROMPT_INITIAL_DEVELOPER)
 
351
  else:
352
- workflow.set_entry_point(self.NODE_PROMPT_EXECUTOR)
353
 
354
  return workflow
355
 
356
  def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
357
- workflow = self._create_workflow(including_initial_developer=(state.system_message is None or state.system_message == ""))
 
358
 
359
  memory = MemorySaver()
360
  graph = workflow.compile(checkpointer=memory)
361
- config = {"configurable": {"thread_id": "1"}, "recursion_limit": recursion_limit}
 
362
 
363
  try:
364
- self.logger.debug("Invoking graph with state: %s", pprint.pformat(state))
 
365
 
366
  output_state = graph.invoke(state, config)
367
 
@@ -369,7 +131,8 @@ Then suggest how to improve System Message.
369
 
370
  return output_state
371
  except GraphRecursionError as e:
372
- self.logger.info("Recursion limit reached. Returning the best state found so far.")
 
373
  checkpoint_states = graph.get_state(config)
374
 
375
  # if the length of states is bigger than 0, print the best system message and output
@@ -377,41 +140,50 @@ Then suggest how to improve System Message.
377
  output_state = checkpoint_states[0]
378
  return output_state
379
  else:
380
- self.logger.info("No checkpoint states found. Returning the input state.")
381
-
 
382
  return state
383
 
384
  def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
385
  logger = self.logger.getChild(node)
386
- prompt = self.prompt_templates[node].format_messages(**state.model_dump())
 
387
 
388
  for message in prompt:
389
- logger.debug({'node': node, 'action': 'invoke', 'type': message.type, 'message': message.content})
390
- response = self.llms[node].invoke(self.prompt_templates[node].format_messages(**state.model_dump()))
391
- logger.debug({'node': node, 'action': 'response', 'type': response.type, 'message': response.content})
392
-
 
 
 
393
  setattr(state, target_attribute, response.content)
394
  return state
395
 
396
  def _output_history_analyzer(self, state: AgentState) -> AgentState:
397
- logger = self.logger.getChild(self.NODE_OUTPUT_HISTORY_ANALYZER)
398
 
399
  if state.best_output is None:
400
  state.best_output = state.output
401
  state.best_system_message = state.system_message
402
  state.best_output_age = 0
403
 
404
- logger.debug("Best output initialized to the current output: \n %s", state.output)
 
405
 
406
  return state
407
 
408
- prompt = self.prompt_templates[self.NODE_OUTPUT_HISTORY_ANALYZER].format_messages(**state.model_dump())
 
409
 
410
  for message in prompt:
411
- logger.debug({'node': self.NODE_OUTPUT_HISTORY_ANALYZER, 'action': 'invoke', 'type': message.type, 'message': message.content})
 
412
 
413
- response = self.llms[self.NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
414
- logger.debug({'node': self.NODE_OUTPUT_HISTORY_ANALYZER, 'action': 'response', 'type': response.type, 'message': response.content})
 
415
 
416
  analysis = response.content
417
 
@@ -420,23 +192,28 @@ Then suggest how to improve System Message.
420
  state.best_system_message = state.system_message
421
  state.best_output_age = 0
422
 
423
- logger.debug("Best output updated to the current output: \n %s", state.output)
 
424
  else:
425
  state.best_output_age += 1
426
 
427
- logger.debug("Best output age incremented to %s", state.best_output_age)
 
428
 
429
  return state
430
 
431
  def _prompt_analyzer(self, state: AgentState) -> AgentState:
432
- logger = self.logger.getChild(self.NODE_PROMPT_ANALYZER)
433
- prompt = self.prompt_templates[self.NODE_PROMPT_ANALYZER].format_messages(**state.model_dump())
 
434
 
435
  for message in prompt:
436
- logger.debug({'node': self.NODE_PROMPT_ANALYZER, 'action': 'invoke', 'type': message.type, 'message': message.content})
 
437
 
438
- response = self.llms[self.NODE_PROMPT_ANALYZER].invoke(prompt)
439
- logger.debug({'node': self.NODE_PROMPT_ANALYZER, 'action': 'response', 'type': response.type, 'message': response.content})
 
440
 
441
  state.analysis = response.content
442
  state.accepted = "Accept: Yes" in response.content
@@ -446,7 +223,7 @@ Then suggest how to improve System Message.
446
  return state
447
 
448
  def _should_exit_on_max_age(self, state: AgentState) -> str:
449
- if state.max_output_age <=0:
450
  # always continue if max age is 0
451
  return "continue"
452
 
 
9
  from langgraph.checkpoint.memory import MemorySaver
10
  from langgraph.errors import GraphRecursionError
11
  from pydantic import BaseModel
12
+ from .consts import *
13
 
14
  class AgentState(BaseModel):
15
  max_output_age: int = 0
 
26
  best_output_age: int = 0
27
 
28
  class MetaPromptGraph:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @classmethod
30
  def get_node_names(cls):
31
+ return META_PROMPT_NODES
 
 
 
 
 
 
 
32
 
33
  def __init__(self,
34
+ llms: Union[BaseLanguageModel,
35
+ Dict[str, BaseLanguageModel]] = {},
36
  prompts: Dict[str, ChatPromptTemplate] = {},
37
  logger: Optional[logging.Logger] = None,
38
+ verbose=False):
39
  self.logger = logger or logging.getLogger(__name__)
40
  if self.logger is not None:
41
  if verbose:
 
44
  self.logger.setLevel(logging.INFO)
45
 
46
  if isinstance(llms, BaseLanguageModel):
47
+ self.llms: Dict[str, BaseLanguageModel] = {
48
+ node: llms for node in self.get_node_names()}
49
  else:
50
  self.llms: Dict[str, BaseLanguageModel] = llms
51
+ self.prompt_templates: Dict[str,
52
+ ChatPromptTemplate] = DEFAULT_PROMPT_TEMPLATES.copy()
53
  self.prompt_templates.update(prompts)
54
 
55
  def _create_workflow(self, including_initial_developer: bool = True) -> StateGraph:
56
  workflow = StateGraph(AgentState)
57
+
58
+ workflow.add_node(NODE_PROMPT_DEVELOPER,
59
  lambda x: self._prompt_node(
60
+ NODE_PROMPT_DEVELOPER,
61
  "system_message",
62
  x))
63
+ workflow.add_node(NODE_PROMPT_EXECUTOR,
64
  lambda x: self._prompt_node(
65
+ NODE_PROMPT_EXECUTOR,
66
  "output",
67
  x))
68
+ workflow.add_node(NODE_OUTPUT_HISTORY_ANALYZER,
69
  lambda x: self._output_history_analyzer(x))
70
+ workflow.add_node(NODE_PROMPT_ANALYZER,
71
  lambda x: self._prompt_analyzer(x))
72
+ workflow.add_node(NODE_PROMPT_SUGGESTER,
73
  lambda x: self._prompt_node(
74
+ NODE_PROMPT_SUGGESTER,
75
  "suggestions",
76
  x))
77
 
78
+ workflow.add_edge(NODE_PROMPT_DEVELOPER, NODE_PROMPT_EXECUTOR)
79
+ workflow.add_edge(NODE_PROMPT_EXECUTOR, NODE_OUTPUT_HISTORY_ANALYZER)
80
+ workflow.add_edge(NODE_PROMPT_SUGGESTER, NODE_PROMPT_DEVELOPER)
81
 
82
  workflow.add_conditional_edges(
83
+ NODE_OUTPUT_HISTORY_ANALYZER,
84
  lambda x: self._should_exit_on_max_age(x),
85
  {
86
+ "continue": NODE_PROMPT_ANALYZER,
87
+ "rerun": NODE_PROMPT_SUGGESTER,
88
  END: END
89
  }
90
  )
91
 
92
  workflow.add_conditional_edges(
93
+ NODE_PROMPT_ANALYZER,
94
  lambda x: self._should_exit_on_acceptable_output(x),
95
  {
96
+ "continue": NODE_PROMPT_SUGGESTER,
97
  END: END
98
  }
99
  )
100
 
101
  if including_initial_developer:
102
+ workflow.add_node(NODE_PROMPT_INITIAL_DEVELOPER,
103
+ lambda x: self._prompt_node(
104
+ NODE_PROMPT_INITIAL_DEVELOPER,
105
+ "system_message",
106
+ x))
107
+ workflow.add_edge(NODE_PROMPT_INITIAL_DEVELOPER,
108
+ NODE_PROMPT_EXECUTOR)
109
+ workflow.set_entry_point(NODE_PROMPT_INITIAL_DEVELOPER)
110
  else:
111
+ workflow.set_entry_point(NODE_PROMPT_EXECUTOR)
112
 
113
  return workflow
114
 
115
  def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
116
+ workflow = self._create_workflow(including_initial_developer=(
117
+ state.system_message is None or state.system_message == ""))
118
 
119
  memory = MemorySaver()
120
  graph = workflow.compile(checkpointer=memory)
121
+ config = {"configurable": {"thread_id": "1"},
122
+ "recursion_limit": recursion_limit}
123
 
124
  try:
125
+ self.logger.debug("Invoking graph with state: %s",
126
+ pprint.pformat(state))
127
 
128
  output_state = graph.invoke(state, config)
129
 
 
131
 
132
  return output_state
133
  except GraphRecursionError as e:
134
+ self.logger.info(
135
+ "Recursion limit reached. Returning the best state found so far.")
136
  checkpoint_states = graph.get_state(config)
137
 
138
  # if the length of states is bigger than 0, print the best system message and output
 
140
  output_state = checkpoint_states[0]
141
  return output_state
142
  else:
143
+ self.logger.info(
144
+ "No checkpoint states found. Returning the input state.")
145
+
146
  return state
147
 
148
  def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
149
  logger = self.logger.getChild(node)
150
+ prompt = self.prompt_templates[node].format_messages(
151
+ **state.model_dump())
152
 
153
  for message in prompt:
154
+ logger.debug({'node': node, 'action': 'invoke',
155
+ 'type': message.type, 'message': message.content})
156
+ response = self.llms[node].invoke(
157
+ self.prompt_templates[node].format_messages(**state.model_dump()))
158
+ logger.debug({'node': node, 'action': 'response',
159
+ 'type': response.type, 'message': response.content})
160
+
161
  setattr(state, target_attribute, response.content)
162
  return state
163
 
164
  def _output_history_analyzer(self, state: AgentState) -> AgentState:
165
+ logger = self.logger.getChild(NODE_OUTPUT_HISTORY_ANALYZER)
166
 
167
  if state.best_output is None:
168
  state.best_output = state.output
169
  state.best_system_message = state.system_message
170
  state.best_output_age = 0
171
 
172
+ logger.debug(
173
+ "Best output initialized to the current output:\n%s", state.output)
174
 
175
  return state
176
 
177
+ prompt = self.prompt_templates[NODE_OUTPUT_HISTORY_ANALYZER].format_messages(
178
+ **state.model_dump())
179
 
180
  for message in prompt:
181
+ logger.debug({'node': NODE_OUTPUT_HISTORY_ANALYZER, 'action': 'invoke',
182
+ 'type': message.type, 'message': message.content})
183
 
184
+ response = self.llms[NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
185
+ logger.debug({'node': NODE_OUTPUT_HISTORY_ANALYZER, 'action': 'response',
186
+ 'type': response.type, 'message': response.content})
187
 
188
  analysis = response.content
189
 
 
192
  state.best_system_message = state.system_message
193
  state.best_output_age = 0
194
 
195
+ logger.debug(
196
+ "Best output updated to the current output:\n%s", state.output)
197
  else:
198
  state.best_output_age += 1
199
 
200
+ logger.debug("Best output age incremented to %s",
201
+ state.best_output_age)
202
 
203
  return state
204
 
205
  def _prompt_analyzer(self, state: AgentState) -> AgentState:
206
+ logger = self.logger.getChild(NODE_PROMPT_ANALYZER)
207
+ prompt = self.prompt_templates[NODE_PROMPT_ANALYZER].format_messages(
208
+ **state.model_dump())
209
 
210
  for message in prompt:
211
+ logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'invoke',
212
+ 'type': message.type, 'message': message.content})
213
 
214
+ response = self.llms[NODE_PROMPT_ANALYZER].invoke(prompt)
215
+ logger.debug({'node': NODE_PROMPT_ANALYZER, 'action': 'response',
216
+ 'type': response.type, 'message': response.content})
217
 
218
  state.analysis = response.content
219
  state.accepted = "Accept: Yes" in response.content
 
223
  return state
224
 
225
  def _should_exit_on_max_age(self, state: AgentState) -> str:
226
+ if state.max_output_age <= 0:
227
  # always continue if max age is 0
228
  return "continue"
229