Spaces:
Configuration error
Configuration error
import sys | |
import os | |
import traceback | |
from dotenv import load_dotenv | |
from fastapi import Request | |
from datetime import datetime | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system path | |
from litellm import Router | |
import pytest | |
import litellm | |
from unittest.mock import patch, MagicMock, AsyncMock | |
from create_mock_standard_logging_payload import create_standard_logging_payload | |
from litellm.types.utils import StandardLoggingPayload | |
def model_list(): | |
return [ | |
{ | |
"model_name": "gpt-3.5-turbo", | |
"litellm_params": { | |
"model": "gpt-3.5-turbo", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
}, | |
"model_info": { | |
"access_groups": ["group1", "group2"], | |
}, | |
}, | |
{ | |
"model_name": "gpt-4o", | |
"litellm_params": { | |
"model": "gpt-4o", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
}, | |
}, | |
{ | |
"model_name": "dall-e-3", | |
"litellm_params": { | |
"model": "dall-e-3", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
}, | |
}, | |
{ | |
"model_name": "*", | |
"litellm_params": { | |
"model": "openai/*", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
}, | |
}, | |
{ | |
"model_name": "claude-*", | |
"litellm_params": { | |
"model": "anthropic/*", | |
"api_key": os.getenv("ANTHROPIC_API_KEY"), | |
}, | |
}, | |
] | |
def test_validate_fallbacks(model_list): | |
router = Router(model_list=model_list, fallbacks=[{"gpt-4o": "gpt-3.5-turbo"}]) | |
router.validate_fallbacks(fallback_param=[{"gpt-4o": "gpt-3.5-turbo"}]) | |
def test_routing_strategy_init(model_list): | |
"""Test if all routing strategies are initialized correctly""" | |
from litellm.types.router import RoutingStrategy | |
router = Router(model_list=model_list) | |
for strategy in RoutingStrategy._member_names_: | |
router.routing_strategy_init( | |
routing_strategy=strategy, routing_strategy_args={} | |
) | |
def test_print_deployment(model_list): | |
"""Test if the api key is masked correctly""" | |
router = Router(model_list=model_list) | |
deployment = { | |
"model_name": "gpt-3.5-turbo", | |
"litellm_params": { | |
"model": "gpt-3.5-turbo", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
}, | |
} | |
printed_deployment = router.print_deployment(deployment) | |
assert 10 * "*" in printed_deployment["litellm_params"]["api_key"] | |
def test_completion(model_list): | |
"""Test if the completion function is working correctly""" | |
router = Router(model_list=model_list) | |
response = router._completion( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "Hello, how are you?"}], | |
mock_response="I'm fine, thank you!", | |
) | |
assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!" | |
async def test_image_generation(model_list, sync_mode): | |
"""Test if the underlying '_image_generation' function is working correctly""" | |
from litellm.types.utils import ImageResponse | |
router = Router(model_list=model_list) | |
if sync_mode: | |
response = router._image_generation( | |
model="dall-e-3", | |
prompt="A cute baby sea otter", | |
) | |
else: | |
response = await router._aimage_generation( | |
model="dall-e-3", | |
prompt="A cute baby sea otter", | |
) | |
ImageResponse.model_validate(response) | |
async def test_router_acompletion_util(model_list): | |
"""Test if the underlying '_acompletion' function is working correctly""" | |
router = Router(model_list=model_list) | |
response = await router._acompletion( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "Hello, how are you?"}], | |
mock_response="I'm fine, thank you!", | |
) | |
assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!" | |
async def test_router_abatch_completion_one_model_multiple_requests_util(model_list): | |
"""Test if the 'abatch_completion_one_model_multiple_requests' function is working correctly""" | |
router = Router(model_list=model_list) | |
response = await router.abatch_completion_one_model_multiple_requests( | |
model="gpt-3.5-turbo", | |
messages=[ | |
[{"role": "user", "content": "Hello, how are you?"}], | |
[{"role": "user", "content": "Hello, how are you?"}], | |
], | |
mock_response="I'm fine, thank you!", | |
) | |
print(response) | |
assert response[0]["choices"][0]["message"]["content"] == "I'm fine, thank you!" | |
assert response[1]["choices"][0]["message"]["content"] == "I'm fine, thank you!" | |
async def test_router_schedule_acompletion(model_list): | |
"""Test if the 'schedule_acompletion' function is working correctly""" | |
router = Router(model_list=model_list) | |
response = await router.schedule_acompletion( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "Hello, how are you?"}], | |
mock_response="I'm fine, thank you!", | |
priority=1, | |
) | |
assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!" | |
async def test_router_schedule_atext_completion(model_list): | |
"""Test if the 'schedule_atext_completion' function is working correctly""" | |
from litellm.types.utils import TextCompletionResponse | |
router = Router(model_list=model_list) | |
with patch.object( | |
router, "_atext_completion", AsyncMock() | |
) as mock_atext_completion: | |
mock_atext_completion.return_value = TextCompletionResponse() | |
response = await router.atext_completion( | |
model="gpt-3.5-turbo", | |
prompt="Hello, how are you?", | |
priority=1, | |
) | |
mock_atext_completion.assert_awaited_once() | |
assert "priority" not in mock_atext_completion.call_args.kwargs | |
async def test_router_schedule_factory(model_list): | |
"""Test if the 'schedule_atext_completion' function is working correctly""" | |
from litellm.types.utils import TextCompletionResponse | |
router = Router(model_list=model_list) | |
with patch.object( | |
router, "_atext_completion", AsyncMock() | |
) as mock_atext_completion: | |
mock_atext_completion.return_value = TextCompletionResponse() | |
response = await router._schedule_factory( | |
model="gpt-3.5-turbo", | |
args=( | |
"gpt-3.5-turbo", | |
"Hello, how are you?", | |
), | |
priority=1, | |
kwargs={}, | |
original_function=router.atext_completion, | |
) | |
mock_atext_completion.assert_awaited_once() | |
assert "priority" not in mock_atext_completion.call_args.kwargs | |
async def test_router_function_with_fallbacks(model_list, sync_mode): | |
"""Test if the router 'async_function_with_fallbacks' + 'function_with_fallbacks' are working correctly""" | |
router = Router(model_list=model_list) | |
data = { | |
"model": "gpt-3.5-turbo", | |
"messages": [{"role": "user", "content": "Hello, how are you?"}], | |
"mock_response": "I'm fine, thank you!", | |
"num_retries": 0, | |
} | |
if sync_mode: | |
response = router.function_with_fallbacks( | |
original_function=router._completion, | |
**data, | |
) | |
else: | |
response = await router.async_function_with_fallbacks( | |
original_function=router._acompletion, | |
**data, | |
) | |
assert response.choices[0].message.content == "I'm fine, thank you!" | |
async def test_router_function_with_retries(model_list, sync_mode): | |
"""Test if the router 'async_function_with_retries' + 'function_with_retries' are working correctly""" | |
router = Router(model_list=model_list) | |
data = { | |
"model": "gpt-3.5-turbo", | |
"messages": [{"role": "user", "content": "Hello, how are you?"}], | |
"mock_response": "I'm fine, thank you!", | |
"num_retries": 0, | |
} | |
response = await router.async_function_with_retries( | |
original_function=router._acompletion, | |
**data, | |
) | |
assert response.choices[0].message.content == "I'm fine, thank you!" | |
async def test_router_make_call(model_list): | |
"""Test if the router 'make_call' function is working correctly""" | |
## ACOMPLETION | |
router = Router(model_list=model_list) | |
response = await router.make_call( | |
original_function=router._acompletion, | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "Hello, how are you?"}], | |
mock_response="I'm fine, thank you!", | |
) | |
assert response.choices[0].message.content == "I'm fine, thank you!" | |
## ATEXT_COMPLETION | |
response = await router.make_call( | |
original_function=router._atext_completion, | |
model="gpt-3.5-turbo", | |
prompt="Hello, how are you?", | |
mock_response="I'm fine, thank you!", | |
) | |
assert response.choices[0].text == "I'm fine, thank you!" | |
## AEMBEDDING | |
response = await router.make_call( | |
original_function=router._aembedding, | |
model="gpt-3.5-turbo", | |
input="Hello, how are you?", | |
mock_response=[0.1, 0.2, 0.3], | |
) | |
assert response.data[0].embedding == [0.1, 0.2, 0.3] | |
## AIMAGE_GENERATION | |
response = await router.make_call( | |
original_function=router._aimage_generation, | |
model="dall-e-3", | |
prompt="A cute baby sea otter", | |
mock_response="https://example.com/image.png", | |
) | |
assert response.data[0].url == "https://example.com/image.png" | |
def test_update_kwargs_with_deployment(model_list): | |
"""Test if the '_update_kwargs_with_deployment' function is working correctly""" | |
router = Router(model_list=model_list) | |
kwargs: dict = {"metadata": {}} | |
deployment = router.get_deployment_by_model_group_name( | |
model_group_name="gpt-3.5-turbo" | |
) | |
router._update_kwargs_with_deployment( | |
deployment=deployment, | |
kwargs=kwargs, | |
) | |
set_fields = ["deployment", "api_base", "model_info"] | |
assert all(field in kwargs["metadata"] for field in set_fields) | |
def test_update_kwargs_with_default_litellm_params(model_list): | |
"""Test if the '_update_kwargs_with_default_litellm_params' function is working correctly""" | |
router = Router( | |
model_list=model_list, | |
default_litellm_params={"api_key": "test", "metadata": {"key": "value"}}, | |
) | |
kwargs: dict = {"metadata": {"key2": "value2"}} | |
router._update_kwargs_with_default_litellm_params(kwargs=kwargs) | |
assert kwargs["api_key"] == "test" | |
assert kwargs["metadata"]["key"] == "value" | |
assert kwargs["metadata"]["key2"] == "value2" | |
def test_get_timeout(model_list): | |
"""Test if the '_get_timeout' function is working correctly""" | |
router = Router(model_list=model_list) | |
timeout = router._get_timeout(kwargs={}, data={"timeout": 100}) | |
assert timeout == 100 | |
def test_handle_mock_testing_fallbacks(model_list, fallback_kwarg, expected_error): | |
"""Test if the '_handle_mock_testing_fallbacks' function is working correctly""" | |
router = Router(model_list=model_list) | |
with pytest.raises(expected_error): | |
data = { | |
fallback_kwarg: True, | |
} | |
router._handle_mock_testing_fallbacks( | |
kwargs=data, | |
) | |
def test_handle_mock_testing_rate_limit_error(model_list): | |
"""Test if the '_handle_mock_testing_rate_limit_error' function is working correctly""" | |
router = Router(model_list=model_list) | |
with pytest.raises(litellm.RateLimitError): | |
data = { | |
"mock_testing_rate_limit_error": True, | |
} | |
router._handle_mock_testing_rate_limit_error( | |
kwargs=data, | |
) | |
def test_get_fallback_model_group_from_fallbacks(model_list): | |
"""Test if the '_get_fallback_model_group_from_fallbacks' function is working correctly""" | |
router = Router(model_list=model_list) | |
fallback_model_group_name = router._get_fallback_model_group_from_fallbacks( | |
model_group="gpt-4o", | |
fallbacks=[{"gpt-4o": "gpt-3.5-turbo"}], | |
) | |
assert fallback_model_group_name == "gpt-3.5-turbo" | |
async def test_deployment_callback_on_success(model_list, sync_mode): | |
"""Test if the '_deployment_callback_on_success' function is working correctly""" | |
import time | |
router = Router(model_list=model_list) | |
standard_logging_payload = create_standard_logging_payload() | |
standard_logging_payload["total_tokens"] = 100 | |
kwargs = { | |
"litellm_params": { | |
"metadata": { | |
"model_group": "gpt-3.5-turbo", | |
}, | |
"model_info": {"id": 100}, | |
}, | |
"standard_logging_object": standard_logging_payload, | |
} | |
response = litellm.ModelResponse( | |
model="gpt-3.5-turbo", | |
usage={"total_tokens": 100}, | |
) | |
if sync_mode: | |
tpm_key = router.sync_deployment_callback_on_success( | |
kwargs=kwargs, | |
completion_response=response, | |
start_time=time.time(), | |
end_time=time.time(), | |
) | |
else: | |
tpm_key = await router.deployment_callback_on_success( | |
kwargs=kwargs, | |
completion_response=response, | |
start_time=time.time(), | |
end_time=time.time(), | |
) | |
assert tpm_key is not None | |
async def test_deployment_callback_on_failure(model_list): | |
"""Test if the '_deployment_callback_on_failure' function is working correctly""" | |
import time | |
router = Router(model_list=model_list) | |
kwargs = { | |
"litellm_params": { | |
"metadata": { | |
"model_group": "gpt-3.5-turbo", | |
}, | |
"model_info": {"id": 100}, | |
}, | |
} | |
result = router.deployment_callback_on_failure( | |
kwargs=kwargs, | |
completion_response=None, | |
start_time=time.time(), | |
end_time=time.time(), | |
) | |
assert isinstance(result, bool) | |
assert result is False | |
model_response = router.completion( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "Hello, how are you?"}], | |
mock_response="I'm fine, thank you!", | |
) | |
result = await router.async_deployment_callback_on_failure( | |
kwargs=kwargs, | |
completion_response=model_response, | |
start_time=time.time(), | |
end_time=time.time(), | |
) | |
def test_log_retry(model_list): | |
"""Test if the '_log_retry' function is working correctly""" | |
import time | |
router = Router(model_list=model_list) | |
new_kwargs = router.log_retry( | |
kwargs={"metadata": {}}, | |
e=Exception(), | |
) | |
assert "metadata" in new_kwargs | |
assert "previous_models" in new_kwargs["metadata"] | |
def test_update_usage(model_list): | |
"""Test if the '_update_usage' function is working correctly""" | |
router = Router(model_list=model_list) | |
deployment = router.get_deployment_by_model_group_name( | |
model_group_name="gpt-3.5-turbo" | |
) | |
deployment_id = deployment["model_info"]["id"] | |
request_count = router._update_usage( | |
deployment_id=deployment_id, parent_otel_span=None | |
) | |
assert request_count == 1 | |
request_count = router._update_usage( | |
deployment_id=deployment_id, parent_otel_span=None | |
) | |
assert request_count == 2 | |
def test_should_raise_content_policy_error( | |
model_list, finish_reason, expected_fallback, fallback_type | |
): | |
"""Test if the '_should_raise_content_policy_error' function is working correctly""" | |
router = Router( | |
model_list=model_list, | |
default_fallbacks=["gpt-4o"] if fallback_type == "default" else None, | |
) | |
assert ( | |
router._should_raise_content_policy_error( | |
model="gpt-3.5-turbo", | |
response=litellm.ModelResponse( | |
model="gpt-3.5-turbo", | |
choices=[ | |
{ | |
"finish_reason": finish_reason, | |
"message": {"content": "I'm fine, thank you!"}, | |
} | |
], | |
usage={"total_tokens": 100}, | |
), | |
kwargs={ | |
"content_policy_fallbacks": ( | |
[{"gpt-3.5-turbo": "gpt-4o"}] | |
if fallback_type == "model-specific" | |
else None | |
) | |
}, | |
) | |
is expected_fallback | |
) | |
def test_get_healthy_deployments(model_list): | |
"""Test if the '_get_healthy_deployments' function is working correctly""" | |
router = Router(model_list=model_list) | |
deployments = router._get_healthy_deployments( | |
model="gpt-3.5-turbo", parent_otel_span=None | |
) | |
assert len(deployments) > 0 | |
async def test_routing_strategy_pre_call_checks(model_list, sync_mode): | |
"""Test if the '_routing_strategy_pre_call_checks' function is working correctly""" | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.litellm_core_utils.litellm_logging import Logging | |
callback = CustomLogger() | |
litellm.callbacks = [callback] | |
router = Router(model_list=model_list) | |
deployment = router.get_deployment_by_model_group_name( | |
model_group_name="gpt-3.5-turbo" | |
) | |
litellm_logging_obj = Logging( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "hi"}], | |
stream=False, | |
call_type="acompletion", | |
litellm_call_id="1234", | |
start_time=datetime.now(), | |
function_id="1234", | |
) | |
if sync_mode: | |
router.routing_strategy_pre_call_checks(deployment) | |
else: | |
## NO EXCEPTION | |
await router.async_routing_strategy_pre_call_checks( | |
deployment, litellm_logging_obj | |
) | |
## WITH EXCEPTION - rate limit error | |
with patch.object( | |
callback, | |
"async_pre_call_check", | |
AsyncMock( | |
side_effect=litellm.RateLimitError( | |
message="Rate limit error", | |
llm_provider="openai", | |
model="gpt-3.5-turbo", | |
) | |
), | |
): | |
try: | |
await router.async_routing_strategy_pre_call_checks( | |
deployment, litellm_logging_obj | |
) | |
pytest.fail("Exception was not raised") | |
except Exception as e: | |
assert isinstance(e, litellm.RateLimitError) | |
## WITH EXCEPTION - generic error | |
with patch.object( | |
callback, "async_pre_call_check", AsyncMock(side_effect=Exception("Error")) | |
): | |
try: | |
await router.async_routing_strategy_pre_call_checks( | |
deployment, litellm_logging_obj | |
) | |
pytest.fail("Exception was not raised") | |
except Exception as e: | |
assert isinstance(e, Exception) | |
def test_create_deployment( | |
model_list, set_supported_environments, supported_environments, is_supported | |
): | |
"""Test if the '_create_deployment' function is working correctly""" | |
router = Router(model_list=model_list) | |
if set_supported_environments: | |
os.environ["LITELLM_ENVIRONMENT"] = "staging" | |
deployment = router._create_deployment( | |
deployment_info={}, | |
_model_name="gpt-3.5-turbo", | |
_litellm_params={ | |
"model": "gpt-3.5-turbo", | |
"api_key": "test", | |
"custom_llm_provider": "openai", | |
}, | |
_model_info={ | |
"id": 100, | |
"supported_environments": supported_environments, | |
}, | |
) | |
if is_supported: | |
assert deployment is not None | |
else: | |
assert deployment is None | |
def test_deployment_is_active_for_environment( | |
model_list, set_supported_environments, supported_environments, is_supported | |
): | |
"""Test if the '_deployment_is_active_for_environment' function is working correctly""" | |
router = Router(model_list=model_list) | |
deployment = router.get_deployment_by_model_group_name( | |
model_group_name="gpt-3.5-turbo" | |
) | |
if set_supported_environments: | |
os.environ["LITELLM_ENVIRONMENT"] = "staging" | |
deployment["model_info"]["supported_environments"] = supported_environments | |
if is_supported: | |
assert ( | |
router.deployment_is_active_for_environment(deployment=deployment) is True | |
) | |
else: | |
assert ( | |
router.deployment_is_active_for_environment(deployment=deployment) is False | |
) | |
def test_set_model_list(model_list): | |
"""Test if the '_set_model_list' function is working correctly""" | |
router = Router(model_list=model_list) | |
router.set_model_list(model_list=model_list) | |
assert len(router.model_list) == len(model_list) | |
def test_add_deployment(model_list): | |
"""Test if the '_add_deployment' function is working correctly""" | |
router = Router(model_list=model_list) | |
deployment = router.get_deployment_by_model_group_name( | |
model_group_name="gpt-3.5-turbo" | |
) | |
deployment["model_info"]["id"] = 100 | |
## Test 1: call user facing function | |
router.add_deployment(deployment=deployment) | |
## Test 2: call internal function | |
router._add_deployment(deployment=deployment) | |
assert len(router.model_list) == len(model_list) + 1 | |
def test_upsert_deployment(model_list): | |
"""Test if the 'upsert_deployment' function is working correctly""" | |
router = Router(model_list=model_list) | |
print("model list", len(router.model_list)) | |
deployment = router.get_deployment_by_model_group_name( | |
model_group_name="gpt-3.5-turbo" | |
) | |
deployment.litellm_params.model = "gpt-4o" | |
router.upsert_deployment(deployment=deployment) | |
assert len(router.model_list) == len(model_list) | |
def test_delete_deployment(model_list): | |
"""Test if the 'delete_deployment' function is working correctly""" | |
router = Router(model_list=model_list) | |
deployment = router.get_deployment_by_model_group_name( | |
model_group_name="gpt-3.5-turbo" | |
) | |
router.delete_deployment(id=deployment["model_info"]["id"]) | |
assert len(router.model_list) == len(model_list) - 1 | |
def test_get_model_info(model_list): | |
"""Test if the 'get_model_info' function is working correctly""" | |
router = Router(model_list=model_list) | |
deployment = router.get_deployment_by_model_group_name( | |
model_group_name="gpt-3.5-turbo" | |
) | |
model_info = router.get_model_info(id=deployment["model_info"]["id"]) | |
assert model_info is not None | |
def test_get_model_group(model_list): | |
"""Test if the 'get_model_group' function is working correctly""" | |
router = Router(model_list=model_list) | |
deployment = router.get_deployment_by_model_group_name( | |
model_group_name="gpt-3.5-turbo" | |
) | |
model_group = router.get_model_group(id=deployment["model_info"]["id"]) | |
assert model_group is not None | |
assert model_group[0]["model_name"] == "gpt-3.5-turbo" | |
def test_set_model_group_info(model_list, user_facing_model_group_name): | |
"""Test if the 'set_model_group_info' function is working correctly""" | |
router = Router(model_list=model_list) | |
resp = router._set_model_group_info( | |
model_group="gpt-3.5-turbo", | |
user_facing_model_group_name=user_facing_model_group_name, | |
) | |
assert resp is not None | |
assert resp.model_group == user_facing_model_group_name | |
async def test_set_response_headers(model_list): | |
"""Test if the 'set_response_headers' function is working correctly""" | |
router = Router(model_list=model_list) | |
resp = await router.set_response_headers(response=None, model_group=None) | |
assert resp is None | |
def test_get_all_deployments(model_list): | |
"""Test if the 'get_all_deployments' function is working correctly""" | |
router = Router(model_list=model_list) | |
deployments = router._get_all_deployments( | |
model_name="gpt-3.5-turbo", model_alias="gpt-3.5-turbo" | |
) | |
assert len(deployments) > 0 | |
def test_get_model_access_groups(model_list): | |
"""Test if the 'get_model_access_groups' function is working correctly""" | |
router = Router(model_list=model_list) | |
access_groups = router.get_model_access_groups() | |
assert len(access_groups) == 2 | |
def test_update_settings(model_list): | |
"""Test if the 'update_settings' function is working correctly""" | |
router = Router(model_list=model_list) | |
pre_update_allowed_fails = router.allowed_fails | |
router.update_settings(**{"allowed_fails": 20}) | |
assert router.allowed_fails != pre_update_allowed_fails | |
assert router.allowed_fails == 20 | |
def test_common_checks_available_deployment(model_list): | |
"""Test if the 'common_checks_available_deployment' function is working correctly""" | |
router = Router(model_list=model_list) | |
_, available_deployments = router._common_checks_available_deployment( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "hi"}], | |
input="hi", | |
specific_deployment=False, | |
) | |
assert len(available_deployments) > 0 | |
def test_filter_cooldown_deployments(model_list): | |
"""Test if the 'filter_cooldown_deployments' function is working correctly""" | |
router = Router(model_list=model_list) | |
deployments = router._filter_cooldown_deployments( | |
healthy_deployments=router._get_all_deployments(model_name="gpt-3.5-turbo"), # type: ignore | |
cooldown_deployments=[], | |
) | |
assert len(deployments) == len( | |
router._get_all_deployments(model_name="gpt-3.5-turbo") | |
) | |
def test_track_deployment_metrics(model_list): | |
"""Test if the 'track_deployment_metrics' function is working correctly""" | |
from litellm.types.utils import ModelResponse | |
router = Router(model_list=model_list) | |
router._track_deployment_metrics( | |
deployment=router.get_deployment_by_model_group_name( | |
model_group_name="gpt-3.5-turbo" | |
), | |
response=ModelResponse( | |
model="gpt-3.5-turbo", | |
usage={"total_tokens": 100}, | |
), | |
parent_otel_span=None, | |
) | |
def test_get_num_retries_from_retry_policy( | |
model_list, exception_type, exception_name, num_retries | |
): | |
"""Test if the 'get_num_retries_from_retry_policy' function is working correctly""" | |
from litellm.router import RetryPolicy | |
data = {exception_name + "Retries": num_retries} | |
print("data", data) | |
router = Router( | |
model_list=model_list, | |
retry_policy=RetryPolicy(**data), | |
) | |
print("exception_type", exception_type) | |
calc_num_retries = router.get_num_retries_from_retry_policy( | |
exception=exception_type( | |
message="test", llm_provider="openai", model="gpt-3.5-turbo" | |
) | |
) | |
assert calc_num_retries == num_retries | |
def test_get_allowed_fails_from_policy( | |
model_list, exception_type, exception_name, allowed_fails | |
): | |
"""Test if the 'get_allowed_fails_from_policy' function is working correctly""" | |
from litellm.types.router import AllowedFailsPolicy | |
data = {exception_name + "AllowedFails": allowed_fails} | |
router = Router( | |
model_list=model_list, allowed_fails_policy=AllowedFailsPolicy(**data) | |
) | |
calc_allowed_fails = router.get_allowed_fails_from_policy( | |
exception=exception_type( | |
message="test", llm_provider="openai", model="gpt-3.5-turbo" | |
) | |
) | |
assert calc_allowed_fails == allowed_fails | |
def test_initialize_alerting(model_list): | |
"""Test if the 'initialize_alerting' function is working correctly""" | |
from litellm.types.router import AlertingConfig | |
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting | |
router = Router( | |
model_list=model_list, alerting_config=AlertingConfig(webhook_url="test") | |
) | |
router._initialize_alerting() | |
callback_added = False | |
for callback in litellm.callbacks: | |
if isinstance(callback, SlackAlerting): | |
callback_added = True | |
assert callback_added is True | |
def test_flush_cache(model_list): | |
"""Test if the 'flush_cache' function is working correctly""" | |
router = Router(model_list=model_list) | |
router.cache.set_cache("test", "test") | |
assert router.cache.get_cache("test") == "test" | |
router.flush_cache() | |
assert router.cache.get_cache("test") is None | |
def test_discard(model_list): | |
""" | |
Test that discard properly removes a Router from the callback lists | |
""" | |
litellm.callbacks = [] | |
litellm.success_callback = [] | |
litellm._async_success_callback = [] | |
litellm.failure_callback = [] | |
litellm._async_failure_callback = [] | |
litellm.input_callback = [] | |
litellm.service_callback = [] | |
router = Router(model_list=model_list) | |
router.discard() | |
# Verify all callback lists are empty | |
assert len(litellm.callbacks) == 0 | |
assert len(litellm.success_callback) == 0 | |
assert len(litellm.failure_callback) == 0 | |
assert len(litellm._async_success_callback) == 0 | |
assert len(litellm._async_failure_callback) == 0 | |
assert len(litellm.input_callback) == 0 | |
assert len(litellm.service_callback) == 0 | |
def test_initialize_assistants_endpoint(model_list): | |
"""Test if the 'initialize_assistants_endpoint' function is working correctly""" | |
router = Router(model_list=model_list) | |
router.initialize_assistants_endpoint() | |
assert router.acreate_assistants is not None | |
assert router.adelete_assistant is not None | |
assert router.aget_assistants is not None | |
assert router.acreate_thread is not None | |
assert router.aget_thread is not None | |
assert router.arun_thread is not None | |
assert router.aget_messages is not None | |
assert router.a_add_message is not None | |
def test_pass_through_assistants_endpoint_factory(model_list): | |
"""Test if the 'pass_through_assistants_endpoint_factory' function is working correctly""" | |
router = Router(model_list=model_list) | |
router._pass_through_assistants_endpoint_factory( | |
original_function=litellm.acreate_assistants, | |
custom_llm_provider="openai", | |
client=None, | |
**{}, | |
) | |
def test_factory_function(model_list): | |
"""Test if the 'factory_function' function is working correctly""" | |
router = Router(model_list=model_list) | |
router.factory_function(litellm.acreate_assistants) | |
def test_get_model_from_alias(model_list): | |
"""Test if the 'get_model_from_alias' function is working correctly""" | |
router = Router( | |
model_list=model_list, | |
model_group_alias={"gpt-4o": "gpt-3.5-turbo"}, | |
) | |
model = router._get_model_from_alias(model="gpt-4o") | |
assert model == "gpt-3.5-turbo" | |
def test_get_deployment_by_litellm_model(model_list): | |
"""Test if the 'get_deployment_by_litellm_model' function is working correctly""" | |
router = Router(model_list=model_list) | |
deployment = router._get_deployment_by_litellm_model(model="gpt-3.5-turbo") | |
assert deployment is not None | |
def test_get_pattern(model_list): | |
router = Router(model_list=model_list) | |
pattern = router.pattern_router.get_pattern(model="claude-3") | |
assert pattern is not None | |
def test_deployments_by_pattern(model_list): | |
router = Router(model_list=model_list) | |
deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3") | |
assert deployments is not None | |
def test_replace_model_in_jsonl(model_list): | |
router = Router(model_list=model_list) | |
deployments = router.pattern_router.get_deployments_by_pattern(model="claude-3") | |
assert deployments is not None | |
# def test_pattern_match_deployments(model_list): | |
# from litellm.router_utils.pattern_match_deployments import PatternMatchRouter | |
# import re | |
# patter_router = PatternMatchRouter() | |
# request = "fo::hi::static::hello" | |
# model_name = "fo::*:static::*" | |
# model_name_regex = patter_router._pattern_to_regex(model_name) | |
# # Match against the request | |
# match = re.match(model_name_regex, request) | |
# print(f"match: {match}") | |
# print(f"match.end: {match.end()}") | |
# if match is None: | |
# raise ValueError("Match not found") | |
# updated_model = patter_router.set_deployment_model_name( | |
# matched_pattern=match, litellm_deployment_litellm_model="openai/*" | |
# ) | |
# assert updated_model == "openai/fo::hi:static::hello" | |
def test_pattern_match_deployment_set_model_name( | |
user_request_model, model_name, litellm_model, expected_model | |
): | |
from re import Match | |
from litellm.router_utils.pattern_match_deployments import PatternMatchRouter | |
pattern_router = PatternMatchRouter() | |
import re | |
# Convert model_name into a proper regex | |
model_name_regex = pattern_router._pattern_to_regex(model_name) | |
# Match against the request | |
match = re.match(model_name_regex, user_request_model) | |
if match is None: | |
raise ValueError("Match not found") | |
# Call the set_deployment_model_name function | |
updated_model = pattern_router.set_deployment_model_name(match, litellm_model) | |
print(updated_model) # Expected output: "openai/fo::hi:static::hello" | |
assert updated_model == expected_model | |
updated_models = pattern_router._return_pattern_matched_deployments( | |
match, | |
deployments=[ | |
{ | |
"model_name": model_name, | |
"litellm_params": {"model": litellm_model}, | |
} | |
], | |
) | |
for model in updated_models: | |
assert model["litellm_params"]["model"] == expected_model | |
async def test_pass_through_moderation_endpoint_factory(model_list): | |
router = Router(model_list=model_list) | |
response = await router._pass_through_moderation_endpoint_factory( | |
original_function=litellm.amoderation, | |
input="this is valid good text", | |
model=None, | |
) | |
assert response is not None | |
def test_has_default_fallbacks(model_list, has_default_fallbacks, expected_result): | |
router = Router( | |
model_list=model_list, | |
default_fallbacks=( | |
["my-default-fallback-model"] if has_default_fallbacks else None | |
), | |
) | |
assert router._has_default_fallbacks() is expected_result | |
def test_add_optional_pre_call_checks(model_list): | |
router = Router(model_list=model_list) | |
router.add_optional_pre_call_checks(["prompt_caching"]) | |
assert len(litellm.callbacks) > 0 | |
async def test_async_callback_filter_deployments(model_list): | |
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting | |
router = Router(model_list=model_list) | |
healthy_deployments = router.get_model_list(model_name="gpt-3.5-turbo") | |
new_healthy_deployments = await router.async_callback_filter_deployments( | |
model="gpt-3.5-turbo", | |
healthy_deployments=healthy_deployments, | |
messages=[], | |
parent_otel_span=None, | |
) | |
assert len(new_healthy_deployments) == len(healthy_deployments) | |
def test_cached_get_model_group_info(model_list): | |
"""Test if the '_cached_get_model_group_info' function is working correctly with LRU cache""" | |
router = Router(model_list=model_list) | |
# First call - should hit the actual function | |
result1 = router._cached_get_model_group_info("gpt-3.5-turbo") | |
# Second call with same argument - should hit the cache | |
result2 = router._cached_get_model_group_info("gpt-3.5-turbo") | |
# Verify results are the same | |
assert result1 == result2 | |
# Verify the cache info shows hits | |
cache_info = router._cached_get_model_group_info.cache_info() | |
assert cache_info.hits > 0 # Should have at least one cache hit | |
def test_init_responses_api_endpoints(model_list): | |
"""Test if the '_init_responses_api_endpoints' function is working correctly""" | |
from typing import Callable | |
router = Router(model_list=model_list) | |
assert router.aget_responses is not None | |
assert isinstance(router.aget_responses, Callable) | |
assert router.adelete_responses is not None | |
assert isinstance(router.adelete_responses, Callable) | |