Spaces:
Configuration error
Configuration error
import asyncio | |
import json | |
import os | |
import sys | |
from unittest.mock import AsyncMock, MagicMock, patch | |
sys.path.insert( | |
0, os.path.abspath("../../..") | |
) # Adds the parent directory to the system path | |
from datetime import datetime, timedelta | |
import pytest | |
import litellm | |
from litellm.proxy._types import ( | |
LiteLLM_ObjectPermissionTable, | |
LiteLLM_TeamTable, | |
LiteLLM_UserTable, | |
LitellmUserRoles, | |
ProxyErrorTypes, | |
ProxyException, | |
SSOUserDefinedValues, | |
UserAPIKeyAuth, | |
) | |
from litellm.proxy.auth.auth_checks import ( | |
ExperimentalUIJWTToken, | |
_can_object_call_vector_stores, | |
get_user_object, | |
vector_store_access_check, | |
) | |
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper | |
from litellm.utils import get_utc_datetime | |
def set_salt_key(monkeypatch): | |
"""Automatically set LITELLM_SALT_KEY for all tests""" | |
monkeypatch.setenv("LITELLM_SALT_KEY", "sk-1234") | |
def valid_sso_user_defined_values(): | |
return LiteLLM_UserTable( | |
user_id="test_user", | |
user_email="[email protected]", | |
user_role=LitellmUserRoles.PROXY_ADMIN.value, | |
models=["gpt-3.5-turbo"], | |
max_budget=100.0, | |
) | |
def invalid_sso_user_defined_values(): | |
return LiteLLM_UserTable( | |
user_id="test_user", | |
user_email="[email protected]", | |
user_role=None, # Missing user role | |
models=["gpt-3.5-turbo"], | |
max_budget=100.0, | |
) | |
def test_get_experimental_ui_login_jwt_auth_token_valid(valid_sso_user_defined_values): | |
"""Test generating JWT token with valid user role""" | |
token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token( | |
valid_sso_user_defined_values | |
) | |
# Decrypt and verify token contents | |
decrypted_token = decrypt_value_helper(token, exception_type="debug") | |
# Check that decrypted_token is not None before using json.loads | |
assert decrypted_token is not None | |
token_data = json.loads(decrypted_token) | |
assert token_data["user_id"] == "test_user" | |
assert token_data["user_role"] == LitellmUserRoles.PROXY_ADMIN.value | |
assert token_data["models"] == ["gpt-3.5-turbo"] | |
assert token_data["max_budget"] == litellm.max_ui_session_budget | |
# Verify expiration time is set and valid | |
assert "expires" in token_data | |
expires = datetime.fromisoformat(token_data["expires"].replace("Z", "+00:00")) | |
assert expires > get_utc_datetime() | |
assert expires <= get_utc_datetime() + timedelta(minutes=10) | |
def test_get_experimental_ui_login_jwt_auth_token_invalid( | |
invalid_sso_user_defined_values, | |
): | |
"""Test generating JWT token with missing user role""" | |
with pytest.raises(Exception) as exc_info: | |
ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token( | |
invalid_sso_user_defined_values | |
) | |
assert str(exc_info.value) == "User role is required for experimental UI login" | |
def test_get_key_object_from_ui_hash_key_valid( | |
valid_sso_user_defined_values, monkeypatch | |
): | |
"""Test getting key object from valid UI hash key""" | |
monkeypatch.setenv("EXPERIMENTAL_UI_LOGIN", "True") | |
# Generate a valid token | |
token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token( | |
valid_sso_user_defined_values | |
) | |
# Get key object | |
key_object = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(token) | |
assert key_object is not None | |
assert key_object.user_id == "test_user" | |
assert key_object.user_role == LitellmUserRoles.PROXY_ADMIN | |
assert key_object.models == ["gpt-3.5-turbo"] | |
assert key_object.max_budget == litellm.max_ui_session_budget | |
def test_get_key_object_from_ui_hash_key_invalid(): | |
"""Test getting key object from invalid UI hash key""" | |
# Test with invalid token | |
key_object = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key("invalid_token") | |
assert key_object is None | |
async def test_default_internal_user_params_with_get_user_object(monkeypatch): | |
"""Test that default_internal_user_params is used when creating a new user via get_user_object""" | |
# Set up default_internal_user_params | |
default_params = { | |
"models": ["gpt-4", "claude-3-opus"], | |
"max_budget": 200.0, | |
"user_role": "internal_user", | |
} | |
monkeypatch.setattr(litellm, "default_internal_user_params", default_params) | |
# Mock the necessary dependencies | |
mock_prisma_client = MagicMock() | |
mock_db = AsyncMock() | |
mock_prisma_client.db = mock_db | |
# Set up the user creation mock - create a complete user model that can be converted to a dict | |
mock_user = MagicMock() | |
mock_user.user_id = "new_test_user" | |
mock_user.models = ["gpt-4", "claude-3-opus"] | |
mock_user.max_budget = 200.0 | |
mock_user.user_role = "internal_user" | |
mock_user.organization_memberships = [] | |
# Make the mock model_dump or dict method return appropriate data | |
mock_user.dict = lambda: { | |
"user_id": "new_test_user", | |
"models": ["gpt-4", "claude-3-opus"], | |
"max_budget": 200.0, | |
"user_role": "internal_user", | |
"organization_memberships": [], | |
} | |
# Setup the mock returns | |
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None) | |
mock_prisma_client.db.litellm_usertable.create = AsyncMock(return_value=mock_user) | |
# Create a mock cache - use AsyncMock for async methods | |
mock_cache = MagicMock() | |
mock_cache.async_get_cache = AsyncMock(return_value=None) | |
mock_cache.async_set_cache = AsyncMock() | |
# Call get_user_object with user_id_upsert=True to trigger user creation | |
try: | |
user_obj = await get_user_object( | |
user_id="new_test_user", | |
prisma_client=mock_prisma_client, | |
user_api_key_cache=mock_cache, | |
user_id_upsert=True, | |
proxy_logging_obj=None, | |
) | |
except Exception as e: | |
# this fails since the mock object is a MagicMock and not a LiteLLM_UserTable | |
print(e) | |
# Verify the user was created with the default params | |
mock_prisma_client.db.litellm_usertable.create.assert_called_once() | |
creation_args = mock_prisma_client.db.litellm_usertable.create.call_args[1]["data"] | |
# Verify defaults were applied to the creation args | |
assert "models" in creation_args | |
assert creation_args["models"] == ["gpt-4", "claude-3-opus"] | |
assert creation_args["max_budget"] == 200.0 | |
assert creation_args["user_role"] == "internal_user" | |
# Vector Store Auth Check Tests | |
async def test_vector_store_access_check_early_returns( | |
prisma_client, vector_store_registry, expected_result | |
): | |
"""Test vector_store_access_check returns True for early exit conditions""" | |
request_body = {"messages": [{"role": "user", "content": "test"}]} | |
if vector_store_registry: | |
vector_store_registry.get_vector_store_ids_to_run.return_value = None | |
with patch("litellm.proxy.proxy_server.prisma_client", prisma_client), patch( | |
"litellm.vector_store_registry", vector_store_registry | |
): | |
result = await vector_store_access_check( | |
request_body=request_body, | |
team_object=None, | |
valid_token=None, | |
) | |
assert result == expected_result | |
def test_can_object_call_vector_stores_scenarios( | |
object_permissions, vector_store_ids, should_raise, error_type | |
): | |
"""Test _can_object_call_vector_stores with various permission scenarios""" | |
# Convert dict to object if not None | |
if object_permissions is not None: | |
mock_permissions = MagicMock() | |
mock_permissions.vector_stores = object_permissions["vector_stores"] | |
object_permissions = mock_permissions | |
object_type = ( | |
"key" | |
if error_type == ProxyErrorTypes.key_vector_store_access_denied | |
else "team" | |
) | |
if should_raise: | |
with pytest.raises(ProxyException) as exc_info: | |
_can_object_call_vector_stores( | |
object_type=object_type, | |
vector_store_ids_to_run=vector_store_ids, | |
object_permissions=object_permissions, | |
) | |
assert exc_info.value.type == error_type | |
else: | |
result = _can_object_call_vector_stores( | |
object_type=object_type, | |
vector_store_ids_to_run=vector_store_ids, | |
object_permissions=object_permissions, | |
) | |
assert result is True | |
async def test_vector_store_access_check_with_permissions(): | |
"""Test vector_store_access_check with actual permission checking""" | |
request_body = {"tools": [{"type": "function", "function": {"name": "test"}}]} | |
# Test with valid token that has access | |
valid_token = UserAPIKeyAuth( | |
token="test-token", | |
object_permission_id="perm-123", | |
models=["gpt-4"], | |
max_budget=100.0, | |
) | |
mock_prisma_client = MagicMock() | |
mock_permissions = MagicMock() | |
mock_permissions.vector_stores = ["store-1", "store-2"] | |
mock_prisma_client.db.litellm_objectpermissiontable.find_unique = AsyncMock( | |
return_value=mock_permissions | |
) | |
mock_vector_store_registry = MagicMock() | |
mock_vector_store_registry.get_vector_store_ids_to_run.return_value = ["store-1"] | |
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch( | |
"litellm.vector_store_registry", mock_vector_store_registry | |
): | |
result = await vector_store_access_check( | |
request_body=request_body, | |
team_object=None, | |
valid_token=valid_token, | |
) | |
assert result is True | |
# Test with denied access | |
mock_vector_store_registry.get_vector_store_ids_to_run.return_value = ["store-3"] | |
with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch( | |
"litellm.vector_store_registry", mock_vector_store_registry | |
): | |
with pytest.raises(ProxyException) as exc_info: | |
await vector_store_access_check( | |
request_body=request_body, | |
team_object=None, | |
valid_token=valid_token, | |
) | |
assert exc_info.value.type == ProxyErrorTypes.key_vector_store_access_denied | |