yaleh commited on
Commit
c58f0cd
·
1 Parent(s): 115910a

Update config files with new options and models. Added `Advanced` mode to chat interface.

Browse files
Files changed (3) hide show
  1. app/examples/log.csv +1 -0
  2. app/gradio_meta_prompt.py +142 -50
  3. config.yml +21 -0
app/examples/log.csv CHANGED
@@ -235,3 +235,4 @@ if (mysqli_num_rows($result) > 0) {
235
  * Extra or missing line breaks at the beginning or end of the output
236
  * YAML wrapped in backquotes",""
237
  “老爸,老爸,我们去哪里呀?”,《爸爸去哪儿》,Exactly text match.,"查询歌词出处。"
 
 
235
  * Extra or missing line breaks at the beginning or end of the output
236
  * YAML wrapped in backquotes",""
237
  “老爸,老爸,我们去哪里呀?”,《爸爸去哪儿》,Exactly text match.,"查询歌词出处。"
238
+ How do I reverse a list in Python?,Use the `[::-1]` slicing technique or the `list.reverse()` method.,"Similar in meaning, text length and style.",
app/gradio_meta_prompt.py CHANGED
@@ -1,13 +1,15 @@
1
  import csv
2
  from pathlib import Path
3
- from typing import Any
4
  import gradio as gr
5
- from gradio import CSVLogger, utils
 
6
  from gradio_client import utils as client_utils
7
  from confz import BaseConfig, CLArgSource, EnvSource, FileSource
8
  from meta_prompt import MetaPromptGraph, AgentState
9
  from langchain_openai import ChatOpenAI
10
  from app.config import MetaPromptConfig
 
11
 
12
  class SimplifiedCSVLogger(CSVLogger):
13
  """
@@ -69,7 +71,7 @@ class LLMModelFactory:
69
  llm_model_factory = LLMModelFactory()
70
 
71
  def process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
72
- recursion_limit: int, model_name: str):
73
  # Create the input state
74
  input_state = AgentState(
75
  user_message=user_message,
@@ -79,10 +81,7 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
79
  )
80
 
81
  # Get the output state from MetaPromptGraph
82
- type = config.llms[model_name].type
83
- args = config.llms[model_name].model_dump(exclude={'type'})
84
- llm = llm_model_factory.create(type, **args)
85
- meta_prompt_graph = MetaPromptGraph(llms=llm)
86
  output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
87
 
88
  # Validate the output state
@@ -107,6 +106,36 @@ def process_message(user_message, expected_output, acceptance_criteria, initial_
107
 
108
  return system_message, output, analysis
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  class FileConfig(BaseConfig):
111
  config_file: str = 'config.yml' # default path
112
 
@@ -122,54 +151,117 @@ config_sources = [
122
  CLArgSource()
123
  ]
124
 
 
 
 
 
 
 
 
 
125
  config = MetaPromptConfig(config_sources=config_sources)
126
 
127
- # Create the Gradio interface
128
- user_message_input = gr.Textbox(label="User Message", show_copy_button=True)
129
- expected_output_input = gr.Textbox(label="Expected Output", show_copy_button=True)
130
- acceptance_criteria_input = gr.Textbox(label="Acceptance Criteria", show_copy_button=True)
131
 
132
- system_message_output = gr.Textbox(label="System Message", show_copy_button=True)
133
- output_output = gr.Textbox(label="Output", show_copy_button=True)
134
- analysis_output = gr.Textbox(label="Analysis", show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- initial_system_message_input = gr.Textbox(label="Initial System Message", show_copy_button=True, value="")
137
- recursion_limit_input = gr.Number(label="Recursion Limit", value=config.recursion_limit,
138
- precision=0, minimum=1, maximum=config.recursion_limit_max, step=1)
139
- model_name_input = gr.Dropdown(
140
- label="Model Name",
141
- choices=config.llms.keys(),
142
- value=list(config.llms.keys())[0],
143
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- flagging_callback = SimplifiedCSVLogger()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- iface = gr.Interface(
148
- fn=process_message,
149
- inputs=[
150
- user_message_input,
151
- expected_output_input,
152
- acceptance_criteria_input,
153
- ],
154
- outputs=[
155
- system_message_output,
156
- output_output,
157
- analysis_output
158
- ],
159
- additional_inputs=[
160
- initial_system_message_input,
161
- recursion_limit_input,
162
- model_name_input
163
- ],
164
- title="MetaPromptGraph Chat Interface",
165
- description="A chat interface for MetaPromptGraph to process user inputs and generate system messages.",
166
- examples=config.examples_path,
167
- allow_flagging=config.allow_flagging,
168
- flagging_dir=config.examples_path,
169
- flagging_options=["Example"],
170
- flagging_callback=flagging_callback
171
- )
172
- flagging_callback.setup([user_message_input, expected_output_input, acceptance_criteria_input, initial_system_message_input],config.examples_path)
173
 
174
  # Launch the Gradio app
175
- iface.launch(server_name=config.server_name, server_port=config.server_port)
 
1
  import csv
2
  from pathlib import Path
3
+ from typing import Any, Dict, Union
4
  import gradio as gr
5
+ from gradio import CSVLogger, utils, Button
6
+ from gradio.flagging import FlagMethod
7
  from gradio_client import utils as client_utils
8
  from confz import BaseConfig, CLArgSource, EnvSource, FileSource
9
  from meta_prompt import MetaPromptGraph, AgentState
10
  from langchain_openai import ChatOpenAI
11
  from app.config import MetaPromptConfig
12
+ from langchain_core.language_models import BaseLanguageModel
13
 
14
  class SimplifiedCSVLogger(CSVLogger):
15
  """
 
71
  llm_model_factory = LLMModelFactory()
72
 
73
  def process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
74
+ recursion_limit: int, llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]]):
75
  # Create the input state
76
  input_state = AgentState(
77
  user_message=user_message,
 
81
  )
82
 
83
  # Get the output state from MetaPromptGraph
84
+ meta_prompt_graph = MetaPromptGraph(llms=llms)
 
 
 
85
  output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
86
 
87
  # Validate the output state
 
106
 
107
  return system_message, output, analysis
108
 
109
+
110
+ def process_message_with_single_llm(user_message, expected_output, acceptance_criteria, initial_system_message,
111
+ recursion_limit: int, model_name: str):
112
+ # Get the output state from MetaPromptGraph
113
+ type = config.llms[model_name].type
114
+ args = config.llms[model_name].model_dump(exclude={'type'})
115
+ llm = llm_model_factory.create(type, **args)
116
+
117
+ return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
118
+ recursion_limit, llm)
119
+
120
+ def process_message_with_2_llms(user_message, expected_output, acceptance_criteria, initial_system_message,
121
+ recursion_limit: int, optimizer_model_name: str, executor_model_name: str,):
122
+ # Get the output state from MetaPromptGraph
123
+ optimizer_model = llm_model_factory.create(config.llms[optimizer_model_name].type,
124
+ **config.llms[optimizer_model_name].model_dump(exclude={'type'}))
125
+ executor_model = llm_model_factory.create(config.llms[executor_model_name].type,
126
+ **config.llms[executor_model_name].model_dump(exclude={'type'}))
127
+ llms = {
128
+ MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER: optimizer_model,
129
+ MetaPromptGraph.NODE_PROMPT_DEVELOPER: optimizer_model,
130
+ MetaPromptGraph.NODE_PROMPT_EXECUTOR: executor_model,
131
+ MetaPromptGraph.NODE_OUTPUT_HISTORY_ANALYZER: optimizer_model,
132
+ MetaPromptGraph.NODE_PROMPT_ANALYZER: optimizer_model,
133
+ MetaPromptGraph.NODE_PROMPT_SUGGESTER: optimizer_model
134
+ }
135
+
136
+ return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
137
+ recursion_limit, llms)
138
+
139
  class FileConfig(BaseConfig):
140
  config_file: str = 'config.yml' # default path
141
 
 
151
  CLArgSource()
152
  ]
153
 
154
+ # Add event handlers
155
+ def handle_submit(user_message, expected_output, acceptance_criteria, initial_system_message, recursion_limit, model_name):
156
+ return process_message(user_message, expected_output, acceptance_criteria, initial_system_message, recursion_limit, model_name)
157
+
158
+ # Define clear function
159
+ def clear_inputs():
160
+ return "", "", "", "", "", ""
161
+
162
  config = MetaPromptConfig(config_sources=config_sources)
163
 
164
+ flagging_callback = SimplifiedCSVLogger()
 
 
 
165
 
166
+ # Create a Gradio Blocks context
167
+ with gr.Blocks() as demo:
168
+ # Define the layout
169
+ with gr.Row():
170
+ gr.Markdown(f"""<h1 style='text-align: left; margin-bottom: 1rem'>Meta Prompt</h1>
171
+ <p style="text-align:left">A tool for generating and analyzing natural language prompts using multiple language models.</p>
172
+ <a href="https://github.com/yaleh/meta-prompt"><img src="https://img.shields.io/badge/GitHub-blue?logo=github" alt="GitHub"></a>""")
173
+ with gr.Row():
174
+ with gr.Column():
175
+ user_message_input = gr.Textbox(
176
+ label="User Message", show_copy_button=True)
177
+ expected_output_input = gr.Textbox(
178
+ label="Expected Output", show_copy_button=True)
179
+ acceptance_criteria_input = gr.Textbox(
180
+ label="Acceptance Criteria", show_copy_button=True)
181
+ initial_system_message_input = gr.Textbox(
182
+ label="Initial System Message", show_copy_button=True, value="")
183
+ recursion_limit_input = gr.Number(label="Recursion Limit", value=config.recursion_limit,
184
+ precision=0, minimum=1, maximum=config.recursion_limit_max, step=1)
185
 
186
+ with gr.Row():
187
+ with gr.Tab('Simple'):
188
+ model_name_input = gr.Dropdown(
189
+ label="Model Name",
190
+ choices=config.llms.keys(),
191
+ value=list(config.llms.keys())[0],
192
+ )
193
+ # Connect the inputs and outputs to the function
194
+ with gr.Row():
195
+ submit_button = gr.Button(value="Submit", variant="primary")
196
+ clear_button = gr.Button(value="Clear", variant="secondary")
197
+ with gr.Tab('Advanced'):
198
+ optimizer_model_name_input = gr.Dropdown(
199
+ label="Optimizer Model Name",
200
+ choices=config.llms.keys(),
201
+ value=list(config.llms.keys())[0],
202
+ )
203
+ executor_model_name_input = gr.Dropdown(
204
+ label="Executor Model Name",
205
+ choices=config.llms.keys(),
206
+ value=list(config.llms.keys())[0],
207
+ )
208
+ # Connect the inputs and outputs to the function
209
+ with gr.Row():
210
+ multiple_submit_button = gr.Button(value="Submit", variant="primary")
211
+ multiple_clear_button = gr.Button(value="Clear", variant="secondary")
212
+ with gr.Column():
213
+ system_message_output = gr.Textbox(
214
+ label="System Message", show_copy_button=True)
215
+ output_output = gr.Textbox(label="Output", show_copy_button=True)
216
+ analysis_output = gr.Textbox(
217
+ label="Analysis", show_copy_button=True)
218
+ flag_button = gr.Button(value="Flag", variant="secondary")
219
 
220
+ submit_button.click(process_message_with_single_llm,
221
+ inputs=[user_message_input, expected_output_input, acceptance_criteria_input,
222
+ initial_system_message_input, recursion_limit_input, model_name_input],
223
+ outputs=[system_message_output, output_output, analysis_output])
224
+ clear_button.click(clear_inputs,
225
+ outputs=[user_message_input, expected_output_input, acceptance_criteria_input, initial_system_message_input])
226
+ multiple_submit_button.click(process_message_with_2_llms,
227
+ inputs=[user_message_input, expected_output_input, acceptance_criteria_input,
228
+ initial_system_message_input, recursion_limit_input,
229
+ optimizer_model_name_input, executor_model_name_input],
230
+ outputs=[system_message_output, output_output, analysis_output])
231
+ multiple_clear_button.click(clear_inputs,
232
+ outputs=[user_message_input, expected_output_input, acceptance_criteria_input, initial_system_message_input])
233
+ flag_button.click(flagging_callback.flag,
234
+ inputs=[user_message_input, expected_output_input, acceptance_criteria_input, initial_system_message_input],
235
+ outputs=[])
236
+
237
+ # Load examples
238
+ examples = config.examples_path
239
+ gr.Examples(examples, inputs=[user_message_input, expected_output_input, acceptance_criteria_input, initial_system_message_input, recursion_limit_input, model_name_input])
240
+
241
+ flagging_inputs = [user_message_input, expected_output_input, acceptance_criteria_input, initial_system_message_input]
242
+
243
+ # Configure flagging
244
+ if config.allow_flagging:
245
+ flag_method = FlagMethod(flagging_callback, "Flag", "")
246
+ flag_button.click(
247
+ utils.async_lambda(
248
+ lambda: Button(value="Saving...", interactive=False)
249
+ ),
250
+ None,
251
+ flag_button,
252
+ queue=False,
253
+ show_api=False,
254
+ )
255
+ flag_button.click(
256
+ flag_method,
257
+ inputs=flagging_inputs,
258
+ outputs=flag_button,
259
+ preprocess=False,
260
+ queue=False,
261
+ show_api=False,
262
+ )
263
 
264
+ flagging_callback.setup(flagging_inputs, config.examples_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  # Launch the Gradio app
267
+ demo.launch(server_name=config.server_name, server_port=config.server_port)
config.yml CHANGED
@@ -7,6 +7,27 @@ llms:
7
  openai_api_base: "https://api.groq.com/openai/v1"
8
  max_tokens: 8192
9
  verbose: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # anthropic/claude-3-haiku:
11
  # type: ChatOpenAI
12
  # temperature: 0.1
 
7
  openai_api_base: "https://api.groq.com/openai/v1"
8
  max_tokens: 8192
9
  verbose: true
10
+ groq/llama3-8b-8192:
11
+ type: ChatOpenAI
12
+ temperature: 0.1
13
+ model_name: "llama3-8b-8192"
14
+ openai_api_base: "https://api.groq.com/openai/v1"
15
+ max_tokens: 8192
16
+ verbose: true
17
+ groq/gemma2-9b-it:
18
+ type: ChatOpenAI
19
+ temperature: 0.1
20
+ model_name: "gemma2-9b-it"
21
+ openai_api_base: "https://api.groq.com/openai/v1"
22
+ max_tokens: 8192
23
+ verbose: true
24
+ groq/mixtral-8x7b-32768:
25
+ type: ChatOpenAI
26
+ temperature: 0.1
27
+ model_name: "mixtral-8x7b-32768"
28
+ openai_api_base: "https://api.groq.com/openai/v1"
29
+ max_tokens: 8192
30
+ verbose: true
31
  # anthropic/claude-3-haiku:
32
  # type: ChatOpenAI
33
  # temperature: 0.1