yaleh commited on
Commit
1b5bff3
·
1 Parent(s): a9f60b0

Fixed the branch bug of Annotated state for LangGraph.

Browse files
app/config.py CHANGED
@@ -19,6 +19,7 @@ class PromptGroup(BaseModel):
19
 
20
  class MetaPromptConfig(BaseConfig):
21
  llms: Optional[dict[str, LLMConfig]]
 
22
  examples_path: Optional[str]
23
  server_name: Optional[str] = None
24
  server_port: Optional[int] = None
 
19
 
20
  class MetaPromptConfig(BaseConfig):
21
  llms: Optional[dict[str, LLMConfig]]
22
+ aggressive_exploration: Optional[bool] = False
23
  examples_path: Optional[str]
24
  server_name: Optional[str] = None
25
  server_port: Optional[int] = None
app/gradio_meta_prompt.py CHANGED
@@ -423,7 +423,8 @@ def process_message(user_message: str, expected_output: str,
423
  acceptance_criteria: str, initial_system_message: str,
424
  recursion_limit: int, max_output_age: int,
425
  llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]],
426
- prompt_template_group: Optional[str] = None) -> tuple:
 
427
  """
428
  Process a user message by executing the MetaPromptGraph with provided language models and input state.
429
  This function sets up the initial state of the conversation, logs the execution if verbose mode is enabled,
@@ -465,6 +466,7 @@ def process_message(user_message: str, expected_output: str,
465
  prompt_template_group = 'default'
466
  prompt_templates = prompt_templates_confz2langchain(config.prompt_templates[prompt_template_group])
467
  meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompt_templates,
 
468
  verbose=config.verbose, logger=logger)
469
  try:
470
  output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
@@ -532,7 +534,9 @@ def initialize_llm(model_name: str, model_config: Optional[Dict[str, Any]] = Non
532
  def process_message_with_single_llm(user_message: str, expected_output: str,
533
  acceptance_criteria: str, initial_system_message: str,
534
  recursion_limit: int, max_output_age: int,
535
- model_name: str, prompt_template_group: Optional[str] = None) -> tuple:
 
 
536
  """
537
  Process a user message using a single language model.
538
 
@@ -563,14 +567,15 @@ def process_message_with_single_llm(user_message: str, expected_output: str,
563
  """
564
  llm = initialize_llm(model_name)
565
  return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
566
- recursion_limit, max_output_age, llm, prompt_template_group)
567
 
568
 
569
  def process_message_with_2_llms(user_message: str, expected_output: str,
570
  acceptance_criteria: str, initial_system_message: str,
571
  recursion_limit: int, max_output_age: int,
572
  optimizer_model_name: str, executor_model_name: str,
573
- prompt_template_group: Optional[str] = None) -> tuple:
 
574
  """
575
  Process a user message using two language models - one for optimization and another for execution.
576
 
@@ -612,7 +617,7 @@ def process_message_with_2_llms(user_message: str, expected_output: str,
612
  NODE_PROMPT_SUGGESTER: optimizer_model
613
  }
614
  return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
615
- recursion_limit, max_output_age, llms, prompt_template_group)
616
 
617
 
618
  def process_message_with_expert_llms(user_message: str, expected_output: str,
@@ -625,7 +630,8 @@ def process_message_with_expert_llms(user_message: str, expected_output: str,
625
  output_history_analyzer_model_name: str, output_history_analyzer_temperature: float,
626
  analyzer_model_name: str, analyzer_temperature: float,
627
  suggester_model_name: str, suggester_temperature: float,
628
- prompt_template_group: Optional[str] = None) -> tuple:
 
629
 
630
  llms = {
631
  NODE_PROMPT_INITIAL_DEVELOPER: initialize_llm(initial_developer_model_name, {"temperature": initial_developer_temperature}),
@@ -637,7 +643,7 @@ def process_message_with_expert_llms(user_message: str, expected_output: str,
637
  NODE_PROMPT_SUGGESTER: initialize_llm(suggester_model_name, {"temperature": suggester_temperature})
638
  }
639
  return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
640
- recursion_limit, max_output_age, llms, prompt_template_group=prompt_template_group)
641
 
642
 
643
  class FileConfig(BaseConfig):
@@ -725,6 +731,10 @@ with gr.Blocks(title='Meta Prompt') as demo:
725
  choices=list(config.prompt_templates.keys()),
726
  value=list(config.prompt_templates.keys())[0]
727
  )
 
 
 
 
728
  with gr.Row():
729
  with gr.Tabs() as llm_tabs:
730
  with gr.Tab('Simple') as simple_llm_tab:
@@ -888,7 +898,9 @@ with gr.Blocks(title='Meta Prompt') as demo:
888
  inputs=[user_message_input, expected_output_input,
889
  simple_model_name_input,
890
  advanced_optimizer_model_name_input,
891
- expert_prompt_acceptance_criteria_model_name_input, expert_prompt_acceptance_criteria_temperature_input],
 
 
892
  outputs=[acceptance_criteria_input, logs_chatbot]
893
  )
894
  generate_initial_system_message_button.click(
@@ -939,7 +951,8 @@ with gr.Blocks(title='Meta Prompt') as demo:
939
  recursion_limit_input,
940
  max_output_age,
941
  simple_model_name_input,
942
- prompt_template_group
 
943
  ],
944
  outputs=[
945
  system_message_output,
@@ -961,7 +974,8 @@ with gr.Blocks(title='Meta Prompt') as demo:
961
  max_output_age,
962
  advanced_optimizer_model_name_input,
963
  advanced_executor_model_name_input,
964
- prompt_template_group
 
965
  ],
966
  outputs=[
967
  system_message_output,
@@ -988,7 +1002,8 @@ with gr.Blocks(title='Meta Prompt') as demo:
988
  expert_output_history_analyzer_model_name_input, expert_output_history_analyzer_temperature_input,
989
  expert_prompt_analyzer_model_name_input, expert_prompt_analyzer_temperature_input,
990
  expert_prompt_suggester_model_name_input, expert_prompt_suggester_temperature_input,
991
- prompt_template_group
 
992
  ],
993
  outputs=[
994
  system_message_output,
 
423
  acceptance_criteria: str, initial_system_message: str,
424
  recursion_limit: int, max_output_age: int,
425
  llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]],
426
+ prompt_template_group: Optional[str] = None,
427
+ aggressive_exploration: bool = False) -> tuple:
428
  """
429
  Process a user message by executing the MetaPromptGraph with provided language models and input state.
430
  This function sets up the initial state of the conversation, logs the execution if verbose mode is enabled,
 
466
  prompt_template_group = 'default'
467
  prompt_templates = prompt_templates_confz2langchain(config.prompt_templates[prompt_template_group])
468
  meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompt_templates,
469
+ aggressive_exploration=aggressive_exploration,
470
  verbose=config.verbose, logger=logger)
471
  try:
472
  output_state = meta_prompt_graph(input_state, recursion_limit=recursion_limit)
 
534
  def process_message_with_single_llm(user_message: str, expected_output: str,
535
  acceptance_criteria: str, initial_system_message: str,
536
  recursion_limit: int, max_output_age: int,
537
+ model_name: str,
538
+ prompt_template_group: Optional[str] = None,
539
+ aggressive_exploration: bool = False) -> tuple:
540
  """
541
  Process a user message using a single language model.
542
 
 
567
  """
568
  llm = initialize_llm(model_name)
569
  return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
570
+ recursion_limit, max_output_age, llm, prompt_template_group, aggressive_exploration)
571
 
572
 
573
  def process_message_with_2_llms(user_message: str, expected_output: str,
574
  acceptance_criteria: str, initial_system_message: str,
575
  recursion_limit: int, max_output_age: int,
576
  optimizer_model_name: str, executor_model_name: str,
577
+ prompt_template_group: Optional[str] = None,
578
+ aggressive_exploration: bool = False) -> tuple:
579
  """
580
  Process a user message using two language models - one for optimization and another for execution.
581
 
 
617
  NODE_PROMPT_SUGGESTER: optimizer_model
618
  }
619
  return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
620
+ recursion_limit, max_output_age, llms, prompt_template_group, aggressive_exploration)
621
 
622
 
623
  def process_message_with_expert_llms(user_message: str, expected_output: str,
 
630
  output_history_analyzer_model_name: str, output_history_analyzer_temperature: float,
631
  analyzer_model_name: str, analyzer_temperature: float,
632
  suggester_model_name: str, suggester_temperature: float,
633
+ prompt_template_group: Optional[str] = None,
634
+ aggressive_exploration: bool = False) -> tuple:
635
 
636
  llms = {
637
  NODE_PROMPT_INITIAL_DEVELOPER: initialize_llm(initial_developer_model_name, {"temperature": initial_developer_temperature}),
 
643
  NODE_PROMPT_SUGGESTER: initialize_llm(suggester_model_name, {"temperature": suggester_temperature})
644
  }
645
  return process_message(user_message, expected_output, acceptance_criteria, initial_system_message,
646
+ recursion_limit, max_output_age, llms, prompt_template_group, aggressive_exploration)
647
 
648
 
649
  class FileConfig(BaseConfig):
 
731
  choices=list(config.prompt_templates.keys()),
732
  value=list(config.prompt_templates.keys())[0]
733
  )
734
+ aggressive_exploration = gr.Checkbox(
735
+ label="Aggressive Exploration",
736
+ value=config.aggressive_exploration
737
+ )
738
  with gr.Row():
739
  with gr.Tabs() as llm_tabs:
740
  with gr.Tab('Simple') as simple_llm_tab:
 
898
  inputs=[user_message_input, expected_output_input,
899
  simple_model_name_input,
900
  advanced_optimizer_model_name_input,
901
+ expert_prompt_acceptance_criteria_model_name_input,
902
+ expert_prompt_acceptance_criteria_temperature_input,
903
+ prompt_template_group],
904
  outputs=[acceptance_criteria_input, logs_chatbot]
905
  )
906
  generate_initial_system_message_button.click(
 
951
  recursion_limit_input,
952
  max_output_age,
953
  simple_model_name_input,
954
+ prompt_template_group,
955
+ aggressive_exploration
956
  ],
957
  outputs=[
958
  system_message_output,
 
974
  max_output_age,
975
  advanced_optimizer_model_name_input,
976
  advanced_executor_model_name_input,
977
+ prompt_template_group,
978
+ aggressive_exploration
979
  ],
980
  outputs=[
981
  system_message_output,
 
1002
  expert_output_history_analyzer_model_name_input, expert_output_history_analyzer_temperature_input,
1003
  expert_prompt_analyzer_model_name_input, expert_prompt_analyzer_temperature_input,
1004
  expert_prompt_suggester_model_name_input, expert_prompt_suggester_temperature_input,
1005
+ prompt_template_group,
1006
+ aggressive_exploration
1007
  ],
1008
  outputs=[
1009
  system_message_output,
meta_prompt/meta_prompt.py CHANGED
@@ -14,8 +14,13 @@ from pydantic import BaseModel
14
  from .consts import *
15
 
16
  def first_non_empty(a, b):
 
17
  return next((s for s in (a, b) if s), None)
18
 
 
 
 
 
19
  class AgentState(BaseModel):
20
  """
21
  Represents the state of an agent in a conversation.
@@ -35,16 +40,16 @@ class AgentState(BaseModel):
35
  - best_output_age (int): The age of the best output.
36
  """
37
  max_output_age: Annotated[int, lambda x, y: max(x, y)] = 0
38
- user_message: Annotated[Optional[str], first_non_empty] = None
39
- expected_output: Annotated[Optional[str], first_non_empty] = None
40
- acceptance_criteria: Annotated[Optional[str], first_non_empty] = None
41
- system_message: Annotated[Optional[str], first_non_empty] = None
42
- output: Annotated[Optional[str], first_non_empty] = None
43
- suggestions: Annotated[Optional[str], first_non_empty] = None
44
  accepted: Annotated[bool, operator.or_] = False
45
- analysis: Annotated[Optional[str], first_non_empty] = None
46
- best_output: Annotated[Optional[str], first_non_empty] = None
47
- best_system_message: Annotated[Optional[str], first_non_empty] = None
48
  best_output_age: Annotated[int, lambda x, y: max(x, y)] = 0
49
 
50
  class MetaPromptGraph:
@@ -82,6 +87,7 @@ class MetaPromptGraph:
82
  llms: Union[BaseLanguageModel,
83
  Dict[str, BaseLanguageModel]] = {},
84
  prompts: Dict[str, ChatPromptTemplate] = {},
 
85
  logger: Optional[logging.Logger] = None,
86
  verbose=False):
87
  """
@@ -118,6 +124,8 @@ class MetaPromptGraph:
118
  ChatPromptTemplate] = DEFAULT_PROMPT_TEMPLATES.copy()
119
  self.prompt_templates.update(prompts)
120
 
 
 
121
 
122
  def _create_acceptance_criteria_workflow(self) -> StateGraph:
123
  workflow = StateGraph(AgentState)
@@ -426,16 +434,22 @@ class MetaPromptGraph:
426
 
427
  analysis = response.content
428
 
429
- if state.best_output is None or (
430
- "# Output ID closer to Expected Output: B" in analysis):
 
 
431
  state.best_output = state.output
432
  state.best_system_message = state.system_message
433
  state.best_output_age = 0
434
- logger.debug(
435
- "Best output updated to the current output:\n%s", state.output)
436
  else:
437
  state.best_output_age += 1
438
- logger.debug("Best output age incremented to %s", state.best_output_age)
 
 
 
 
439
 
440
  return state
441
 
 
14
  from .consts import *
15
 
16
  def first_non_empty(a, b):
17
+ # return the first non-none value
18
  return next((s for s in (a, b) if s), None)
19
 
20
+ def last_non_empty(a, b):
21
+ # return the last non-none value
22
+ return next((s for s in (b, a) if s), None)
23
+
24
  class AgentState(BaseModel):
25
  """
26
  Represents the state of an agent in a conversation.
 
40
  - best_output_age (int): The age of the best output.
41
  """
42
  max_output_age: Annotated[int, lambda x, y: max(x, y)] = 0
43
+ user_message: Annotated[Optional[str], last_non_empty] = None
44
+ expected_output: Annotated[Optional[str], last_non_empty] = None
45
+ acceptance_criteria: Annotated[Optional[str], last_non_empty] = None
46
+ system_message: Annotated[Optional[str], last_non_empty] = None
47
+ output: Annotated[Optional[str], last_non_empty] = None
48
+ suggestions: Annotated[Optional[str], last_non_empty] = None
49
  accepted: Annotated[bool, operator.or_] = False
50
+ analysis: Annotated[Optional[str], last_non_empty] = None
51
+ best_output: Annotated[Optional[str], last_non_empty] = None
52
+ best_system_message: Annotated[Optional[str], last_non_empty] = None
53
  best_output_age: Annotated[int, lambda x, y: max(x, y)] = 0
54
 
55
  class MetaPromptGraph:
 
87
  llms: Union[BaseLanguageModel,
88
  Dict[str, BaseLanguageModel]] = {},
89
  prompts: Dict[str, ChatPromptTemplate] = {},
90
+ aggressive_exploration: bool = False,
91
  logger: Optional[logging.Logger] = None,
92
  verbose=False):
93
  """
 
124
  ChatPromptTemplate] = DEFAULT_PROMPT_TEMPLATES.copy()
125
  self.prompt_templates.update(prompts)
126
 
127
+ self.aggressive_exploration = aggressive_exploration
128
+
129
 
130
  def _create_acceptance_criteria_workflow(self) -> StateGraph:
131
  workflow = StateGraph(AgentState)
 
434
 
435
  analysis = response.content
436
 
437
+ if (state.best_output is None or
438
+ "# Output ID closer to Expected Output: B" in analysis or
439
+ (self.aggressive_exploration and
440
+ "# Output ID closer to Expected Output: A" not in analysis)):
441
  state.best_output = state.output
442
  state.best_system_message = state.system_message
443
  state.best_output_age = 0
444
+ logger.debug("Best output updated to the current output:\n%s",
445
+ state.output)
446
  else:
447
  state.best_output_age += 1
448
+ # rollback output and system message
449
+ state.output = state.best_output
450
+ state.system_message = state.best_system_message
451
+ logger.debug("Best output age incremented to %s",
452
+ state.best_output_age)
453
 
454
  return state
455