yaleh commited on
Commit
f41c216
·
1 Parent(s): 1b5bff3

Reformated code.

Browse files
app/gradio_meta_prompt.py CHANGED
@@ -3,47 +3,55 @@ import io
3
  import json
4
  import logging
5
  from pathlib import Path
6
- from typing import Any, Dict, Optional, Union
 
7
  import gradio as gr
8
  from gradio import CSVLogger, Button, utils
9
  from gradio.flagging import FlagMethod
10
  from gradio_client import utils as client_utils
 
11
  from confz import BaseConfig, CLArgSource, EnvSource, FileSource
12
  from app.config import MetaPromptConfig, RoleMessage
13
  from langchain_core.language_models import BaseLanguageModel
14
  from langchain_core.prompts import ChatPromptTemplate
15
- from langchain_openai import ChatOpenAI
16
  from meta_prompt import *
17
  from pythonjsonlogger import jsonlogger
18
- import pprint
19
- from langchain_core.prompts import ChatPromptTemplate
20
- from typing import Optional, Dict, List
21
 
22
- def prompt_templates_confz2langchain(prompt_templates: Dict[str, Dict[str, List[RoleMessage]]]) -> Dict[str, ChatPromptTemplate]:
 
 
23
  """
24
- Convert a dictionary of prompt templates from the configuration format to the language chain format.
 
25
 
26
- This function takes a dictionary of prompt templates in the configuration format and converts them to the language chain format.
27
- Each prompt template is converted to a ChatPromptTemplate object, which is then stored in a new dictionary with the same keys.
 
 
28
 
29
  Args:
30
- prompt_templates (Dict[str, Dict[str, List[RoleMessage]]]): A dictionary of prompt templates in the configuration format.
 
31
 
32
  Returns:
33
- Dict[str, ChatPromptTemplate]: A dictionary of prompt templates in the language chain format.
 
34
  """
35
  return {
36
- node: ChatPromptTemplate.from_messages([
37
- (role_message.role, role_message.message)
38
- for role_message in role_messages
39
- ])
 
 
40
  for node, role_messages in prompt_templates.items()
41
  }
42
 
43
  class SimplifiedCSVLogger(CSVLogger):
44
  """
45
- A subclass of CSVLogger that logs only the components data to a CSV file, excluding
46
- flag, username, and timestamp information.
47
  """
48
 
49
  def flag(
@@ -62,19 +70,13 @@ class SimplifiedCSVLogger(CSVLogger):
62
 
63
  csv_data = []
64
  for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
65
- save_dir = Path(
66
- flagging_dir
67
- ) / client_utils.strip_invalid_filename_characters(
68
  getattr(component, "label", None) or f"component {idx}"
69
  )
70
  if utils.is_prop_update(sample):
71
  csv_data.append(str(sample))
72
  else:
73
- data = (
74
- component.flag(sample, flag_dir=save_dir)
75
- if sample is not None
76
- else ""
77
- )
78
  if self.simplify_file_data:
79
  data = utils.simplify_file_data_in_str(data)
80
  csv_data.append(data)
@@ -128,19 +130,23 @@ class LLMModelFactory:
128
  return model_class(**kwargs)
129
 
130
 
131
- def chat_log_2_chatbot_list(chat_log: str):
132
- """Convert a chat log string into a list of dialogues for the Chatbot format.
 
133
 
134
  Args:
135
- chat_log (str): A JSON formatted chat log where each line represents an action with its message.
136
- Expected actions are 'invoke' and 'response'.
 
137
 
138
  Returns:
139
- List[List[str]]: A list of dialogue pairs where the first element is a user input and the second element is a bot response.
140
- If the action was 'invoke', the first element will be the message, and the second element will be None.
141
- If the action was 'response', the first element will be None, and the second element will be the message.
 
 
 
142
  """
143
-
144
  chatbot_list = []
145
  if chat_log is None or chat_log == '':
146
  return chatbot_list
@@ -242,33 +248,48 @@ def get_current_model(simple_model_name: str,
242
  raise RuntimeError(f"Failed to retrieve the model: {e}")
243
 
244
 
245
- def evaluate_system_message(system_message, user_message,
246
- simple_model,
247
- advanced_executor_model,
248
- expert_executor_model, expert_execuor_model_temperature=0.1):
249
  """
250
- Evaluate a system message by using it to generate a response from an executor model based on the current active tab and provided user message.
 
251
 
252
- This function retrieves the appropriate language model (LLM) for the current active model tab, formats a chat prompt template with the system message and user message, invokes the LLM using this formatted prompt, and returns the content of the output if it exists.
 
 
 
253
 
254
  Args:
255
- system_message (str): The system message to use when evaluating the response.
256
- user_message (str): The user's input message for which a response will be generated.
257
- simple_model (str): The name of the simple language model. This should correspond to a key in the 'llms' section of the application's configuration.
258
- advanced_executor_model (str): The name of the advanced language model. This should correspond to a key in the 'llms' section of the application's configuration.
259
- expert_executor_model (str): The name of the expert language model. This should correspond to a key in the 'llms' section of the application's configuration.
260
- expert_execuor_model_temperature (float, optional): The temperature parameter for the expert executor model. Defaults to 0.1.
 
 
 
 
 
 
 
 
 
261
 
262
  Returns:
263
- str: The content of the output generated by the LLM based on the system message and user message, if it exists; otherwise, an empty string.
 
264
 
265
  Raises:
266
- gr.Error: If there is a Gradio-specific error during the execution of this function.
267
- Exception: For any other unexpected errors that occur during the execution of this function.
 
 
268
  """
269
- llm = get_current_model(simple_model,
270
- advanced_executor_model,
271
- expert_executor_model, {"temperature": expert_execuor_model_temperature})
272
  template = ChatPromptTemplate.from_messages([
273
  ("system", "{system_message}"),
274
  ("human", "{user_message}")
@@ -365,8 +386,10 @@ def generate_initial_system_message(
365
  simple_model (str): The name of the simple language model.
366
  advanced_executor_model (str): The name of the advanced language model.
367
  expert_prompt_initial_developer_model (str): The name of the expert language model.
368
- expert_prompt_initial_developer_temperature (float, optional): The temperature parameter for the expert model. Defaults to 0.1.
369
- prompt_template_group (Optional[str], optional): The group of prompt templates to use. Defaults to None.
 
 
370
 
371
  Returns:
372
  tuple: A tuple containing the initial system message and the chat log.
@@ -419,32 +442,45 @@ def generate_initial_system_message(
419
  return system_message, chat_log_2_chatbot_list(log_output)
420
 
421
 
422
- def process_message(user_message: str, expected_output: str,
423
- acceptance_criteria: str, initial_system_message: str,
424
- recursion_limit: int, max_output_age: int,
425
- llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]],
426
- prompt_template_group: Optional[str] = None,
427
- aggressive_exploration: bool = False) -> tuple:
 
428
  """
429
- Process a user message by executing the MetaPromptGraph with provided language models and input state.
430
- This function sets up the initial state of the conversation, logs the execution if verbose mode is enabled,
431
- and extracts the best system message, output, and analysis from the output state of the MetaPromptGraph.
 
 
 
432
 
433
  Args:
434
- user_message (str): The user's input message to be processed by the language model(s).
435
- expected_output (str): The anticipated response or outcome from the language model(s) based on the user's message.
436
- acceptance_criteria (str): Criteria that determines whether the output is acceptable or not.
437
- initial_system_message (str): Initial instruction given to the language model(s) before processing the user's message.
438
- recursion_limit (int): The maximum number of times the MetaPromptGraph can call itself recursively.
439
- max_output_age (int): The maximum age of output messages that should be considered in the conversation history.
440
- llms (Union[BaseLanguageModel, Dict[str, BaseLanguageModel]]): A single language model or a dictionary of language models to use for processing the user's message.
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
  Returns:
443
- tuple: A tuple containing the best system message, output, analysis, and chat log in JSON format.
444
- - best_system_message (str): The system message that resulted in the most appropriate response based on the acceptance criteria.
445
- - best_output (str): The output generated by the language model(s) that best meets the expected outcome and acceptance criteria.
446
- - analysis (str): An analysis of how well the generated output matches the expected output and acceptance criteria.
447
- - chat_log (list): A list containing JSON objects representing the conversation log, with each object containing a timestamp, logger name, levelname, and message.
448
  """
449
  input_state = AgentState(
450
  user_message=user_message,
@@ -498,112 +534,160 @@ def initialize_llm(model_name: str, model_config: Optional[Dict[str, Any]] = Non
498
  """
499
  Initialize and return a language model (LLM) based on its name.
500
 
501
- This function retrieves the configuration for the specified language model from the application's
502
- configuration, creates an instance of the appropriate type of language model using that
503
- configuration, and returns it.
504
 
505
  Args:
506
- model_name (str): The name of the language model to initialize.
507
- This should correspond to a key in the 'llms' section of the application's configuration.
508
- model_config (Optional[Dict[str, Any]], optional): Optional model configurations. Defaults to None.
 
 
509
 
510
  Returns:
511
- Any: An instance of the specified type of language model, initialized with its configured settings.
 
512
 
513
  Raises:
514
  KeyError: If no configuration exists for the specified model name.
515
- NotImplementedError: If an unrecognized type is configured for the language model.
516
- This should not occur under normal circumstances because the LLMModelFactory class
517
- checks and validates the type when creating a new language model.
 
518
  """
519
  try:
520
  llm_config = config.llms[model_name]
521
  model_type = llm_config.type
522
  dumped_config = llm_config.model_dump(exclude={'type'})
523
-
524
  if model_config:
525
  dumped_config.update(model_config)
526
-
527
  return LLMModelFactory().create(model_type, **dumped_config)
528
  except KeyError:
529
  raise KeyError(f"No configuration exists for the model name: {model_name}")
530
  except NotImplementedError:
531
- raise NotImplementedError(f"Unrecognized type configured for the language model: {model_type}")
 
 
532
 
533
 
534
- def process_message_with_single_llm(user_message: str, expected_output: str,
535
- acceptance_criteria: str, initial_system_message: str,
536
- recursion_limit: int, max_output_age: int,
537
- model_name: str,
538
- prompt_template_group: Optional[str] = None,
539
- aggressive_exploration: bool = False) -> tuple:
540
  """
541
  Process a user message using a single language model.
542
 
543
- This function initializes a language model based on the provided model name and
544
- uses it to process the user's message. The function takes in additional parameters
545
- such as the user's message, expected output, acceptance criteria, initial system
546
- message, recursion limit, and max output age. It then calls the `process_message`
547
- function with the initialized language model to obtain the best system message,
548
- output, analysis, and chat log.
549
 
550
  Parameters:
551
- user_message (str): The user's input message to be processed by the language model.
552
- expected_output (str): The anticipated response or outcome from the language model based on the user's message.
553
- acceptance_criteria (str): Criteria that determines whether the output is acceptable or not.
554
- initial_system_message (str): Initial instruction given to the language model before processing the user's message.
555
- recursion_limit (int): The maximum number of times the MetaPromptGraph can call itself recursively.
556
- max_output_age (int): The maximum age of output messages that should be considered in the conversation history.
557
- model_name (str): The name of the language model to initialize and use for processing the user's message.
558
- This should correspond to a key in the 'llms' section of the application's configuration.
559
- prompt_template_group (Optional[str], optional): The name of the prompt template group to use for processing the user's message. Defaults to None.
 
 
 
 
 
 
 
 
 
 
560
 
561
  Returns:
562
- tuple: A tuple containing the best system message, output, analysis, and chat log in JSON format.
563
- - best_system_message (str): The system message that resulted in the most appropriate response based on the acceptance criteria.
564
- - best_output (str): The output generated by the language model that best meets the expected outcome and acceptance criteria.
565
- - analysis (str): An analysis of how well the generated output matches the expected output and acceptance criteria.
566
- - chat_log (list): A list containing JSON objects representing the conversation log, with each object containing a timestamp, logger name, levelname, and message.
 
 
 
 
 
 
567
  """
568
  llm = initialize_llm(model_name)
569
- return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
570
- recursion_limit, max_output_age, llm, prompt_template_group, aggressive_exploration)
 
 
571
 
572
 
573
- def process_message_with_2_llms(user_message: str, expected_output: str,
574
- acceptance_criteria: str, initial_system_message: str,
575
- recursion_limit: int, max_output_age: int,
576
- optimizer_model_name: str, executor_model_name: str,
577
- prompt_template_group: Optional[str] = None,
578
- aggressive_exploration: bool = False) -> tuple:
 
579
  """
580
- Process a user message using two language models - one for optimization and another for execution.
 
581
 
582
- This function initializes the specified optimizer and executor language models and then uses them to process
583
- the user's message along with other provided input parameters such as expected output, acceptance criteria,
584
- initial system message, recursion limit, and max output age. The result is obtained by calling the `process_message`
585
- function with a dictionary of language models where all nodes except for NODE_PROMPT_EXECUTOR use the optimizer model
586
- and NODE_PROMPT_EXECUTOR uses the executor model.
 
 
587
 
588
  Args:
589
- user_message (str): The user's input message to be processed by the language models.
590
- expected_output (str): The anticipated response or outcome from the language models based on the user's message.
591
- acceptance_criteria (str): Criteria that determines whether the output is acceptable or not.
592
- initial_system_message (str): Initial instruction given to the language models before processing the user's message.
593
- recursion_limit (int): The maximum number of times the MetaPromptGraph can call itself recursively.
594
- max_output_age (int): The maximum age of output messages that should be considered in the conversation history.
595
- optimizer_model_name (str): The name of the language model to initialize and use for optimization tasks like prompt development, analysis, and suggestion.
596
- This should correspond to a key in the 'llms' section of the application's configuration.
597
- executor_model_name (str): The name of the language model to initialize and use for execution tasks like running code or providing final outputs.
598
- This should correspond to a key in the 'llms' section of the application's configuration.
599
- prompt_template_group (Optional[str], optional): The name of the prompt template group to use for processing the user's message. Defaults to None.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
  Returns:
602
- tuple: A tuple containing the best system message, output, analysis, and chat log in JSON format.
603
- - best_system_message (str): The system message that resulted in the most appropriate response based on the acceptance criteria.
604
- - best_output (str): The output generated by the language models that best meets the expected outcome and acceptance criteria.
605
- - analysis (str): An analysis of how well the generated output matches the expected output and acceptance criteria.
606
- - chat_log (list): A list containing JSON objects representing the conversation log, with each object containing a timestamp, logger name, levelname, and message.
 
 
 
 
 
 
607
  """
608
  optimizer_model = initialize_llm(optimizer_model_name)
609
  executor_model = initialize_llm(executor_model_name)
@@ -616,34 +700,90 @@ def process_message_with_2_llms(user_message: str, expected_output: str,
616
  NODE_PROMPT_ANALYZER: optimizer_model,
617
  NODE_PROMPT_SUGGESTER: optimizer_model
618
  }
619
- return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
620
- recursion_limit, max_output_age, llms, prompt_template_group, aggressive_exploration)
621
-
622
-
623
- def process_message_with_expert_llms(user_message: str, expected_output: str,
624
- acceptance_criteria: str, initial_system_message: str,
625
- recursion_limit: int, max_output_age: int,
626
- initial_developer_model_name: str, initial_developer_temperature: float,
627
- acceptance_criteria_model_name: str, acceptance_criteria_temperature: float,
628
- developer_model_name: str, developer_temperature: float,
629
- executor_model_name: str, executor_temperature: float,
630
- output_history_analyzer_model_name: str, output_history_analyzer_temperature: float,
631
- analyzer_model_name: str, analyzer_temperature: float,
632
- suggester_model_name: str, suggester_temperature: float,
633
- prompt_template_group: Optional[str] = None,
634
- aggressive_exploration: bool = False) -> tuple:
 
 
 
 
 
635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
  llms = {
637
- NODE_PROMPT_INITIAL_DEVELOPER: initialize_llm(initial_developer_model_name, {"temperature": initial_developer_temperature}),
638
- NODE_ACCEPTANCE_CRITERIA_DEVELOPER: initialize_llm(acceptance_criteria_model_name, {"temperature": acceptance_criteria_temperature}),
639
- NODE_PROMPT_DEVELOPER: initialize_llm(developer_model_name, {"temperature": developer_temperature}),
640
- NODE_PROMPT_EXECUTOR: initialize_llm(executor_model_name, {"temperature": executor_temperature}),
641
- NODE_OUTPUT_HISTORY_ANALYZER: initialize_llm(output_history_analyzer_model_name, {"temperature": output_history_analyzer_temperature}),
642
- NODE_PROMPT_ANALYZER: initialize_llm(analyzer_model_name, {"temperature": analyzer_temperature}),
643
- NODE_PROMPT_SUGGESTER: initialize_llm(suggester_model_name, {"temperature": suggester_temperature})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
  }
645
- return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
646
- recursion_limit, max_output_age, llms, prompt_template_group, aggressive_exploration)
 
 
 
 
 
 
 
 
 
647
 
648
 
649
  class FileConfig(BaseConfig):
@@ -916,18 +1056,26 @@ with gr.Blocks(title='Meta Prompt') as demo:
916
 
917
  evaluate_initial_system_message_button.click(
918
  evaluate_system_message,
919
- inputs=[initial_system_message_input, user_message_input,
920
- simple_model_name_input,
921
- advanced_executor_model_name_input,
922
- expert_prompt_executor_model_name_input, expert_prompt_executor_temperature_input],
 
 
 
 
923
  outputs=[output_output]
924
  )
925
  evaluate_system_message_button.click(
926
  evaluate_system_message,
927
- inputs=[system_message_output, user_message_input,
928
- simple_model_name_input,
929
- advanced_executor_model_name_input,
930
- expert_prompt_executor_model_name_input, expert_prompt_executor_temperature_input],
 
 
 
 
931
  outputs=[output_output]
932
  )
933
  copy_to_initial_system_message_button.click(
@@ -995,13 +1143,20 @@ with gr.Blocks(title='Meta Prompt') as demo:
995
  initial_system_message_input,
996
  recursion_limit_input,
997
  max_output_age,
998
- expert_prompt_initial_developer_model_name_input, expert_prompt_initial_developer_temperature_input,
999
- expert_prompt_acceptance_criteria_model_name_input, expert_prompt_acceptance_criteria_temperature_input,
1000
- expert_prompt_developer_model_name_input, expert_prompt_developer_temperature_input,
1001
- expert_prompt_executor_model_name_input, expert_prompt_executor_temperature_input,
1002
- expert_output_history_analyzer_model_name_input, expert_output_history_analyzer_temperature_input,
1003
- expert_prompt_analyzer_model_name_input, expert_prompt_analyzer_temperature_input,
1004
- expert_prompt_suggester_model_name_input, expert_prompt_suggester_temperature_input,
 
 
 
 
 
 
 
1005
  prompt_template_group,
1006
  aggressive_exploration
1007
  ],
 
3
  import json
4
  import logging
5
  from pathlib import Path
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
  import gradio as gr
9
  from gradio import CSVLogger, Button, utils
10
  from gradio.flagging import FlagMethod
11
  from gradio_client import utils as client_utils
12
+
13
  from confz import BaseConfig, CLArgSource, EnvSource, FileSource
14
  from app.config import MetaPromptConfig, RoleMessage
15
  from langchain_core.language_models import BaseLanguageModel
16
  from langchain_core.prompts import ChatPromptTemplate
17
+ from langchain_openai import ChatOpenAI # Don't remove this import
18
  from meta_prompt import *
19
  from pythonjsonlogger import jsonlogger
 
 
 
20
 
21
+ def prompt_templates_confz2langchain(
22
+ prompt_templates: Dict[str, Dict[str, List[RoleMessage]]]
23
+ ) -> Dict[str, ChatPromptTemplate]:
24
  """
25
+ Convert a dictionary of prompt templates from the configuration format to
26
+ the language chain format.
27
 
28
+ This function takes a dictionary of prompt templates in the configuration
29
+ format and converts them to the language chain format. Each prompt template
30
+ is converted to a ChatPromptTemplate object, which is then stored in a new
31
+ dictionary with the same keys.
32
 
33
  Args:
34
+ prompt_templates (Dict[str, Dict[str, List[RoleMessage]]]):
35
+ A dictionary of prompt templates in the configuration format.
36
 
37
  Returns:
38
+ Dict[str, ChatPromptTemplate]:
39
+ A dictionary of prompt templates in the language chain format.
40
  """
41
  return {
42
+ node: ChatPromptTemplate.from_messages(
43
+ [
44
+ (role_message.role, role_message.message)
45
+ for role_message in role_messages
46
+ ]
47
+ )
48
  for node, role_messages in prompt_templates.items()
49
  }
50
 
51
  class SimplifiedCSVLogger(CSVLogger):
52
  """
53
+ A subclass of CSVLogger that logs only the components data to a CSV file,
54
+ excluding flag, username, and timestamp information.
55
  """
56
 
57
  def flag(
 
70
 
71
  csv_data = []
72
  for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
73
+ save_dir = Path(flagging_dir) / client_utils.strip_invalid_filename_characters(
 
 
74
  getattr(component, "label", None) or f"component {idx}"
75
  )
76
  if utils.is_prop_update(sample):
77
  csv_data.append(str(sample))
78
  else:
79
+ data = component.flag(sample, flag_dir=save_dir) if sample is not None else ""
 
 
 
 
80
  if self.simplify_file_data:
81
  data = utils.simplify_file_data_in_str(data)
82
  csv_data.append(data)
 
130
  return model_class(**kwargs)
131
 
132
 
133
+ def chat_log_2_chatbot_list(chat_log: str) -> List[List[str]]:
134
+ """
135
+ Convert a chat log string into a list of dialogues for the Chatbot format.
136
 
137
  Args:
138
+ chat_log (str): A JSON formatted chat log where each line represents an
139
+ action with its message. Expected actions are 'invoke'
140
+ and 'response'.
141
 
142
  Returns:
143
+ List[List[str]]: A list of dialogue pairs where the first element is a
144
+ user input and the second element is a bot response.
145
+ If the action was 'invoke', the first element will be
146
+ the message, and the second element will be None. If
147
+ the action was 'response', the first element will be
148
+ None, and the second element will be the message.
149
  """
 
150
  chatbot_list = []
151
  if chat_log is None or chat_log == '':
152
  return chatbot_list
 
248
  raise RuntimeError(f"Failed to retrieve the model: {e}")
249
 
250
 
251
+ def evaluate_system_message(system_message, user_message, simple_model,
252
+ advanced_executor_model, expert_executor_model,
253
+ expert_executor_model_temperature=0.1):
 
254
  """
255
+ Evaluate a system message by using it to generate a response from an
256
+ executor model based on the current active tab and provided user message.
257
 
258
+ This function retrieves the appropriate language model (LLM) for the
259
+ current active model tab, formats a chat prompt template with the system
260
+ message and user message, invokes the LLM using this formatted prompt, and
261
+ returns the content of the output if it exists.
262
 
263
  Args:
264
+ system_message (str): The system message to use when evaluating the
265
+ response.
266
+ user_message (str): The user's input message for which a response will
267
+ be generated.
268
+ simple_model (str): The name of the simple language model. This should
269
+ correspond to a key in the 'llms' section of the application's
270
+ configuration.
271
+ advanced_executor_model (str): The name of the advanced language model.
272
+ This should correspond to a key in the 'llms' section of the
273
+ application's configuration.
274
+ expert_executor_model (str): The name of the expert language model.
275
+ This should correspond to a key in the 'llms' section of the
276
+ application's configuration.
277
+ expert_executor_model_temperature (float, optional): The temperature
278
+ parameter for the expert executor model. Defaults to 0.1.
279
 
280
  Returns:
281
+ str: The content of the output generated by the LLM based on the system
282
+ message and user message, if it exists; otherwise, an empty string.
283
 
284
  Raises:
285
+ gr.Error: If there is a Gradio-specific error during the execution of
286
+ this function.
287
+ Exception: For any other unexpected errors that occur during the
288
+ execution of this function.
289
  """
290
+ llm = get_current_model(simple_model, advanced_executor_model,
291
+ expert_executor_model,
292
+ {"temperature": expert_executor_model_temperature})
293
  template = ChatPromptTemplate.from_messages([
294
  ("system", "{system_message}"),
295
  ("human", "{user_message}")
 
386
  simple_model (str): The name of the simple language model.
387
  advanced_executor_model (str): The name of the advanced language model.
388
  expert_prompt_initial_developer_model (str): The name of the expert language model.
389
+ expert_prompt_initial_developer_temperature (float, optional):
390
+ The temperature parameter for the expert model. Defaults to 0.1.
391
+ prompt_template_group (Optional[str], optional):
392
+ The group of prompt templates to use. Defaults to None.
393
 
394
  Returns:
395
  tuple: A tuple containing the initial system message and the chat log.
 
442
  return system_message, chat_log_2_chatbot_list(log_output)
443
 
444
 
445
+ def process_message(
446
+ user_message: str, expected_output: str, acceptance_criteria: str,
447
+ initial_system_message: str, recursion_limit: int, max_output_age: int,
448
+ llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]],
449
+ prompt_template_group: Optional[str] = None,
450
+ aggressive_exploration: bool = False
451
+ ) -> tuple:
452
  """
453
+ Process a user message by executing the MetaPromptGraph with provided
454
+ language models and input state.
455
+
456
+ This function sets up the initial state of the conversation, logs the
457
+ execution if verbose mode is enabled, and extracts the best system message,
458
+ output, and analysis from the output state of the MetaPromptGraph.
459
 
460
  Args:
461
+ user_message (str): The user's input message to be processed by the
462
+ language model(s).
463
+ expected_output (str): The anticipated response or outcome from the
464
+ language model(s) based on the user's message.
465
+ acceptance_criteria (str): Criteria that determines whether the output
466
+ is acceptable or not.
467
+ initial_system_message (str): Initial instruction given to the language
468
+ model(s) before processing the user's message.
469
+ recursion_limit (int): The maximum number of times the MetaPromptGraph
470
+ can call itself recursively.
471
+ max_output_age (int): The maximum age of output messages that should be
472
+ considered in the conversation history.
473
+ llms (Union[BaseLanguageModel, Dict[str, BaseLanguageModel]]): A single
474
+ language model or a dictionary of language models to use for
475
+ processing the user's message.
476
+ prompt_template_group (Optional[str], optional): The group of prompt
477
+ templates to use. Defaults to None.
478
+ aggressive_exploration (bool, optional): Whether to use aggressive
479
+ exploration. Defaults to False.
480
 
481
  Returns:
482
+ tuple: A tuple containing the best system message, output, analysis,
483
+ acceptance criteria, and chat log in JSON format.
 
 
 
484
  """
485
  input_state = AgentState(
486
  user_message=user_message,
 
534
  """
535
  Initialize and return a language model (LLM) based on its name.
536
 
537
+ This function retrieves the configuration for the specified language model
538
+ from the application's configuration, creates an instance of the appropriate
539
+ type of language model using that configuration, and returns it.
540
 
541
  Args:
542
+ model_name (str): The name of the language model to initialize. This
543
+ should correspond to a key in the 'llms' section of the application's
544
+ configuration.
545
+ model_config (Optional[Dict[str, Any]], optional): Optional model
546
+ configurations. Defaults to None.
547
 
548
  Returns:
549
+ Any: An instance of the specified type of language model, initialized
550
+ with its configured settings.
551
 
552
  Raises:
553
  KeyError: If no configuration exists for the specified model name.
554
+ NotImplementedError: If an unrecognized type is configured for the
555
+ language model. This should not occur under normal circumstances
556
+ because the LLMModelFactory class checks and validates the type when
557
+ creating a new language model.
558
  """
559
  try:
560
  llm_config = config.llms[model_name]
561
  model_type = llm_config.type
562
  dumped_config = llm_config.model_dump(exclude={'type'})
563
+
564
  if model_config:
565
  dumped_config.update(model_config)
566
+
567
  return LLMModelFactory().create(model_type, **dumped_config)
568
  except KeyError:
569
  raise KeyError(f"No configuration exists for the model name: {model_name}")
570
  except NotImplementedError:
571
+ raise NotImplementedError(
572
+ f"Unrecognized type configured for the language model: {model_type}"
573
+ )
574
 
575
 
576
+ def process_message_with_single_llm(
577
+ user_message: str, expected_output: str, acceptance_criteria: str,
578
+ initial_system_message: str, recursion_limit: int, max_output_age: int,
579
+ model_name: str, prompt_template_group: Optional[str] = None,
580
+ aggressive_exploration: bool = False
581
+ ) -> tuple:
582
  """
583
  Process a user message using a single language model.
584
 
585
+ This function initializes a language model based on the provided model name
586
+ and uses it to process the user's message. The function takes in additional
587
+ parameters such as the user's message, expected output, acceptance criteria,
588
+ initial system message, recursion limit, and max output age. It then calls
589
+ the `process_message` function with the initialized language model to obtain
590
+ the best system message, output, analysis, and chat log.
591
 
592
  Parameters:
593
+ user_message (str): The user's input message to be processed by the language
594
+ model.
595
+ expected_output (str): The anticipated response or outcome from the language
596
+ model based on the user's message.
597
+ acceptance_criteria (str): Criteria that determines whether the output is
598
+ acceptable or not.
599
+ initial_system_message (str): Initial instruction given to the language
600
+ model before processing the user's message.
601
+ recursion_limit (int): The maximum number of times the MetaPromptGraph can
602
+ call itself recursively.
603
+ max_output_age (int): The maximum age of output messages that should be
604
+ considered in the conversation history.
605
+ model_name (str): The name of the language model to initialize and use for
606
+ processing the user's message. This should correspond to a key in the
607
+ 'llms' section of the application's configuration.
608
+ prompt_template_group (Optional[str], optional): The name of the prompt
609
+ template group to use for processing the user's message. Defaults to None.
610
+ aggressive_exploration (bool, optional): Whether to use aggressive
611
+ exploration techniques. Defaults to False.
612
 
613
  Returns:
614
+ tuple: A tuple containing the best system message, output, analysis, and
615
+ chat log in JSON format.
616
+ - best_system_message (str): The system message that resulted in the
617
+ most appropriate response based on the acceptance criteria.
618
+ - best_output (str): The output generated by the language model that
619
+ best meets the expected outcome and acceptance criteria.
620
+ - analysis (str): An analysis of how well the generated output
621
+ matches the expected output and acceptance criteria.
622
+ - chat_log (list): A list containing JSON objects representing the
623
+ conversation log, with each object containing a timestamp, logger
624
+ name, levelname, and message.
625
  """
626
  llm = initialize_llm(model_name)
627
+ return process_message(
628
+ user_message, expected_output, acceptance_criteria, initial_system_message,
629
+ recursion_limit, max_output_age, llm, prompt_template_group, aggressive_exploration
630
+ )
631
 
632
 
633
+ def process_message_with_2_llms(
634
+ user_message: str, expected_output: str, acceptance_criteria: str,
635
+ initial_system_message: str, recursion_limit: int, max_output_age: int,
636
+ optimizer_model_name: str, executor_model_name: str,
637
+ prompt_template_group: Optional[str] = None,
638
+ aggressive_exploration: bool = False
639
+ ) -> tuple:
640
  """
641
+ Process a user message using two language models - one for optimization and
642
+ another for execution.
643
 
644
+ This function initializes the specified optimizer and executor language
645
+ models and then uses them to process the user's message along with other
646
+ provided input parameters such as expected output, acceptance criteria,
647
+ initial system message, recursion limit, and max output age. The result is
648
+ obtained by calling the `process_message` function with a dictionary of
649
+ language models where all nodes except for NODE_PROMPT_EXECUTOR use the
650
+ optimizer model and NODE_PROMPT_EXECUTOR uses the executor model.
651
 
652
  Args:
653
+ user_message (str): The user's input message to be processed by the
654
+ language models.
655
+ expected_output (str): The anticipated response or outcome from the
656
+ language models based on the user's message.
657
+ acceptance_criteria (str): Criteria that determines whether the output
658
+ is acceptable or not.
659
+ initial_system_message (str): Initial instruction given to the language
660
+ models before processing the user's message.
661
+ recursion_limit (int): The maximum number of times the MetaPromptGraph
662
+ can call itself recursively.
663
+ max_output_age (int): The maximum age of output messages that should be
664
+ considered in the conversation history.
665
+ optimizer_model_name (str): The name of the language model to initialize
666
+ and use for optimization tasks like prompt development, analysis,
667
+ and suggestion. This should correspond to a key in the 'llms' section
668
+ of the application's configuration.
669
+ executor_model_name (str): The name of the language model to initialize
670
+ and use for execution tasks like running code or providing final
671
+ outputs. This should correspond to a key in the 'llms' section of the
672
+ application's configuration.
673
+ prompt_template_group (Optional[str], optional): The name of the prompt
674
+ template group to use for processing the user's message. Defaults to
675
+ None.
676
+ aggressive_exploration (bool, optional): Whether to use aggressive
677
+ exploration techniques. Defaults to False.
678
 
679
  Returns:
680
+ tuple: A tuple containing the best system message, output, analysis, and
681
+ chat log in JSON format.
682
+ - best_system_message (str): The system message that resulted in the
683
+ most appropriate response based on the acceptance criteria.
684
+ - best_output (str): The output generated by the language models that
685
+ best meets the expected outcome and acceptance criteria.
686
+ - analysis (str): An analysis of how well the generated output
687
+ matches the expected output and acceptance criteria.
688
+ - chat_log (list): A list containing JSON objects representing the
689
+ conversation log, with each object containing a timestamp,
690
+ logger name, levelname, and message.
691
  """
692
  optimizer_model = initialize_llm(optimizer_model_name)
693
  executor_model = initialize_llm(executor_model_name)
 
700
  NODE_PROMPT_ANALYZER: optimizer_model,
701
  NODE_PROMPT_SUGGESTER: optimizer_model
702
  }
703
+ return process_message(
704
+ user_message, expected_output, acceptance_criteria,
705
+ initial_system_message, recursion_limit, max_output_age, llms,
706
+ prompt_template_group, aggressive_exploration
707
+ )
708
+
709
+
710
+ def process_message_with_expert_llms(
711
+ user_message: str, expected_output: str, acceptance_criteria: str,
712
+ initial_system_message: str, recursion_limit: int, max_output_age: int,
713
+ initial_developer_model_name: str, initial_developer_temperature: float,
714
+ acceptance_criteria_model_name: str, acceptance_criteria_temperature: float,
715
+ developer_model_name: str, developer_temperature: float,
716
+ executor_model_name: str, executor_temperature: float,
717
+ output_history_analyzer_model_name: str, output_history_analyzer_temperature: float,
718
+ analyzer_model_name: str, analyzer_temperature: float,
719
+ suggester_model_name: str, suggester_temperature: float,
720
+ prompt_template_group: Optional[str] = None, aggressive_exploration: bool = False
721
+ ) -> tuple:
722
+ """
723
+ Process a message using expert language models with specified temperatures.
724
 
725
+ Args:
726
+ user_message (str): The user's input message.
727
+ expected_output (str): The anticipated response or outcome from the language model.
728
+ acceptance_criteria (str): Criteria for accepting the generated output.
729
+ initial_system_message (str): The initial system message to use.
730
+ recursion_limit (int): The maximum number of recursive calls.
731
+ max_output_age (int): The maximum age of output messages to consider.
732
+ initial_developer_model_name (str): The name of the initial developer model.
733
+ initial_developer_temperature (float): The temperature for the initial developer model.
734
+ acceptance_criteria_model_name (str): The name of the acceptance criteria model.
735
+ acceptance_criteria_temperature (float): The temperature for the acceptance criteria model.
736
+ developer_model_name (str): The name of the developer model.
737
+ developer_temperature (float): The temperature for the developer model.
738
+ executor_model_name (str): The name of the executor model.
739
+ executor_temperature (float): The temperature for the executor model.
740
+ output_history_analyzer_model_name (str): The name of the output history analyzer model.
741
+ output_history_analyzer_temperature (float): The temperature for the output history analyzer model.
742
+ analyzer_model_name (str): The name of the analyzer model.
743
+ analyzer_temperature (float): The temperature for the analyzer model.
744
+ suggester_model_name (str): The name of the suggester model.
745
+ suggester_temperature (float): The temperature for the suggester model.
746
+ prompt_template_group (Optional[str], optional): The group of prompt templates to use. Defaults to None.
747
+ aggressive_exploration (bool, optional): Whether to use aggressive exploration. Defaults to False.
748
+
749
+ Returns:
750
+ tuple: A tuple containing the processed message results.
751
+ """
752
  llms = {
753
+ NODE_PROMPT_INITIAL_DEVELOPER: initialize_llm(
754
+ initial_developer_model_name, {"temperature": initial_developer_temperature}
755
+ ),
756
+ NODE_ACCEPTANCE_CRITERIA_DEVELOPER: initialize_llm(
757
+ acceptance_criteria_model_name, {"temperature": acceptance_criteria_temperature}
758
+ ),
759
+ NODE_PROMPT_DEVELOPER: initialize_llm(
760
+ developer_model_name, {"temperature": developer_temperature}
761
+ ),
762
+ NODE_PROMPT_EXECUTOR: initialize_llm(
763
+ executor_model_name, {"temperature": executor_temperature}
764
+ ),
765
+ NODE_OUTPUT_HISTORY_ANALYZER: initialize_llm(
766
+ output_history_analyzer_model_name,
767
+ {"temperature": output_history_analyzer_temperature}
768
+ ),
769
+ NODE_PROMPT_ANALYZER: initialize_llm(
770
+ analyzer_model_name, {"temperature": analyzer_temperature}
771
+ ),
772
+ NODE_PROMPT_SUGGESTER: initialize_llm(
773
+ suggester_model_name, {"temperature": suggester_temperature}
774
+ )
775
  }
776
+ return process_message(
777
+ user_message,
778
+ expected_output,
779
+ acceptance_criteria,
780
+ initial_system_message,
781
+ recursion_limit,
782
+ max_output_age,
783
+ llms,
784
+ prompt_template_group,
785
+ aggressive_exploration
786
+ )
787
 
788
 
789
  class FileConfig(BaseConfig):
 
1056
 
1057
  evaluate_initial_system_message_button.click(
1058
  evaluate_system_message,
1059
+ inputs=[
1060
+ initial_system_message_input,
1061
+ user_message_input,
1062
+ simple_model_name_input,
1063
+ advanced_executor_model_name_input,
1064
+ expert_prompt_executor_model_name_input,
1065
+ expert_prompt_executor_temperature_input
1066
+ ],
1067
  outputs=[output_output]
1068
  )
1069
  evaluate_system_message_button.click(
1070
  evaluate_system_message,
1071
+ inputs=[
1072
+ system_message_output,
1073
+ user_message_input,
1074
+ simple_model_name_input,
1075
+ advanced_executor_model_name_input,
1076
+ expert_prompt_executor_model_name_input,
1077
+ expert_prompt_executor_temperature_input
1078
+ ],
1079
  outputs=[output_output]
1080
  )
1081
  copy_to_initial_system_message_button.click(
 
1143
  initial_system_message_input,
1144
  recursion_limit_input,
1145
  max_output_age,
1146
+ expert_prompt_initial_developer_model_name_input,
1147
+ expert_prompt_initial_developer_temperature_input,
1148
+ expert_prompt_acceptance_criteria_model_name_input,
1149
+ expert_prompt_acceptance_criteria_temperature_input,
1150
+ expert_prompt_developer_model_name_input,
1151
+ expert_prompt_developer_temperature_input,
1152
+ expert_prompt_executor_model_name_input,
1153
+ expert_prompt_executor_temperature_input,
1154
+ expert_output_history_analyzer_model_name_input,
1155
+ expert_output_history_analyzer_temperature_input,
1156
+ expert_prompt_analyzer_model_name_input,
1157
+ expert_prompt_analyzer_temperature_input,
1158
+ expert_prompt_suggester_model_name_input,
1159
+ expert_prompt_suggester_temperature_input,
1160
  prompt_template_group,
1161
  aggressive_exploration
1162
  ],
meta_prompt/meta_prompt.py CHANGED
@@ -1,16 +1,14 @@
1
- import typing
2
- import pprint
3
  import logging
4
  import operator
5
- from typing import Dict, Any, Callable, List, Union, Optional, Annotated
6
  from langchain_core.language_models import BaseLanguageModel
7
- from langchain_core.messages import HumanMessage, SystemMessage
8
  from langchain_core.prompts import ChatPromptTemplate
9
- from langgraph.graph import StateGraph, START, END
10
  from langgraph.checkpoint.memory import MemorySaver
11
  from langgraph.errors import GraphRecursionError
 
12
  from langchain_core.runnables.base import RunnableLike
13
  from pydantic import BaseModel
 
14
  from .consts import *
15
 
16
  def first_non_empty(a, b):
 
 
 
1
  import logging
2
  import operator
3
+ import pprint
4
  from langchain_core.language_models import BaseLanguageModel
 
5
  from langchain_core.prompts import ChatPromptTemplate
 
6
  from langgraph.checkpoint.memory import MemorySaver
7
  from langgraph.errors import GraphRecursionError
8
+ from langgraph.graph import StateGraph, START, END
9
  from langchain_core.runnables.base import RunnableLike
10
  from pydantic import BaseModel
11
+ from typing import Annotated, Dict, Optional, Union
12
  from .consts import *
13
 
14
  def first_non_empty(a, b):
tests/meta_prompt_graph_test.py CHANGED
@@ -1,15 +1,12 @@
1
  import unittest
2
- import pprint
3
- import logging
4
- import functools
5
  from unittest.mock import MagicMock, Mock
 
 
6
  from langchain_core.language_models import BaseLanguageModel
7
  from langchain_openai import ChatOpenAI
8
-
9
- # Assuming the necessary imports are made for the classes and functions used in meta_prompt_graph.py
10
  from meta_prompt import *
11
  from meta_prompt.consts import NODE_ACCEPTANCE_CRITERIA_DEVELOPER
12
- from langgraph.graph import StateGraph, END
13
 
14
  class TestMetaPromptGraph(unittest.TestCase):
15
  def setUp(self):
@@ -21,45 +18,50 @@ class TestMetaPromptGraph(unittest.TestCase):
21
  """
22
  Test the _prompt_node method of MetaPromptGraph.
23
 
24
- This test case sets up a mock language model that returns a response content and verifies that the
25
- updated state has the output attribute updated with the mocked response content.
 
26
  """
27
  llms = {
28
  NODE_PROMPT_INITIAL_DEVELOPER: MagicMock(
29
- invoke=MagicMock(return_value=MagicMock(content="Mocked response content"))
 
 
30
  )
31
  }
32
 
33
- # Create an instance of MetaPromptGraph with the mocked language model and template
34
  graph = MetaPromptGraph(llms=llms)
35
-
36
- # Create a mock AgentState
37
- state = AgentState(user_message="Test message", expected_output="Expected output")
38
-
39
- # Invoke the _prompt_node method with the mock node, target attribute, and state
40
  updated_state = graph._prompt_node(
41
  NODE_PROMPT_INITIAL_DEVELOPER, "output", state
42
  )
43
 
44
- # Assertions
45
- assert updated_state.output == "Mocked response content", \
46
- "The output attribute should be updated with the mocked response content"
47
 
48
 
49
  def test_output_history_analyzer(self):
50
  """
51
  Test the _output_history_analyzer method of MetaPromptGraph.
52
 
53
- This test case sets up a mock language model that returns an analysis response and verifies that the
54
- updated state has the best output, best system message, and best output age updated correctly.
 
55
  """
56
- # Setup
57
  llms = {
58
- "output_history_analyzer": MagicMock(invoke=lambda prompt: MagicMock(content="""# Analysis
 
 
59
 
60
- This analysis compares two outputs to the expected output based on specific criteria.
 
61
 
62
- # Output ID closer to Expected Output: B"""))
 
 
63
  }
64
  prompts = {}
65
  meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompts)
@@ -70,43 +72,50 @@ class TestMetaPromptGraph(unittest.TestCase):
70
  system_message="To reverse a list, use slicing or the reverse method.",
71
  best_output="To reverse a list in Python, use the `reverse()` method.",
72
  best_system_message="To reverse a list, use the `reverse()` method.",
73
- acceptance_criteria="The output should correctly describe how to reverse a list in Python."
74
  )
75
 
76
- # Invoke the output history analyzer node
77
  updated_state = meta_prompt_graph._output_history_analyzer(state)
78
 
79
- # Assertions
80
- assert updated_state.best_output == state.output, \
81
- "Best output should be updated to the current output."
82
- assert updated_state.best_system_message == state.system_message, \
83
- "Best system message should be updated to the current system message."
84
- assert updated_state.best_output_age == 0, \
85
- "Best output age should be reset to 0."
 
 
86
 
87
 
88
  def test_prompt_analyzer_accept(self):
89
  """
90
- Test the _prompt_analyzer method of MetaPromptGraph when the prompt analyzer accepts the output.
 
91
 
92
- This test case sets up a mock language model that returns an acceptance response and verifies that the
93
- updated state has the accepted attribute set to True.
 
94
  """
95
  llms = {
96
  NODE_PROMPT_ANALYZER: MagicMock(
97
- invoke=lambda prompt: MagicMock(content="Accept: Yes"))
 
98
  }
99
- meta_prompt_graph = MetaPromptGraph(llms)
100
- state = AgentState(output="Test output", expected_output="Expected output")
 
 
101
  updated_state = meta_prompt_graph._prompt_analyzer(state)
102
- assert updated_state.accepted == True
103
 
104
 
105
  def test_get_node_names(self):
106
  """
107
  Test the get_node_names method of MetaPromptGraph.
108
 
109
- This test case verifies that the get_node_names method returns the correct list of node names.
 
110
  """
111
  graph = MetaPromptGraph()
112
  node_names = graph.get_node_names()
@@ -121,40 +130,36 @@ class TestMetaPromptGraph(unittest.TestCase):
121
  executes it with a given input state. It then verifies that the output
122
  state contains the expected keys and values.
123
  """
124
- # MODEL_NAME = "anthropic/claude-3.5-sonnet:beta"
125
- # MODEL_NAME = "meta-llama/llama-3-70b-instruct"
126
- # MODEL_NAME = "deepseek/deepseek-chat"
127
- MODEL_NAME = "google/gemma-2-9b-it"
128
- # MODEL_NAME = "recursal/eagle-7b"
129
- # MODEL_NAME = "meta-llama/llama-3-8b-instruct"
130
- llm = ChatOpenAI(model_name=MODEL_NAME)
131
 
132
  meta_prompt_graph = MetaPromptGraph(llms=llm)
133
  input_state = AgentState(
134
  user_message="How do I reverse a list in Python?",
135
- expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
 
136
  acceptance_criteria="Similar in meaning, text length and style."
137
- )
138
  output_state = meta_prompt_graph(input_state, recursion_limit=25)
139
 
140
  pprint.pp(output_state)
141
- # if output_state has key 'best_system_message', print it
142
- assert 'best_system_message' in output_state, \
143
- "The output state should contain the key 'best_system_message'"
144
- assert output_state['best_system_message'] is not None, \
145
- "The best system message should not be None"
146
- if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
147
- print(output_state['best_system_message'])
148
-
149
- # try another similar user message with the generated system message
 
 
 
150
  user_message = "How can I create a list of numbers in Python?"
151
- messages = [("system", output_state['best_system_message']),
152
- ("human", user_message)]
153
  result = llm.invoke(messages)
154
 
155
- # assert attr 'content' in result
156
- assert hasattr(result, 'content'), \
157
- "The result should have the attribute 'content'"
158
  print(result.content)
159
 
160
 
@@ -166,8 +171,12 @@ class TestMetaPromptGraph(unittest.TestCase):
166
  executes it with a given input state. It then verifies that the output
167
  state contains the expected keys and values.
168
  """
169
- optimizer_llm = ChatOpenAI(model_name="deepseek/deepseek-chat", temperature=0.5)
170
- executor_llm = ChatOpenAI(model_name="meta-llama/llama-3-8b-instruct", temperature=0.01)
 
 
 
 
171
 
172
  llms = {
173
  NODE_PROMPT_INITIAL_DEVELOPER: optimizer_llm,
@@ -175,35 +184,36 @@ class TestMetaPromptGraph(unittest.TestCase):
175
  NODE_PROMPT_EXECUTOR: executor_llm,
176
  NODE_OUTPUT_HISTORY_ANALYZER: optimizer_llm,
177
  NODE_PROMPT_ANALYZER: optimizer_llm,
178
- NODE_PROMPT_SUGGESTER: optimizer_llm
179
  }
180
 
181
  meta_prompt_graph = MetaPromptGraph(llms=llms)
182
  input_state = AgentState(
183
  user_message="How do I reverse a list in Python?",
184
- expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
 
185
  acceptance_criteria="Similar in meaning, text length and style."
186
- )
187
  output_state = meta_prompt_graph(input_state, recursion_limit=25)
188
 
189
  pprint.pp(output_state)
190
- # if output_state has key 'best_system_message', print it
191
- assert 'best_system_message' in output_state, \
192
- "The output state should contain the key 'best_system_message'"
193
- assert output_state['best_system_message'] is not None, \
194
- "The best system message should not be None"
195
- if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
196
- print(output_state['best_system_message'])
197
-
198
- # try another similar user message with the generated system message
 
 
 
199
  user_message = "How can I create a list of numbers in Python?"
200
- messages = [("system", output_state['best_system_message']),
201
- ("human", user_message)]
202
  result = executor_llm.invoke(messages)
203
 
204
- # assert attr 'content' in result
205
- assert hasattr(result, 'content'), \
206
- "The result should have the attribute 'content'"
207
  print(result.content)
208
 
209
 
@@ -318,14 +328,15 @@ class TestMetaPromptGraph(unittest.TestCase):
318
 
319
 
320
  def test_run_acceptance_criteria_graph(self):
321
- """
322
- Test the run_acceptance_criteria_graph method of MetaPromptGraph.
323
 
324
- This test case verifies that the run_acceptance_criteria_graph method returns a state with acceptance criteria.
 
325
  """
326
  llms = {
327
  NODE_ACCEPTANCE_CRITERIA_DEVELOPER: MagicMock(
328
- invoke=lambda prompt: MagicMock(content="Acceptance criteria: ..."))
 
329
  }
330
  meta_prompt_graph = MetaPromptGraph(llms=llms)
331
  state = AgentState(
@@ -335,21 +346,22 @@ class TestMetaPromptGraph(unittest.TestCase):
335
  output_state = meta_prompt_graph.run_acceptance_criteria_graph(state)
336
 
337
  # Check if the output state contains the acceptance criteria
338
- self.assertIsNotNone(output_state['acceptance_criteria'])
339
 
340
  # Check if the acceptance criteria includes the expected content
341
- self.assertIn("Acceptance criteria: ...", output_state['acceptance_criteria'])
342
 
343
 
344
  def test_run_prompt_initial_developer_graph(self):
345
- """
346
- Test the run_prompt_initial_developer_graph method of MetaPromptGraph.
347
 
348
- This test case verifies that the run_prompt_initial_developer_graph method returns a state with an initial developer prompt.
 
349
  """
350
  llms = {
351
  NODE_PROMPT_INITIAL_DEVELOPER: MagicMock(
352
- invoke=lambda prompt: MagicMock(content="Initial developer prompt: ..."))
 
353
  }
354
  meta_prompt_graph = MetaPromptGraph(llms=llms)
355
  state = AgentState(user_message="How do I reverse a list in Python?")
 
1
  import unittest
 
 
 
2
  from unittest.mock import MagicMock, Mock
3
+ import functools
4
+ import pprint
5
  from langchain_core.language_models import BaseLanguageModel
6
  from langchain_openai import ChatOpenAI
 
 
7
  from meta_prompt import *
8
  from meta_prompt.consts import NODE_ACCEPTANCE_CRITERIA_DEVELOPER
9
+ from langgraph.graph import END
10
 
11
  class TestMetaPromptGraph(unittest.TestCase):
12
  def setUp(self):
 
18
  """
19
  Test the _prompt_node method of MetaPromptGraph.
20
 
21
+ This test case sets up a mock language model that returns a response content
22
+ and verifies that the updated state has the output attribute updated with
23
+ the mocked response content.
24
  """
25
  llms = {
26
  NODE_PROMPT_INITIAL_DEVELOPER: MagicMock(
27
+ invoke=MagicMock(
28
+ return_value=MagicMock(content="Mocked response content")
29
+ )
30
  )
31
  }
32
 
 
33
  graph = MetaPromptGraph(llms=llms)
34
+ state = AgentState(
35
+ user_message="Test message", expected_output="Expected output"
36
+ )
 
 
37
  updated_state = graph._prompt_node(
38
  NODE_PROMPT_INITIAL_DEVELOPER, "output", state
39
  )
40
 
41
+ assert (
42
+ updated_state.output == "Mocked response content"
43
+ ), "The output attribute should be updated with the mocked response content"
44
 
45
 
46
  def test_output_history_analyzer(self):
47
  """
48
  Test the _output_history_analyzer method of MetaPromptGraph.
49
 
50
+ This test case sets up a mock language model that returns an analysis
51
+ response and verifies that the updated state has the best output, best
52
+ system message, and best output age updated correctly.
53
  """
 
54
  llms = {
55
+ "output_history_analyzer": MagicMock(
56
+ invoke=lambda prompt: MagicMock(
57
+ content="""# Analysis
58
 
59
+ This analysis compares two outputs to the expected output based on specific
60
+ criteria.
61
 
62
+ # Output ID closer to Expected Output: B"""
63
+ )
64
+ )
65
  }
66
  prompts = {}
67
  meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompts)
 
72
  system_message="To reverse a list, use slicing or the reverse method.",
73
  best_output="To reverse a list in Python, use the `reverse()` method.",
74
  best_system_message="To reverse a list, use the `reverse()` method.",
75
+ acceptance_criteria="The output should correctly describe how to reverse a list in Python.",
76
  )
77
 
 
78
  updated_state = meta_prompt_graph._output_history_analyzer(state)
79
 
80
+ assert (
81
+ updated_state.best_output == state.output
82
+ ), "Best output should be updated to the current output."
83
+ assert (
84
+ updated_state.best_system_message == state.system_message
85
+ ), "Best system message should be updated to the current system message."
86
+ assert (
87
+ updated_state.best_output_age == 0
88
+ ), "Best output age should be reset to 0."
89
 
90
 
91
  def test_prompt_analyzer_accept(self):
92
  """
93
+ Test the _prompt_analyzer method of MetaPromptGraph when the prompt analyzer
94
+ accepts the output.
95
 
96
+ This test case sets up a mock language model that returns an acceptance
97
+ response and verifies that the updated state has the accepted attribute
98
+ set to True.
99
  """
100
  llms = {
101
  NODE_PROMPT_ANALYZER: MagicMock(
102
+ invoke=lambda prompt: MagicMock(content="Accept: Yes")
103
+ )
104
  }
105
+ meta_prompt_graph = MetaPromptGraph(llms=llms)
106
+ state = AgentState(
107
+ output="Test output", expected_output="Expected output"
108
+ )
109
  updated_state = meta_prompt_graph._prompt_analyzer(state)
110
+ assert updated_state.accepted is True
111
 
112
 
113
  def test_get_node_names(self):
114
  """
115
  Test the get_node_names method of MetaPromptGraph.
116
 
117
+ This test case verifies that the get_node_names method returns the
118
+ correct list of node names.
119
  """
120
  graph = MetaPromptGraph()
121
  node_names = graph.get_node_names()
 
130
  executes it with a given input state. It then verifies that the output
131
  state contains the expected keys and values.
132
  """
133
+ model_name = "google/gemma-2-9b-it"
134
+ llm = ChatOpenAI(model_name=model_name)
 
 
 
 
 
135
 
136
  meta_prompt_graph = MetaPromptGraph(llms=llm)
137
  input_state = AgentState(
138
  user_message="How do I reverse a list in Python?",
139
+ expected_output="Use the `[::-1]` slicing technique or the "
140
+ "`list.reverse()` method.",
141
  acceptance_criteria="Similar in meaning, text length and style."
142
+ )
143
  output_state = meta_prompt_graph(input_state, recursion_limit=25)
144
 
145
  pprint.pp(output_state)
146
+ assert (
147
+ "best_system_message" in output_state
148
+ ), "The output state should contain the key 'best_system_message'"
149
+ assert (
150
+ output_state["best_system_message"] is not None
151
+ ), "The best system message should not be None"
152
+ if (
153
+ "best_system_message" in output_state
154
+ and output_state["best_system_message"] is not None
155
+ ):
156
+ print(output_state["best_system_message"])
157
+
158
  user_message = "How can I create a list of numbers in Python?"
159
+ messages = [("system", output_state["best_system_message"]), ("human", user_message)]
 
160
  result = llm.invoke(messages)
161
 
162
+ assert hasattr(result, "content"), "The result should have the attribute 'content'"
 
 
163
  print(result.content)
164
 
165
 
 
171
  executes it with a given input state. It then verifies that the output
172
  state contains the expected keys and values.
173
  """
174
+ optimizer_llm = ChatOpenAI(
175
+ model_name="deepseek/deepseek-chat", temperature=0.5
176
+ )
177
+ executor_llm = ChatOpenAI(
178
+ model_name="meta-llama/llama-3-8b-instruct", temperature=0.01
179
+ )
180
 
181
  llms = {
182
  NODE_PROMPT_INITIAL_DEVELOPER: optimizer_llm,
 
184
  NODE_PROMPT_EXECUTOR: executor_llm,
185
  NODE_OUTPUT_HISTORY_ANALYZER: optimizer_llm,
186
  NODE_PROMPT_ANALYZER: optimizer_llm,
187
+ NODE_PROMPT_SUGGESTER: optimizer_llm,
188
  }
189
 
190
  meta_prompt_graph = MetaPromptGraph(llms=llms)
191
  input_state = AgentState(
192
  user_message="How do I reverse a list in Python?",
193
+ expected_output="Use the `[::-1]` slicing technique or the "
194
+ "`list.reverse()` method.",
195
  acceptance_criteria="Similar in meaning, text length and style."
196
+ )
197
  output_state = meta_prompt_graph(input_state, recursion_limit=25)
198
 
199
  pprint.pp(output_state)
200
+ assert (
201
+ "best_system_message" in output_state
202
+ ), "The output state should contain the key 'best_system_message'"
203
+ assert (
204
+ output_state["best_system_message"] is not None
205
+ ), "The best system message should not be None"
206
+ if (
207
+ "best_system_message" in output_state
208
+ and output_state["best_system_message"] is not None
209
+ ):
210
+ print(output_state["best_system_message"])
211
+
212
  user_message = "How can I create a list of numbers in Python?"
213
+ messages = [("system", output_state["best_system_message"]), ("human", user_message)]
 
214
  result = executor_llm.invoke(messages)
215
 
216
+ assert hasattr(result, "content"), "The result should have the attribute 'content'"
 
 
217
  print(result.content)
218
 
219
 
 
328
 
329
 
330
  def test_run_acceptance_criteria_graph(self):
331
+ """Test the run_acceptance_criteria_graph method of MetaPromptGraph.
 
332
 
333
+ This test case verifies that the run_acceptance_criteria_graph method
334
+ returns a state with acceptance criteria.
335
  """
336
  llms = {
337
  NODE_ACCEPTANCE_CRITERIA_DEVELOPER: MagicMock(
338
+ invoke=lambda prompt: MagicMock(content="Acceptance criteria: ...")
339
+ )
340
  }
341
  meta_prompt_graph = MetaPromptGraph(llms=llms)
342
  state = AgentState(
 
346
  output_state = meta_prompt_graph.run_acceptance_criteria_graph(state)
347
 
348
  # Check if the output state contains the acceptance criteria
349
+ self.assertIsNotNone(output_state["acceptance_criteria"])
350
 
351
  # Check if the acceptance criteria includes the expected content
352
+ self.assertIn("Acceptance criteria: ...", output_state["acceptance_criteria"])
353
 
354
 
355
  def test_run_prompt_initial_developer_graph(self):
356
+ """Test the run_prompt_initial_developer_graph method of MetaPromptGraph.
 
357
 
358
+ This test case verifies that the run_prompt_initial_developer_graph method
359
+ returns a state with an initial developer prompt.
360
  """
361
  llms = {
362
  NODE_PROMPT_INITIAL_DEVELOPER: MagicMock(
363
+ invoke=lambda prompt: MagicMock(content="Initial developer prompt: ...")
364
+ )
365
  }
366
  meta_prompt_graph = MetaPromptGraph(llms=llms)
367
  state = AgentState(user_message="How do I reverse a list in Python?")