yaleh commited on
Commit
e8de3ce
·
1 Parent(s): 458530f

Updated unit test cases and docs.

Browse files
Files changed (1) hide show
  1. tests/meta_prompt_graph_test.py +63 -2
tests/meta_prompt_graph_test.py CHANGED
@@ -14,7 +14,14 @@ class TestMetaPromptGraph(unittest.TestCase):
14
  # logging.basicConfig(level=logging.DEBUG)
15
  pass
16
 
 
17
  def test_prompt_node(self):
 
 
 
 
 
 
18
  llms = {
19
  NODE_PROMPT_INITIAL_DEVELOPER: MagicMock(
20
  invoke=MagicMock(return_value=MagicMock(content="Mocked response content"))
@@ -36,14 +43,21 @@ class TestMetaPromptGraph(unittest.TestCase):
36
  assert updated_state.output == "Mocked response content", \
37
  "The output attribute should be updated with the mocked response content"
38
 
 
39
  def test_output_history_analyzer(self):
 
 
 
 
 
 
40
  # Setup
41
  llms = {
42
  "output_history_analyzer": MagicMock(invoke=lambda prompt: MagicMock(content="""# Analysis
43
 
44
  This analysis compares two outputs to the expected output based on specific criteria.
45
 
46
- # Preferred Output ID: B"""))
47
  }
48
  prompts = {}
49
  meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompts)
@@ -68,7 +82,14 @@ class TestMetaPromptGraph(unittest.TestCase):
68
  assert updated_state.best_output_age == 0, \
69
  "Best output age should be reset to 0."
70
 
 
71
  def test_prompt_analyzer_accept(self):
 
 
 
 
 
 
72
  llms = {
73
  NODE_PROMPT_ANALYZER: MagicMock(
74
  invoke=lambda prompt: MagicMock(content="Accept: Yes"))
@@ -78,12 +99,26 @@ class TestMetaPromptGraph(unittest.TestCase):
78
  updated_state = meta_prompt_graph._prompt_analyzer(state)
79
  assert updated_state.accepted == True
80
 
 
81
  def test_get_node_names(self):
 
 
 
 
 
82
  graph = MetaPromptGraph()
83
  node_names = graph.get_node_names()
84
  self.assertEqual(node_names, META_PROMPT_NODES)
85
 
 
86
  def test_workflow_execution(self):
 
 
 
 
 
 
 
87
  # MODEL_NAME = "anthropic/claude-3.5-sonnet:beta"
88
  # MODEL_NAME = "meta-llama/llama-3-70b-instruct"
89
  MODEL_NAME = "deepseek/deepseek-chat"
@@ -120,7 +155,15 @@ class TestMetaPromptGraph(unittest.TestCase):
120
  "The result should have the attribute 'content'"
121
  print(result.content)
122
 
 
123
  def test_workflow_execution_with_llms(self):
 
 
 
 
 
 
 
124
  optimizer_llm = ChatOpenAI(model_name="deepseek/deepseek-chat", temperature=0.5)
125
  executor_llm = ChatOpenAI(model_name="meta-llama/llama-3-8b-instruct", temperature=0.01)
126
 
@@ -160,8 +203,16 @@ class TestMetaPromptGraph(unittest.TestCase):
160
  assert hasattr(result, 'content'), \
161
  "The result should have the attribute 'content'"
162
  print(result.content)
 
163
 
164
  def test_simple_workflow_execution(self):
 
 
 
 
 
 
 
165
  # Create a mock LLM that returns predefined responses based on the input messages
166
  llm = Mock(spec=BaseLanguageModel)
167
  responses = [
@@ -184,8 +235,17 @@ class TestMetaPromptGraph(unittest.TestCase):
184
  self.assertIsNotNone(output_state['best_output'])
185
 
186
  pprint.pp(output_state["best_output"])
 
187
 
188
  def test_iterated_workflow_execution(self):
 
 
 
 
 
 
 
 
189
  # Create a mock LLM that returns predefined responses based on the input messages
190
  llm = Mock(spec=BaseLanguageModel)
191
  responses = [
@@ -195,7 +255,7 @@ class TestMetaPromptGraph(unittest.TestCase):
195
  Mock(type="content", content="Try using the `reverse()` method instead."), # NODE_PROMPT_SUGGESTER
196
  Mock(type="content", content="Explain how to reverse a list in Python. Output in a Markdown List."), # NODE_PROMPT_DEVELOPER
197
  Mock(type="content", content="Here's one way: `my_list.reverse()`"), # NODE_PROMPT_EXECUTOR
198
- Mock(type="content", content="# Preferred Output ID: B"), # NODE_OUTPUT_HISTORY_ANALYZER
199
  Mock(type="content", content="Accept: Yes"), # NODE_PPROMPT_ANALYZER
200
  ]
201
  llm.invoke = lambda _: responses.pop(0)
@@ -214,5 +274,6 @@ class TestMetaPromptGraph(unittest.TestCase):
214
 
215
  pprint.pp(output_state["best_output"])
216
 
 
217
  if __name__ == '__main__':
218
  unittest.main()
 
14
  # logging.basicConfig(level=logging.DEBUG)
15
  pass
16
 
17
+
18
  def test_prompt_node(self):
19
+ """
20
+ Test the _prompt_node method of MetaPromptGraph.
21
+
22
+ This test case sets up a mock language model that returns a response content and verifies that the
23
+ updated state has the output attribute updated with the mocked response content.
24
+ """
25
  llms = {
26
  NODE_PROMPT_INITIAL_DEVELOPER: MagicMock(
27
  invoke=MagicMock(return_value=MagicMock(content="Mocked response content"))
 
43
  assert updated_state.output == "Mocked response content", \
44
  "The output attribute should be updated with the mocked response content"
45
 
46
+
47
  def test_output_history_analyzer(self):
48
+ """
49
+ Test the _output_history_analyzer method of MetaPromptGraph.
50
+
51
+ This test case sets up a mock language model that returns an analysis response and verifies that the
52
+ updated state has the best output, best system message, and best output age updated correctly.
53
+ """
54
  # Setup
55
  llms = {
56
  "output_history_analyzer": MagicMock(invoke=lambda prompt: MagicMock(content="""# Analysis
57
 
58
  This analysis compares two outputs to the expected output based on specific criteria.
59
 
60
+ # Output ID closer to Expected Output: B"""))
61
  }
62
  prompts = {}
63
  meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompts)
 
82
  assert updated_state.best_output_age == 0, \
83
  "Best output age should be reset to 0."
84
 
85
+
86
  def test_prompt_analyzer_accept(self):
87
+ """
88
+ Test the _prompt_analyzer method of MetaPromptGraph when the prompt analyzer accepts the output.
89
+
90
+ This test case sets up a mock language model that returns an acceptance response and verifies that the
91
+ updated state has the accepted attribute set to True.
92
+ """
93
  llms = {
94
  NODE_PROMPT_ANALYZER: MagicMock(
95
  invoke=lambda prompt: MagicMock(content="Accept: Yes"))
 
99
  updated_state = meta_prompt_graph._prompt_analyzer(state)
100
  assert updated_state.accepted == True
101
 
102
+
103
  def test_get_node_names(self):
104
+ """
105
+ Test the get_node_names method of MetaPromptGraph.
106
+
107
+ This test case verifies that the get_node_names method returns the correct list of node names.
108
+ """
109
  graph = MetaPromptGraph()
110
  node_names = graph.get_node_names()
111
  self.assertEqual(node_names, META_PROMPT_NODES)
112
 
113
+
114
  def test_workflow_execution(self):
115
+ """
116
+ Test the workflow execution of the MetaPromptGraph.
117
+
118
+ This test case sets up a MetaPromptGraph with a single language model and
119
+ executes it with a given input state. It then verifies that the output
120
+ state contains the expected keys and values.
121
+ """
122
  # MODEL_NAME = "anthropic/claude-3.5-sonnet:beta"
123
  # MODEL_NAME = "meta-llama/llama-3-70b-instruct"
124
  MODEL_NAME = "deepseek/deepseek-chat"
 
155
  "The result should have the attribute 'content'"
156
  print(result.content)
157
 
158
+
159
  def test_workflow_execution_with_llms(self):
160
+ """
161
+ Test the workflow execution of the MetaPromptGraph with multiple LLMs.
162
+
163
+ This test case sets up a MetaPromptGraph with multiple language models and
164
+ executes it with a given input state. It then verifies that the output
165
+ state contains the expected keys and values.
166
+ """
167
  optimizer_llm = ChatOpenAI(model_name="deepseek/deepseek-chat", temperature=0.5)
168
  executor_llm = ChatOpenAI(model_name="meta-llama/llama-3-8b-instruct", temperature=0.01)
169
 
 
203
  assert hasattr(result, 'content'), \
204
  "The result should have the attribute 'content'"
205
  print(result.content)
206
+
207
 
208
  def test_simple_workflow_execution(self):
209
+ """
210
+ Test the simple workflow execution of the MetaPromptGraph.
211
+
212
+ This test case sets up a MetaPromptGraph with a mock LLM and executes it
213
+ with a given input state. It then verifies that the output state contains
214
+ the expected keys and values.
215
+ """
216
  # Create a mock LLM that returns predefined responses based on the input messages
217
  llm = Mock(spec=BaseLanguageModel)
218
  responses = [
 
235
  self.assertIsNotNone(output_state['best_output'])
236
 
237
  pprint.pp(output_state["best_output"])
238
+
239
 
240
  def test_iterated_workflow_execution(self):
241
+ """
242
+ Test the iterated workflow execution of the MetaPromptGraph.
243
+
244
+ This test case sets up a MetaPromptGraph with a mock LLM and executes it
245
+ with a given input state. It then verifies that the output state contains
246
+ the expected keys and values. The test case simulates an iterated workflow
247
+ where the LLM provides multiple responses based on the input messages.
248
+ """
249
  # Create a mock LLM that returns predefined responses based on the input messages
250
  llm = Mock(spec=BaseLanguageModel)
251
  responses = [
 
255
  Mock(type="content", content="Try using the `reverse()` method instead."), # NODE_PROMPT_SUGGESTER
256
  Mock(type="content", content="Explain how to reverse a list in Python. Output in a Markdown List."), # NODE_PROMPT_DEVELOPER
257
  Mock(type="content", content="Here's one way: `my_list.reverse()`"), # NODE_PROMPT_EXECUTOR
258
+ Mock(type="content", content="# Output ID closer to Expected Output: B"), # NODE_OUTPUT_HISTORY_ANALYZER
259
  Mock(type="content", content="Accept: Yes"), # NODE_PPROMPT_ANALYZER
260
  ]
261
  llm.invoke = lambda _: responses.pop(0)
 
274
 
275
  pprint.pp(output_state["best_output"])
276
 
277
+
278
  if __name__ == '__main__':
279
  unittest.main()