yaleh commited on
Commit
2d1b8d7
·
1 Parent(s): fcaac18

Update meta_prompt_graph.py to handle llms as a single BaseLanguageModel or a dictionary of BaseLanguageModels

Browse files
Files changed (2) hide show
  1. meta_prompt_graph.py +8 -5
  2. meta_prompt_graph_test.py +57 -7
meta_prompt_graph.py CHANGED
@@ -265,7 +265,7 @@ Analysis:
265
  ]
266
 
267
  def __init__(self,
268
- llms: Dict[str, BaseLanguageModel] = {},
269
  prompts: Dict[str, ChatPromptTemplate] = {},
270
  verbose = False):
271
  self.logger = logging.getLogger(__name__)
@@ -274,7 +274,10 @@ Analysis:
274
  else:
275
  self.logger.setLevel(logging.INFO)
276
 
277
- self.llms: Dict[str, BaseLanguageModel] = llms
 
 
 
278
  self.prompt_templates: Dict[str, ChatPromptTemplate] = self.DEFAULT_PROMPT_TEMPLATES.copy()
279
  self.prompt_templates.update(prompts)
280
 
@@ -363,7 +366,7 @@ Analysis:
363
 
364
  self.logger.debug("Invoking %s with prompt: %s", node, pprint.pformat(prompt))
365
  response = self.llms[node].invoke(self.prompt_templates[node].format_messages(**state.model_dump()))
366
- self.logger.debug("Response: %s", pprint.pformat(response.content))
367
 
368
  setattr(state, target_attribute, response.content)
369
  return state
@@ -384,7 +387,7 @@ Analysis:
384
  self.NODE_OUTPUT_HISTORY_ANALYZER,
385
  pprint.pformat(prompt))
386
  response = self.llms[self.NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
387
- self.logger.debug("Response: %s", pprint.pformat(response.content))
388
 
389
  analysis = response.content
390
 
@@ -408,7 +411,7 @@ Analysis:
408
  self.NODE_PROMPT_ANALYZER,
409
  pprint.pformat(prompt))
410
  response = self.llms[self.NODE_PROMPT_ANALYZER].invoke(prompt)
411
- self.logger.debug("Response: %s", pprint.pformat(response.content))
412
 
413
  state.analysis = response.content
414
  state.accepted = "Accept: Yes" in response.content
 
265
  ]
266
 
267
  def __init__(self,
268
+ llms: Union[BaseLanguageModel, Dict[str, BaseLanguageModel]] = {},
269
  prompts: Dict[str, ChatPromptTemplate] = {},
270
  verbose = False):
271
  self.logger = logging.getLogger(__name__)
 
274
  else:
275
  self.logger.setLevel(logging.INFO)
276
 
277
+ if isinstance(llms, BaseLanguageModel):
278
+ self.llms: Dict[str, BaseLanguageModel] = {node: llms for node in self.get_node_names()}
279
+ else:
280
+ self.llms: Dict[str, BaseLanguageModel] = llms
281
  self.prompt_templates: Dict[str, ChatPromptTemplate] = self.DEFAULT_PROMPT_TEMPLATES.copy()
282
  self.prompt_templates.update(prompts)
283
 
 
366
 
367
  self.logger.debug("Invoking %s with prompt: %s", node, pprint.pformat(prompt))
368
  response = self.llms[node].invoke(self.prompt_templates[node].format_messages(**state.model_dump()))
369
+ self.logger.debug("Response: %s", response.content)
370
 
371
  setattr(state, target_attribute, response.content)
372
  return state
 
387
  self.NODE_OUTPUT_HISTORY_ANALYZER,
388
  pprint.pformat(prompt))
389
  response = self.llms[self.NODE_OUTPUT_HISTORY_ANALYZER].invoke(prompt)
390
+ self.logger.debug("Response: %s", response.content)
391
 
392
  analysis = response.content
393
 
 
411
  self.NODE_PROMPT_ANALYZER,
412
  pprint.pformat(prompt))
413
  response = self.llms[self.NODE_PROMPT_ANALYZER].invoke(prompt)
414
+ self.logger.debug("Response: %s", response.content)
415
 
416
  state.analysis = response.content
417
  state.accepted = "Accept: Yes" in response.content
meta_prompt_graph_test.py CHANGED
@@ -11,8 +11,8 @@ from langchain_openai import ChatOpenAI
11
 
12
  class TestMetaPromptGraph(unittest.TestCase):
13
  def setUp(self):
14
- # Mocking the BaseLanguageModel and ChatPromptTemplate for testing
15
- logging.basicConfig(level=logging.DEBUG)
16
 
17
  def test_prompt_node(self):
18
  llms = {
@@ -79,17 +79,56 @@ class TestMetaPromptGraph(unittest.TestCase):
79
  assert updated_state.accepted == True
80
 
81
  def test_workflow_execution(self):
82
- # MODEL_NAME = "google/gemma-2-9b-it"
83
  MODEL_NAME = "anthropic/claude-3.5-sonnet:haiku"
 
 
 
 
 
84
  llm = ChatOpenAI(model_name=MODEL_NAME)
85
 
86
- node_names = MetaPromptGraph.get_node_names()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  llms = {
 
 
 
 
 
 
88
  }
89
- for node_name in node_names:
90
- llms[node_name] = llm
91
 
92
- meta_prompt_graph = MetaPromptGraph(llms=llms, verbose=True)
93
  input_state = AgentState(
94
  user_message="How do I reverse a list in Python?",
95
  expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
@@ -106,5 +145,16 @@ class TestMetaPromptGraph(unittest.TestCase):
106
  if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
107
  print(output_state['best_system_message'])
108
 
 
 
 
 
 
 
 
 
 
 
 
109
  if __name__ == '__main__':
110
  unittest.main()
 
11
 
12
  class TestMetaPromptGraph(unittest.TestCase):
13
  def setUp(self):
14
+ # logging.basicConfig(level=logging.DEBUG)
15
+ pass
16
 
17
  def test_prompt_node(self):
18
  llms = {
 
79
  assert updated_state.accepted == True
80
 
81
  def test_workflow_execution(self):
 
82
  MODEL_NAME = "anthropic/claude-3.5-sonnet:haiku"
83
+ # MODEL_NAME = "meta-llama/llama-3-70b-instruct"
84
+ # MODEL_NAME = "deepseek/deepseek-chat"
85
+ # MODEL_NAME = "google/gemma-2-9b-it"
86
+ # MODEL_NAME = "recursal/eagle-7b"
87
+ # MODEL_NAME = "meta-llama/llama-3-8b-instruct"
88
  llm = ChatOpenAI(model_name=MODEL_NAME)
89
 
90
+ meta_prompt_graph = MetaPromptGraph(llms=llm)
91
+ input_state = AgentState(
92
+ user_message="How do I reverse a list in Python?",
93
+ expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
94
+ acceptance_criteria="Similar in meaning, text length and style."
95
+ )
96
+ output_state = meta_prompt_graph(input_state, recursion_limit=25)
97
+
98
+ pprint.pp(output_state)
99
+ # if output_state has key 'best_system_message', print it
100
+ assert 'best_system_message' in output_state, \
101
+ "The output state should contain the key 'best_system_message'"
102
+ assert output_state['best_system_message'] is not None, \
103
+ "The best system message should not be None"
104
+ if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
105
+ print(output_state['best_system_message'])
106
+
107
+ # try another similar user message with the generated system message
108
+ user_message = "How can I create a list of numbers in Python?"
109
+ messages = [("system", output_state['best_system_message']),
110
+ ("human", user_message)]
111
+ result = llm.invoke(messages)
112
+
113
+ # assert attr 'content' in result
114
+ assert hasattr(result, 'content'), \
115
+ "The result should have the attribute 'content'"
116
+ print(result.content)
117
+
118
+ def test_workflow_execution_with_llms(self):
119
+ optimizer_llm = ChatOpenAI(model_name="anthropic/claude-3.5-sonnet:haiku", temperature=0.5)
120
+ executor_llm = ChatOpenAI(model_name="meta-llama/llama-3-8b-instruct", temperature=0.01)
121
+
122
  llms = {
123
+ MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER: optimizer_llm,
124
+ MetaPromptGraph.NODE_PROMPT_DEVELOPER: optimizer_llm,
125
+ MetaPromptGraph.NODE_PROMPT_EXECUTOR: executor_llm,
126
+ MetaPromptGraph.NODE_OUTPUT_HISTORY_ANALYZER: optimizer_llm,
127
+ MetaPromptGraph.NODE_PROMPT_ANALYZER: optimizer_llm,
128
+ MetaPromptGraph.NODE_PROMPT_SUGGESTER: optimizer_llm
129
  }
 
 
130
 
131
+ meta_prompt_graph = MetaPromptGraph(llms=llms)
132
  input_state = AgentState(
133
  user_message="How do I reverse a list in Python?",
134
  expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
 
145
  if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
146
  print(output_state['best_system_message'])
147
 
148
+ # try another similar user message with the generated system message
149
+ user_message = "How can I create a list of numbers in Python?"
150
+ messages = [("system", output_state['best_system_message']),
151
+ ("human", user_message)]
152
+ result = executor_llm.invoke(messages)
153
+
154
+ # assert attr 'content' in result
155
+ assert hasattr(result, 'content'), \
156
+ "The result should have the attribute 'content'"
157
+ print(result.content)
158
+
159
  if __name__ == '__main__':
160
  unittest.main()