Spaces:
Configuration error
Configuration error
import asyncio | |
import copy | |
import json | |
import os | |
import sys | |
from unittest.mock import MagicMock, patch | |
import pytest | |
from fastapi import Request | |
from litellm.proxy._types import UserAPIKeyAuth | |
from litellm.proxy.litellm_pre_call_utils import ( | |
_get_enforced_params, | |
add_litellm_data_to_request, | |
check_if_token_is_service_account, | |
) | |
sys.path.insert( | |
0, os.path.abspath("../../..") | |
) # Adds the parent directory to the system path | |
def test_check_if_token_is_service_account(): | |
""" | |
Test that only keys with `service_account_id` in metadata are considered service accounts | |
""" | |
# Test case 1: Service account token | |
service_account_token = UserAPIKeyAuth( | |
api_key="test-key", metadata={"service_account_id": "test-service-account"} | |
) | |
assert check_if_token_is_service_account(service_account_token) == True | |
# Test case 2: Regular user token | |
regular_token = UserAPIKeyAuth(api_key="test-key", metadata={}) | |
assert check_if_token_is_service_account(regular_token) == False | |
# Test case 3: Token with other metadata | |
other_metadata_token = UserAPIKeyAuth( | |
api_key="test-key", metadata={"user_id": "test-user"} | |
) | |
assert check_if_token_is_service_account(other_metadata_token) == False | |
def test_get_enforced_params_for_service_account_settings(): | |
""" | |
Test that service account enforced params are only added to service account keys | |
""" | |
service_account_token = UserAPIKeyAuth( | |
api_key="test-key", metadata={"service_account_id": "test-service-account"} | |
) | |
general_settings_with_service_account_settings = { | |
"service_account_settings": {"enforced_params": ["metadata.service"]}, | |
} | |
result = _get_enforced_params( | |
general_settings=general_settings_with_service_account_settings, | |
user_api_key_dict=service_account_token, | |
) | |
assert result == ["metadata.service"] | |
regular_token = UserAPIKeyAuth( | |
api_key="test-key", metadata={"enforced_params": ["user"]} | |
) | |
result = _get_enforced_params( | |
general_settings=general_settings_with_service_account_settings, | |
user_api_key_dict=regular_token, | |
) | |
assert result == ["user"] | |
def test_get_enforced_params( | |
general_settings, user_api_key_dict, expected_enforced_params | |
): | |
from litellm.proxy.litellm_pre_call_utils import _get_enforced_params | |
enforced_params = _get_enforced_params(general_settings, user_api_key_dict) | |
assert enforced_params == expected_enforced_params | |
async def test_add_litellm_data_to_request_parses_string_metadata(): | |
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request | |
# Setup | |
request_mock = MagicMock(spec=Request) | |
request_mock.url.path = "/v1/completions" | |
request_mock.url = MagicMock() | |
request_mock.url.__str__.return_value = "http://localhost/v1/completions" | |
request_mock.method = "POST" | |
request_mock.query_params = {} | |
request_mock.headers = {"Content-Type": "application/json"} | |
request_mock.client = MagicMock() | |
request_mock.client.host = "127.0.0.1" | |
# Simulate data with stringified metadata | |
fake_metadata = {"generation_name": "gen123"} | |
data = {"metadata": json.dumps(fake_metadata), "model": "gpt-3.5-turbo"} | |
user_api_key_dict = UserAPIKeyAuth( | |
api_key="hashed-key", | |
metadata={}, | |
team_metadata={}, | |
spend=0.0, | |
max_budget=100.0, | |
model_max_budget={}, # this one can be a dict | |
team_spend=0.0, | |
team_max_budget=200.0, | |
) | |
# Call | |
updated_data = await add_litellm_data_to_request( | |
data=data, | |
request=request_mock, | |
user_api_key_dict=user_api_key_dict, | |
proxy_config=MagicMock(), | |
general_settings={}, | |
version="test-version", | |
) | |
# Assert | |
litellm_metadata = updated_data.get("metadata", {}) | |
assert isinstance(litellm_metadata, dict) | |
assert updated_data["metadata"]["generation_name"] == "gen123" | |
async def test_add_litellm_data_to_request_audio_transcription_multipart(): | |
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request | |
# Setup request mock for /v1/audio/transcriptions | |
request_mock = MagicMock(spec=Request) | |
request_mock.url.path = "/v1/audio/transcriptions" | |
request_mock.url = MagicMock() | |
request_mock.url.__str__.return_value = "http://localhost/v1/audio/transcriptions" | |
request_mock.method = "POST" | |
request_mock.query_params = {} | |
request_mock.headers = { | |
"Content-Type": "multipart/form-data", | |
"Authorization": "Bearer sk-1234", | |
} | |
request_mock.client = MagicMock() | |
request_mock.client.host = "127.0.0.1" | |
# Simulate multipart data (metadata as string) | |
metadata_dict = { | |
"tags": ["jobID:214590dsff09fds", "taskName:run_page_classification"] | |
} | |
stringified_metadata = json.dumps(metadata_dict) | |
data = { | |
"model": "fake-openai-endpoint", | |
"metadata": stringified_metadata, # Simulating multipart-form field | |
"file": b"Fake audio bytes", | |
} | |
user_api_key_dict = UserAPIKeyAuth( | |
api_key="hashed-key", | |
metadata={}, | |
team_metadata={}, | |
spend=0.0, | |
max_budget=100.0, | |
model_max_budget={}, | |
team_spend=0.0, | |
team_max_budget=200.0, | |
) | |
updated_data = await add_litellm_data_to_request( | |
data=data, | |
request=request_mock, | |
user_api_key_dict=user_api_key_dict, | |
proxy_config=MagicMock(), | |
general_settings={}, | |
version="test-version", | |
) | |
# Assert metadata was parsed correctly | |
metadata_field = updated_data.get("metadata", {}) | |
litellm_metadata = updated_data.get("litellm_metadata", {}) | |
assert isinstance(metadata_field, dict) | |
assert "tags" in metadata_field | |
assert metadata_field["tags"] == [ | |
"jobID:214590dsff09fds", | |
"taskName:run_page_classification", | |
] | |