Spaces:
Runtime error
Runtime error
from unittest.mock import Mock, patch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from swarms import ToolAgent | |
def test_tool_agent_init(): | |
model = Mock(spec=AutoModelForCausalLM) | |
tokenizer = Mock(spec=AutoTokenizer) | |
json_schema = { | |
"type": "object", | |
"properties": { | |
"name": {"type": "string"}, | |
"age": {"type": "number"}, | |
"is_student": {"type": "boolean"}, | |
"courses": {"type": "array", "items": {"type": "string"}}, | |
}, | |
} | |
name = "Test Agent" | |
description = "This is a test agent" | |
agent = ToolAgent( | |
name, description, model, tokenizer, json_schema | |
) | |
assert agent.name == name | |
assert agent.description == description | |
assert agent.model == model | |
assert agent.tokenizer == tokenizer | |
assert agent.json_schema == json_schema | |
def test_tool_agent_run(mock_run): | |
model = Mock(spec=AutoModelForCausalLM) | |
tokenizer = Mock(spec=AutoTokenizer) | |
json_schema = { | |
"type": "object", | |
"properties": { | |
"name": {"type": "string"}, | |
"age": {"type": "number"}, | |
"is_student": {"type": "boolean"}, | |
"courses": {"type": "array", "items": {"type": "string"}}, | |
}, | |
} | |
name = "Test Agent" | |
description = "This is a test agent" | |
task = ( | |
"Generate a person's information based on the following" | |
" schema:" | |
) | |
agent = ToolAgent( | |
name, description, model, tokenizer, json_schema | |
) | |
agent.run(task) | |
mock_run.assert_called_once_with(task) | |
def test_tool_agent_init_with_kwargs(): | |
model = Mock(spec=AutoModelForCausalLM) | |
tokenizer = Mock(spec=AutoTokenizer) | |
json_schema = { | |
"type": "object", | |
"properties": { | |
"name": {"type": "string"}, | |
"age": {"type": "number"}, | |
"is_student": {"type": "boolean"}, | |
"courses": {"type": "array", "items": {"type": "string"}}, | |
}, | |
} | |
name = "Test Agent" | |
description = "This is a test agent" | |
kwargs = { | |
"debug": True, | |
"max_array_length": 20, | |
"max_number_tokens": 12, | |
"temperature": 0.5, | |
"max_string_token_length": 20, | |
} | |
agent = ToolAgent( | |
name, description, model, tokenizer, json_schema, **kwargs | |
) | |
assert agent.name == name | |
assert agent.description == description | |
assert agent.model == model | |
assert agent.tokenizer == tokenizer | |
assert agent.json_schema == json_schema | |
assert agent.debug == kwargs["debug"] | |
assert agent.max_array_length == kwargs["max_array_length"] | |
assert agent.max_number_tokens == kwargs["max_number_tokens"] | |
assert agent.temperature == kwargs["temperature"] | |
assert ( | |
agent.max_string_token_length | |
== kwargs["max_string_token_length"] | |
) | |