yaleh commited on
Commit
c8b5135
·
1 Parent(s): b01579f

Moved stateless funcs from gradio_meta_prompt.py to gradio_meta_prompt_utils.py .

Browse files
app/gradio_meta_prompt.py CHANGED
@@ -1,503 +1,11 @@
1
- import csv
2
- import io
3
- import json
4
- import logging
5
- from pathlib import Path
6
- from typing import Any, Dict, List, Optional, Union
7
-
8
  import gradio as gr
9
- from gradio import CSVLogger, Button, utils
10
  from gradio.flagging import FlagMethod
11
- from gradio_client import utils as client_utils
12
 
13
- from confz import BaseConfig, CLArgSource, EnvSource, FileSource
14
- from app.config import MetaPromptConfig, RoleMessage
15
- from langchain_core.language_models import BaseLanguageModel
16
- from langchain_core.prompts import ChatPromptTemplate
17
- from langchain_openai import ChatOpenAI # Don't remove this import
18
  from meta_prompt import *
19
- from pythonjsonlogger import jsonlogger
20
-
21
- def prompt_templates_confz2langchain(
22
- prompt_templates: Dict[str, Dict[str, List[RoleMessage]]]
23
- ) -> Dict[str, ChatPromptTemplate]:
24
- """
25
- Convert a dictionary of prompt templates from the configuration format to
26
- the language chain format.
27
-
28
- This function takes a dictionary of prompt templates in the configuration
29
- format and converts them to the language chain format. Each prompt template
30
- is converted to a ChatPromptTemplate object, which is then stored in a new
31
- dictionary with the same keys.
32
-
33
- Args:
34
- prompt_templates (Dict[str, Dict[str, List[RoleMessage]]]):
35
- A dictionary of prompt templates in the configuration format.
36
-
37
- Returns:
38
- Dict[str, ChatPromptTemplate]:
39
- A dictionary of prompt templates in the language chain format.
40
- """
41
- return {
42
- node: ChatPromptTemplate.from_messages(
43
- [
44
- (role_message.role, role_message.message)
45
- for role_message in role_messages
46
- ]
47
- )
48
- for node, role_messages in prompt_templates.items()
49
- }
50
-
51
- class SimplifiedCSVLogger(CSVLogger):
52
- """
53
- A subclass of CSVLogger that logs only the components data to a CSV file,
54
- excluding flag, username, and timestamp information.
55
- """
56
-
57
- def flag(
58
- self,
59
- flag_data: list[Any],
60
- flag_option: str = "",
61
- username: str | None = None,
62
- ) -> int:
63
- flagging_dir = self.flagging_dir
64
- log_filepath = Path(flagging_dir) / "log.csv"
65
- is_new = not Path(log_filepath).exists()
66
- headers = [
67
- getattr(component, "label", None) or f"component {idx}"
68
- for idx, component in enumerate(self.components)
69
- ]
70
-
71
- csv_data = []
72
- for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
73
- save_dir = Path(flagging_dir) / client_utils.strip_invalid_filename_characters(
74
- getattr(component, "label", None) or f"component {idx}"
75
- )
76
- if utils.is_prop_update(sample):
77
- csv_data.append(str(sample))
78
- else:
79
- data = component.flag(sample, flag_dir=save_dir) if sample is not None else ""
80
- if self.simplify_file_data:
81
- data = utils.simplify_file_data_in_str(data)
82
- csv_data.append(data)
83
-
84
- with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
85
- writer = csv.writer(csvfile)
86
- if is_new:
87
- writer.writerow(utils.sanitize_list_for_csv(headers))
88
- writer.writerow(utils.sanitize_list_for_csv(csv_data))
89
-
90
- with open(log_filepath, encoding="utf-8") as csvfile:
91
- line_count = len(list(csv.reader(csvfile))) - 1
92
-
93
- return line_count
94
-
95
-
96
- class LLMModelFactory:
97
- """A factory class for creating instances of LLM models.
98
-
99
- This class follows the Singleton pattern, ensuring that only one instance is created.
100
- The `create` method dynamically instantiates a model based on the provided `model_type`.
101
-
102
- Attributes:
103
- _instance (LLMModelFactory): A private class variable to store the singleton instance.
104
-
105
- Methods:
106
- create(model_type: str, **kwargs) -> BaseLanguageModel:
107
- Dynamically creates and returns an instance of a model based on `model_type`.
108
-
109
- """
110
-
111
- _instance = None
112
-
113
- def __new__(cls):
114
- if not cls._instance:
115
- cls._instance = super(LLMModelFactory, cls).__new__(cls)
116
- return cls._instance
117
-
118
- def create(self, model_type: str, **kwargs) -> BaseLanguageModel:
119
- """Creates and returns an instance of a model based on `model_type`.
120
-
121
- Args:
122
- model_type (str): The name of the model class to instantiate.
123
- **kwargs: Additional keyword arguments to pass to the model constructor.
124
-
125
- Returns:
126
- BaseLanguageModel: An instance of a model that inherits from BaseLanguageModel.
127
-
128
- """
129
- model_class = globals()[model_type]
130
- return model_class(**kwargs)
131
-
132
-
133
- def chat_log_2_chatbot_list(chat_log: str) -> List[List[str]]:
134
- """
135
- Convert a chat log string into a list of dialogues for the Chatbot format.
136
-
137
- Args:
138
- chat_log (str): A JSON formatted chat log where each line represents an
139
- action with its message. Expected actions are 'invoke'
140
- and 'response'.
141
-
142
- Returns:
143
- List[List[str]]: A list of dialogue pairs where the first element is a
144
- user input and the second element is a bot response.
145
- If the action was 'invoke', the first element will be
146
- the message, and the second element will be None. If
147
- the action was 'response', the first element will be
148
- None, and the second element will be the message.
149
- """
150
- chatbot_list = []
151
- if chat_log is None or chat_log == '':
152
- return chatbot_list
153
- for line in chat_log.splitlines():
154
- try:
155
- json_line = json.loads(line)
156
- if 'action' in json_line:
157
- if json_line['action'] == 'invoke':
158
- chatbot_list.append([json_line['message'], None])
159
- if json_line['action'] == 'response':
160
- chatbot_list.append([None, json_line['message']])
161
- except json.decoder.JSONDecodeError as e:
162
- print(f"Error decoding JSON log output: {e}")
163
- print(line)
164
- except KeyError as e:
165
- print(f"Error accessing key in JSON log output: {e}")
166
- print(line)
167
- return chatbot_list
168
-
169
- def on_model_tab_select(simple_model_name,
170
- advanced_optimizer_model_name, advanced_executor_model_name,
171
- expert_prompt_initial_developer_model_name,
172
- expert_prompt_acceptance_criteria_developer_model_name,
173
- expert_prompt_developer_model_name,
174
- expert_prompt_executor_model_name,
175
- expert_prompt_history_analyzer_model_name,
176
- expert_prompt_analyzer_model_name,
177
- expert_prompt_suggester_model_name,
178
- event: gr.SelectData):
179
- if event.value == 'Simple':
180
- return simple_model_name, \
181
- simple_model_name, \
182
- simple_model_name, \
183
- simple_model_name, \
184
- simple_model_name, \
185
- simple_model_name, \
186
- simple_model_name
187
- elif event.value == 'Advanced':
188
- return advanced_optimizer_model_name, \
189
- advanced_optimizer_model_name, \
190
- advanced_optimizer_model_name, \
191
- advanced_executor_model_name, \
192
- advanced_optimizer_model_name, \
193
- advanced_optimizer_model_name, \
194
- advanced_optimizer_model_name
195
- elif event.value == 'Expert':
196
- return expert_prompt_initial_developer_model_name, \
197
- expert_prompt_acceptance_criteria_developer_model_name, \
198
- expert_prompt_developer_model_name, \
199
- expert_prompt_executor_model_name, \
200
- expert_prompt_history_analyzer_model_name, \
201
- expert_prompt_analyzer_model_name, \
202
- expert_prompt_suggester_model_name
203
- else:
204
- raise ValueError(f"Invalid model tab selected: {event.value}")
205
-
206
- def evaluate_system_message(system_message, user_message, executor_model_name):
207
- """
208
- Evaluate a system message by using it to generate a response from an
209
- executor model based on the current active tab and provided user message.
210
-
211
- This function retrieves the appropriate language model (LLM) for the
212
- current active model tab, formats a chat prompt template with the system
213
- message and user message, invokes the LLM using this formatted prompt, and
214
- returns the content of the output if it exists.
215
-
216
- Args:
217
- system_message (str): The system message to use when evaluating the
218
- response.
219
- user_message (str): The user's input message for which a response will
220
- be generated.
221
- executor_model_state (gr.State): The state object containing the name
222
- of the executor model to use.
223
-
224
- Returns:
225
- str: The content of the output generated by the LLM based on the system
226
- message and user message, if it exists; otherwise, an empty string.
227
-
228
- Raises:
229
- gr.Error: If there is a Gradio-specific error during the execution of
230
- this function.
231
- Exception: For any other unexpected errors that occur during the
232
- execution of this function.
233
- """
234
- llm = initialize_llm(executor_model_name)
235
- template = ChatPromptTemplate.from_messages([
236
- ("system", "{system_message}"),
237
- ("human", "{user_message}")
238
- ])
239
- try:
240
- output = llm.invoke(template.format(
241
- system_message=system_message, user_message=user_message))
242
- return output.content if hasattr(output, 'content') else ""
243
- except gr.Error as e:
244
- raise e
245
- except Exception as e:
246
- raise gr.Error(f"Error: {e}")
247
-
248
-
249
- def generate_acceptance_criteria(user_message, expected_output, acceptance_criteria_model_name, prompt_template_group):
250
- """
251
- Generate acceptance criteria based on the user message and expected output.
252
-
253
- This function uses the MetaPromptGraph's run_acceptance_criteria_graph method
254
- to generate acceptance criteria.
255
-
256
- Args:
257
- user_message (str): The user's input message.
258
- expected_output (str): The anticipated response or outcome from the language
259
- model based on the user's message.
260
- acceptance_criteria_model_name (str): The name of the acceptance criteria model to use.
261
- prompt_template_group (Optional[str], optional): The group of prompt templates
262
- to use. Defaults to None.
263
-
264
- Returns:
265
- tuple: A tuple containing the generated acceptance criteria and the chat log.
266
- """
267
-
268
- log_stream = io.StringIO()
269
- logger = logging.getLogger(MetaPromptGraph.__name__) if config.verbose else None
270
- log_handler = logging.StreamHandler(log_stream) if logger else None
271
-
272
- if log_handler:
273
- log_handler.setFormatter(
274
- jsonlogger.JsonFormatter('%(asctime)s %(name)s %(levelname)s %(message)s')
275
- )
276
- logger.addHandler(log_handler)
277
-
278
- llm = initialize_llm(acceptance_criteria_model_name)
279
- if prompt_template_group is None:
280
- prompt_template_group = 'default'
281
- prompt_templates = prompt_templates_confz2langchain(
282
- config.prompt_templates[prompt_template_group]
283
- )
284
- acceptance_criteria_graph = MetaPromptGraph(llms={
285
- NODE_ACCEPTANCE_CRITERIA_DEVELOPER: llm
286
- }, prompts=prompt_templates,
287
- verbose=config.verbose, logger=logger)
288
- state = AgentState(
289
- user_message=user_message,
290
- expected_output=expected_output
291
- )
292
- output_state = acceptance_criteria_graph.run_acceptance_criteria_graph(state)
293
-
294
- if log_handler:
295
- log_handler.close()
296
- log_output = log_stream.getvalue()
297
- else:
298
- log_output = None
299
- return output_state.get('acceptance_criteria', ""), chat_log_2_chatbot_list(log_output)
300
-
301
-
302
- def generate_initial_system_message(
303
- user_message: str,
304
- expected_output: str,
305
- initial_developer_model_name: str,
306
- prompt_template_group: Optional[str] = None
307
- ) -> tuple:
308
- """
309
- Generate an initial system message based on the user message and expected output.
310
-
311
- Args:
312
- user_message (str): The user's input message.
313
- expected_output (str): The anticipated response or outcome from the language model.
314
- initial_developer_model_name (str): The name of the initial developer model to use.
315
- prompt_template_group (Optional[str], optional):
316
- The group of prompt templates to use. Defaults to None.
317
-
318
- Returns:
319
- tuple: A tuple containing the initial system message and the chat log.
320
- """
321
-
322
- log_stream = io.StringIO()
323
- logger = logging.getLogger(MetaPromptGraph.__name__) if config.verbose else None
324
- log_handler = logging.StreamHandler(log_stream) if logger else None
325
-
326
- if log_handler:
327
- log_handler.setFormatter(
328
- jsonlogger.JsonFormatter('%(asctime)s %(name)s %(levelname)s %(message)s')
329
- )
330
- logger.addHandler(log_handler)
331
-
332
- llm = initialize_llm(initial_developer_model_name)
333
-
334
- if prompt_template_group is None:
335
- prompt_template_group = 'default'
336
- prompt_templates = prompt_templates_confz2langchain(
337
- config.prompt_templates[prompt_template_group]
338
- )
339
-
340
- initial_system_message_graph = MetaPromptGraph(
341
- llms={NODE_PROMPT_INITIAL_DEVELOPER: llm},
342
- prompts=prompt_templates,
343
- verbose=config.verbose,
344
- logger=logger
345
- )
346
-
347
- state = AgentState(
348
- user_message=user_message,
349
- expected_output=expected_output
350
- )
351
-
352
- output_state = initial_system_message_graph.run_prompt_initial_developer_graph(state)
353
-
354
- if log_handler:
355
- log_handler.close()
356
- log_output = log_stream.getvalue()
357
- else:
358
- log_output = None
359
-
360
- system_message = output_state.get('system_message', "")
361
- return system_message, chat_log_2_chatbot_list(log_output)
362
-
363
-
364
- def process_message_with_models(
365
- user_message: str, expected_output: str, acceptance_criteria: str,
366
- initial_system_message: str, recursion_limit: int, max_output_age: int,
367
- initial_developer_model_name: str, acceptance_criteria_model_name: str,
368
- developer_model_name: str, executor_model_name: str, history_analyzer_model_name: str,
369
- analyzer_model_name: str, suggester_model_name: str,
370
- prompt_template_group: Optional[str] = None,
371
- aggressive_exploration: bool = False
372
- ) -> tuple:
373
- """
374
- Process a user message by executing the MetaPromptGraph with provided language models and input state.
375
-
376
- This function sets up the initial state of the conversation, logs the execution if verbose mode is enabled,
377
- and extracts the best system message, output, and analysis from the output state of the MetaPromptGraph.
378
-
379
- Args:
380
- user_message (str): The user's input message to be processed by the language model(s).
381
- expected_output (str): The anticipated response or outcome from the language model(s) based on the user's message.
382
- acceptance_criteria (str): Criteria that determines whether the output is acceptable or not.
383
- initial_system_message (str): Initial instruction given to the language model(s) before processing the user's message.
384
- recursion_limit (int): The maximum number of times the MetaPromptGraph can call itself recursively.
385
- max_output_age (int): The maximum age of output messages that should be considered in the conversation history.
386
- initial_developer_model_name (str): The name of the initial developer model to use.
387
- acceptance_criteria_model_name (str): The name of the acceptance criteria model to use.
388
- developer_model_name (str): The name of the developer model to use.
389
- executor_model_name (str): The name of the executor model to use.
390
- history_analyzer_model_name (str): The name of the history analyzer model to use.
391
- analyzer_model_name (str): The name of the analyzer model to use.
392
- suggester_model_name (str): The name of the suggester model to use.
393
- prompt_template_group (Optional[str], optional): The group of prompt templates to use. Defaults to None.
394
- aggressive_exploration (bool, optional): Whether to use aggressive exploration. Defaults to False.
395
-
396
- Returns:
397
- tuple: A tuple containing the best system message, output, analysis, acceptance criteria, and chat log in JSON format.
398
- """
399
- input_state = AgentState(
400
- user_message=user_message,
401
- expected_output=expected_output,
402
- acceptance_criteria=acceptance_criteria,
403
- system_message=initial_system_message,
404
- max_output_age=max_output_age
405
- )
406
-
407
- log_stream = io.StringIO()
408
- logger = logging.getLogger(MetaPromptGraph.__name__) if config.verbose else None
409
- log_handler = logging.StreamHandler(log_stream) if logger else None
410
- if log_handler:
411
- log_handler.setFormatter(jsonlogger.JsonFormatter(
412
- '%(asctime)s %(name)s %(levelname)s %(message)s'))
413
- logger.addHandler(log_handler)
414
-
415
- if prompt_template_group is None:
416
- prompt_template_group = 'default'
417
- prompt_templates = prompt_templates_confz2langchain(config.prompt_templates[prompt_template_group])
418
- llms = {
419
- NODE_PROMPT_INITIAL_DEVELOPER: initialize_llm(initial_developer_model_name),
420
- NODE_ACCEPTANCE_CRITERIA_DEVELOPER: initialize_llm(acceptance_criteria_model_name),
421
- NODE_PROMPT_DEVELOPER: initialize_llm(developer_model_name),
422
- NODE_PROMPT_EXECUTOR: initialize_llm(executor_model_name),
423
- NODE_OUTPUT_HISTORY_ANALYZER: initialize_llm(history_analyzer_model_name),
424
- NODE_PROMPT_ANALYZER: initialize_llm(analyzer_model_name),
425
- NODE_PROMPT_SUGGESTER: initialize_llm(suggester_model_name)
426
- }
427
- meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompt_templates,
428
- aggressive_exploration=aggressive_exploration,
429
- verbose=config.verbose, logger=logger)
430
- try:
431
- output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
432
- except Exception as e:
433
- if isinstance(e, gr.Error):
434
- raise e
435
- else:
436
- raise gr.Error(f"Error: {e}")
437
-
438
- if log_handler:
439
- log_handler.close()
440
- log_output = log_stream.getvalue()
441
- else:
442
- log_output = None
443
-
444
- system_message = output_state.get(
445
- 'best_system_message', "Error: The output state does not contain a valid 'best_system_message'")
446
- output = output_state.get(
447
- 'best_output', "Error: The output state does not contain a valid 'best_output'")
448
- analysis = output_state.get(
449
- 'analysis', "Error: The output state does not contain a valid 'analysis'")
450
- acceptance_criteria = output_state.get(
451
- 'acceptance_criteria', "Error: The output state does not contain a valid 'acceptance_criteria'")
452
-
453
- return (system_message, output, analysis, acceptance_criteria, chat_log_2_chatbot_list(log_output))
454
-
455
-
456
- def initialize_llm(model_name: str, model_config: Optional[Dict[str, Any]] = None) -> Any:
457
- """
458
- Initialize and return a language model (LLM) based on its name.
459
-
460
- This function retrieves the configuration for the specified language model
461
- from the application's configuration, creates an instance of the appropriate
462
- type of language model using that configuration, and returns it.
463
-
464
- Args:
465
- model_name (str): The name of the language model to initialize. This
466
- should correspond to a key in the 'llms' section of the application's
467
- configuration.
468
- model_config (Optional[Dict[str, Any]], optional): Optional model
469
- configurations. Defaults to None.
470
-
471
- Returns:
472
- Any: An instance of the specified type of language model, initialized
473
- with its configured settings.
474
-
475
- Raises:
476
- KeyError: If no configuration exists for the specified model name.
477
- NotImplementedError: If an unrecognized type is configured for the
478
- language model. This should not occur under normal circumstances
479
- because the LLMModelFactory class checks and validates the type when
480
- creating a new language model.
481
- """
482
- try:
483
- llm_config = config.llms[model_name]
484
- model_type = llm_config.type
485
- dumped_config = llm_config.model_dump(exclude={'type'})
486
-
487
- if model_config:
488
- dumped_config.update(model_config)
489
-
490
- return LLMModelFactory().create(model_type, **dumped_config)
491
- except KeyError:
492
- raise KeyError(f"No configuration exists for the model name: {model_name}")
493
- except NotImplementedError:
494
- raise NotImplementedError(
495
- f"Unrecognized type configured for the language model: {model_type}"
496
- )
497
-
498
-
499
- class FileConfig(BaseConfig):
500
- config_file: str = 'config.yml' # default path
501
 
502
  pre_config_sources = [
503
  EnvSource(prefix='METAPROMPT_', allow_all=True),
@@ -739,15 +247,17 @@ with gr.Blocks(title='Meta Prompt') as demo:
739
  ])
740
 
741
  model_states = {
742
- "initial_developer": gr.State(value=None), # None | str
743
- "acceptance_criteria": gr.State(value=None), # None | str
744
- "developer": gr.State(value=None), # None | str
745
- "executor": gr.State(value=None), # None | str
746
- "history_analyzer": gr.State(value=None), # None | str
747
- "analyzer": gr.State(value=None), # None | str
748
- "suggester": gr.State(value=None) # None | str
749
  }
750
 
 
 
751
  # set up event handlers
752
  simple_llm_tab.select(
753
  on_model_tab_select,
@@ -824,14 +334,14 @@ with gr.Blocks(title='Meta Prompt') as demo:
824
 
825
  generate_acceptance_criteria_button.click(
826
  generate_acceptance_criteria,
827
- inputs=[user_message_input, expected_output_input,
828
  model_states["acceptance_criteria"],
829
  prompt_template_group],
830
  outputs=[acceptance_criteria_input, logs_chatbot]
831
  )
832
  generate_initial_system_message_button.click(
833
  generate_initial_system_message,
834
- inputs=[user_message_input, expected_output_input,
835
  model_states["initial_developer"],
836
  prompt_template_group],
837
  outputs=[initial_system_message_input, logs_chatbot]
@@ -840,6 +350,7 @@ with gr.Blocks(title='Meta Prompt') as demo:
840
  evaluate_initial_system_message_button.click(
841
  evaluate_system_message,
842
  inputs=[
 
843
  initial_system_message_input,
844
  user_message_input,
845
  model_states["executor"]
@@ -849,6 +360,7 @@ with gr.Blocks(title='Meta Prompt') as demo:
849
  evaluate_system_message_button.click(
850
  evaluate_system_message,
851
  inputs=[
 
852
  system_message_output,
853
  user_message_input,
854
  model_states["executor"]
@@ -869,6 +381,7 @@ with gr.Blocks(title='Meta Prompt') as demo:
869
  simple_submit_button.click(
870
  process_message_with_models,
871
  inputs=[
 
872
  user_message_input,
873
  expected_output_input,
874
  acceptance_criteria_input,
@@ -897,6 +410,7 @@ with gr.Blocks(title='Meta Prompt') as demo:
897
  advanced_submit_button.click(
898
  process_message_with_models,
899
  inputs=[
 
900
  user_message_input,
901
  expected_output_input,
902
  acceptance_criteria_input,
@@ -925,6 +439,7 @@ with gr.Blocks(title='Meta Prompt') as demo:
925
  expert_submit_button.click(
926
  process_message_with_models,
927
  inputs=[
 
928
  user_message_input,
929
  expected_output_input,
930
  acceptance_criteria_input,
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from gradio import Button, utils
3
  from gradio.flagging import FlagMethod
 
4
 
5
+ from confz import CLArgSource, EnvSource, FileSource
6
+ from app.config import MetaPromptConfig
 
 
 
7
  from meta_prompt import *
8
+ from app.gradio_meta_prompt_utils import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  pre_config_sources = [
11
  EnvSource(prefix='METAPROMPT_', allow_all=True),
 
247
  ])
248
 
249
  model_states = {
250
+ "initial_developer": gr.State(value=simple_model_name_input.value), # None | str
251
+ "acceptance_criteria": gr.State(value=simple_model_name_input.value), # None | str
252
+ "developer": gr.State(value=simple_model_name_input.value), # None | str
253
+ "executor": gr.State(value=simple_model_name_input.value), # None | str
254
+ "history_analyzer": gr.State(value=simple_model_name_input.value), # None | str
255
+ "analyzer": gr.State(value=simple_model_name_input.value), # None | str
256
+ "suggester": gr.State(value=simple_model_name_input.value) # None | str
257
  }
258
 
259
+ config_state = gr.State(value=config)
260
+
261
  # set up event handlers
262
  simple_llm_tab.select(
263
  on_model_tab_select,
 
334
 
335
  generate_acceptance_criteria_button.click(
336
  generate_acceptance_criteria,
337
+ inputs=[config_state, user_message_input, expected_output_input,
338
  model_states["acceptance_criteria"],
339
  prompt_template_group],
340
  outputs=[acceptance_criteria_input, logs_chatbot]
341
  )
342
  generate_initial_system_message_button.click(
343
  generate_initial_system_message,
344
+ inputs=[config_state, user_message_input, expected_output_input,
345
  model_states["initial_developer"],
346
  prompt_template_group],
347
  outputs=[initial_system_message_input, logs_chatbot]
 
350
  evaluate_initial_system_message_button.click(
351
  evaluate_system_message,
352
  inputs=[
353
+ config_state,
354
  initial_system_message_input,
355
  user_message_input,
356
  model_states["executor"]
 
360
  evaluate_system_message_button.click(
361
  evaluate_system_message,
362
  inputs=[
363
+ config_state,
364
  system_message_output,
365
  user_message_input,
366
  model_states["executor"]
 
381
  simple_submit_button.click(
382
  process_message_with_models,
383
  inputs=[
384
+ config_state,
385
  user_message_input,
386
  expected_output_input,
387
  acceptance_criteria_input,
 
410
  advanced_submit_button.click(
411
  process_message_with_models,
412
  inputs=[
413
+ config_state,
414
  user_message_input,
415
  expected_output_input,
416
  acceptance_criteria_input,
 
439
  expert_submit_button.click(
440
  process_message_with_models,
441
  inputs=[
442
+ config_state,
443
  user_message_input,
444
  expected_output_input,
445
  acceptance_criteria_input,
app/gradio_meta_prompt_utils.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ import csv
7
+ import io
8
+
9
+ import gradio as gr
10
+ from gradio import CSVLogger, utils
11
+ from gradio_client import utils as client_utils
12
+
13
+ from confz import BaseConfig
14
+ from langchain_core.language_models import BaseLanguageModel
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from langchain_openai import ChatOpenAI # Don't remove this import
17
+ from pythonjsonlogger import jsonlogger
18
+
19
+ from app.config import MetaPromptConfig, RoleMessage
20
+ from meta_prompt import *
21
+
22
+
23
+ def prompt_templates_confz2langchain(
24
+ prompt_templates: Dict[str, Dict[str, List[RoleMessage]]]
25
+ ) -> Dict[str, ChatPromptTemplate]:
26
+ """
27
+ Convert a dictionary of prompt templates from the configuration format to
28
+ the language chain format.
29
+
30
+ This function takes a dictionary of prompt templates in the configuration
31
+ format and converts them to the language chain format. Each prompt template
32
+ is converted to a ChatPromptTemplate object, which is then stored in a new
33
+ dictionary with the same keys.
34
+
35
+ Args:
36
+ prompt_templates (Dict[str, Dict[str, List[RoleMessage]]]):
37
+ A dictionary of prompt templates in the configuration format.
38
+
39
+ Returns:
40
+ Dict[str, ChatPromptTemplate]:
41
+ A dictionary of prompt templates in the language chain format.
42
+ """
43
+ return {
44
+ node: ChatPromptTemplate.from_messages(
45
+ [
46
+ (role_message.role, role_message.message)
47
+ for role_message in role_messages
48
+ ]
49
+ )
50
+ for node, role_messages in prompt_templates.items()
51
+ }
52
+
53
+ class SimplifiedCSVLogger(CSVLogger):
54
+ """
55
+ A subclass of CSVLogger that logs only the components data to a CSV file,
56
+ excluding flag, username, and timestamp information.
57
+ """
58
+
59
+ def flag(
60
+ self,
61
+ flag_data: list[Any],
62
+ flag_option: str = "",
63
+ username: str | None = None,
64
+ ) -> int:
65
+ flagging_dir = self.flagging_dir
66
+ log_filepath = Path(flagging_dir) / "log.csv"
67
+ is_new = not Path(log_filepath).exists()
68
+ headers = [
69
+ getattr(component, "label", None) or f"component {idx}"
70
+ for idx, component in enumerate(self.components)
71
+ ]
72
+
73
+ csv_data = []
74
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
75
+ save_dir = Path(flagging_dir) / client_utils.strip_invalid_filename_characters(
76
+ getattr(component, "label", None) or f"component {idx}"
77
+ )
78
+ if utils.is_prop_update(sample):
79
+ csv_data.append(str(sample))
80
+ else:
81
+ data = component.flag(sample, flag_dir=save_dir) if sample is not None else ""
82
+ if self.simplify_file_data:
83
+ data = utils.simplify_file_data_in_str(data)
84
+ csv_data.append(data)
85
+
86
+ with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
87
+ writer = csv.writer(csvfile)
88
+ if is_new:
89
+ writer.writerow(utils.sanitize_list_for_csv(headers))
90
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
91
+
92
+ with open(log_filepath, encoding="utf-8") as csvfile:
93
+ line_count = len(list(csv.reader(csvfile))) - 1
94
+
95
+ return line_count
96
+
97
+
98
+ class LLMModelFactory:
99
+ """A factory class for creating instances of LLM models.
100
+
101
+ This class follows the Singleton pattern, ensuring that only one instance is created.
102
+ The `create` method dynamically instantiates a model based on the provided `model_type`.
103
+
104
+ Attributes:
105
+ _instance (LLMModelFactory): A private class variable to store the singleton instance.
106
+
107
+ Methods:
108
+ create(model_type: str, **kwargs) -> BaseLanguageModel:
109
+ Dynamically creates and returns an instance of a model based on `model_type`.
110
+
111
+ """
112
+
113
+ _instance = None
114
+
115
+ def __new__(cls):
116
+ if not cls._instance:
117
+ cls._instance = super(LLMModelFactory, cls).__new__(cls)
118
+ return cls._instance
119
+
120
+ def create(self, model_type: str, **kwargs) -> BaseLanguageModel:
121
+ """Creates and returns an instance of a model based on `model_type`.
122
+
123
+ Args:
124
+ model_type (str): The name of the model class to instantiate.
125
+ **kwargs: Additional keyword arguments to pass to the model constructor.
126
+
127
+ Returns:
128
+ BaseLanguageModel: An instance of a model that inherits from BaseLanguageModel.
129
+
130
+ """
131
+ model_class = globals()[model_type]
132
+ return model_class(**kwargs)
133
+
134
+
135
+ def chat_log_2_chatbot_list(chat_log: str) -> List[List[str]]:
136
+ """
137
+ Convert a chat log string into a list of dialogues for the Chatbot format.
138
+
139
+ Args:
140
+ chat_log (str): A JSON formatted chat log where each line represents an
141
+ action with its message. Expected actions are 'invoke'
142
+ and 'response'.
143
+
144
+ Returns:
145
+ List[List[str]]: A list of dialogue pairs where the first element is a
146
+ user input and the second element is a bot response.
147
+ If the action was 'invoke', the first element will be
148
+ the message, and the second element will be None. If
149
+ the action was 'response', the first element will be
150
+ None, and the second element will be the message.
151
+ """
152
+ chatbot_list = []
153
+ if chat_log is None or chat_log == '':
154
+ return chatbot_list
155
+ for line in chat_log.splitlines():
156
+ try:
157
+ json_line = json.loads(line)
158
+ if 'action' in json_line:
159
+ if json_line['action'] == 'invoke':
160
+ chatbot_list.append([json_line['message'], None])
161
+ if json_line['action'] == 'response':
162
+ chatbot_list.append([None, json_line['message']])
163
+ except json.decoder.JSONDecodeError as e:
164
+ print(f"Error decoding JSON log output: {e}")
165
+ print(line)
166
+ except KeyError as e:
167
+ print(f"Error accessing key in JSON log output: {e}")
168
+ print(line)
169
+ return chatbot_list
170
+
171
+ def on_model_tab_select(simple_model_name,
172
+ advanced_optimizer_model_name, advanced_executor_model_name,
173
+ expert_prompt_initial_developer_model_name,
174
+ expert_prompt_acceptance_criteria_developer_model_name,
175
+ expert_prompt_developer_model_name,
176
+ expert_prompt_executor_model_name,
177
+ expert_prompt_history_analyzer_model_name,
178
+ expert_prompt_analyzer_model_name,
179
+ expert_prompt_suggester_model_name,
180
+ event: gr.SelectData):
181
+ if event.value == 'Simple':
182
+ return simple_model_name, \
183
+ simple_model_name, \
184
+ simple_model_name, \
185
+ simple_model_name, \
186
+ simple_model_name, \
187
+ simple_model_name, \
188
+ simple_model_name
189
+ elif event.value == 'Advanced':
190
+ return advanced_optimizer_model_name, \
191
+ advanced_optimizer_model_name, \
192
+ advanced_optimizer_model_name, \
193
+ advanced_executor_model_name, \
194
+ advanced_optimizer_model_name, \
195
+ advanced_optimizer_model_name, \
196
+ advanced_optimizer_model_name
197
+ elif event.value == 'Expert':
198
+ return expert_prompt_initial_developer_model_name, \
199
+ expert_prompt_acceptance_criteria_developer_model_name, \
200
+ expert_prompt_developer_model_name, \
201
+ expert_prompt_executor_model_name, \
202
+ expert_prompt_history_analyzer_model_name, \
203
+ expert_prompt_analyzer_model_name, \
204
+ expert_prompt_suggester_model_name
205
+ else:
206
+ raise ValueError(f"Invalid model tab selected: {event.value}")
207
+
208
+ def evaluate_system_message(config, system_message, user_message, executor_model_name):
209
+ """
210
+ Evaluate a system message by using it to generate a response from an
211
+ executor model based on the current active tab and provided user message.
212
+
213
+ This function retrieves the appropriate language model (LLM) for the
214
+ current active model tab, formats a chat prompt template with the system
215
+ message and user message, invokes the LLM using this formatted prompt, and
216
+ returns the content of the output if it exists.
217
+
218
+ Args:
219
+ system_message (str): The system message to use when evaluating the
220
+ response.
221
+ user_message (str): The user's input message for which a response will
222
+ be generated.
223
+ executor_model_state (gr.State): The state object containing the name
224
+ of the executor model to use.
225
+
226
+ Returns:
227
+ str: The content of the output generated by the LLM based on the system
228
+ message and user message, if it exists; otherwise, an empty string.
229
+
230
+ Raises:
231
+ gr.Error: If there is a Gradio-specific error during the execution of
232
+ this function.
233
+ Exception: For any other unexpected errors that occur during the
234
+ execution of this function.
235
+ """
236
+ llm = initialize_llm(config, executor_model_name)
237
+ template = ChatPromptTemplate.from_messages([
238
+ ("system", "{system_message}"),
239
+ ("human", "{user_message}")
240
+ ])
241
+ try:
242
+ output = llm.invoke(template.format(
243
+ system_message=system_message, user_message=user_message))
244
+ return output.content if hasattr(output, 'content') else ""
245
+ except gr.Error as e:
246
+ raise e
247
+ except Exception as e:
248
+ raise gr.Error(f"Error: {e}")
249
+
250
+
251
+ def generate_acceptance_criteria(config, user_message, expected_output, acceptance_criteria_model_name, prompt_template_group):
252
+ """
253
+ Generate acceptance criteria based on the user message and expected output.
254
+
255
+ This function uses the MetaPromptGraph's run_acceptance_criteria_graph method
256
+ to generate acceptance criteria.
257
+
258
+ Args:
259
+ user_message (str): The user's input message.
260
+ expected_output (str): The anticipated response or outcome from the language
261
+ model based on the user's message.
262
+ acceptance_criteria_model_name (str): The name of the acceptance criteria model to use.
263
+ prompt_template_group (Optional[str], optional): The group of prompt templates
264
+ to use. Defaults to None.
265
+
266
+ Returns:
267
+ tuple: A tuple containing the generated acceptance criteria and the chat log.
268
+ """
269
+
270
+ log_stream = io.StringIO()
271
+ logger = logging.getLogger(MetaPromptGraph.__name__) if config.verbose else None
272
+ log_handler = logging.StreamHandler(log_stream) if logger else None
273
+
274
+ if log_handler:
275
+ log_handler.setFormatter(
276
+ jsonlogger.JsonFormatter('%(asctime)s %(name)s %(levelname)s %(message)s')
277
+ )
278
+ logger.addHandler(log_handler)
279
+
280
+ llm = initialize_llm(config, acceptance_criteria_model_name)
281
+ if prompt_template_group is None:
282
+ prompt_template_group = 'default'
283
+ prompt_templates = prompt_templates_confz2langchain(
284
+ config.prompt_templates[prompt_template_group]
285
+ )
286
+ acceptance_criteria_graph = MetaPromptGraph(llms={
287
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER: llm
288
+ }, prompts=prompt_templates,
289
+ verbose=config.verbose, logger=logger)
290
+ state = AgentState(
291
+ user_message=user_message,
292
+ expected_output=expected_output
293
+ )
294
+ output_state = acceptance_criteria_graph.run_acceptance_criteria_graph(state)
295
+
296
+ if log_handler:
297
+ log_handler.close()
298
+ log_output = log_stream.getvalue()
299
+ else:
300
+ log_output = None
301
+ return output_state.get('acceptance_criteria', ""), chat_log_2_chatbot_list(log_output)
302
+
303
+
304
+ def generate_initial_system_message(
305
+ config,
306
+ user_message: str,
307
+ expected_output: str,
308
+ initial_developer_model_name: str,
309
+ prompt_template_group: Optional[str] = None
310
+ ) -> tuple:
311
+ """
312
+ Generate an initial system message based on the user message and expected output.
313
+
314
+ Args:
315
+ user_message (str): The user's input message.
316
+ expected_output (str): The anticipated response or outcome from the language model.
317
+ initial_developer_model_name (str): The name of the initial developer model to use.
318
+ prompt_template_group (Optional[str], optional):
319
+ The group of prompt templates to use. Defaults to None.
320
+
321
+ Returns:
322
+ tuple: A tuple containing the initial system message and the chat log.
323
+ """
324
+
325
+ log_stream = io.StringIO()
326
+ logger = logging.getLogger(MetaPromptGraph.__name__) if config.verbose else None
327
+ log_handler = logging.StreamHandler(log_stream) if logger else None
328
+
329
+ if log_handler:
330
+ log_handler.setFormatter(
331
+ jsonlogger.JsonFormatter('%(asctime)s %(name)s %(levelname)s %(message)s')
332
+ )
333
+ logger.addHandler(log_handler)
334
+
335
+ llm = initialize_llm(config, initial_developer_model_name)
336
+
337
+ if prompt_template_group is None:
338
+ prompt_template_group = 'default'
339
+ prompt_templates = prompt_templates_confz2langchain(
340
+ config.prompt_templates[prompt_template_group]
341
+ )
342
+
343
+ initial_system_message_graph = MetaPromptGraph(
344
+ llms={NODE_PROMPT_INITIAL_DEVELOPER: llm},
345
+ prompts=prompt_templates,
346
+ verbose=config.verbose,
347
+ logger=logger
348
+ )
349
+
350
+ state = AgentState(
351
+ user_message=user_message,
352
+ expected_output=expected_output
353
+ )
354
+
355
+ output_state = initial_system_message_graph.run_prompt_initial_developer_graph(state)
356
+
357
+ if log_handler:
358
+ log_handler.close()
359
+ log_output = log_stream.getvalue()
360
+ else:
361
+ log_output = None
362
+
363
+ system_message = output_state.get('system_message', "")
364
+ return system_message, chat_log_2_chatbot_list(log_output)
365
+
366
+
367
+ def process_message_with_models(
368
+ config,
369
+ user_message: str, expected_output: str, acceptance_criteria: str,
370
+ initial_system_message: str, recursion_limit: int, max_output_age: int,
371
+ initial_developer_model_name: str, acceptance_criteria_model_name: str,
372
+ developer_model_name: str, executor_model_name: str, history_analyzer_model_name: str,
373
+ analyzer_model_name: str, suggester_model_name: str,
374
+ prompt_template_group: Optional[str] = None,
375
+ aggressive_exploration: bool = False
376
+ ) -> tuple:
377
+ """
378
+ Process a user message by executing the MetaPromptGraph with provided language models and input state.
379
+
380
+ This function sets up the initial state of the conversation, logs the execution if verbose mode is enabled,
381
+ and extracts the best system message, output, and analysis from the output state of the MetaPromptGraph.
382
+
383
+ Args:
384
+ user_message (str): The user's input message to be processed by the language model(s).
385
+ expected_output (str): The anticipated response or outcome from the language model(s) based on the user's message.
386
+ acceptance_criteria (str): Criteria that determines whether the output is acceptable or not.
387
+ initial_system_message (str): Initial instruction given to the language model(s) before processing the user's message.
388
+ recursion_limit (int): The maximum number of times the MetaPromptGraph can call itself recursively.
389
+ max_output_age (int): The maximum age of output messages that should be considered in the conversation history.
390
+ initial_developer_model_name (str): The name of the initial developer model to use.
391
+ acceptance_criteria_model_name (str): The name of the acceptance criteria model to use.
392
+ developer_model_name (str): The name of the developer model to use.
393
+ executor_model_name (str): The name of the executor model to use.
394
+ history_analyzer_model_name (str): The name of the history analyzer model to use.
395
+ analyzer_model_name (str): The name of the analyzer model to use.
396
+ suggester_model_name (str): The name of the suggester model to use.
397
+ prompt_template_group (Optional[str], optional): The group of prompt templates to use. Defaults to None.
398
+ aggressive_exploration (bool, optional): Whether to use aggressive exploration. Defaults to False.
399
+
400
+ Returns:
401
+ tuple: A tuple containing the best system message, output, analysis, acceptance criteria, and chat log in JSON format.
402
+ """
403
+ input_state = AgentState(
404
+ user_message=user_message,
405
+ expected_output=expected_output,
406
+ acceptance_criteria=acceptance_criteria,
407
+ system_message=initial_system_message,
408
+ max_output_age=max_output_age
409
+ )
410
+
411
+ log_stream = io.StringIO()
412
+ logger = logging.getLogger(MetaPromptGraph.__name__) if config.verbose else None
413
+ log_handler = logging.StreamHandler(log_stream) if logger else None
414
+ if log_handler:
415
+ log_handler.setFormatter(jsonlogger.JsonFormatter(
416
+ '%(asctime)s %(name)s %(levelname)s %(message)s'))
417
+ logger.addHandler(log_handler)
418
+
419
+ if prompt_template_group is None:
420
+ prompt_template_group = 'default'
421
+ prompt_templates = prompt_templates_confz2langchain(config.prompt_templates[prompt_template_group])
422
+ llms = {
423
+ NODE_PROMPT_INITIAL_DEVELOPER: initialize_llm(config, initial_developer_model_name),
424
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER: initialize_llm(config, acceptance_criteria_model_name),
425
+ NODE_PROMPT_DEVELOPER: initialize_llm(config, developer_model_name),
426
+ NODE_PROMPT_EXECUTOR: initialize_llm(config, executor_model_name),
427
+ NODE_OUTPUT_HISTORY_ANALYZER: initialize_llm(config, history_analyzer_model_name),
428
+ NODE_PROMPT_ANALYZER: initialize_llm(config, analyzer_model_name),
429
+ NODE_PROMPT_SUGGESTER: initialize_llm(config, suggester_model_name)
430
+ }
431
+ meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompt_templates,
432
+ aggressive_exploration=aggressive_exploration,
433
+ verbose=config.verbose, logger=logger)
434
+ try:
435
+ output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
436
+ except Exception as e:
437
+ if isinstance(e, gr.Error):
438
+ raise e
439
+ else:
440
+ raise gr.Error(f"Error: {e}")
441
+
442
+ if log_handler:
443
+ log_handler.close()
444
+ log_output = log_stream.getvalue()
445
+ else:
446
+ log_output = None
447
+
448
+ system_message = output_state.get(
449
+ 'best_system_message', "Error: The output state does not contain a valid 'best_system_message'")
450
+ output = output_state.get(
451
+ 'best_output', "Error: The output state does not contain a valid 'best_output'")
452
+ analysis = output_state.get(
453
+ 'analysis', "Error: The output state does not contain a valid 'analysis'")
454
+ acceptance_criteria = output_state.get(
455
+ 'acceptance_criteria', "Error: The output state does not contain a valid 'acceptance_criteria'")
456
+
457
+ return (system_message, output, analysis, acceptance_criteria, chat_log_2_chatbot_list(log_output))
458
+
459
+
460
+ def initialize_llm(config: MetaPromptConfig, model_name: str, model_config: Optional[Dict[str, Any]] = None) -> Any:
461
+ """
462
+ Initialize and return a language model (LLM) based on its name.
463
+
464
+ This function retrieves the configuration for the specified language model
465
+ from the application's configuration, creates an instance of the appropriate
466
+ type of language model using that configuration, and returns it.
467
+
468
+ Args:
469
+ model_name (str): The name of the language model to initialize. This
470
+ should correspond to a key in the 'llms' section of the application's
471
+ configuration.
472
+ model_config (Optional[Dict[str, Any]], optional): Optional model
473
+ configurations. Defaults to None.
474
+
475
+ Returns:
476
+ Any: An instance of the specified type of language model, initialized
477
+ with its configured settings.
478
+
479
+ Raises:
480
+ KeyError: If no configuration exists for the specified model name.
481
+ NotImplementedError: If an unrecognized type is configured for the
482
+ language model. This should not occur under normal circumstances
483
+ because the LLMModelFactory class checks and validates the type when
484
+ creating a new language model.
485
+ """
486
+ try:
487
+ llm_config = config.llms[model_name]
488
+ model_type = llm_config.type
489
+ dumped_config = llm_config.model_dump(exclude={'type'})
490
+
491
+ if model_config:
492
+ dumped_config.update(model_config)
493
+
494
+ return LLMModelFactory().create(model_type, **dumped_config)
495
+ except KeyError:
496
+ raise KeyError(f"No configuration exists for the model name: {model_name}")
497
+ except NotImplementedError:
498
+ raise NotImplementedError(
499
+ f"Unrecognized type configured for the language model: {model_type}"
500
+ )
501
+
502
+
503
+ class FileConfig(BaseConfig):
504
+ config_file: str = 'config.yml' # default path