yaleh commited on
Commit
48f5e34
·
1 Parent(s): 9a76340

Streamlit tab app works.

Browse files
Files changed (2) hide show
  1. app/meta_prompt_utils.py +382 -0
  2. app/streamlit_tab_app.py +730 -0
app/meta_prompt_utils.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # meta_prompt_utils.py
2
+
3
+ import json
4
+ import logging
5
+ import io
6
+ from typing import Any, Dict, List, Optional, Union
7
+ from langchain_core.language_models import BaseLanguageModel
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_openai import ChatOpenAI
10
+ from meta_prompt import *
11
+ from meta_prompt.sample_generator import TaskDescriptionGenerator
12
+ from pythonjsonlogger import jsonlogger
13
+ from app.config import MetaPromptConfig, RoleMessage
14
+ from confz import BaseConfig, CLArgSource, EnvSource, FileSource
15
+
16
+ def prompt_templates_confz2langchain(
17
+ prompt_templates: Dict[str, Dict[str, List[RoleMessage]]]
18
+ ) -> Dict[str, ChatPromptTemplate]:
19
+ return {
20
+ node: ChatPromptTemplate.from_messages(
21
+ [
22
+ (role_message.role, role_message.message)
23
+ for role_message in role_messages
24
+ ]
25
+ )
26
+ for node, role_messages in prompt_templates.items()
27
+ }
28
+
29
+ class LLMModelFactory:
30
+ _instance = None
31
+
32
+ def __new__(cls):
33
+ if not cls._instance:
34
+ cls._instance = super(LLMModelFactory, cls).__new__(cls)
35
+ return cls._instance
36
+
37
+ def create(self, model_type: str, **kwargs) -> BaseLanguageModel:
38
+ model_class = globals()[model_type]
39
+ return model_class(**kwargs)
40
+
41
+ def chat_log_2_chatbot_list(chat_log: str) -> List[List[str]]:
42
+ chatbot_list = []
43
+ if chat_log is None or chat_log == '':
44
+ return chatbot_list
45
+ for line in chat_log.splitlines():
46
+ try:
47
+ json_line = json.loads(line)
48
+ if 'action' in json_line:
49
+ if json_line['action'] == 'invoke':
50
+ chatbot_list.append([json_line['message'], None])
51
+ if json_line['action'] == 'response':
52
+ chatbot_list.append([None, json_line['message']])
53
+ except json.decoder.JSONDecodeError as e:
54
+ print(f"Error decoding JSON log output: {e}")
55
+ print(line)
56
+ except KeyError as e:
57
+ print(f"Error accessing key in JSON log output: {e}")
58
+ print(line)
59
+ return chatbot_list
60
+
61
+ def get_current_model(simple_model_name: str,
62
+ advanced_model_name: str,
63
+ expert_model_name: str,
64
+ expert_model_config: Optional[Dict[str, Any]] = None,
65
+ config: MetaPromptConfig = None,
66
+ active_model_tab: str = "Simple") -> BaseLanguageModel:
67
+ model_mapping = {
68
+ "Simple": simple_model_name,
69
+ "Advanced": advanced_model_name,
70
+ "Expert": expert_model_name
71
+ }
72
+
73
+ try:
74
+ model_name = model_mapping.get(active_model_tab, simple_model_name)
75
+ model = config.llms[model_name]
76
+ model_type = model.type
77
+ model_config = model.model_dump(exclude={'type'})
78
+
79
+ if active_model_tab == "Expert" and expert_model_config:
80
+ model_config.update(expert_model_config)
81
+
82
+ return LLMModelFactory().create(model_type, **model_config)
83
+
84
+ except KeyError as e:
85
+ logging.error(f"Configuration key error: {e}")
86
+ raise ValueError(f"Invalid model name or configuration: {e}")
87
+
88
+ except Exception as e:
89
+ logging.error(f"An unexpected error occurred: {e}")
90
+ raise RuntimeError(f"Failed to retrieve the model: {e}")
91
+
92
+ def evaluate_system_message(system_message, user_message, simple_model,
93
+ advanced_executor_model, expert_executor_model,
94
+ expert_executor_model_temperature=0.1,
95
+ config: MetaPromptConfig = None,
96
+ active_model_tab: str = "Simple"):
97
+ llm = get_current_model(simple_model, advanced_executor_model,
98
+ expert_executor_model,
99
+ {"temperature": expert_executor_model_temperature},
100
+ config, active_model_tab)
101
+ template = ChatPromptTemplate.from_messages([
102
+ ("system", "{system_message}"),
103
+ ("human", "{user_message}")
104
+ ])
105
+ try:
106
+ output = llm.invoke(template.format(
107
+ system_message=system_message, user_message=user_message))
108
+ return output.content if hasattr(output, 'content') else ""
109
+ except Exception as e:
110
+ raise Exception(f"Error: {e}")
111
+
112
+ def generate_acceptance_criteria(user_message, expected_output,
113
+ simple_model, advanced_executor_model,
114
+ expert_prompt_acceptance_criteria_model,
115
+ expert_prompt_acceptance_criteria_temperature=0.1,
116
+ prompt_template_group: Optional[str] = None,
117
+ config: MetaPromptConfig = None,
118
+ active_model_tab: str = "Simple"):
119
+ log_stream = io.StringIO()
120
+ logger = logging.getLogger(MetaPromptGraph.__name__) if config.verbose else None
121
+ log_handler = logging.StreamHandler(log_stream) if logger else None
122
+
123
+ if log_handler:
124
+ log_handler.setFormatter(
125
+ jsonlogger.JsonFormatter('%(asctime)s %(name)s %(levelname)s %(message)s')
126
+ )
127
+ logger.addHandler(log_handler)
128
+
129
+ llm = get_current_model(simple_model, advanced_executor_model,
130
+ expert_prompt_acceptance_criteria_model,
131
+ {"temperature": expert_prompt_acceptance_criteria_temperature},
132
+ config, active_model_tab)
133
+ if prompt_template_group is None:
134
+ prompt_template_group = 'default'
135
+ prompt_templates = prompt_templates_confz2langchain(
136
+ config.prompt_templates[prompt_template_group]
137
+ )
138
+ acceptance_criteria_graph = MetaPromptGraph(llms={
139
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER: llm
140
+ }, prompts=prompt_templates,
141
+ verbose=config.verbose, logger=logger)
142
+ state = AgentState(
143
+ user_message=user_message,
144
+ expected_output=expected_output
145
+ )
146
+ output_state = acceptance_criteria_graph.run_acceptance_criteria_graph(state)
147
+
148
+ if log_handler:
149
+ log_handler.close()
150
+ log_output = log_stream.getvalue()
151
+ else:
152
+ log_output = None
153
+ return output_state.get('acceptance_criteria', ""), chat_log_2_chatbot_list(log_output)
154
+
155
+ def generate_initial_system_message(
156
+ user_message: str,
157
+ expected_output: str,
158
+ simple_model: str,
159
+ advanced_executor_model: str,
160
+ expert_prompt_initial_developer_model: str,
161
+ expert_prompt_initial_developer_temperature: float = 0.1,
162
+ prompt_template_group: Optional[str] = None,
163
+ config: MetaPromptConfig = None,
164
+ active_model_tab: str = "Simple"
165
+ ) -> tuple:
166
+ log_stream = io.StringIO()
167
+ logger = logging.getLogger(MetaPromptGraph.__name__) if config.verbose else None
168
+ log_handler = logging.StreamHandler(log_stream) if logger else None
169
+
170
+ if log_handler:
171
+ log_handler.setFormatter(
172
+ jsonlogger.JsonFormatter('%(asctime)s %(name)s %(levelname)s %(message)s')
173
+ )
174
+ logger.addHandler(log_handler)
175
+
176
+ llm = get_current_model(
177
+ simple_model,
178
+ advanced_executor_model,
179
+ expert_prompt_initial_developer_model,
180
+ {"temperature": expert_prompt_initial_developer_temperature},
181
+ config,
182
+ active_model_tab
183
+ )
184
+
185
+ if prompt_template_group is None:
186
+ prompt_template_group = 'default'
187
+ prompt_templates = prompt_templates_confz2langchain(
188
+ config.prompt_templates[prompt_template_group]
189
+ )
190
+
191
+ initial_system_message_graph = MetaPromptGraph(
192
+ llms={NODE_PROMPT_INITIAL_DEVELOPER: llm},
193
+ prompts=prompt_templates,
194
+ verbose=config.verbose,
195
+ logger=logger
196
+ )
197
+
198
+ state = AgentState(
199
+ user_message=user_message,
200
+ expected_output=expected_output
201
+ )
202
+
203
+ output_state = initial_system_message_graph.run_prompt_initial_developer_graph(state)
204
+
205
+ if log_handler:
206
+ log_handler.close()
207
+ log_output = log_stream.getvalue()
208
+ else:
209
+ log_output = None
210
+
211
+ system_message = output_state.get('system_message', "")
212
+ return system_message, chat_log_2_chatbot_list(log_output)
213
+
214
+ def process_message(
215
+ user_message: str, expected_output: str, acceptance_criteria: str,
216
+ initial_system_message: str, recursion_limit: int, max_output_age: int,
217
+ llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]],
218
+ prompt_template_group: Optional[str] = None,
219
+ aggressive_exploration: bool = False,
220
+ config: MetaPromptConfig = None
221
+ ) -> tuple:
222
+ input_state = AgentState(
223
+ user_message=user_message,
224
+ expected_output=expected_output,
225
+ acceptance_criteria=acceptance_criteria,
226
+ system_message=initial_system_message,
227
+ max_output_age=max_output_age
228
+ )
229
+
230
+ log_stream = io.StringIO()
231
+ logger = logging.getLogger(MetaPromptGraph.__name__) if config.verbose else None
232
+ log_handler = logging.StreamHandler(log_stream) if logger else None
233
+ if log_handler:
234
+ log_handler.setFormatter(jsonlogger.JsonFormatter(
235
+ '%(asctime)s %(name)s %(levelname)s %(message)s'))
236
+ logger.addHandler(log_handler)
237
+
238
+ if prompt_template_group is None:
239
+ prompt_template_group = 'default'
240
+ prompt_templates = prompt_templates_confz2langchain(config.prompt_templates[prompt_template_group])
241
+ meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompt_templates,
242
+ aggressive_exploration=aggressive_exploration,
243
+ verbose=config.verbose, logger=logger)
244
+ try:
245
+ output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
246
+ except Exception as e:
247
+ raise Exception(f"Error: {e}")
248
+
249
+ if log_handler:
250
+ log_handler.close()
251
+ log_output = log_stream.getvalue()
252
+ else:
253
+ log_output = None
254
+
255
+ system_message = output_state.get(
256
+ 'best_system_message', "Error: The output state does not contain a valid 'best_system_message'")
257
+ output = output_state.get(
258
+ 'best_output', "Error: The output state does not contain a valid 'best_output'")
259
+ analysis = output_state.get(
260
+ 'analysis', "Error: The output state does not contain a valid 'analysis'")
261
+ acceptance_criteria = output_state.get(
262
+ 'acceptance_criteria', "Error: The output state does not contain a valid 'acceptance_criteria'")
263
+
264
+ return (system_message, output, analysis, acceptance_criteria, chat_log_2_chatbot_list(log_output))
265
+
266
+ def initialize_llm(model_name: str, model_config: Optional[Dict[str, Any]] = None, config: MetaPromptConfig = None) -> Any:
267
+ try:
268
+ llm_config = config.llms[model_name]
269
+ model_type = llm_config.type
270
+ dumped_config = llm_config.model_dump(exclude={'type'})
271
+
272
+ if model_config:
273
+ dumped_config.update(model_config)
274
+
275
+ return LLMModelFactory().create(model_type, **dumped_config)
276
+ except KeyError:
277
+ raise KeyError(f"No configuration exists for the model name: {model_name}")
278
+ except NotImplementedError:
279
+ raise NotImplementedError(
280
+ f"Unrecognized type configured for the language model: {model_type}"
281
+ )
282
+
283
+ # Sample generator functions
284
+
285
+ def process_json(input_json, model_name, generating_batch_size, temperature, config: MetaPromptConfig = None):
286
+ try:
287
+ model = ChatOpenAI(
288
+ model=model_name, temperature=temperature, max_retries=3)
289
+ generator = TaskDescriptionGenerator(model)
290
+ result = generator.process(input_json, generating_batch_size)
291
+ description = result["description"]
292
+ suggestions = result["suggestions"]
293
+ examples_directly = [[example["input"], example["output"]]
294
+ for example in result["examples_directly"]["examples"]]
295
+ input_analysis = result["examples_from_briefs"]["input_analysis"]
296
+ new_example_briefs = result["examples_from_briefs"]["new_example_briefs"]
297
+ examples_from_briefs = [[example["input"], example["output"]]
298
+ for example in result["examples_from_briefs"]["examples"]]
299
+ examples = [[example["input"], example["output"]]
300
+ for example in result["additional_examples"]]
301
+ return description, suggestions, examples_directly, input_analysis, new_example_briefs, examples_from_briefs, examples
302
+ except Exception as e:
303
+ raise Exception(f"An error occurred: {str(e)}. Returning default values.")
304
+
305
+ def generate_description_only(input_json, model_name, temperature, config: MetaPromptConfig = None):
306
+ try:
307
+ model = ChatOpenAI(
308
+ model=model_name, temperature=temperature, max_retries=3)
309
+ generator = TaskDescriptionGenerator(model)
310
+ result = generator.generate_description(input_json)
311
+ description = result["description"]
312
+ suggestions = result["suggestions"]
313
+ return description, suggestions
314
+ except Exception as e:
315
+ raise Exception(f"An error occurred: {str(e)}")
316
+
317
+ def analyze_input(description, model_name, temperature, config: MetaPromptConfig = None):
318
+ try:
319
+ model = ChatOpenAI(
320
+ model=model_name, temperature=temperature, max_retries=3)
321
+ generator = TaskDescriptionGenerator(model)
322
+ input_analysis = generator.analyze_input(description)
323
+ return input_analysis
324
+ except Exception as e:
325
+ raise Exception(f"An error occurred: {str(e)}")
326
+
327
+ def generate_briefs(description, input_analysis, generating_batch_size, model_name, temperature, config: MetaPromptConfig = None):
328
+ try:
329
+ model = ChatOpenAI(
330
+ model=model_name, temperature=temperature, max_retries=3)
331
+ generator = TaskDescriptionGenerator(model)
332
+ briefs = generator.generate_briefs(
333
+ description, input_analysis, generating_batch_size)
334
+ return briefs
335
+ except Exception as e:
336
+ raise Exception(f"An error occurred: {str(e)}")
337
+
338
+ def generate_examples_from_briefs(description, new_example_briefs, input_str, generating_batch_size, model_name, temperature, config: MetaPromptConfig = None):
339
+ try:
340
+ model = ChatOpenAI(
341
+ model=model_name, temperature=temperature, max_retries=3)
342
+ generator = TaskDescriptionGenerator(model)
343
+ result = generator.generate_examples_from_briefs(
344
+ description, new_example_briefs, input_str, generating_batch_size)
345
+ examples = [[example["input"], example["output"]]
346
+ for example in result["examples"]]
347
+ return examples
348
+ except Exception as e:
349
+ raise Exception(f"An error occurred: {str(e)}")
350
+
351
+ def generate_examples_directly(description, raw_example, generating_batch_size, model_name, temperature, config: MetaPromptConfig = None):
352
+ try:
353
+ model = ChatOpenAI(
354
+ model=model_name, temperature=temperature, max_retries=3)
355
+ generator = TaskDescriptionGenerator(model)
356
+ result = generator.generate_examples_directly(
357
+ description, raw_example, generating_batch_size)
358
+ examples = [[example["input"], example["output"]]
359
+ for example in result["examples"]]
360
+ return examples
361
+ except Exception as e:
362
+ raise Exception(f"An error occurred: {str(e)}")
363
+
364
+ class FileConfig(BaseConfig):
365
+ config_file: str = 'config.yml' # default path
366
+
367
+ def load_config():
368
+ pre_config_sources = [
369
+ EnvSource(prefix='METAPROMPT_', allow_all=True),
370
+ CLArgSource()
371
+ ]
372
+ pre_config = FileConfig(config_sources=pre_config_sources)
373
+
374
+ config_sources = [
375
+ FileSource(file=pre_config.config_file, optional=True),
376
+ EnvSource(prefix='METAPROMPT_', allow_all=True),
377
+ CLArgSource()
378
+ ]
379
+
380
+ return MetaPromptConfig(config_sources=config_sources)
381
+
382
+ # Add any additional utility functions here if needed
app/streamlit_tab_app.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ import json
4
+ from app.meta_prompt_utils import *
5
+ from meta_prompt.sample_generator import TaskDescriptionGenerator
6
+
7
+ # Initialize session state
8
+ def init_session_state():
9
+ if 'shared_input_data' not in st.session_state:
10
+ st.session_state.shared_input_data = pd.DataFrame(columns=["Input", "Output"])
11
+ if 'initial_system_message' not in st.session_state:
12
+ st.session_state.initial_system_message = ""
13
+ if 'initial_acceptance_criteria' not in st.session_state:
14
+ st.session_state.initial_acceptance_criteria = ""
15
+ if 'system_message_output' not in st.session_state:
16
+ st.session_state.system_message_output = ""
17
+ if 'output' not in st.session_state:
18
+ st.session_state.output = ""
19
+ if 'analysis' not in st.session_state:
20
+ st.session_state.analysis = ""
21
+ if 'acceptance_criteria_output' not in st.session_state:
22
+ st.session_state.acceptance_criteria_output = ""
23
+ if 'chat_log' not in st.session_state:
24
+ st.session_state.chat_log = []
25
+ if 'description_output_text' not in st.session_state:
26
+ st.session_state.description_output_text = ''
27
+ if 'suggestions' not in st.session_state:
28
+ st.session_state.suggestions = []
29
+ if 'input_analysis_output_text' not in st.session_state:
30
+ st.session_state.input_analysis_output_text = ''
31
+ if 'example_briefs_output_text' not in st.session_state:
32
+ st.session_state.example_briefs_output_text = ''
33
+ if 'examples_from_briefs_dataframe' not in st.session_state:
34
+ st.session_state.examples_from_briefs_dataframe = pd.DataFrame(columns=["Input", "Output"])
35
+ if 'examples_directly_dataframe' not in st.session_state:
36
+ st.session_state.examples_directly_dataframe = pd.DataFrame(columns=["Input", "Output"])
37
+ if 'examples_dataframe' not in st.session_state:
38
+ st.session_state.examples_dataframe = pd.DataFrame(columns=["Input", "Output"])
39
+ if 'selected_example' not in st.session_state:
40
+ st.session_state.selected_example = None
41
+
42
+ # UI helper functions
43
+ def clear_session_state():
44
+ for key in list(st.session_state.keys()):
45
+ del st.session_state[key]
46
+ init_session_state()
47
+
48
+ def sync_input_data():
49
+ st.session_state.shared_input_data = st.session_state.data_editor_data.copy()
50
+
51
+ # Sample Generator Functions
52
+
53
+ def process_json(input_json, model_name, generating_batch_size, temperature):
54
+ try:
55
+ model = ChatOpenAI(
56
+ model=model_name, temperature=temperature, max_retries=3)
57
+ generator = TaskDescriptionGenerator(model)
58
+ result = generator.process(input_json, generating_batch_size)
59
+ description = result["description"]
60
+ suggestions = result["suggestions"]
61
+ examples_directly = [[example["input"], example["output"]]
62
+ for example in result["examples_directly"]["examples"]]
63
+ input_analysis = result["examples_from_briefs"]["input_analysis"]
64
+ new_example_briefs = result["examples_from_briefs"]["new_example_briefs"]
65
+ examples_from_briefs = [[example["input"], example["output"]]
66
+ for example in result["examples_from_briefs"]["examples"]]
67
+ examples = [[example["input"], example["output"]]
68
+ for example in result["additional_examples"]]
69
+ return description, suggestions, examples_directly, input_analysis, new_example_briefs, examples_from_briefs, examples
70
+ except Exception as e:
71
+ st.warning(f"An error occurred: {str(e)}. Returning default values.")
72
+ return "", [], [], "", [], [], []
73
+
74
+
75
+ def generate_description_only(input_json, model_name, temperature):
76
+ try:
77
+ model = ChatOpenAI(
78
+ model=model_name, temperature=temperature, max_retries=3)
79
+ generator = TaskDescriptionGenerator(model)
80
+ result = generator.generate_description(input_json)
81
+ description = result["description"]
82
+ suggestions = result["suggestions"]
83
+ return description, suggestions
84
+ except Exception as e:
85
+ st.warning(f"An error occurred: {str(e)}")
86
+ return "", []
87
+
88
+
89
+ def analyze_input(description, model_name, temperature):
90
+ try:
91
+ model = ChatOpenAI(
92
+ model=model_name, temperature=temperature, max_retries=3)
93
+ generator = TaskDescriptionGenerator(model)
94
+ input_analysis = generator.analyze_input(description)
95
+ return input_analysis
96
+ except Exception as e:
97
+ st.warning(f"An error occurred: {str(e)}")
98
+ return ""
99
+
100
+
101
+ def generate_briefs(description, input_analysis, generating_batch_size, model_name, temperature):
102
+ try:
103
+ model = ChatOpenAI(
104
+ model=model_name, temperature=temperature, max_retries=3)
105
+ generator = TaskDescriptionGenerator(model)
106
+ briefs = generator.generate_briefs(
107
+ description, input_analysis, generating_batch_size)
108
+ return briefs
109
+ except Exception as e:
110
+ st.warning(f"An error occurred: {str(e)}")
111
+ return ""
112
+
113
+
114
+ def generate_examples_from_briefs(description, new_example_briefs, input_str, generating_batch_size, model_name, temperature):
115
+ try:
116
+ model = ChatOpenAI(
117
+ model=model_name, temperature=temperature, max_retries=3)
118
+ generator = TaskDescriptionGenerator(model)
119
+ result = generator.generate_examples_from_briefs(
120
+ description, new_example_briefs, input_str, generating_batch_size)
121
+ examples = [[example["input"], example["output"]]
122
+ for example in result["examples"]]
123
+ return examples
124
+ except Exception as e:
125
+ st.warning(f"An error occurred: {str(e)}")
126
+ return []
127
+
128
+
129
+ def generate_examples_directly(description, raw_example, generating_batch_size, model_name, temperature):
130
+ try:
131
+ model = ChatOpenAI(
132
+ model=model_name, temperature=temperature, max_retries=3)
133
+ generator = TaskDescriptionGenerator(model)
134
+ result = generator.generate_examples_directly(
135
+ description, raw_example, generating_batch_size)
136
+ examples = [[example["input"], example["output"]]
137
+ for example in result["examples"]]
138
+ return examples
139
+ except Exception as e:
140
+ st.warning(f"An error occurred: {str(e)}")
141
+ return []
142
+
143
+
144
+ def example_directly_selected():
145
+ if 'selected_example_directly_id' in st.session_state:
146
+ try:
147
+ selected_example_ids = st.session_state.selected_example_directly_id[
148
+ 'selection']['rows']
149
+ # set selected examples to the selected rows if there are any
150
+ if selected_example_ids:
151
+ selected_examples = st.session_state.examples_directly_dataframe.iloc[selected_example_ids].to_dict(
152
+ 'records')
153
+ st.session_state.selected_example = pd.DataFrame(selected_examples) # Convert to DataFrame
154
+ else:
155
+ st.session_state.selected_example = None
156
+ except Exception as e:
157
+ st.session_state.selected_example = None
158
+
159
+
160
+ def example_from_briefs_selected():
161
+ if 'selected_example_from_briefs_id' in st.session_state:
162
+ try:
163
+ selected_example_ids = st.session_state.selected_example_from_briefs_id[
164
+ 'selection']['rows']
165
+ # set selected examples to the selected rows if there are any
166
+ if selected_example_ids:
167
+ selected_examples = st.session_state.examples_from_briefs_dataframe.iloc[selected_example_ids].to_dict(
168
+ 'records')
169
+ st.session_state.selected_example = pd.DataFrame(selected_examples) # Convert to DataFrame
170
+ else:
171
+ st.session_state.selected_example = None
172
+ except Exception as e:
173
+ st.session_state.selected_example = None
174
+
175
+
176
+ def example_selected():
177
+ if 'selected_example_id' in st.session_state:
178
+ try:
179
+ selected_example_ids = st.session_state.selected_example_id['selection']['rows']
180
+ # set selected examples to the selected rows if there are any
181
+ if selected_example_ids:
182
+ selected_examples = st.session_state.examples_dataframe.iloc[selected_example_ids].to_dict(
183
+ 'records')
184
+ st.session_state.selected_example = pd.DataFrame(selected_examples) # Convert to DataFrame
185
+ else:
186
+ st.session_state.selected_example = None
187
+ except Exception as e:
188
+ st.session_state.selected_example = None
189
+
190
+ def update_description_output_text():
191
+ input_json = package_input_data()
192
+ result = generate_description_only(input_json, model_name, temperature)
193
+ st.session_state.description_output_text = result[0]
194
+ st.session_state.suggestions = result[1]
195
+
196
+
197
+ def update_input_analysis_output_text():
198
+ st.session_state.input_analysis_output_text = analyze_input(
199
+ description_output, model_name, temperature)
200
+
201
+
202
+ def update_example_briefs_output_text():
203
+ st.session_state.example_briefs_output_text = generate_briefs(
204
+ description_output, input_analysis_output, generating_batch_size, model_name, temperature)
205
+
206
+
207
+ def update_examples_from_briefs_dataframe():
208
+ input_json = package_input_data()
209
+ examples = generate_examples_from_briefs(
210
+ description_output, example_briefs_output, input_json, generating_batch_size, model_name, temperature)
211
+ st.session_state.examples_from_briefs_dataframe = pd.DataFrame(
212
+ examples, columns=["Input", "Output"])
213
+
214
+
215
+ def update_examples_directly_dataframe():
216
+ input_json = package_input_data()
217
+ examples = generate_examples_directly(
218
+ description_output, input_json, generating_batch_size, model_name, temperature)
219
+ st.session_state.examples_directly_dataframe = pd.DataFrame(
220
+ examples, columns=["Input", "Output"])
221
+
222
+
223
+ def generate_examples_dataframe():
224
+ input_json = package_input_data()
225
+ result = process_json(input_json, model_name,
226
+ generating_batch_size, temperature)
227
+ description, suggestions, examples_directly, input_analysis, new_example_briefs, examples_from_briefs, examples = result
228
+ st.session_state.description_output_text = description
229
+ st.session_state.suggestions = suggestions # Ensure suggestions are stored in session state
230
+ st.session_state.examples_directly_dataframe = pd.DataFrame(
231
+ examples_directly, columns=["Input", "Output"])
232
+ st.session_state.input_analysis_output_text = input_analysis
233
+ st.session_state.example_briefs_output_text = new_example_briefs
234
+ st.session_state.examples_from_briefs_dataframe = pd.DataFrame(
235
+ examples_from_briefs, columns=["Input", "Output"])
236
+ st.session_state.examples_dataframe = pd.DataFrame(
237
+ examples, columns=["Input", "Output"])
238
+ st.session_state.selected_example = None
239
+
240
+ def package_input_data():
241
+ data = data_editor_data.to_dict(orient='records')
242
+ lowered_data = [{k.lower(): v for k, v in d.items()} for d in data]
243
+ return json.dumps(lowered_data, ensure_ascii=False)
244
+
245
+ def export_input_data_to_json():
246
+ input_data_json = package_input_data()
247
+ st.download_button(
248
+ label="Download input data as JSON",
249
+ data=input_data_json,
250
+ file_name="input_data.json",
251
+ mime="application/json"
252
+ )
253
+
254
+ def import_input_data_from_json():
255
+ try:
256
+ if 'input_file' in st.session_state and st.session_state.input_file is not None:
257
+ data = st.session_state.input_file.getvalue()
258
+ data = json.loads(data)
259
+ data = [{k.capitalize(): v for k, v in d.items()} for d in data]
260
+ st.session_state.shared_input_data = pd.DataFrame(data)
261
+ except Exception as e:
262
+ st.warning(f"Failed to import JSON: {str(e)}")
263
+
264
+ def apply_suggestions():
265
+ try:
266
+ result = TaskDescriptionGenerator(
267
+ ChatOpenAI(model=model_name, temperature=temperature, max_retries=3)).update_description(
268
+ package_input_data(), st.session_state.description_output_text, st.session_state.selected_suggestions)
269
+ st.session_state.description_output_text = result["description"]
270
+ st.session_state.suggestions = result["suggestions"]
271
+ except Exception as e:
272
+ st.warning(f"Failed to update description: {str(e)}")
273
+
274
+ def generate_suggestions():
275
+ try:
276
+ description = st.session_state.description_output_text
277
+ input_json = package_input_data()
278
+
279
+ model = ChatOpenAI(model=model_name, temperature=temperature, max_retries=3)
280
+ generator = TaskDescriptionGenerator(model)
281
+ result = generator.generate_suggestions(input_json, description)
282
+ st.session_state.suggestions = result["suggestions"]
283
+ except Exception as e:
284
+ st.warning(f"Failed to generate suggestions: {str(e)}")
285
+
286
+ # Function to add new suggestion to the list and select it
287
+ def add_new_suggestion():
288
+ if st.session_state.new_suggestion:
289
+ st.session_state.suggestions.append(st.session_state.new_suggestion)
290
+ st.session_state.new_suggestion = "" # Clear the input field
291
+
292
+ def append_selected_to_input_data():
293
+ if st.session_state.selected_example is not None:
294
+ st.session_state.shared_input_data = pd.concat(
295
+ [data_editor_data, st.session_state.selected_example], ignore_index=True)
296
+ st.session_state.selected_example = None
297
+
298
+ def show_scoping_sidebar():
299
+ if st.session_state.selected_example is not None:
300
+ with st.sidebar:
301
+ st.dataframe(st.session_state.selected_example, hide_index=False)
302
+ st.button("Append to Input Data", on_click=append_selected_to_input_data)
303
+
304
+ # Meta Prompt Functions
305
+ def process_message_with_single_llm(
306
+ user_message: str, expected_output: str, acceptance_criteria: str,
307
+ initial_system_message: str, recursion_limit: int, max_output_age: int,
308
+ model_name: str, prompt_template_group: Optional[str] = None,
309
+ aggressive_exploration: bool = False, config: MetaPromptConfig = None
310
+ ) -> tuple:
311
+ llm = initialize_llm(model_name, config=config)
312
+ return process_message(
313
+ user_message, expected_output, acceptance_criteria, initial_system_message,
314
+ recursion_limit, max_output_age, llm, prompt_template_group, aggressive_exploration,
315
+ config
316
+ )
317
+
318
+ def process_message_with_2_llms(
319
+ user_message: str, expected_output: str, acceptance_criteria: str,
320
+ initial_system_message: str, recursion_limit: int, max_output_age: int,
321
+ optimizer_model_name: str, executor_model_name: str,
322
+ prompt_template_group: Optional[str] = None,
323
+ aggressive_exploration: bool = False, config: MetaPromptConfig = None
324
+ ) -> tuple:
325
+ optimizer_model = initialize_llm(optimizer_model_name, config=config)
326
+ executor_model = initialize_llm(executor_model_name, config=config)
327
+ llms = {
328
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER: optimizer_model,
329
+ NODE_PROMPT_INITIAL_DEVELOPER: optimizer_model,
330
+ NODE_PROMPT_DEVELOPER: optimizer_model,
331
+ NODE_PROMPT_EXECUTOR: executor_model,
332
+ NODE_OUTPUT_HISTORY_ANALYZER: optimizer_model,
333
+ NODE_PROMPT_ANALYZER: optimizer_model,
334
+ NODE_PROMPT_SUGGESTER: optimizer_model
335
+ }
336
+ return process_message(
337
+ user_message, expected_output, acceptance_criteria,
338
+ initial_system_message, recursion_limit, max_output_age, llms,
339
+ prompt_template_group, aggressive_exploration, config
340
+ )
341
+
342
+ def process_message_with_expert_llms(
343
+ user_message: str, expected_output: str, acceptance_criteria: str,
344
+ initial_system_message: str, recursion_limit: int, max_output_age: int,
345
+ initial_developer_model_name: str, initial_developer_temperature: float,
346
+ acceptance_criteria_model_name: str, acceptance_criteria_temperature: float,
347
+ developer_model_name: str, developer_temperature: float,
348
+ executor_model_name: str, executor_temperature: float,
349
+ output_history_analyzer_model_name: str, output_history_analyzer_temperature: float,
350
+ analyzer_model_name: str, analyzer_temperature: float,
351
+ suggester_model_name: str, suggester_temperature: float,
352
+ prompt_template_group: Optional[str] = None, aggressive_exploration: bool = False,
353
+ config: MetaPromptConfig = None
354
+ ) -> tuple:
355
+ llms = {
356
+ NODE_PROMPT_INITIAL_DEVELOPER: initialize_llm(
357
+ initial_developer_model_name, {"temperature": initial_developer_temperature}, config
358
+ ),
359
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER: initialize_llm(
360
+ acceptance_criteria_model_name, {"temperature": acceptance_criteria_temperature}, config
361
+ ),
362
+ NODE_PROMPT_DEVELOPER: initialize_llm(
363
+ developer_model_name, {"temperature": developer_temperature}, config
364
+ ),
365
+ NODE_PROMPT_EXECUTOR: initialize_llm(
366
+ executor_model_name, {"temperature": executor_temperature}, config
367
+ ),
368
+ NODE_OUTPUT_HISTORY_ANALYZER: initialize_llm(
369
+ output_history_analyzer_model_name,
370
+ {"temperature": output_history_analyzer_temperature},
371
+ config
372
+ ),
373
+ NODE_PROMPT_ANALYZER: initialize_llm(
374
+ analyzer_model_name, {"temperature": analyzer_temperature}, config
375
+ ),
376
+ NODE_PROMPT_SUGGESTER: initialize_llm(
377
+ suggester_model_name, {"temperature": suggester_temperature}, config
378
+ )
379
+ }
380
+ return process_message(
381
+ user_message,
382
+ expected_output,
383
+ acceptance_criteria,
384
+ initial_system_message,
385
+ recursion_limit,
386
+ max_output_age,
387
+ llms,
388
+ prompt_template_group,
389
+ aggressive_exploration,
390
+ config
391
+ )
392
+
393
+ def copy_system_message():
394
+ st.session_state.initial_system_message = system_message_output
395
+
396
+ def copy_acceptance_criteria():
397
+ st.session_state.initial_acceptance_criteria = acceptance_criteria_output
398
+
399
+ def clear_session_state():
400
+ st.session_state.shared_input_data = pd.DataFrame(columns=["Input", "Output"])
401
+ st.session_state.initial_system_message = ""
402
+ st.session_state.initial_acceptance_criteria = ""
403
+ st.session_state.system_message_output = ""
404
+ st.session_state.output = ""
405
+ st.session_state.analysis = ""
406
+ st.session_state.acceptance_criteria_output = ""
407
+ st.session_state.chat_log = []
408
+
409
+ def pull_sample_description():
410
+ st.session_state.initial_system_message = description_output
411
+
412
+ def generate_callback():
413
+ try:
414
+ first_input_key = data_editor_data["Input"].first_valid_index()
415
+ first_output_key = data_editor_data["Output"].first_valid_index()
416
+ user_message = data_editor_data["Input"][first_input_key].strip()
417
+ expected_output = data_editor_data["Output"][first_output_key].strip()
418
+
419
+ input_acceptance_criteria = initial_acceptance_criteria.strip() if 'initial_acceptance_criteria' in st.session_state else ""
420
+ input_system_message = initial_system_message.strip() if 'initial_system_message' in st.session_state else ""
421
+
422
+ if model_tab == "Simple":
423
+ system_message, output, analysis, acceptance_criteria, chat_log = process_message_with_single_llm(
424
+ user_message,
425
+ expected_output,
426
+ input_acceptance_criteria,
427
+ input_system_message,
428
+ recursion_limit_input,
429
+ max_output_age_input,
430
+ simple_model_name_input,
431
+ prompt_template_group_input,
432
+ aggressive_exploration_input,
433
+ config=config
434
+ )
435
+ elif model_tab == "Advanced":
436
+ system_message, output, analysis, acceptance_criteria, chat_log = process_message_with_2_llms(
437
+ user_message,
438
+ expected_output,
439
+ input_acceptance_criteria,
440
+ input_system_message,
441
+ recursion_limit_input,
442
+ max_output_age_input,
443
+ advanced_optimizer_model_name_input,
444
+ advanced_executor_model_name_input,
445
+ prompt_template_group_input,
446
+ aggressive_exploration_input,
447
+ config=config
448
+ )
449
+ else: # Expert
450
+ system_message, output, analysis, acceptance_criteria, chat_log = process_message_with_expert_llms(
451
+ user_message,
452
+ expected_output,
453
+ input_acceptance_criteria,
454
+ input_system_message,
455
+ recursion_limit_input,
456
+ max_output_age_input,
457
+ expert_prompt_initial_developer_model_name_input,
458
+ expert_prompt_initial_developer_temperature_input,
459
+ expert_prompt_acceptance_criteria_model_name_input,
460
+ expert_prompt_acceptance_criteria_temperature_input,
461
+ expert_prompt_developer_model_name_input,
462
+ expert_prompt_developer_temperature_input,
463
+ expert_prompt_executor_model_name_input,
464
+ expert_prompt_executor_temperature_input,
465
+ expert_prompt_output_history_analyzer_model_name_input,
466
+ expert_prompt_output_history_analyzer_temperature_input,
467
+ expert_prompt_analyzer_model_name_input,
468
+ expert_prompt_analyzer_temperature_input,
469
+ expert_prompt_suggester_model_name_input,
470
+ expert_prompt_suggester_temperature_input,
471
+ prompt_template_group_input,
472
+ aggressive_exploration_input,
473
+ config=config
474
+ )
475
+
476
+ st.session_state.system_message_output = system_message
477
+ st.session_state.output = output
478
+ st.session_state.analysis = analysis
479
+ st.session_state.acceptance_criteria_output = acceptance_criteria
480
+ st.session_state.chat_log = chat_log
481
+
482
+ except Exception as e:
483
+ st.error(f"Error: {e}")
484
+
485
+ # Meta Prompt Config
486
+
487
+ pre_config_sources = [
488
+ EnvSource(prefix='METAPROMPT_', allow_all=True),
489
+ CLArgSource()
490
+ ]
491
+ pre_config = FileConfig(config_sources=pre_config_sources)
492
+
493
+ # Load configuration
494
+ config = MetaPromptConfig(config_sources=[
495
+ FileSource(file=pre_config.config_file, optional=True),
496
+ EnvSource(prefix='METAPROMPT_', allow_all=True),
497
+ CLArgSource()
498
+ ])
499
+
500
+ # Initialize session state
501
+ init_session_state()
502
+
503
+ # Streamlit UI
504
+
505
+ st.title("Meta Prompt")
506
+ st.markdown("Enter input-output pairs as the examples for the prompt.")
507
+ data_editor_data = st.data_editor(
508
+ st.session_state.shared_input_data,
509
+ # key="meta_prompt_input_data",
510
+ num_rows="dynamic",
511
+ column_config={
512
+ "Input": st.column_config.TextColumn("Input", width="large"),
513
+ "Output": st.column_config.TextColumn("Output", width="large"),
514
+ },
515
+ hide_index=False,
516
+ use_container_width=True,
517
+ )
518
+
519
+ with st.expander("Data Management"):
520
+ # col1, col2 = st.columns(2)
521
+ # with col1:
522
+ input_file = st.file_uploader(
523
+ label="Import Input Data from JSON",
524
+ type="json",
525
+ key="input_file",
526
+ on_change=import_input_data_from_json
527
+ )
528
+ # with col2:
529
+ export_button = st.button( # Add the export button
530
+ "Export Input Data to JSON", on_click=export_input_data_to_json
531
+ )
532
+
533
+ tab_scoping, tab_prompting = st.tabs(["Scope", "Prompt"])
534
+
535
+ with tab_scoping:
536
+ # Streamlit UI
537
+ st.markdown("Define the task scope using the above input-output pairs.")
538
+
539
+ submit_button = st.button(
540
+ "Go", type="primary", on_click=generate_examples_dataframe,
541
+ use_container_width=True)
542
+
543
+ with st.expander("Model Settings"):
544
+ model_name = st.selectbox(
545
+ "Model Name",
546
+ ["llama3-70b-8192", "llama3-8b-8192", "llama-3.1-70b-versatile",
547
+ "llama-3.1-8b-instant", "gemma2-9b-it"],
548
+ index=0
549
+ )
550
+ temperature = st.slider("Temperature", 0.0, 1.0, 1.0, 0.1)
551
+ generating_batch_size = st.slider("Generating Batch Size", 1, 10, 3, 1)
552
+
553
+ with st.expander("Description and Analysis"):
554
+ generate_description_button = st.button(
555
+ "Generate Description", on_click=update_description_output_text)
556
+
557
+ description_output = st.text_area(
558
+ "Description", value=st.session_state.description_output_text, height=100)
559
+
560
+ col3, col4, col5 = st.columns(3)
561
+ with col3:
562
+ generate_suggestions_button = st.button("Generate Suggestions", on_click=generate_suggestions)
563
+ with col4:
564
+ generate_examples_directly_button = st.button(
565
+ "Generate Examples Directly", on_click=update_examples_directly_dataframe)
566
+ with col5:
567
+ analyze_input_button = st.button(
568
+ "Analyze Input", on_click=update_input_analysis_output_text)
569
+
570
+ # Add multiselect for suggestions
571
+ selected_suggestions = st.multiselect(
572
+ "Suggestions", options=st.session_state.suggestions, key="selected_suggestions")
573
+
574
+ # Add button to apply suggestions
575
+ apply_suggestions_button = st.button("Apply Suggestions", on_click=apply_suggestions)
576
+
577
+ # Add text input for adding new suggestions
578
+ new_suggestion = st.text_input("Add New Suggestion", key="new_suggestion", on_change=add_new_suggestion)
579
+
580
+ examples_directly_output = st.dataframe(st.session_state.examples_directly_dataframe, use_container_width=True,
581
+ selection_mode="multi-row", key="selected_example_directly_id",
582
+ on_select=example_directly_selected, hide_index=False)
583
+ input_analysis_output = st.text_area(
584
+ "Input Analysis", value=st.session_state.input_analysis_output_text, height=100)
585
+ generate_briefs_button = st.button(
586
+ "Generate Briefs", on_click=update_example_briefs_output_text)
587
+ example_briefs_output = st.text_area(
588
+ "Example Briefs", value=st.session_state.example_briefs_output_text, height=100)
589
+ generate_examples_from_briefs_button = st.button(
590
+ "Generate Examples from Briefs", on_click=update_examples_from_briefs_dataframe)
591
+ examples_from_briefs_output = st.dataframe(st.session_state.examples_from_briefs_dataframe, use_container_width=True,
592
+ selection_mode="multi-row", key="selected_example_from_briefs_id",
593
+ on_select=example_from_briefs_selected, hide_index=False)
594
+
595
+ examples_output = st.dataframe(st.session_state.examples_dataframe, use_container_width=True,
596
+ selection_mode="multi-row", key="selected_example_id", on_select=example_selected, hide_index=True)
597
+
598
+ show_scoping_sidebar()
599
+
600
+ with tab_prompting:
601
+ # Prompting UI
602
+ st.markdown("Generate the prompt with the above input-output pairs.")
603
+
604
+ generate_button_clicked = st.button("Generate", key="generate_button",
605
+ on_click=generate_callback,
606
+ type="primary", use_container_width=True)
607
+
608
+ col1, col2 = st.columns(2)
609
+
610
+ with col1:
611
+ with st.expander("Advanced Inputs"):
612
+ initial_system_message = st.text_area(
613
+ "Initial System Message",
614
+ key="initial_system_message"
615
+ )
616
+
617
+ col1_1, col1_2 = st.columns(2)
618
+ with col1_1:
619
+ pull_sample_description_button = st.button("Pull Sample Description", key="pull_sample_description",
620
+ on_click=pull_sample_description)
621
+ with col1_2:
622
+ st.button("Pull Output", key="copy_system_message",
623
+ on_click=copy_system_message)
624
+ initial_acceptance_criteria = st.text_area(
625
+ "Acceptance Criteria",
626
+ key="initial_acceptance_criteria"
627
+ )
628
+ st.button("Pull Output", key="copy_acceptance_criteria",
629
+ on_click=copy_acceptance_criteria)
630
+
631
+ # New expander for model settings
632
+ with st.expander("Model Settings"):
633
+ model_tab = st.selectbox("Select Model Type", ["Simple", "Advanced", "Expert"], key="model_tab")
634
+
635
+ if model_tab == "Simple":
636
+ simple_model_name_input = st.selectbox(
637
+ "Model Name",
638
+ config.llms.keys(),
639
+ index=0,
640
+ )
641
+ elif model_tab == "Advanced":
642
+ advanced_optimizer_model_name_input = st.selectbox(
643
+ "Optimizer Model Name",
644
+ config.llms.keys(),
645
+ index=0,
646
+ )
647
+ advanced_executor_model_name_input = st.selectbox(
648
+ "Executor Model Name",
649
+ config.llms.keys(),
650
+ index=1,
651
+ )
652
+ else: # Expert
653
+ expert_prompt_initial_developer_model_name_input = st.selectbox(
654
+ "Initial Developer Model Name",
655
+ config.llms.keys(),
656
+ index=0,
657
+ )
658
+ expert_prompt_initial_developer_temperature_input = st.slider(
659
+ "Initial Developer Temperature", 0.0, 1.0, 0.1, 0.1
660
+ )
661
+
662
+ expert_prompt_acceptance_criteria_model_name_input = st.selectbox(
663
+ "Acceptance Criteria Model Name",
664
+ config.llms.keys(),
665
+ index=0,
666
+ )
667
+ expert_prompt_acceptance_criteria_temperature_input = st.slider(
668
+ "Acceptance Criteria Temperature", 0.0, 1.0, 0.1, 0.1
669
+ )
670
+
671
+ expert_prompt_developer_model_name_input = st.selectbox(
672
+ "Developer Model Name", config.llms.keys(), index=0
673
+ )
674
+ expert_prompt_developer_temperature_input = st.slider(
675
+ "Developer Temperature", 0.0, 1.0, 0.1, 0.1
676
+ )
677
+
678
+ expert_prompt_executor_model_name_input = st.selectbox(
679
+ "Executor Model Name", config.llms.keys(), index=1
680
+ )
681
+ expert_prompt_executor_temperature_input = st.slider(
682
+ "Executor Temperature", 0.0, 1.0, 0.1, 0.1
683
+ )
684
+
685
+ expert_prompt_output_history_analyzer_model_name_input = st.selectbox(
686
+ "Output History Analyzer Model Name",
687
+ config.llms.keys(),
688
+ index=0,
689
+ )
690
+ expert_prompt_output_history_analyzer_temperature_input = st.slider(
691
+ "Output History Analyzer Temperature", 0.0, 1.0, 0.1, 0.1
692
+ )
693
+
694
+ expert_prompt_analyzer_model_name_input = st.selectbox(
695
+ "Analyzer Model Name", config.llms.keys(), index=0
696
+ )
697
+ expert_prompt_analyzer_temperature_input = st.slider(
698
+ "Analyzer Temperature", 0.0, 1.0, 0.1, 0.1
699
+ )
700
+
701
+ expert_prompt_suggester_model_name_input = st.selectbox(
702
+ "Suggester Model Name", config.llms.keys(), index=0
703
+ )
704
+ expert_prompt_suggester_temperature_input = st.slider(
705
+ "Suggester Temperature", 0.0, 1.0, 0.1, 0.1
706
+ )
707
+
708
+ # st.header("Prompt Template Settings")
709
+ prompt_template_group_input = st.selectbox(
710
+ "Prompt Template Group", config.prompt_templates.keys(), index=0
711
+ )
712
+
713
+ # st.header("Advanced Settings")
714
+ recursion_limit_input = st.number_input("Recursion Limit", 1, 100, 16, 1)
715
+ max_output_age_input = st.number_input("Max Output Age", 1, 10, 2, 1)
716
+ aggressive_exploration_input = st.checkbox("Aggressive Exploration", False)
717
+
718
+ with col2:
719
+ system_message_output = st.text_area("System Message",
720
+ key="system_message_output",
721
+ height=100)
722
+
723
+ acceptance_criteria_output = st.text_area(
724
+ "Acceptance Criteria",
725
+ key="acceptance_criteria_output",
726
+ height=100)
727
+ st.text_area("Output", st.session_state.output, height=100)
728
+ st.text_area("Analysis", st.session_state.analysis, height=100)
729
+
730
+ st.json(st.session_state.chat_log, expanded=False)