Spaces:
Configuration error
Configuration error
# stdlib imports | |
import json | |
import os | |
import time | |
from unittest.mock import patch | |
import pytest | |
# third party imports | |
from click.testing import CliRunner | |
# local imports | |
from litellm.proxy.client.cli import cli | |
from litellm.proxy.client.cli.commands.models import ( | |
format_cost_per_1k_tokens, | |
format_iso_datetime_str, | |
format_timestamp, | |
) | |
def mock_client(): | |
"""Fixture to create a mock client with common setup""" | |
with patch("litellm.proxy.client.cli.commands.models.Client") as MockClient: | |
yield MockClient | |
def cli_runner(): | |
"""Fixture for Click CLI runner""" | |
return CliRunner() | |
def mock_env(): | |
"""Fixture to set up environment variables for all tests""" | |
with patch.dict( | |
os.environ, | |
{ | |
"LITELLM_PROXY_URL": "http://localhost:4000", | |
"LITELLM_PROXY_API_KEY": "sk-test", | |
}, | |
): | |
yield | |
def mock_models_list(mock_client): | |
"""Fixture to set up common mocking pattern for models list tests""" | |
mock_client.return_value.models.list.return_value = [ | |
{ | |
"id": "model-123", | |
"object": "model", | |
"created": 1699848889, | |
"owned_by": "organization-123", | |
}, | |
{ | |
"id": "model-456", | |
"object": "model", | |
"created": 1699848890, | |
"owned_by": "organization-456", | |
}, | |
] | |
mock_client.assert_not_called() # Ensure clean slate | |
return mock_client | |
def mock_models_info(mock_client): | |
"""Fixture to set up models info mock""" | |
mock_client.return_value.models.info.return_value = [ | |
{ | |
"model_name": "gpt-4", | |
"litellm_params": {"model": "gpt-4", "litellm_credential_name": "openai-1"}, | |
"model_info": { | |
"id": "model-123", | |
"created_at": "2025-04-29T21:31:43.843000+00:00", | |
"updated_at": "2025-04-29T21:31:43.843000+00:00", | |
"input_cost_per_token": 0.00001, | |
"output_cost_per_token": 0.00002, | |
}, | |
} | |
] | |
mock_client.assert_not_called() | |
return mock_client | |
def force_utc_tz(): | |
"""Fixture to force UTC timezone for tests that depend on system TZ.""" | |
old_tz = os.environ.get("TZ") | |
os.environ["TZ"] = "UTC" | |
if hasattr(time, "tzset"): | |
time.tzset() | |
yield | |
# Restore previous TZ | |
if old_tz is not None: | |
os.environ["TZ"] = old_tz | |
else: | |
if "TZ" in os.environ: | |
del os.environ["TZ"] | |
if hasattr(time, "tzset"): | |
time.tzset() | |
def test_models_list_json_format(mock_models_list, cli_runner): | |
"""Test the models list command with JSON output format""" | |
# Run the command | |
result = cli_runner.invoke(cli, ["models", "list", "--format", "json"]) | |
# Check that the command succeeded | |
assert result.exit_code == 0 | |
# Parse the output and verify it matches our mock data | |
output_data = json.loads(result.output) | |
assert output_data == mock_models_list.return_value.models.list.return_value | |
# Verify the client was called correctly | |
mock_models_list.assert_called_once_with( | |
base_url="http://localhost:4000", api_key="sk-test" | |
) | |
mock_models_list.return_value.models.list.assert_called_once() | |
def test_models_list_table_format(mock_models_list, cli_runner): | |
"""Test the models list command with table output format""" | |
# Run the command | |
result = cli_runner.invoke(cli, ["models", "list"]) | |
# Check that the command succeeded | |
assert result.exit_code == 0 | |
# Verify the output contains expected table elements | |
assert "ID" in result.output | |
assert "Object" in result.output | |
assert "Created" in result.output | |
assert "Owned By" in result.output | |
assert "model-123" in result.output | |
assert "organization-123" in result.output | |
assert format_timestamp(1699848889) in result.output | |
# Verify the client was called correctly | |
mock_models_list.assert_called_once_with( | |
base_url="http://localhost:4000", api_key="sk-test" | |
) | |
mock_models_list.return_value.models.list.assert_called_once() | |
def test_models_list_with_base_url(mock_models_list, cli_runner): | |
"""Test the models list command with custom base URL overriding env var""" | |
custom_base_url = "http://custom.server:8000" | |
# Run the command with custom base URL | |
result = cli_runner.invoke(cli, ["--base-url", custom_base_url, "models", "list"]) | |
# Check that the command succeeded | |
assert result.exit_code == 0 | |
# Verify the client was created with the custom base URL (overriding env var) | |
mock_models_list.assert_called_once_with( | |
base_url=custom_base_url, | |
api_key="sk-test", # Should still use env var for API key | |
) | |
def test_models_list_with_api_key(mock_models_list, cli_runner): | |
"""Test the models list command with API key overriding env var""" | |
custom_api_key = "custom-test-key" | |
# Run the command with custom API key | |
result = cli_runner.invoke(cli, ["--api-key", custom_api_key, "models", "list"]) | |
# Check that the command succeeded | |
assert result.exit_code == 0 | |
# Verify the client was created with the custom API key (overriding env var) | |
mock_models_list.assert_called_once_with( | |
base_url="http://localhost:4000", # Should still use env var for base URL | |
api_key=custom_api_key, | |
) | |
def test_models_list_error_handling(mock_client, cli_runner): | |
"""Test error handling in the models list command""" | |
# Configure mock to raise an exception | |
mock_client.return_value.models.list.side_effect = Exception("API Error") | |
# Run the command | |
result = cli_runner.invoke(cli, ["models", "list"]) | |
# Check that the command failed | |
assert result.exit_code != 0 | |
assert "API Error" in str(result.exception) | |
# Verify the client was created with env var values | |
mock_client.assert_called_once_with( | |
base_url="http://localhost:4000", api_key="sk-test" | |
) | |
def test_models_info_json_format(mock_models_info, cli_runner): | |
"""Test the models info command with JSON output format""" | |
# Run the command | |
result = cli_runner.invoke(cli, ["models", "info", "--format", "json"]) | |
# Check that the command succeeded | |
assert result.exit_code == 0 | |
# Parse the output and verify it matches our mock data | |
output_data = json.loads(result.output) | |
assert output_data == mock_models_info.return_value.models.info.return_value | |
# Verify the client was called correctly with env var values | |
mock_models_info.assert_called_once_with( | |
base_url="http://localhost:4000", api_key="sk-test" | |
) | |
mock_models_info.return_value.models.info.assert_called_once() | |
def test_models_info_table_format(mock_models_info, cli_runner): | |
"""Test the models info command with table output format""" | |
# Run the command with default columns | |
result = cli_runner.invoke(cli, ["models", "info"]) | |
# Check that the command succeeded | |
assert result.exit_code == 0 | |
# Verify the output contains expected table elements | |
assert "Public Model" in result.output | |
assert "Upstream Model" in result.output | |
assert "Updated At" in result.output | |
assert "gpt-4" in result.output | |
assert "2025-04-29 21:31" in result.output | |
# Verify seconds and microseconds are not shown | |
assert "21:31:43" not in result.output | |
assert "843000" not in result.output | |
# Verify the client was called correctly with env var values | |
mock_models_info.assert_called_once_with( | |
base_url="http://localhost:4000", api_key="sk-test" | |
) | |
mock_models_info.return_value.models.info.assert_called_once() | |
def test_models_import_only_models_matching_regex(tmp_path, mock_client, cli_runner): | |
"""Test the --only-models-matching-regex option for models import command""" | |
# Prepare a YAML file with a mix of models | |
yaml_content = { | |
"model_list": [ | |
{ | |
"model_name": "gpt-4-model", | |
"litellm_params": {"model": "gpt-4"}, | |
"model_info": {"id": "id-1"}, | |
}, | |
{ | |
"model_name": "gpt-3.5-model", | |
"litellm_params": {"model": "gpt-3.5-turbo"}, | |
"model_info": {"id": "id-2"}, | |
}, | |
{ | |
"model_name": "llama2-model", | |
"litellm_params": {"model": "llama2"}, | |
"model_info": {"id": "id-3"}, | |
}, | |
{ | |
"model_name": "other-model", | |
"litellm_params": {"model": "other"}, | |
"model_info": {"id": "id-4"}, | |
}, | |
] | |
} | |
import yaml as pyyaml | |
yaml_file = tmp_path / "models.yaml" | |
with open(yaml_file, "w") as f: | |
pyyaml.safe_dump(yaml_content, f) | |
# Patch client.models.new to track calls | |
mock_new = mock_client.return_value.models.new | |
# Only match models containing 'gpt' in their litellm_params.model | |
result = cli_runner.invoke( | |
cli, ["models", "import", str(yaml_file), "--only-models-matching-regex", "gpt"] | |
) | |
# Should succeed | |
assert result.exit_code == 0 | |
# Only the two gpt models should be imported | |
calls = [call.kwargs["model_params"]["model"] for call in mock_new.call_args_list] | |
assert set(calls) == {"gpt-4", "gpt-3.5-turbo"} | |
# Should not include llama2 or other | |
assert "llama2" not in calls | |
assert "other" not in calls | |
# Output summary should mention the correct providers | |
assert "gpt-4".split("-")[0] in result.output or "gpt" in result.output | |
def test_models_import_only_access_groups_matching_regex( | |
tmp_path, mock_client, cli_runner | |
): | |
"""Test the --only-access-groups-matching-regex option for models import command""" | |
# Prepare a YAML file with a mix of models | |
yaml_content = { | |
"model_list": [ | |
{ | |
"model_name": "gpt-4-model", | |
"litellm_params": {"model": "gpt-4"}, | |
"model_info": { | |
"id": "id-1", | |
"access_groups": ["beta-models", "prod-models"], | |
}, | |
}, | |
{ | |
"model_name": "gpt-3.5-model", | |
"litellm_params": {"model": "gpt-3.5-turbo"}, | |
"model_info": {"id": "id-2", "access_groups": ["alpha-models"]}, | |
}, | |
{ | |
"model_name": "llama2-model", | |
"litellm_params": {"model": "llama2"}, | |
"model_info": {"id": "id-3", "access_groups": ["beta-models"]}, | |
}, | |
{ | |
"model_name": "other-model", | |
"litellm_params": {"model": "other"}, | |
"model_info": {"id": "id-4", "access_groups": ["other-group"]}, | |
}, | |
{ | |
"model_name": "no-access-group-model", | |
"litellm_params": {"model": "no-access"}, | |
"model_info": {"id": "id-5"}, | |
}, | |
] | |
} | |
import yaml as pyyaml | |
yaml_file = tmp_path / "models.yaml" | |
with open(yaml_file, "w") as f: | |
pyyaml.safe_dump(yaml_content, f) | |
# Patch client.models.new to track calls | |
mock_new = mock_client.return_value.models.new | |
# Only match models with access_groups containing 'beta' | |
result = cli_runner.invoke( | |
cli, | |
[ | |
"models", | |
"import", | |
str(yaml_file), | |
"--only-access-groups-matching-regex", | |
"beta", | |
], | |
) | |
# Should succeed | |
assert result.exit_code == 0 | |
# Only the two models with 'beta-models' in access_groups should be imported | |
calls = [call.kwargs["model_params"]["model"] for call in mock_new.call_args_list] | |
assert set(calls) == {"gpt-4", "llama2"} | |
# Should not include gpt-3.5, other, or no-access | |
assert "gpt-3.5-turbo" not in calls | |
assert "other" not in calls | |
assert "no-access" not in calls | |
# Output summary should mention the correct providers | |
assert "gpt-4".split("-")[0] in result.output or "gpt" in result.output | |
def test_format_iso_datetime_str(input_str, expected): | |
assert format_iso_datetime_str(input_str) == expected | |
def test_format_timestamp(input_val, expected, force_utc_tz): | |
actual = format_timestamp(input_val) | |
if actual != expected: | |
print(f"input: {input_val}, expected: {expected}, actual: {actual}") | |
assert actual == expected | |
def test_format_cost_per_1k_tokens(input_val, expected): | |
actual = format_cost_per_1k_tokens(input_val) | |
if actual != expected: | |
print(f"input: {input_val}, expected: {expected}, actual: {actual}") | |
assert actual == expected | |