Spaces:
Sleeping
Sleeping
import unittest | |
import pprint | |
import logging | |
import functools | |
from unittest.mock import MagicMock, Mock | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_openai import ChatOpenAI | |
# Assuming the necessary imports are made for the classes and functions used in meta_prompt_graph.py | |
from meta_prompt import * | |
from meta_prompt.consts import NODE_ACCEPTANCE_CRITERIA_DEVELOPER | |
from langgraph.graph import StateGraph, END | |
class TestMetaPromptGraph(unittest.TestCase): | |
def setUp(self): | |
# logging.basicConfig(level=logging.DEBUG) | |
pass | |
def test_prompt_node(self): | |
""" | |
Test the _prompt_node method of MetaPromptGraph. | |
This test case sets up a mock language model that returns a response content and verifies that the | |
updated state has the output attribute updated with the mocked response content. | |
""" | |
llms = { | |
NODE_PROMPT_INITIAL_DEVELOPER: MagicMock( | |
invoke=MagicMock(return_value=MagicMock(content="Mocked response content")) | |
) | |
} | |
# Create an instance of MetaPromptGraph with the mocked language model and template | |
graph = MetaPromptGraph(llms=llms) | |
# Create a mock AgentState | |
state = AgentState(user_message="Test message", expected_output="Expected output") | |
# Invoke the _prompt_node method with the mock node, target attribute, and state | |
updated_state = graph._prompt_node( | |
NODE_PROMPT_INITIAL_DEVELOPER, "output", state | |
) | |
# Assertions | |
assert updated_state.output == "Mocked response content", \ | |
"The output attribute should be updated with the mocked response content" | |
def test_output_history_analyzer(self): | |
""" | |
Test the _output_history_analyzer method of MetaPromptGraph. | |
This test case sets up a mock language model that returns an analysis response and verifies that the | |
updated state has the best output, best system message, and best output age updated correctly. | |
""" | |
# Setup | |
llms = { | |
"output_history_analyzer": MagicMock(invoke=lambda prompt: MagicMock(content="""# Analysis | |
This analysis compares two outputs to the expected output based on specific criteria. | |
# Output ID closer to Expected Output: B""")) | |
} | |
prompts = {} | |
meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompts) | |
state = AgentState( | |
user_message="How do I reverse a list in Python?", | |
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.", | |
output="To reverse a list in Python, you can use the `[::-1]` slicing.", | |
system_message="To reverse a list, use slicing or the reverse method.", | |
best_output="To reverse a list in Python, use the `reverse()` method.", | |
best_system_message="To reverse a list, use the `reverse()` method.", | |
acceptance_criteria="The output should correctly describe how to reverse a list in Python." | |
) | |
# Invoke the output history analyzer node | |
updated_state = meta_prompt_graph._output_history_analyzer(state) | |
# Assertions | |
assert updated_state.best_output == state.output, \ | |
"Best output should be updated to the current output." | |
assert updated_state.best_system_message == state.system_message, \ | |
"Best system message should be updated to the current system message." | |
assert updated_state.best_output_age == 0, \ | |
"Best output age should be reset to 0." | |
def test_prompt_analyzer_accept(self): | |
""" | |
Test the _prompt_analyzer method of MetaPromptGraph when the prompt analyzer accepts the output. | |
This test case sets up a mock language model that returns an acceptance response and verifies that the | |
updated state has the accepted attribute set to True. | |
""" | |
llms = { | |
NODE_PROMPT_ANALYZER: MagicMock( | |
invoke=lambda prompt: MagicMock(content="Accept: Yes")) | |
} | |
meta_prompt_graph = MetaPromptGraph(llms) | |
state = AgentState(output="Test output", expected_output="Expected output") | |
updated_state = meta_prompt_graph._prompt_analyzer(state) | |
assert updated_state.accepted == True | |
def test_get_node_names(self): | |
""" | |
Test the get_node_names method of MetaPromptGraph. | |
This test case verifies that the get_node_names method returns the correct list of node names. | |
""" | |
graph = MetaPromptGraph() | |
node_names = graph.get_node_names() | |
self.assertEqual(node_names, META_PROMPT_NODES) | |
def test_workflow_execution(self): | |
""" | |
Test the workflow execution of the MetaPromptGraph. | |
This test case sets up a MetaPromptGraph with a single language model and | |
executes it with a given input state. It then verifies that the output | |
state contains the expected keys and values. | |
""" | |
# MODEL_NAME = "anthropic/claude-3.5-sonnet:beta" | |
# MODEL_NAME = "meta-llama/llama-3-70b-instruct" | |
MODEL_NAME = "deepseek/deepseek-chat" | |
# MODEL_NAME = "google/gemma-2-9b-it" | |
# MODEL_NAME = "recursal/eagle-7b" | |
# MODEL_NAME = "meta-llama/llama-3-8b-instruct" | |
llm = ChatOpenAI(model_name=MODEL_NAME) | |
meta_prompt_graph = MetaPromptGraph(llms=llm) | |
input_state = AgentState( | |
user_message="How do I reverse a list in Python?", | |
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.", | |
acceptance_criteria="Similar in meaning, text length and style." | |
) | |
output_state = meta_prompt_graph(input_state, recursion_limit=25) | |
pprint.pp(output_state) | |
# if output_state has key 'best_system_message', print it | |
assert 'best_system_message' in output_state, \ | |
"The output state should contain the key 'best_system_message'" | |
assert output_state['best_system_message'] is not None, \ | |
"The best system message should not be None" | |
if 'best_system_message' in output_state and output_state['best_system_message'] is not None: | |
print(output_state['best_system_message']) | |
# try another similar user message with the generated system message | |
user_message = "How can I create a list of numbers in Python?" | |
messages = [("system", output_state['best_system_message']), | |
("human", user_message)] | |
result = llm.invoke(messages) | |
# assert attr 'content' in result | |
assert hasattr(result, 'content'), \ | |
"The result should have the attribute 'content'" | |
print(result.content) | |
def test_workflow_execution_with_llms(self): | |
""" | |
Test the workflow execution of the MetaPromptGraph with multiple LLMs. | |
This test case sets up a MetaPromptGraph with multiple language models and | |
executes it with a given input state. It then verifies that the output | |
state contains the expected keys and values. | |
""" | |
optimizer_llm = ChatOpenAI(model_name="deepseek/deepseek-chat", temperature=0.5) | |
executor_llm = ChatOpenAI(model_name="meta-llama/llama-3-8b-instruct", temperature=0.01) | |
llms = { | |
NODE_PROMPT_INITIAL_DEVELOPER: optimizer_llm, | |
NODE_PROMPT_DEVELOPER: optimizer_llm, | |
NODE_PROMPT_EXECUTOR: executor_llm, | |
NODE_OUTPUT_HISTORY_ANALYZER: optimizer_llm, | |
NODE_PROMPT_ANALYZER: optimizer_llm, | |
NODE_PROMPT_SUGGESTER: optimizer_llm | |
} | |
meta_prompt_graph = MetaPromptGraph(llms=llms) | |
input_state = AgentState( | |
user_message="How do I reverse a list in Python?", | |
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.", | |
acceptance_criteria="Similar in meaning, text length and style." | |
) | |
output_state = meta_prompt_graph(input_state, recursion_limit=25) | |
pprint.pp(output_state) | |
# if output_state has key 'best_system_message', print it | |
assert 'best_system_message' in output_state, \ | |
"The output state should contain the key 'best_system_message'" | |
assert output_state['best_system_message'] is not None, \ | |
"The best system message should not be None" | |
if 'best_system_message' in output_state and output_state['best_system_message'] is not None: | |
print(output_state['best_system_message']) | |
# try another similar user message with the generated system message | |
user_message = "How can I create a list of numbers in Python?" | |
messages = [("system", output_state['best_system_message']), | |
("human", user_message)] | |
result = executor_llm.invoke(messages) | |
# assert attr 'content' in result | |
assert hasattr(result, 'content'), \ | |
"The result should have the attribute 'content'" | |
print(result.content) | |
def test_simple_workflow_execution(self): | |
""" | |
Test the simple workflow execution of the MetaPromptGraph. | |
This test case sets up a MetaPromptGraph with a mock LLM and executes it | |
with a given input state. It then verifies that the output state contains | |
the expected keys and values. | |
""" | |
# Create a mock LLM that returns predefined responses based on the input messages | |
llm = Mock(spec=BaseLanguageModel) | |
responses = [ | |
Mock(type="content", content="Explain how to reverse a list in Python."), # NODE_PROMPT_INITIAL_DEVELOPER | |
Mock(type="content", content="Here's one way: `my_list[::-1]`"), # NODE_PROMPT_EXECUTOR | |
Mock(type="content", content="Accept: Yes"), # NODE_PPROMPT_ANALYZER | |
] | |
llm.invoke = functools.partial(next, iter(responses)) | |
meta_prompt_graph = MetaPromptGraph(llms=llm) | |
input_state = AgentState( | |
user_message="How do I reverse a list in Python?", | |
expected_output="The output should use the `reverse()` method.", | |
acceptance_criteria="The output should be correct and efficient." | |
) | |
output_state = meta_prompt_graph(input_state) | |
self.assertIsNotNone(output_state['best_system_message']) | |
self.assertIsNotNone(output_state['best_output']) | |
pprint.pp(output_state["best_output"]) | |
def test_iterated_workflow_execution(self): | |
""" | |
Test the iterated workflow execution of the MetaPromptGraph. | |
This test case sets up a MetaPromptGraph with a mock LLM and executes it | |
with a given input state. It then verifies that the output state contains | |
the expected keys and values. The test case simulates an iterated workflow | |
where the LLM provides multiple responses based on the input messages. | |
""" | |
# Create a mock LLM that returns predefined responses based on the input messages | |
llm = Mock(spec=BaseLanguageModel) | |
responses = [ | |
Mock(type="content", content="Explain how to reverse a list in Python."), # NODE_PROMPT_INITIAL_DEVELOPER | |
Mock(type="content", content="Here's one way: `my_list[::-1]`"), # NODE_PROMPT_EXECUTOR | |
Mock(type="content", content="Accept: No"), # NODE_PPROMPT_ANALYZER | |
Mock(type="content", content="Try using the `reverse()` method instead."), # NODE_PROMPT_SUGGESTER | |
Mock(type="content", content="Explain how to reverse a list in Python. Output in a Markdown List."), # NODE_PROMPT_DEVELOPER | |
Mock(type="content", content="Here's one way: `my_list.reverse()`"), # NODE_PROMPT_EXECUTOR | |
Mock(type="content", content="# Output ID closer to Expected Output: B"), # NODE_OUTPUT_HISTORY_ANALYZER | |
Mock(type="content", content="Accept: Yes"), # NODE_PPROMPT_ANALYZER | |
] | |
llm.invoke = lambda _: responses.pop(0) | |
meta_prompt_graph = MetaPromptGraph(llms=llm) | |
input_state = AgentState( | |
user_message="How do I reverse a list in Python?", | |
expected_output="The output should use the `reverse()` method.", | |
acceptance_criteria="The output should be correct and efficient." | |
) | |
output_state = meta_prompt_graph(input_state) | |
self.assertIsNotNone(output_state['best_system_message']) | |
self.assertIsNotNone(output_state['best_output']) | |
pprint.pp(output_state["best_output"]) | |
def test_create_acceptance_criteria_workflow(self): | |
""" | |
Test the _create_acceptance_criteria_workflow method of MetaPromptGraph. | |
This test case verifies that the workflow created by the _create_acceptance_criteria_workflow method | |
contains the correct node and edge. | |
""" | |
llms = { | |
NODE_ACCEPTANCE_CRITERIA_DEVELOPER: ChatOpenAI(model_name="deepseek/deepseek-chat") | |
} | |
meta_prompt_graph = MetaPromptGraph(llms=llms) | |
workflow = meta_prompt_graph._create_acceptance_criteria_workflow() | |
# Check if the workflow contains the correct node | |
self.assertIn(NODE_ACCEPTANCE_CRITERIA_DEVELOPER, workflow.nodes) | |
# Check if the workflow contains the correct edge | |
self.assertIn((NODE_ACCEPTANCE_CRITERIA_DEVELOPER, END), workflow.edges) | |
# compile the workflow | |
graph = workflow.compile() | |
print(graph) | |
# invoke the workflow | |
state = AgentState( | |
user_message="How do I reverse a list in Python?", | |
expected_output="The output should use the `reverse()` method.", | |
# system_message="Create acceptance criteria for the task of reversing a list in Python." | |
) | |
output_state = graph.invoke(state) | |
# check if the output state contains the acceptance criteria | |
self.assertIsNotNone(output_state['acceptance_criteria']) | |
# check if the acceptance criteria includes string '`reverse()`' | |
self.assertIn('`reverse()`', output_state['acceptance_criteria']) | |
pprint.pp(output_state["acceptance_criteria"]) | |
def test_run_acceptance_criteria_graph(self): | |
""" | |
Test the run_acceptance_criteria_graph method of MetaPromptGraph. | |
This test case verifies that the run_acceptance_criteria_graph method returns a state with acceptance criteria. | |
""" | |
llms = { | |
NODE_ACCEPTANCE_CRITERIA_DEVELOPER: MagicMock( | |
invoke=lambda prompt: MagicMock(content="Acceptance criteria: ...")) | |
} | |
meta_prompt_graph = MetaPromptGraph(llms=llms) | |
state = AgentState( | |
user_message="How do I reverse a list in Python?", | |
expected_output="The output should use the `reverse()` method.", | |
) | |
output_state = meta_prompt_graph.run_acceptance_criteria_graph(state) | |
# Check if the output state contains the acceptance criteria | |
self.assertIsNotNone(output_state['acceptance_criteria']) | |
# Check if the acceptance criteria includes the expected content | |
self.assertIn("Acceptance criteria: ...", output_state['acceptance_criteria']) | |
def test_run_prompt_initial_developer_graph(self): | |
""" | |
Test the run_prompt_initial_developer_graph method of MetaPromptGraph. | |
This test case verifies that the run_prompt_initial_developer_graph method returns a state with an initial developer prompt. | |
""" | |
llms = { | |
NODE_PROMPT_INITIAL_DEVELOPER: MagicMock( | |
invoke=lambda prompt: MagicMock(content="Initial developer prompt: ...")) | |
} | |
meta_prompt_graph = MetaPromptGraph(llms=llms) | |
state = AgentState(user_message="How do I reverse a list in Python?") | |
output_state = meta_prompt_graph.run_prompt_initial_developer_graph(state) | |
# Check if the output state contains the initial developer prompt | |
self.assertIsNotNone(output_state['system_message']) | |
# Check if the initial developer prompt includes the expected content | |
self.assertIn("Initial developer prompt: ...", output_state['system_message']) | |
if __name__ == '__main__': | |
unittest.main() |