meta-prompt / tests /meta_prompt_graph_test.py
yaleh's picture
Update meta_prompt_graph.py to handle llms as a single BaseLanguageModel or a dictionary of BaseLanguageModels
68c6b73
raw
history blame
7.32 kB
import unittest
import pprint
import logging
from unittest.mock import MagicMock
from unittest.mock import patch
# Assuming the necessary imports are made for the classes and functions used in meta_prompt_graph.py
from meta_prompt import AgentState, MetaPromptGraph
from langchain_openai import ChatOpenAI
class TestMetaPromptGraph(unittest.TestCase):
def setUp(self):
# logging.basicConfig(level=logging.DEBUG)
pass
def test_prompt_node(self):
llms = {
MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER: MagicMock(
invoke=MagicMock(return_value=MagicMock(content="Mocked response content"))
)
}
# Create an instance of MetaPromptGraph with the mocked language model and template
graph = MetaPromptGraph(llms=llms)
# Create a mock AgentState
state = AgentState(user_message="Test message", expected_output="Expected output")
# Invoke the _prompt_node method with the mock node, target attribute, and state
updated_state = graph._prompt_node(
MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER, "output", state
)
# Assertions
assert updated_state.output == "Mocked response content", \
"The output attribute should be updated with the mocked response content"
def test_output_history_analyzer(self):
# Setup
llms = {
"output_history_analyzer": MagicMock(invoke=lambda prompt: MagicMock(content="""# Analysis
This analysis compares two outputs to the expected output based on specific criteria.
# Preferred Output ID: B"""))
}
prompts = {}
meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompts)
state = AgentState(
user_message="How do I reverse a list in Python?",
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
output="To reverse a list in Python, you can use the `[::-1]` slicing.",
system_message="To reverse a list, use slicing or the reverse method.",
best_output="To reverse a list in Python, use the `reverse()` method.",
best_system_message="To reverse a list, use the `reverse()` method.",
acceptance_criteria="The output should correctly describe how to reverse a list in Python."
)
# Invoke the output history analyzer node
updated_state = meta_prompt_graph._output_history_analyzer(state)
# Assertions
assert updated_state.best_output == state.output, \
"Best output should be updated to the current output."
assert updated_state.best_system_message == state.system_message, \
"Best system message should be updated to the current system message."
assert updated_state.best_output_age == 0, \
"Best output age should be reset to 0."
def test_prompt_analyzer_accept(self):
llms = {
MetaPromptGraph.NODE_PROMPT_ANALYZER: MagicMock(
invoke=lambda prompt: MagicMock(content="Accept: Yes"))
}
meta_prompt_graph = MetaPromptGraph(llms)
state = AgentState(output="Test output", expected_output="Expected output")
updated_state = meta_prompt_graph._prompt_analyzer(state)
assert updated_state.accepted == True
def test_workflow_execution(self):
MODEL_NAME = "anthropic/claude-3.5-sonnet:haiku"
# MODEL_NAME = "meta-llama/llama-3-70b-instruct"
# MODEL_NAME = "deepseek/deepseek-chat"
# MODEL_NAME = "google/gemma-2-9b-it"
# MODEL_NAME = "recursal/eagle-7b"
# MODEL_NAME = "meta-llama/llama-3-8b-instruct"
llm = ChatOpenAI(model_name=MODEL_NAME)
meta_prompt_graph = MetaPromptGraph(llms=llm)
input_state = AgentState(
user_message="How do I reverse a list in Python?",
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
acceptance_criteria="Similar in meaning, text length and style."
)
output_state = meta_prompt_graph(input_state, recursion_limit=25)
pprint.pp(output_state)
# if output_state has key 'best_system_message', print it
assert 'best_system_message' in output_state, \
"The output state should contain the key 'best_system_message'"
assert output_state['best_system_message'] is not None, \
"The best system message should not be None"
if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
print(output_state['best_system_message'])
# try another similar user message with the generated system message
user_message = "How can I create a list of numbers in Python?"
messages = [("system", output_state['best_system_message']),
("human", user_message)]
result = llm.invoke(messages)
# assert attr 'content' in result
assert hasattr(result, 'content'), \
"The result should have the attribute 'content'"
print(result.content)
def test_workflow_execution_with_llms(self):
optimizer_llm = ChatOpenAI(model_name="anthropic/claude-3.5-sonnet:haiku", temperature=0.5)
executor_llm = ChatOpenAI(model_name="meta-llama/llama-3-8b-instruct", temperature=0.01)
llms = {
MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER: optimizer_llm,
MetaPromptGraph.NODE_PROMPT_DEVELOPER: optimizer_llm,
MetaPromptGraph.NODE_PROMPT_EXECUTOR: executor_llm,
MetaPromptGraph.NODE_OUTPUT_HISTORY_ANALYZER: optimizer_llm,
MetaPromptGraph.NODE_PROMPT_ANALYZER: optimizer_llm,
MetaPromptGraph.NODE_PROMPT_SUGGESTER: optimizer_llm
}
meta_prompt_graph = MetaPromptGraph(llms=llms)
input_state = AgentState(
user_message="How do I reverse a list in Python?",
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
acceptance_criteria="Similar in meaning, text length and style."
)
output_state = meta_prompt_graph(input_state, recursion_limit=25)
pprint.pp(output_state)
# if output_state has key 'best_system_message', print it
assert 'best_system_message' in output_state, \
"The output state should contain the key 'best_system_message'"
assert output_state['best_system_message'] is not None, \
"The best system message should not be None"
if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
print(output_state['best_system_message'])
# try another similar user message with the generated system message
user_message = "How can I create a list of numbers in Python?"
messages = [("system", output_state['best_system_message']),
("human", user_message)]
result = executor_llm.invoke(messages)
# assert attr 'content' in result
assert hasattr(result, 'content'), \
"The result should have the attribute 'content'"
print(result.content)
if __name__ == '__main__':
unittest.main()