Spaces:
Sleeping
Sleeping
File size: 7,318 Bytes
fcaac18 68c6b73 fcaac18 2d1b8d7 fcaac18 2d1b8d7 fcaac18 2d1b8d7 fcaac18 2d1b8d7 fcaac18 2d1b8d7 fcaac18 2d1b8d7 fcaac18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import unittest
import pprint
import logging
from unittest.mock import MagicMock
from unittest.mock import patch
# Assuming the necessary imports are made for the classes and functions used in meta_prompt_graph.py
from meta_prompt import AgentState, MetaPromptGraph
from langchain_openai import ChatOpenAI
class TestMetaPromptGraph(unittest.TestCase):
def setUp(self):
# logging.basicConfig(level=logging.DEBUG)
pass
def test_prompt_node(self):
llms = {
MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER: MagicMock(
invoke=MagicMock(return_value=MagicMock(content="Mocked response content"))
)
}
# Create an instance of MetaPromptGraph with the mocked language model and template
graph = MetaPromptGraph(llms=llms)
# Create a mock AgentState
state = AgentState(user_message="Test message", expected_output="Expected output")
# Invoke the _prompt_node method with the mock node, target attribute, and state
updated_state = graph._prompt_node(
MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER, "output", state
)
# Assertions
assert updated_state.output == "Mocked response content", \
"The output attribute should be updated with the mocked response content"
def test_output_history_analyzer(self):
# Setup
llms = {
"output_history_analyzer": MagicMock(invoke=lambda prompt: MagicMock(content="""# Analysis
This analysis compares two outputs to the expected output based on specific criteria.
# Preferred Output ID: B"""))
}
prompts = {}
meta_prompt_graph = MetaPromptGraph(llms=llms, prompts=prompts)
state = AgentState(
user_message="How do I reverse a list in Python?",
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
output="To reverse a list in Python, you can use the `[::-1]` slicing.",
system_message="To reverse a list, use slicing or the reverse method.",
best_output="To reverse a list in Python, use the `reverse()` method.",
best_system_message="To reverse a list, use the `reverse()` method.",
acceptance_criteria="The output should correctly describe how to reverse a list in Python."
)
# Invoke the output history analyzer node
updated_state = meta_prompt_graph._output_history_analyzer(state)
# Assertions
assert updated_state.best_output == state.output, \
"Best output should be updated to the current output."
assert updated_state.best_system_message == state.system_message, \
"Best system message should be updated to the current system message."
assert updated_state.best_output_age == 0, \
"Best output age should be reset to 0."
def test_prompt_analyzer_accept(self):
llms = {
MetaPromptGraph.NODE_PROMPT_ANALYZER: MagicMock(
invoke=lambda prompt: MagicMock(content="Accept: Yes"))
}
meta_prompt_graph = MetaPromptGraph(llms)
state = AgentState(output="Test output", expected_output="Expected output")
updated_state = meta_prompt_graph._prompt_analyzer(state)
assert updated_state.accepted == True
def test_workflow_execution(self):
MODEL_NAME = "anthropic/claude-3.5-sonnet:haiku"
# MODEL_NAME = "meta-llama/llama-3-70b-instruct"
# MODEL_NAME = "deepseek/deepseek-chat"
# MODEL_NAME = "google/gemma-2-9b-it"
# MODEL_NAME = "recursal/eagle-7b"
# MODEL_NAME = "meta-llama/llama-3-8b-instruct"
llm = ChatOpenAI(model_name=MODEL_NAME)
meta_prompt_graph = MetaPromptGraph(llms=llm)
input_state = AgentState(
user_message="How do I reverse a list in Python?",
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
acceptance_criteria="Similar in meaning, text length and style."
)
output_state = meta_prompt_graph(input_state, recursion_limit=25)
pprint.pp(output_state)
# if output_state has key 'best_system_message', print it
assert 'best_system_message' in output_state, \
"The output state should contain the key 'best_system_message'"
assert output_state['best_system_message'] is not None, \
"The best system message should not be None"
if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
print(output_state['best_system_message'])
# try another similar user message with the generated system message
user_message = "How can I create a list of numbers in Python?"
messages = [("system", output_state['best_system_message']),
("human", user_message)]
result = llm.invoke(messages)
# assert attr 'content' in result
assert hasattr(result, 'content'), \
"The result should have the attribute 'content'"
print(result.content)
def test_workflow_execution_with_llms(self):
optimizer_llm = ChatOpenAI(model_name="anthropic/claude-3.5-sonnet:haiku", temperature=0.5)
executor_llm = ChatOpenAI(model_name="meta-llama/llama-3-8b-instruct", temperature=0.01)
llms = {
MetaPromptGraph.NODE_PROMPT_INITIAL_DEVELOPER: optimizer_llm,
MetaPromptGraph.NODE_PROMPT_DEVELOPER: optimizer_llm,
MetaPromptGraph.NODE_PROMPT_EXECUTOR: executor_llm,
MetaPromptGraph.NODE_OUTPUT_HISTORY_ANALYZER: optimizer_llm,
MetaPromptGraph.NODE_PROMPT_ANALYZER: optimizer_llm,
MetaPromptGraph.NODE_PROMPT_SUGGESTER: optimizer_llm
}
meta_prompt_graph = MetaPromptGraph(llms=llms)
input_state = AgentState(
user_message="How do I reverse a list in Python?",
expected_output="Use the `[::-1]` slicing technique or the `list.reverse()` method.",
acceptance_criteria="Similar in meaning, text length and style."
)
output_state = meta_prompt_graph(input_state, recursion_limit=25)
pprint.pp(output_state)
# if output_state has key 'best_system_message', print it
assert 'best_system_message' in output_state, \
"The output state should contain the key 'best_system_message'"
assert output_state['best_system_message'] is not None, \
"The best system message should not be None"
if 'best_system_message' in output_state and output_state['best_system_message'] is not None:
print(output_state['best_system_message'])
# try another similar user message with the generated system message
user_message = "How can I create a list of numbers in Python?"
messages = [("system", output_state['best_system_message']),
("human", user_message)]
result = executor_llm.invoke(messages)
# assert attr 'content' in result
assert hasattr(result, 'content'), \
"The result should have the attribute 'content'"
print(result.content)
if __name__ == '__main__':
unittest.main() |