yaleh commited on
Commit
b10c78f
·
1 Parent(s): 68c6b73

Gradio demo works with Confz now.

Browse files
.gitignore CHANGED
@@ -2,3 +2,4 @@
2
  .vscode
3
  __pycache__
4
  .env
 
 
2
  .vscode
3
  __pycache__
4
  .env
5
+ config.yml.debug
config.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ llms:
2
+ anthropic/claude-3-haiku:
3
+ type: ChatOpenAI
4
+ temperature: 0.1
5
+ model_name: "anthropic/claude-3-haiku:beta"
6
+ openai_api_key: ""
7
+ openai_api_base: "https://openrouter.ai/api/v1"
8
+ max_tokens: 8192
9
+ verbose: true
10
+ anthropic/claude-3-sonnet:
11
+ type: ChatOpenAI
12
+ temperature: 0.1
13
+ model_name: "anthropic/claude-3-sonnet:beta"
14
+ openai_api_key: ""
15
+ openai_api_base: "https://openrouter.ai/api/v1"
16
+ max_tokens: 8192
17
+ verbose: true
18
+ anthropic/deepseek-chat:
19
+ type: ChatOpenAI
20
+ temperature: 0.1
21
+ model_name: "deepseek/deepseek-chat"
22
+ openai_api_key: ""
23
+ openai_api_base: "https://openrouter.ai/api/v1"
24
+ max_tokens: 8192
25
+ verbose: true
26
+ groq/llama3-70b-8192:
27
+ type: ChatOpenAI
28
+ temperature: 0.1
29
+ model_name: "llama3-70b-8192"
30
+ openai_api_key: ""
31
+ openai_api_base: "https://api.groq.com/openai/v1"
32
+ max_tokens: 8192
33
+ verbose: true
34
+
35
+ examples_path: "demo/examples"
36
+ server_name: 0.0.0.0
37
+ server_port: 7870
demo/config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ from confz import BaseConfig
3
+ from pydantic import BaseModel, Extra
4
+ from typing import Optional
5
+
6
+ class LLMConfig(BaseModel):
7
+ type: str
8
+
9
+ class Config:
10
+ extra = Extra.allow
11
+
12
+ class MetaPromptConfig(BaseConfig):
13
+ llms: Optional[dict[str, LLMConfig]]
14
+ examples_path: Optional[str]
15
+ server_name: Optional[str] = '127.0.0.1'
16
+ server_port: Optional[int] = 7878
demo/examples/log.csv CHANGED
@@ -188,10 +188,12 @@ What is the meaning of life?,"[
188
  * Data types and formats of all JSON fields
189
  * Top layer sections
190
  * Acceptable differences:
191
- * Differences in field values
 
192
  * Extra or missing spaces
193
  * Extra or missing line breaks at the beginning or end of the output
194
- * JSON wrapped in backquotes"
 
195
  "<?php
196
  $username = $_POST['username'];
197
  $password = $_POST['password'];
 
188
  * Data types and formats of all JSON fields
189
  * Top layer sections
190
  * Acceptable differences:
191
+ * Different personas or prompts
192
+ * Different numbers of personas
193
  * Extra or missing spaces
194
  * Extra or missing line breaks at the beginning or end of the output
195
+ * Unacceptable:
196
+ * Showing the personas in Expected Output in System Message"
197
  "<?php
198
  $username = $_POST['username'];
199
  $password = $_POST['password'];
demo/gradio_meta_prompt.py CHANGED
@@ -1,31 +1,40 @@
1
  import gradio as gr
 
2
  from meta_prompt import MetaPromptGraph, AgentState
3
  from langchain_openai import ChatOpenAI
 
4
 
5
- # Initialize the MetaPromptGraph with the required LLMs
6
- MODEL_NAME = "anthropic/claude-3.5-sonnet:haiku"
7
- # MODEL_NAME = "meta-llama/llama-3-70b-instruct"
8
- # MODEL_NAME = "deepseek/deepseek-chat"
9
- # MODEL_NAME = "google/gemma-2-9b-it"
10
- # MODEL_NAME = "recursal/eagle-7b"
11
- # MODEL_NAME = "meta-llama/llama-3-8b-instruct"
12
- llm = ChatOpenAI(model_name=MODEL_NAME)
13
- meta_prompt_graph = MetaPromptGraph(llms=llm)
14
-
15
- def process_message(user_message, expected_output, acceptance_criteria, recursion_limit: int=25):
 
16
  # Create the input state
17
  input_state = AgentState(
18
  user_message=user_message,
19
  expected_output=expected_output,
20
- acceptance_criteria=acceptance_criteria
 
21
  )
22
 
23
  # Get the output state from MetaPromptGraph
 
 
 
 
24
  output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
25
 
26
  # Validate the output state
27
  system_message = ''
28
  output = ''
 
29
 
30
  if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
31
  system_message = output_state['best_system_message']
@@ -37,22 +46,58 @@ def process_message(user_message, expected_output, acceptance_criteria, recursio
37
  else:
38
  output = "Error: The output state does not contain a valid 'best_output'"
39
 
40
- return system_message, output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Create the Gradio interface
43
  iface = gr.Interface(
44
  fn=process_message,
45
  inputs=[
46
- gr.Textbox(label="User Message"),
47
- gr.Textbox(label="Expected Output"),
48
- gr.Textbox(label="Acceptance Criteria"),
49
- gr.Number(label="Recursion Limit", value=25, precision=0, minimum=1, maximum=100, step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  ],
51
- outputs=[gr.Textbox(label="System Message"), gr.Textbox(label="Output")],
52
  title="MetaPromptGraph Chat Interface",
53
  description="A chat interface for MetaPromptGraph to process user inputs and generate system messages.",
54
- examples="demo/examples"
55
  )
56
 
57
  # Launch the Gradio app
58
- iface.launch()
 
1
  import gradio as gr
2
+ from confz import BaseConfig, CLArgSource, EnvSource, FileSource
3
  from meta_prompt import MetaPromptGraph, AgentState
4
  from langchain_openai import ChatOpenAI
5
+ from config import MetaPromptConfig
6
 
7
+ class LLMModelFactory:
8
+ def __init__(self):
9
+ pass
10
+
11
+ def create(self, model_type: str, **kwargs):
12
+ model_class = globals()[model_type]
13
+ return model_class(**kwargs)
14
+
15
+ llm_model_factory = LLMModelFactory()
16
+
17
+ def process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
18
+ recursion_limit: int, model_name: str):
19
  # Create the input state
20
  input_state = AgentState(
21
  user_message=user_message,
22
  expected_output=expected_output,
23
+ acceptance_criteria=acceptance_criteria,
24
+ system_message=initial_system_message
25
  )
26
 
27
  # Get the output state from MetaPromptGraph
28
+ type = config.llms[model_name].type
29
+ args = config.llms[model_name].model_dump(exclude={'type'})
30
+ llm = llm_model_factory.create(type, **args)
31
+ meta_prompt_graph = MetaPromptGraph(llms=llm)
32
  output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
33
 
34
  # Validate the output state
35
  system_message = ''
36
  output = ''
37
+ analysis = ''
38
 
39
  if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
40
  system_message = output_state['best_system_message']
 
46
  else:
47
  output = "Error: The output state does not contain a valid 'best_output'"
48
 
49
+ if 'analysis' in output_state and output_state['analysis'] is not None:
50
+ analysis = output_state['analysis']
51
+ else:
52
+ analysis = "Error: The output state does not contain a valid 'analysis'"
53
+
54
+ return system_message, output, analysis
55
+
56
+ class FileConfig(BaseConfig):
57
+ config_file: str = 'config.yml' # default path
58
+
59
+ pre_config_sources = [
60
+ EnvSource(prefix='METAPROMPT_', allow_all=True),
61
+ CLArgSource()
62
+ ]
63
+ pre_config = FileConfig(config_sources=pre_config_sources)
64
+
65
+ config_sources = [
66
+ FileSource(file=pre_config.config_file, optional=True),
67
+ EnvSource(prefix='METAPROMPT_', allow_all=True),
68
+ CLArgSource()
69
+ ]
70
+
71
+ config = MetaPromptConfig(config_sources=config_sources)
72
 
73
  # Create the Gradio interface
74
  iface = gr.Interface(
75
  fn=process_message,
76
  inputs=[
77
+ gr.Textbox(label="User Message", show_copy_button=True),
78
+ gr.Textbox(label="Expected Output", show_copy_button=True),
79
+ gr.Textbox(label="Acceptance Criteria", show_copy_button=True),
80
+ ],
81
+ outputs=[
82
+ gr.Textbox(label="System Message", show_copy_button=True),
83
+ gr.Textbox(label="Output", show_copy_button=True),
84
+ gr.Textbox(label="Analysis", show_copy_button=True)
85
+ ],
86
+ additional_inputs=[
87
+ gr.Textbox(label="Initial System Message", show_copy_button=True, value=""),
88
+ gr.Number(label="Recursion Limit", value=25,
89
+ precision=0, minimum=1, maximum=100, step=1),
90
+ gr.Dropdown(
91
+ label="Model Name",
92
+ choices=config.llms.keys(),
93
+ value=list(config.llms.keys())[0],
94
+ )
95
  ],
96
+ # stop_btn = gr.Button("Stop", variant="stop", visible=True),
97
  title="MetaPromptGraph Chat Interface",
98
  description="A chat interface for MetaPromptGraph to process user inputs and generate system messages.",
99
+ examples=config.examples_path
100
  )
101
 
102
  # Launch the Gradio app
103
+ iface.launch(server_name=config.server_name, server_port=config.server_port)
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  aiofiles==23.2.1
2
- aiohttp==3.8.5
3
  aiosignal==1.3.1
4
  altair==5.1.1
5
  annotated-types==0.5.0
@@ -11,6 +11,7 @@ certifi==2023.7.22
11
  charset-normalizer==3.2.0
12
  click==8.1.7
13
  comm==0.2.2
 
14
  contourpy==1.1.1
15
  cycler==0.11.0
16
  dataclasses-json==0.6.0
@@ -25,16 +26,17 @@ filelock==3.12.4
25
  fonttools==4.42.1
26
  frozenlist==1.4.0
27
  fsspec==2023.9.2
28
- gradio==3.44.4
29
- gradio_client==0.5.1
30
  greenlet==2.0.2
31
  h11==0.14.0
32
  httpcore==0.18.0
33
  httpx==0.25.0
34
- huggingface-hub==0.17.2
35
  idna==3.4
36
  importlib-resources==6.1.0
37
  ipykernel==6.29.4
 
38
  jedi==0.19.1
39
  Jinja2==3.1.2
40
  joblib==1.3.2
@@ -51,10 +53,12 @@ langchain-openai==0.1.13
51
  langchain-text-splitters==0.2.2
52
  langgraph==0.1.4
53
  langsmith==0.1.82
 
54
  MarkupSafe==2.1.3
55
  marshmallow==3.20.1
56
  matplotlib==3.8.0
57
  matplotlib-inline==0.1.7
 
58
  multidict==6.0.4
59
  mypy-extensions==1.0.0
60
  nest-asyncio==1.6.0
@@ -78,17 +82,21 @@ pydub==0.25.1
78
  Pygments==2.18.0
79
  pyparsing==3.1.1
80
  python-dateutil==2.9.0.post0
81
- python-multipart==0.0.6
 
82
  pytz==2023.3.post1
83
  PyYAML==6.0.1
84
  pyzmq==26.0.3
85
  referencing==0.30.2
86
  regex==2024.5.15
87
  requests==2.31.0
 
88
  rpds-py==0.10.3
 
89
  scikit-learn==1.3.1
90
  scipy==1.11.3
91
  semantic-version==2.10.0
 
92
  six==1.16.0
93
  sniffio==1.3.0
94
  SQLAlchemy==2.0.21
@@ -97,10 +105,13 @@ starlette==0.27.0
97
  tenacity==8.2.3
98
  threadpoolctl==3.2.0
99
  tiktoken==0.7.0
 
 
100
  toolz==0.12.0
101
  tornado==6.4.1
102
  tqdm==4.66.1
103
  traitlets==5.14.3
 
104
  typing-inspect==0.9.0
105
  typing_extensions==4.12.2
106
  tzdata==2023.3
 
1
  aiofiles==23.2.1
2
+ aiohttp==3.9.5
3
  aiosignal==1.3.1
4
  altair==5.1.1
5
  annotated-types==0.5.0
 
11
  charset-normalizer==3.2.0
12
  click==8.1.7
13
  comm==0.2.2
14
+ confz==2.0.1
15
  contourpy==1.1.1
16
  cycler==0.11.0
17
  dataclasses-json==0.6.0
 
26
  fonttools==4.42.1
27
  frozenlist==1.4.0
28
  fsspec==2023.9.2
29
+ gradio==4.37.2
30
+ gradio_client==1.0.2
31
  greenlet==2.0.2
32
  h11==0.14.0
33
  httpcore==0.18.0
34
  httpx==0.25.0
35
+ huggingface-hub==0.23.4
36
  idna==3.4
37
  importlib-resources==6.1.0
38
  ipykernel==6.29.4
39
+ ipython==8.26.0
40
  jedi==0.19.1
41
  Jinja2==3.1.2
42
  joblib==1.3.2
 
53
  langchain-text-splitters==0.2.2
54
  langgraph==0.1.4
55
  langsmith==0.1.82
56
+ markdown-it-py==3.0.0
57
  MarkupSafe==2.1.3
58
  marshmallow==3.20.1
59
  matplotlib==3.8.0
60
  matplotlib-inline==0.1.7
61
+ mdurl==0.1.2
62
  multidict==6.0.4
63
  mypy-extensions==1.0.0
64
  nest-asyncio==1.6.0
 
82
  Pygments==2.18.0
83
  pyparsing==3.1.1
84
  python-dateutil==2.9.0.post0
85
+ python-dotenv==1.0.1
86
+ python-multipart==0.0.9
87
  pytz==2023.3.post1
88
  PyYAML==6.0.1
89
  pyzmq==26.0.3
90
  referencing==0.30.2
91
  regex==2024.5.15
92
  requests==2.31.0
93
+ rich==13.7.1
94
  rpds-py==0.10.3
95
+ ruff==0.5.0
96
  scikit-learn==1.3.1
97
  scipy==1.11.3
98
  semantic-version==2.10.0
99
+ shellingham==1.5.4
100
  six==1.16.0
101
  sniffio==1.3.0
102
  SQLAlchemy==2.0.21
 
105
  tenacity==8.2.3
106
  threadpoolctl==3.2.0
107
  tiktoken==0.7.0
108
+ toml==0.10.2
109
+ tomlkit==0.12.0
110
  toolz==0.12.0
111
  tornado==6.4.1
112
  tqdm==4.66.1
113
  traitlets==5.14.3
114
+ typer==0.12.3
115
  typing-inspect==0.9.0
116
  typing_extensions==4.12.2
117
  tzdata==2023.3
src/meta_prompt/meta_prompt.py CHANGED
@@ -164,23 +164,9 @@ If both outputs are equally similar to the expected output, output the following
164
  ]),
165
  NODE_PROMPT_ANALYZER: ChatPromptTemplate.from_messages([
166
  ("system", """
167
- You are a text comparing program. You compare the following output texts and provide a
168
- detailed analysis according to `Acceptance Criteria`. Then you decide whether `Actual Output`
169
- is acceptable.
170
-
171
- # Expected Output
172
-
173
- ```
174
- {expected_output}
175
- ```
176
-
177
- # Actual Output
178
-
179
- ```
180
- {output}
181
- ```
182
-
183
- ----
184
 
185
  Provide your analysis in the following format:
186
 
@@ -200,6 +186,25 @@ Provide your analysis in the following format:
200
  ```
201
  {acceptance_criteria}
202
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  """)
204
  ]),
205
  NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
@@ -281,42 +286,34 @@ Analysis:
281
  self.prompt_templates: Dict[str, ChatPromptTemplate] = self.DEFAULT_PROMPT_TEMPLATES.copy()
282
  self.prompt_templates.update(prompts)
283
 
284
- # create workflow
285
- self.workflow = StateGraph(AgentState)
286
-
287
- self.workflow.add_node(self.NODE_PROMPT_INITIAL_DEVELOPER,
288
- lambda x: self._prompt_node(
289
- self.NODE_PROMPT_INITIAL_DEVELOPER,
290
- "system_message",
291
- x))
292
- self.workflow.add_node(self.NODE_PROMPT_DEVELOPER,
293
- lambda x: self._prompt_node(
294
- self.NODE_PROMPT_DEVELOPER,
295
- "system_message",
296
- x))
297
- self.workflow.add_node(self.NODE_PROMPT_EXECUTOR,
298
- lambda x: self._prompt_node(
299
- self.NODE_PROMPT_EXECUTOR,
300
- "output",
301
- x))
302
- self.workflow.add_node(self.NODE_OUTPUT_HISTORY_ANALYZER,
303
- lambda x: self._output_history_analyzer(x))
304
- self.workflow.add_node(self.NODE_PROMPT_ANALYZER,
305
- lambda x: self._prompt_analyzer(x))
306
- self.workflow.add_node(self.NODE_PROMPT_SUGGESTER,
307
- lambda x: self._prompt_node(
308
- self.NODE_PROMPT_SUGGESTER,
309
- "suggestions",
310
- x))
311
-
312
- self.workflow.set_entry_point(self.NODE_PROMPT_INITIAL_DEVELOPER)
313
-
314
- self.workflow.add_edge(self.NODE_PROMPT_INITIAL_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
315
- self.workflow.add_edge(self.NODE_PROMPT_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
316
- self.workflow.add_edge(self.NODE_PROMPT_EXECUTOR, self.NODE_OUTPUT_HISTORY_ANALYZER)
317
- self.workflow.add_edge(self.NODE_PROMPT_SUGGESTER, self.NODE_PROMPT_DEVELOPER)
318
-
319
- self.workflow.add_conditional_edges(
320
  self.NODE_OUTPUT_HISTORY_ANALYZER,
321
  lambda x: self._should_exit_on_max_age(x),
322
  {
@@ -326,7 +323,7 @@ Analysis:
326
  }
327
  )
328
 
329
- self.workflow.add_conditional_edges(
330
  self.NODE_PROMPT_ANALYZER,
331
  lambda x: self._should_exit_on_acceptable_output(x),
332
  {
@@ -335,9 +332,24 @@ Analysis:
335
  }
336
  )
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
 
 
339
  memory = MemorySaver()
340
- graph = self.workflow.compile(checkpointer=memory)
341
  config = {"configurable": {"thread_id": "1"}, "recursion_limit": recursion_limit}
342
 
343
  try:
 
164
  ]),
165
  NODE_PROMPT_ANALYZER: ChatPromptTemplate.from_messages([
166
  ("system", """
167
+ You are a text comparing program. You compare the following output texts,
168
+ analysis the System Message and provide a detailed analysis according to
169
+ `Acceptance Criteria`. Then you decide whether `Actual Output` is acceptable.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  Provide your analysis in the following format:
172
 
 
186
  ```
187
  {acceptance_criteria}
188
  ```
189
+ """),
190
+ ("human", """
191
+ # System Message
192
+
193
+ ```
194
+ {system_message}
195
+ ```
196
+
197
+ # Expected Output
198
+
199
+ ```
200
+ {expected_output}
201
+ ```
202
+
203
+ # Actual Output
204
+
205
+ ```
206
+ {output}
207
+ ```
208
  """)
209
  ]),
210
  NODE_PROMPT_SUGGESTER: ChatPromptTemplate.from_messages([
 
286
  self.prompt_templates: Dict[str, ChatPromptTemplate] = self.DEFAULT_PROMPT_TEMPLATES.copy()
287
  self.prompt_templates.update(prompts)
288
 
289
+ def _create_workflow(self, including_initial_developer: bool = True) -> StateGraph:
290
+ workflow = StateGraph(AgentState)
291
+
292
+ workflow.add_node(self.NODE_PROMPT_DEVELOPER,
293
+ lambda x: self._prompt_node(
294
+ self.NODE_PROMPT_DEVELOPER,
295
+ "system_message",
296
+ x))
297
+ workflow.add_node(self.NODE_PROMPT_EXECUTOR,
298
+ lambda x: self._prompt_node(
299
+ self.NODE_PROMPT_EXECUTOR,
300
+ "output",
301
+ x))
302
+ workflow.add_node(self.NODE_OUTPUT_HISTORY_ANALYZER,
303
+ lambda x: self._output_history_analyzer(x))
304
+ workflow.add_node(self.NODE_PROMPT_ANALYZER,
305
+ lambda x: self._prompt_analyzer(x))
306
+ workflow.add_node(self.NODE_PROMPT_SUGGESTER,
307
+ lambda x: self._prompt_node(
308
+ self.NODE_PROMPT_SUGGESTER,
309
+ "suggestions",
310
+ x))
311
+
312
+ workflow.add_edge(self.NODE_PROMPT_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
313
+ workflow.add_edge(self.NODE_PROMPT_EXECUTOR, self.NODE_OUTPUT_HISTORY_ANALYZER)
314
+ workflow.add_edge(self.NODE_PROMPT_SUGGESTER, self.NODE_PROMPT_DEVELOPER)
315
+
316
+ workflow.add_conditional_edges(
 
 
 
 
 
 
 
 
317
  self.NODE_OUTPUT_HISTORY_ANALYZER,
318
  lambda x: self._should_exit_on_max_age(x),
319
  {
 
323
  }
324
  )
325
 
326
+ workflow.add_conditional_edges(
327
  self.NODE_PROMPT_ANALYZER,
328
  lambda x: self._should_exit_on_acceptable_output(x),
329
  {
 
332
  }
333
  )
334
 
335
+ if including_initial_developer:
336
+ workflow.add_node(self.NODE_PROMPT_INITIAL_DEVELOPER,
337
+ lambda x: self._prompt_node(
338
+ self.NODE_PROMPT_INITIAL_DEVELOPER,
339
+ "system_message",
340
+ x))
341
+ workflow.add_edge(self.NODE_PROMPT_INITIAL_DEVELOPER, self.NODE_PROMPT_EXECUTOR)
342
+ workflow.set_entry_point(self.NODE_PROMPT_INITIAL_DEVELOPER)
343
+ else:
344
+ workflow.set_entry_point(self.NODE_PROMPT_EXECUTOR)
345
+
346
+ return workflow
347
+
348
  def __call__(self, state: AgentState, recursion_limit: int = 25) -> AgentState:
349
+ workflow = self._create_workflow(including_initial_developer=(state.system_message is None or state.system_message == ""))
350
+
351
  memory = MemorySaver()
352
+ graph = workflow.compile(checkpointer=memory)
353
  config = {"configurable": {"thread_id": "1"}, "recursion_limit": recursion_limit}
354
 
355
  try: