Spaces:
Running
Running
Displays logs. Optimized prompts.
Browse files- Dockerfile +4 -2
- app/config.py +4 -1
- app/gradio_meta_prompt.py +81 -31
- config.yml +3 -2
- meta_prompt/meta_prompt.py +78 -57
- poetry.lock +12 -1
- pyproject.toml +1 -0
Dockerfile
CHANGED
@@ -6,13 +6,15 @@ WORKDIR /app
|
|
6 |
|
7 |
# Copy all files from the current directory to the working directory in the container
|
8 |
COPY config.yml poetry.lock pyproject.toml /app/
|
9 |
-
COPY app /app/app/
|
10 |
-
COPY meta_prompt /app/meta_prompt/
|
11 |
|
12 |
RUN pip install --no-cache-dir -U poetry
|
13 |
RUN poetry config virtualenvs.create false
|
14 |
RUN poetry install --with=dev
|
15 |
|
|
|
|
|
|
|
|
|
16 |
# Expose the port (if necessary)
|
17 |
EXPOSE 7860
|
18 |
|
|
|
6 |
|
7 |
# Copy all files from the current directory to the working directory in the container
|
8 |
COPY config.yml poetry.lock pyproject.toml /app/
|
|
|
|
|
9 |
|
10 |
RUN pip install --no-cache-dir -U poetry
|
11 |
RUN poetry config virtualenvs.create false
|
12 |
RUN poetry install --with=dev
|
13 |
|
14 |
+
COPY meta_prompt /app/meta_prompt/
|
15 |
+
COPY app /app/app/
|
16 |
+
RUN poetry install --with=dev
|
17 |
+
|
18 |
# Expose the port (if necessary)
|
19 |
EXPOSE 7860
|
20 |
|
app/config.py
CHANGED
@@ -16,4 +16,7 @@ class MetaPromptConfig(BaseConfig):
|
|
16 |
server_port: Optional[int] = None
|
17 |
recursion_limit: Optional[int] = 25
|
18 |
recursion_limit_max: Optional[int] = 50
|
19 |
-
allow_flagging: Optional[bool] = False
|
|
|
|
|
|
|
|
16 |
server_port: Optional[int] = None
|
17 |
recursion_limit: Optional[int] = 25
|
18 |
recursion_limit_max: Optional[int] = 50
|
19 |
+
allow_flagging: Optional[bool] = False
|
20 |
+
verbose: Optional[bool] = False
|
21 |
+
max_output_age: Optional[int] = 3
|
22 |
+
max_output_age_max: Optional[int] = 8
|
app/gradio_meta_prompt.py
CHANGED
@@ -1,15 +1,20 @@
|
|
1 |
import csv
|
|
|
|
|
|
|
2 |
from pathlib import Path
|
|
|
3 |
from typing import Any, Dict, Union
|
4 |
import gradio as gr
|
5 |
-
from gradio import CSVLogger,
|
6 |
from gradio.flagging import FlagMethod
|
7 |
from gradio_client import utils as client_utils
|
8 |
from confz import BaseConfig, CLArgSource, EnvSource, FileSource
|
9 |
-
from meta_prompt import MetaPromptGraph, AgentState
|
10 |
-
from langchain_openai import ChatOpenAI
|
11 |
from app.config import MetaPromptConfig
|
12 |
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
|
|
13 |
|
14 |
class SimplifiedCSVLogger(CSVLogger):
|
15 |
"""
|
@@ -70,19 +75,56 @@ class LLMModelFactory:
|
|
70 |
|
71 |
llm_model_factory = LLMModelFactory()
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
74 |
-
recursion_limit: int,
|
|
|
75 |
# Create the input state
|
76 |
input_state = AgentState(
|
77 |
user_message=user_message,
|
78 |
expected_output=expected_output,
|
79 |
acceptance_criteria=acceptance_criteria,
|
80 |
-
system_message=initial_system_message
|
|
|
81 |
)
|
82 |
|
83 |
# Get the output state from MetaPromptGraph
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
# Validate the output state
|
88 |
system_message = ''
|
@@ -104,21 +146,23 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
|
|
104 |
else:
|
105 |
analysis = "Error: The output state does not contain a valid 'analysis'"
|
106 |
|
107 |
-
return system_message, output, analysis
|
108 |
|
109 |
|
110 |
def process_message_with_single_llm(user_message, expected_output, acceptance_criteria, initial_system_message,
|
111 |
-
recursion_limit: int,
|
|
|
112 |
# Get the output state from MetaPromptGraph
|
113 |
type = config.llms[model_name].type
|
114 |
args = config.llms[model_name].model_dump(exclude={'type'})
|
115 |
llm = llm_model_factory.create(type, **args)
|
116 |
|
117 |
return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
118 |
-
recursion_limit, llm)
|
119 |
|
120 |
def process_message_with_2_llms(user_message, expected_output, acceptance_criteria, initial_system_message,
|
121 |
-
recursion_limit: int,
|
|
|
122 |
# Get the output state from MetaPromptGraph
|
123 |
optimizer_model = llm_model_factory.create(config.llms[optimizer_model_name].type,
|
124 |
**config.llms[optimizer_model_name].model_dump(exclude={'type'}))
|
@@ -134,7 +178,7 @@ def process_message_with_2_llms(user_message, expected_output, acceptance_criter
|
|
134 |
}
|
135 |
|
136 |
return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
137 |
-
recursion_limit, llms)
|
138 |
|
139 |
class FileConfig(BaseConfig):
|
140 |
config_file: str = 'config.yml' # default path
|
@@ -151,20 +195,12 @@ config_sources = [
|
|
151 |
CLArgSource()
|
152 |
]
|
153 |
|
154 |
-
# Add event handlers
|
155 |
-
def handle_submit(user_message, expected_output, acceptance_criteria, initial_system_message, recursion_limit, model_name):
|
156 |
-
return process_message(user_message, expected_output, acceptance_criteria, initial_system_message, recursion_limit, model_name)
|
157 |
-
|
158 |
-
# Define clear function
|
159 |
-
def clear_inputs():
|
160 |
-
return "", "", "", "", "", ""
|
161 |
-
|
162 |
config = MetaPromptConfig(config_sources=config_sources)
|
163 |
|
164 |
flagging_callback = SimplifiedCSVLogger()
|
165 |
|
166 |
# Create a Gradio Blocks context
|
167 |
-
with gr.Blocks() as demo:
|
168 |
# Define the layout
|
169 |
with gr.Row():
|
170 |
gr.Markdown(f"""<h1 style='text-align: left; margin-bottom: 1rem'>Meta Prompt</h1>
|
@@ -182,7 +218,8 @@ with gr.Blocks() as demo:
|
|
182 |
label="Initial System Message", show_copy_button=True, value="")
|
183 |
recursion_limit_input = gr.Number(label="Recursion Limit", value=config.recursion_limit,
|
184 |
precision=0, minimum=1, maximum=config.recursion_limit_max, step=1)
|
185 |
-
|
|
|
186 |
with gr.Row():
|
187 |
with gr.Tab('Simple'):
|
188 |
model_name_input = gr.Dropdown(
|
@@ -193,7 +230,10 @@ with gr.Blocks() as demo:
|
|
193 |
# Connect the inputs and outputs to the function
|
194 |
with gr.Row():
|
195 |
submit_button = gr.Button(value="Submit", variant="primary")
|
196 |
-
clear_button = gr.
|
|
|
|
|
|
|
197 |
with gr.Tab('Advanced'):
|
198 |
optimizer_model_name_input = gr.Dropdown(
|
199 |
label="Optimizer Model Name",
|
@@ -208,7 +248,11 @@ with gr.Blocks() as demo:
|
|
208 |
# Connect the inputs and outputs to the function
|
209 |
with gr.Row():
|
210 |
multiple_submit_button = gr.Button(value="Submit", variant="primary")
|
211 |
-
multiple_clear_button = gr.
|
|
|
|
|
|
|
|
|
212 |
with gr.Column():
|
213 |
system_message_output = gr.Textbox(
|
214 |
label="System Message", show_copy_button=True)
|
@@ -216,20 +260,26 @@ with gr.Blocks() as demo:
|
|
216 |
analysis_output = gr.Textbox(
|
217 |
label="Analysis", show_copy_button=True)
|
218 |
flag_button = gr.Button(value="Flag", variant="secondary", visible=config.allow_flagging)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
submit_button.click(process_message_with_single_llm,
|
221 |
inputs=[user_message_input, expected_output_input, acceptance_criteria_input,
|
222 |
-
initial_system_message_input, recursion_limit_input,
|
223 |
-
|
224 |
-
|
225 |
-
outputs=[user_message_input, expected_output_input, acceptance_criteria_input, initial_system_message_input])
|
226 |
multiple_submit_button.click(process_message_with_2_llms,
|
227 |
inputs=[user_message_input, expected_output_input, acceptance_criteria_input,
|
228 |
-
initial_system_message_input, recursion_limit_input,
|
229 |
optimizer_model_name_input, executor_model_name_input],
|
230 |
-
outputs=[system_message_output, output_output, analysis_output])
|
231 |
-
multiple_clear_button.click(clear_inputs,
|
232 |
-
outputs=[user_message_input, expected_output_input, acceptance_criteria_input, initial_system_message_input])
|
233 |
|
234 |
# Load examples
|
235 |
examples = config.examples_path
|
|
|
1 |
import csv
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
from pathlib import Path
|
6 |
+
import pprint
|
7 |
from typing import Any, Dict, Union
|
8 |
import gradio as gr
|
9 |
+
from gradio import CSVLogger, Button, utils
|
10 |
from gradio.flagging import FlagMethod
|
11 |
from gradio_client import utils as client_utils
|
12 |
from confz import BaseConfig, CLArgSource, EnvSource, FileSource
|
|
|
|
|
13 |
from app.config import MetaPromptConfig
|
14 |
from langchain_core.language_models import BaseLanguageModel
|
15 |
+
from langchain_openai import ChatOpenAI
|
16 |
+
from meta_prompt import MetaPromptGraph, AgentState
|
17 |
+
from pythonjsonlogger import jsonlogger
|
18 |
|
19 |
class SimplifiedCSVLogger(CSVLogger):
|
20 |
"""
|
|
|
75 |
|
76 |
llm_model_factory = LLMModelFactory()
|
77 |
|
78 |
+
def chat_log_2_chatbot_list(chat_log: str):
|
79 |
+
chatbot_list = []
|
80 |
+
if chat_log is None or chat_log == '':
|
81 |
+
return chatbot_list
|
82 |
+
for line in chat_log.splitlines():
|
83 |
+
try:
|
84 |
+
json_line = json.loads(line)
|
85 |
+
if 'action' in json_line:
|
86 |
+
if json_line['action'] == 'invoke':
|
87 |
+
chatbot_list.append([json_line['message'],None])
|
88 |
+
if json_line['action'] == 'response':
|
89 |
+
chatbot_list.append([None,json_line['message']])
|
90 |
+
except json.decoder.JSONDecodeError as e:
|
91 |
+
print(f"Error decoding JSON log output: {e}")
|
92 |
+
print(line)
|
93 |
+
except KeyError as e:
|
94 |
+
print(f"Error accessing key in JSON log output: {e}")
|
95 |
+
print(line)
|
96 |
+
return chatbot_list
|
97 |
+
|
98 |
def process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
99 |
+
recursion_limit: int, max_output_age: int,
|
100 |
+
llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]]):
|
101 |
# Create the input state
|
102 |
input_state = AgentState(
|
103 |
user_message=user_message,
|
104 |
expected_output=expected_output,
|
105 |
acceptance_criteria=acceptance_criteria,
|
106 |
+
system_message=initial_system_message,
|
107 |
+
max_output_age=max_output_age
|
108 |
)
|
109 |
|
110 |
# Get the output state from MetaPromptGraph
|
111 |
+
log_stream = io.StringIO()
|
112 |
+
log_handler = None
|
113 |
+
logger = None
|
114 |
+
if config.verbose:
|
115 |
+
log_handler = logging.StreamHandler(log_stream)
|
116 |
+
logger = logging.getLogger(MetaPromptGraph.__name__)
|
117 |
+
log_handler.setFormatter(jsonlogger.JsonFormatter('%(asctime)s %(name)s %(levelname)s %(message)s'))
|
118 |
+
logger.addHandler(log_handler)
|
119 |
+
|
120 |
+
meta_prompt_graph = MetaPromptGraph(llms=llms, verbose=config.verbose, logger=logger)
|
121 |
output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
|
122 |
+
|
123 |
+
if config.verbose:
|
124 |
+
log_handler.close()
|
125 |
+
log_output = log_stream.getvalue()
|
126 |
+
else:
|
127 |
+
log_output = None
|
128 |
|
129 |
# Validate the output state
|
130 |
system_message = ''
|
|
|
146 |
else:
|
147 |
analysis = "Error: The output state does not contain a valid 'analysis'"
|
148 |
|
149 |
+
return system_message, output, analysis, chat_log_2_chatbot_list(log_output)
|
150 |
|
151 |
|
152 |
def process_message_with_single_llm(user_message, expected_output, acceptance_criteria, initial_system_message,
|
153 |
+
recursion_limit: int, max_output_age: int,
|
154 |
+
model_name: str):
|
155 |
# Get the output state from MetaPromptGraph
|
156 |
type = config.llms[model_name].type
|
157 |
args = config.llms[model_name].model_dump(exclude={'type'})
|
158 |
llm = llm_model_factory.create(type, **args)
|
159 |
|
160 |
return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
161 |
+
recursion_limit, max_output_age, llm)
|
162 |
|
163 |
def process_message_with_2_llms(user_message, expected_output, acceptance_criteria, initial_system_message,
|
164 |
+
recursion_limit: int, max_output_age: int,
|
165 |
+
optimizer_model_name: str, executor_model_name: str,):
|
166 |
# Get the output state from MetaPromptGraph
|
167 |
optimizer_model = llm_model_factory.create(config.llms[optimizer_model_name].type,
|
168 |
**config.llms[optimizer_model_name].model_dump(exclude={'type'}))
|
|
|
178 |
}
|
179 |
|
180 |
return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
|
181 |
+
recursion_limit, max_output_age, llms)
|
182 |
|
183 |
class FileConfig(BaseConfig):
|
184 |
config_file: str = 'config.yml' # default path
|
|
|
195 |
CLArgSource()
|
196 |
]
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
config = MetaPromptConfig(config_sources=config_sources)
|
199 |
|
200 |
flagging_callback = SimplifiedCSVLogger()
|
201 |
|
202 |
# Create a Gradio Blocks context
|
203 |
+
with gr.Blocks(title='Meta Prompt') as demo:
|
204 |
# Define the layout
|
205 |
with gr.Row():
|
206 |
gr.Markdown(f"""<h1 style='text-align: left; margin-bottom: 1rem'>Meta Prompt</h1>
|
|
|
218 |
label="Initial System Message", show_copy_button=True, value="")
|
219 |
recursion_limit_input = gr.Number(label="Recursion Limit", value=config.recursion_limit,
|
220 |
precision=0, minimum=1, maximum=config.recursion_limit_max, step=1)
|
221 |
+
max_output_age = gr.Number(label="Max Output Age", value=config.max_output_age,
|
222 |
+
precision=0, minimum=1, maximum=config.max_output_age_max, step=1)
|
223 |
with gr.Row():
|
224 |
with gr.Tab('Simple'):
|
225 |
model_name_input = gr.Dropdown(
|
|
|
230 |
# Connect the inputs and outputs to the function
|
231 |
with gr.Row():
|
232 |
submit_button = gr.Button(value="Submit", variant="primary")
|
233 |
+
clear_button = gr.ClearButton(
|
234 |
+
[user_message_input, expected_output_input,
|
235 |
+
acceptance_criteria_input, initial_system_message_input],
|
236 |
+
value='Clear All')
|
237 |
with gr.Tab('Advanced'):
|
238 |
optimizer_model_name_input = gr.Dropdown(
|
239 |
label="Optimizer Model Name",
|
|
|
248 |
# Connect the inputs and outputs to the function
|
249 |
with gr.Row():
|
250 |
multiple_submit_button = gr.Button(value="Submit", variant="primary")
|
251 |
+
multiple_clear_button = gr.ClearButton(components=[user_message_input,
|
252 |
+
expected_output_input,
|
253 |
+
acceptance_criteria_input,
|
254 |
+
initial_system_message_input],
|
255 |
+
value='Clear All')
|
256 |
with gr.Column():
|
257 |
system_message_output = gr.Textbox(
|
258 |
label="System Message", show_copy_button=True)
|
|
|
260 |
analysis_output = gr.Textbox(
|
261 |
label="Analysis", show_copy_button=True)
|
262 |
flag_button = gr.Button(value="Flag", variant="secondary", visible=config.allow_flagging)
|
263 |
+
with gr.Accordion("Details", open=False, visible=config.verbose):
|
264 |
+
logs_chatbot = gr.Chatbot(label='Messages', show_copy_button=True, layout='bubble',
|
265 |
+
bubble_full_width=False, render_markdown=False)
|
266 |
+
clear_logs_button = gr.ClearButton([logs_chatbot], value='Clear Logs')
|
267 |
+
|
268 |
+
clear_button.add([system_message_output, output_output,
|
269 |
+
analysis_output, logs_chatbot])
|
270 |
+
multiple_clear_button.add([system_message_output, output_output,
|
271 |
+
analysis_output, logs_chatbot])
|
272 |
|
273 |
submit_button.click(process_message_with_single_llm,
|
274 |
inputs=[user_message_input, expected_output_input, acceptance_criteria_input,
|
275 |
+
initial_system_message_input, recursion_limit_input, max_output_age,
|
276 |
+
model_name_input],
|
277 |
+
outputs=[system_message_output, output_output, analysis_output, logs_chatbot])
|
|
|
278 |
multiple_submit_button.click(process_message_with_2_llms,
|
279 |
inputs=[user_message_input, expected_output_input, acceptance_criteria_input,
|
280 |
+
initial_system_message_input, recursion_limit_input, max_output_age,
|
281 |
optimizer_model_name_input, executor_model_name_input],
|
282 |
+
outputs=[system_message_output, output_output, analysis_output, logs_chatbot])
|
|
|
|
|
283 |
|
284 |
# Load examples
|
285 |
examples = config.examples_path
|
config.yml
CHANGED
@@ -27,7 +27,7 @@ llms:
|
|
27 |
model_name: "mixtral-8x7b-32768"
|
28 |
openai_api_base: "https://api.groq.com/openai/v1"
|
29 |
max_tokens: 8192
|
30 |
-
verbose: true
|
31 |
# anthropic/claude-3-haiku:
|
32 |
# type: ChatOpenAI
|
33 |
# temperature: 0.1
|
@@ -58,4 +58,5 @@ server_name: 0.0.0.0
|
|
58 |
# server_port: 7860
|
59 |
recursion_limit: 16
|
60 |
recursion_limit_max: 20
|
61 |
-
|
|
|
|
27 |
model_name: "mixtral-8x7b-32768"
|
28 |
openai_api_base: "https://api.groq.com/openai/v1"
|
29 |
max_tokens: 8192
|
30 |
+
verbose: true
|
31 |
# anthropic/claude-3-haiku:
|
32 |
# type: ChatOpenAI
|
33 |
# temperature: 0.1
|
|
|
58 |
# server_port: 7860
|
59 |
recursion_limit: 16
|
60 |
recursion_limit_max: 20
|
61 |
+
allow_flagging: false
|
62 |
+
verbose: false
|
meta_prompt/meta_prompt.py
CHANGED
@@ -77,11 +77,14 @@ Output and match Expected Output more closely.
|
|
77 |
3. Modify only the content mentioned in the Suggestion. Do not change the
|
78 |
parts that are not related to the Suggestion.
|
79 |
4. Output only the updated system message, without any additional content.
|
80 |
-
5.
|
|
|
|
|
|
|
81 |
it's OK to use some similar text as an example instead.
|
82 |
* Remove the Expected Output text or text highly similar to Expected Output
|
83 |
from System Message, if it's present.
|
84 |
-
|
85 |
(except for raw text).
|
86 |
|
87 |
## Output
|
@@ -112,12 +115,13 @@ Provide only the updated System Message, adhering to the above guidelines.
|
|
112 |
NODE_OUTPUT_HISTORY_ANALYZER: ChatPromptTemplate.from_messages([
|
113 |
("system", """You are a text comparing program. You read the Acceptance Criteria, compare the
|
114 |
compare the exptected output with two different outputs, and decide which one is
|
115 |
-
more
|
|
|
116 |
|
117 |
You output the following analysis according to the Acceptance Criteria:
|
118 |
|
119 |
* Your analysis in a Markdown list.
|
120 |
-
* The ID of the output that is more
|
121 |
Output ID, with the following format:
|
122 |
|
123 |
```
|
@@ -210,50 +214,52 @@ Provide your analysis in the following format:
|
|
210 |
NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
|
211 |
("system", """
|
212 |
Read the following inputs and outputs of an LLM prompt, and also analysis about them.
|
213 |
-
Then suggest how to improve System
|
214 |
|
215 |
-
System
|
216 |
-
```
|
217 |
-
{system_message}
|
218 |
-
```
|
219 |
-
User Message:
|
220 |
-
```
|
221 |
-
{user_message}
|
222 |
-
```
|
223 |
-
Expected Output:
|
224 |
-
```
|
225 |
-
{expected_output}
|
226 |
-
```
|
227 |
-
Actual Output:
|
228 |
-
```
|
229 |
-
{output}
|
230 |
-
```
|
231 |
-
|
232 |
-
Acceptance Criteria:
|
233 |
-
```
|
234 |
-
{acceptance_criteria}
|
235 |
-
```
|
236 |
-
|
237 |
-
Analysis:
|
238 |
-
```
|
239 |
-
{analysis}
|
240 |
-
```
|
241 |
-
|
242 |
-
* The goal is to improve the System Prompt to match the Expected Output better.
|
243 |
* Ignore all Acceptable Differences and focus on Unacceptable Differences.
|
244 |
* Suggest formal changes first, then semantic changes.
|
245 |
* Provide your suggestions in a Markdown list, nothing else. Output only the
|
246 |
suggestions related with Unacceptable Differences.
|
247 |
-
*
|
248 |
* Figue out the contexts of the System Message that conflict with the suggestions,
|
249 |
and suggest modification or deletion.
|
|
|
|
|
|
|
|
|
250 |
* Expected Output text should not appear in System Message as an example. But
|
251 |
-
it's OK to use some similar text as an example instead.
|
252 |
* Ask to remove the Expected Output text or text highly similar to Expected Output
|
253 |
from System Message, if it's present.
|
254 |
* Provide format examples or detected format name, if System Message does not.
|
255 |
* Specify the detected format name (e.g. XML, JSON, etc.) of Expected Output, if
|
256 |
System Message does not mention it.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
""")
|
258 |
])
|
259 |
}
|
@@ -272,12 +278,14 @@ Analysis:
|
|
272 |
def __init__(self,
|
273 |
llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]] = {},
|
274 |
prompts: Dict[str, ChatPromptTemplate] = {},
|
|
|
275 |
verbose = False):
|
276 |
-
self.logger = logging.getLogger(__name__)
|
277 |
-
if
|
278 |
-
|
279 |
-
|
280 |
-
|
|
|
281 |
|
282 |
if isinstance(llms, BaseLanguageModel):
|
283 |
self.llms: Dict[str, BaseLanguageModel] = {node: llms for node in self.get_node_names()}
|
@@ -374,32 +382,36 @@ Analysis:
|
|
374 |
return state
|
375 |
|
376 |
def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
|
|
|
377 |
prompt = self.prompt_templates[node].format_messages(**state.model_dump())
|
378 |
|
379 |
-
|
|
|
380 |
response = self.llms[node].invoke(self.prompt_templates[node].format_messages(**state.model_dump()))
|
381 |
-
|
382 |
|
383 |
setattr(state, target_attribute, response.content)
|
384 |
return state
|
385 |
|
386 |
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
|
|
|
|
387 |
if state.best_output is None:
|
388 |
state.best_output = state.output
|
389 |
state.best_system_message = state.system_message
|
390 |
state.best_output_age = 0
|
391 |
|
392 |
-
|
393 |
|
394 |
return state
|
395 |
|
396 |
prompt = self.prompt_templates[self.NODE_OUTPUT_HISTORY_ANALYZER].format_messages(**state.model_dump())
|
397 |
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
response = self.llms[self.NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
|
402 |
-
|
403 |
|
404 |
analysis = response.content
|
405 |
|
@@ -408,35 +420,44 @@ Analysis:
|
|
408 |
state.best_system_message = state.system_message
|
409 |
state.best_output_age = 0
|
410 |
|
411 |
-
|
412 |
else:
|
413 |
state.best_output_age += 1
|
414 |
|
415 |
-
|
416 |
|
417 |
return state
|
418 |
|
419 |
def _prompt_analyzer(self, state: AgentState) -> AgentState:
|
|
|
420 |
prompt = self.prompt_templates[self.NODE_PROMPT_ANALYZER].format_messages(**state.model_dump())
|
421 |
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
response = self.llms[self.NODE_PROMPT_ANALYZER].invoke(prompt)
|
426 |
-
|
427 |
|
428 |
state.analysis = response.content
|
429 |
state.accepted = "Accept: Yes" in response.content
|
430 |
|
431 |
-
|
432 |
|
433 |
return state
|
434 |
|
435 |
def _should_exit_on_max_age(self, state: AgentState) -> str:
|
436 |
-
if state.max_output_age <=
|
|
|
437 |
return "continue"
|
438 |
-
|
439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
def _should_exit_on_acceptable_output(self, state: AgentState) -> str:
|
442 |
return "continue" if not state.accepted else END
|
|
|
77 |
3. Modify only the content mentioned in the Suggestion. Do not change the
|
78 |
parts that are not related to the Suggestion.
|
79 |
4. Output only the updated system message, without any additional content.
|
80 |
+
5. Avoiding the behavior should be explicitly requested (e.g. `Don't ...`) in the
|
81 |
+
System Message, if the behavior is: asked to be avoid by the Suggestions;
|
82 |
+
but not mentioned in the Current System Message.
|
83 |
+
6. Expected Output text should not appear in System Message as an example. But
|
84 |
it's OK to use some similar text as an example instead.
|
85 |
* Remove the Expected Output text or text highly similar to Expected Output
|
86 |
from System Message, if it's present.
|
87 |
+
7. Format the system message well, with no more than 80 characters per line
|
88 |
(except for raw text).
|
89 |
|
90 |
## Output
|
|
|
115 |
NODE_OUTPUT_HISTORY_ANALYZER: ChatPromptTemplate.from_messages([
|
116 |
("system", """You are a text comparing program. You read the Acceptance Criteria, compare the
|
117 |
compare the exptected output with two different outputs, and decide which one is
|
118 |
+
more consistent with the expected output. When comparing the outputs, ignore the
|
119 |
+
differences which are acceptable or ignorable according to the Acceptance Criteria.
|
120 |
|
121 |
You output the following analysis according to the Acceptance Criteria:
|
122 |
|
123 |
* Your analysis in a Markdown list.
|
124 |
+
* The ID of the output that is more consistent with the Expected Output as Preferred
|
125 |
Output ID, with the following format:
|
126 |
|
127 |
```
|
|
|
214 |
NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
|
215 |
("system", """
|
216 |
Read the following inputs and outputs of an LLM prompt, and also analysis about them.
|
217 |
+
Then suggest how to improve System Message.
|
218 |
|
219 |
+
* The goal is to improve the System Message to match the Expected Output better.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
* Ignore all Acceptable Differences and focus on Unacceptable Differences.
|
221 |
* Suggest formal changes first, then semantic changes.
|
222 |
* Provide your suggestions in a Markdown list, nothing else. Output only the
|
223 |
suggestions related with Unacceptable Differences.
|
224 |
+
* Start every suggestion with `The System Message should ...`.
|
225 |
* Figue out the contexts of the System Message that conflict with the suggestions,
|
226 |
and suggest modification or deletion.
|
227 |
+
* Avoiding the behavior should be explicitly requested (e.g. `The System Message
|
228 |
+
should explicitly state that the output shoud not ...`) in the System Message, if
|
229 |
+
the behavior is: asked to be removed by the Suggestions; appeared in the Actual
|
230 |
+
Output; but not mentioned in the Current System Message.
|
231 |
* Expected Output text should not appear in System Message as an example. But
|
232 |
+
it's OK to use some similar but distinct text as an example instead.
|
233 |
* Ask to remove the Expected Output text or text highly similar to Expected Output
|
234 |
from System Message, if it's present.
|
235 |
* Provide format examples or detected format name, if System Message does not.
|
236 |
* Specify the detected format name (e.g. XML, JSON, etc.) of Expected Output, if
|
237 |
System Message does not mention it.
|
238 |
+
"""),
|
239 |
+
("human", """
|
240 |
+
<|Start_System_Message|>
|
241 |
+
{system_message}
|
242 |
+
<|End_System_Message|>
|
243 |
+
|
244 |
+
<|Start_User_Message|>
|
245 |
+
{user_message}
|
246 |
+
<|End_User_Message|>
|
247 |
+
|
248 |
+
<|Start_Expected_Output|>
|
249 |
+
{expected_output}
|
250 |
+
<|End_Expected_Output|>
|
251 |
+
|
252 |
+
<|Start_Actual_Output|>
|
253 |
+
{output}
|
254 |
+
<|End_Actual_Output|>
|
255 |
+
|
256 |
+
<|Start_Acceptance Criteria|>
|
257 |
+
{acceptance_criteria}
|
258 |
+
<|End_Acceptance Criteria|>
|
259 |
+
|
260 |
+
<|Start_Analysis|>
|
261 |
+
{analysis}
|
262 |
+
<|End_Analysis|>
|
263 |
""")
|
264 |
])
|
265 |
}
|
|
|
278 |
def __init__(self,
|
279 |
llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]] = {},
|
280 |
prompts: Dict[str, ChatPromptTemplate] = {},
|
281 |
+
logger: Optional[logging.Logger] = None,
|
282 |
verbose = False):
|
283 |
+
self.logger = logger or logging.getLogger(__name__)
|
284 |
+
if self.logger is not None:
|
285 |
+
if verbose:
|
286 |
+
self.logger.setLevel(logging.DEBUG)
|
287 |
+
else:
|
288 |
+
self.logger.setLevel(logging.INFO)
|
289 |
|
290 |
if isinstance(llms, BaseLanguageModel):
|
291 |
self.llms: Dict[str, BaseLanguageModel] = {node: llms for node in self.get_node_names()}
|
|
|
382 |
return state
|
383 |
|
384 |
def _prompt_node(self, node, target_attribute: str, state: AgentState) -> AgentState:
|
385 |
+
logger = self.logger.getChild(node)
|
386 |
prompt = self.prompt_templates[node].format_messages(**state.model_dump())
|
387 |
|
388 |
+
for message in prompt:
|
389 |
+
logger.debug({'node': node, 'action': 'invoke', 'type': message.type, 'message': message.content})
|
390 |
response = self.llms[node].invoke(self.prompt_templates[node].format_messages(**state.model_dump()))
|
391 |
+
logger.debug({'node': node, 'action': 'response', 'type': response.type, 'message': response.content})
|
392 |
|
393 |
setattr(state, target_attribute, response.content)
|
394 |
return state
|
395 |
|
396 |
def _output_history_analyzer(self, state: AgentState) -> AgentState:
|
397 |
+
logger = self.logger.getChild(self.NODE_OUTPUT_HISTORY_ANALYZER)
|
398 |
+
|
399 |
if state.best_output is None:
|
400 |
state.best_output = state.output
|
401 |
state.best_system_message = state.system_message
|
402 |
state.best_output_age = 0
|
403 |
|
404 |
+
logger.debug("Best output initialized to the current output: \n %s", state.output)
|
405 |
|
406 |
return state
|
407 |
|
408 |
prompt = self.prompt_templates[self.NODE_OUTPUT_HISTORY_ANALYZER].format_messages(**state.model_dump())
|
409 |
|
410 |
+
for message in prompt:
|
411 |
+
logger.debug({'node': self.NODE_OUTPUT_HISTORY_ANALYZER, 'action': 'invoke', 'type': message.type, 'message': message.content})
|
412 |
+
|
413 |
response = self.llms[self.NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
|
414 |
+
logger.debug({'node': self.NODE_OUTPUT_HISTORY_ANALYZER, 'action': 'response', 'type': response.type, 'message': response.content})
|
415 |
|
416 |
analysis = response.content
|
417 |
|
|
|
420 |
state.best_system_message = state.system_message
|
421 |
state.best_output_age = 0
|
422 |
|
423 |
+
logger.debug("Best output updated to the current output: \n %s", state.output)
|
424 |
else:
|
425 |
state.best_output_age += 1
|
426 |
|
427 |
+
logger.debug("Best output age incremented to %s", state.best_output_age)
|
428 |
|
429 |
return state
|
430 |
|
431 |
def _prompt_analyzer(self, state: AgentState) -> AgentState:
|
432 |
+
logger = self.logger.getChild(self.NODE_PROMPT_ANALYZER)
|
433 |
prompt = self.prompt_templates[self.NODE_PROMPT_ANALYZER].format_messages(**state.model_dump())
|
434 |
|
435 |
+
for message in prompt:
|
436 |
+
logger.debug({'node': self.NODE_PROMPT_ANALYZER, 'action': 'invoke', 'type': message.type, 'message': message.content})
|
437 |
+
|
438 |
response = self.llms[self.NODE_PROMPT_ANALYZER].invoke(prompt)
|
439 |
+
logger.debug({'node': self.NODE_PROMPT_ANALYZER, 'action': 'response', 'type': response.type, 'message': response.content})
|
440 |
|
441 |
state.analysis = response.content
|
442 |
state.accepted = "Accept: Yes" in response.content
|
443 |
|
444 |
+
logger.debug("Accepted: %s", state.accepted)
|
445 |
|
446 |
return state
|
447 |
|
448 |
def _should_exit_on_max_age(self, state: AgentState) -> str:
|
449 |
+
if state.max_output_age <=0:
|
450 |
+
# always continue if max age is 0
|
451 |
return "continue"
|
452 |
+
|
453 |
+
if state.best_output_age >= state.max_output_age:
|
454 |
+
return END
|
455 |
+
|
456 |
+
if state.best_output_age > 0:
|
457 |
+
# skip prompt_analyzer and prompt_suggester, goto prompt_developer
|
458 |
+
return "rerun"
|
459 |
+
|
460 |
+
return "continue"
|
461 |
|
462 |
def _should_exit_on_acceptable_output(self, state: AgentState) -> str:
|
463 |
return "continue" if not state.accepted else END
|
poetry.lock
CHANGED
@@ -2160,6 +2160,17 @@ files = [
|
|
2160 |
[package.extras]
|
2161 |
cli = ["click (>=5.0)"]
|
2162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2163 |
[[package]]
|
2164 |
name = "python-multipart"
|
2165 |
version = "0.0.9"
|
@@ -3290,4 +3301,4 @@ multidict = ">=4.0"
|
|
3290 |
[metadata]
|
3291 |
lock-version = "2.0"
|
3292 |
python-versions = "^3.10"
|
3293 |
-
content-hash = "
|
|
|
2160 |
[package.extras]
|
2161 |
cli = ["click (>=5.0)"]
|
2162 |
|
2163 |
+
[[package]]
|
2164 |
+
name = "python-json-logger"
|
2165 |
+
version = "2.0.7"
|
2166 |
+
description = "A python library adding a json log formatter"
|
2167 |
+
optional = false
|
2168 |
+
python-versions = ">=3.6"
|
2169 |
+
files = [
|
2170 |
+
{file = "python-json-logger-2.0.7.tar.gz", hash = "sha256:23e7ec02d34237c5aa1e29a070193a4ea87583bb4e7f8fd06d3de8264c4b2e1c"},
|
2171 |
+
{file = "python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd"},
|
2172 |
+
]
|
2173 |
+
|
2174 |
[[package]]
|
2175 |
name = "python-multipart"
|
2176 |
version = "0.0.9"
|
|
|
3301 |
[metadata]
|
3302 |
lock-version = "2.0"
|
3303 |
python-versions = "^3.10"
|
3304 |
+
content-hash = "dca41322d7cd0e10cb5cd3e5748b9f505a2d29f00ca78816e397c1c1f5e9d693"
|
pyproject.toml
CHANGED
@@ -11,6 +11,7 @@ langgraph = "^0.1.5"
|
|
11 |
langchain = "^0.2.6"
|
12 |
langchain-openai = "^0.1.14"
|
13 |
pydantic = "^2.8.2"
|
|
|
14 |
|
15 |
[tool.poetry.dev-dependencies]
|
16 |
gradio = "^4.37.2"
|
|
|
11 |
langchain = "^0.2.6"
|
12 |
langchain-openai = "^0.1.14"
|
13 |
pydantic = "^2.8.2"
|
14 |
+
python-json-logger = "^2.0.7"
|
15 |
|
16 |
[tool.poetry.dev-dependencies]
|
17 |
gradio = "^4.37.2"
|