test3 / tests /proxy_unit_tests /test_proxy_config_unit_test.py
DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
import os
import sys
import traceback
from unittest import mock
import pytest
from dotenv import load_dotenv
import litellm.proxy
import litellm.proxy.proxy_server
load_dotenv()
import io
import os
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import logging
from litellm.proxy.proxy_server import ProxyConfig
INVALID_FILES = ["config_with_missing_include.yaml"]
@pytest.mark.asyncio
async def test_basic_reading_configs_from_files():
"""
Test that the config is read correctly from the files in the example_config_yaml folder
"""
proxy_config_instance = ProxyConfig()
current_path = os.path.dirname(os.path.abspath(__file__))
example_config_yaml_path = os.path.join(current_path, "example_config_yaml")
# get all the files from example_config_yaml
files = os.listdir(example_config_yaml_path)
print(files)
for file in files:
if file in INVALID_FILES: # these are intentionally invalid files
continue
print("reading file=", file)
config_path = os.path.join(example_config_yaml_path, file)
config = await proxy_config_instance.get_config(config_file_path=config_path)
print(config)
@pytest.mark.asyncio
async def test_read_config_from_bad_file_path():
"""
Raise an exception if the file path is not valid
"""
proxy_config_instance = ProxyConfig()
config_path = "non-existent-file.yaml"
with pytest.raises(Exception):
config = await proxy_config_instance.get_config(config_file_path=config_path)
@pytest.mark.asyncio
async def test_read_config_file_with_os_environ_vars():
"""
Ensures os.environ variables are read correctly from config.yaml
Following vars are set as os.environ variables in the config.yaml file
- DEFAULT_USER_ROLE
- AWS_ACCESS_KEY_ID
- AWS_SECRET_ACCESS_KEY
- AZURE_GPT_4O
- FIREWORKS
"""
_env_vars_for_testing = {
"DEFAULT_USER_ROLE": "admin",
"AWS_ACCESS_KEY_ID": "1234567890",
"AWS_SECRET_ACCESS_KEY": "1234567890",
"AZURE_GPT_4O": "1234567890",
"FIREWORKS": "1234567890",
}
_old_env_vars = {}
for key, value in _env_vars_for_testing.items():
if key in os.environ:
_old_env_vars[key] = os.environ.get(key)
os.environ[key] = value
# Read config
proxy_config_instance = ProxyConfig()
current_path = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(
current_path, "example_config_yaml", "config_with_env_vars.yaml"
)
config = await proxy_config_instance.get_config(config_file_path=config_path)
print(config)
# Add assertions
assert (
config["litellm_settings"]["default_internal_user_params"]["user_role"]
== "admin"
)
assert (
config["litellm_settings"]["s3_callback_params"]["s3_aws_access_key_id"]
== "1234567890"
)
assert (
config["litellm_settings"]["s3_callback_params"]["s3_aws_secret_access_key"]
== "1234567890"
)
for model in config["model_list"]:
if "azure" in model["litellm_params"]["model"]:
assert model["litellm_params"]["api_key"] == "1234567890"
elif "fireworks" in model["litellm_params"]["model"]:
assert model["litellm_params"]["api_key"] == "1234567890"
# cleanup
for key, value in _env_vars_for_testing.items():
if key in _old_env_vars:
os.environ[key] = _old_env_vars[key]
else:
del os.environ[key]
@pytest.mark.asyncio
async def test_basic_include_directive():
"""
Test that the include directive correctly loads and merges configs
"""
proxy_config_instance = ProxyConfig()
current_path = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(
current_path, "example_config_yaml", "config_with_include.yaml"
)
config = await proxy_config_instance.get_config(config_file_path=config_path)
# Verify the included model list was merged
assert len(config["model_list"]) > 0
assert any(
model["model_name"] == "included-model" for model in config["model_list"]
)
# Verify original config settings remain
assert config["litellm_settings"]["callbacks"] == ["prometheus"]
@pytest.mark.asyncio
async def test_missing_include_file():
"""
Test that a missing included file raises FileNotFoundError
"""
proxy_config_instance = ProxyConfig()
current_path = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(
current_path, "example_config_yaml", "config_with_missing_include.yaml"
)
with pytest.raises(FileNotFoundError):
await proxy_config_instance.get_config(config_file_path=config_path)
@pytest.mark.asyncio
async def test_multiple_includes():
"""
Test that multiple files in the include list are all processed correctly
"""
proxy_config_instance = ProxyConfig()
current_path = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(
current_path, "example_config_yaml", "config_with_multiple_includes.yaml"
)
config = await proxy_config_instance.get_config(config_file_path=config_path)
# Verify models from both included files are present
assert len(config["model_list"]) == 2
assert any(
model["model_name"] == "included-model-1" for model in config["model_list"]
)
assert any(
model["model_name"] == "included-model-2" for model in config["model_list"]
)
# Verify original config settings remain
assert config["litellm_settings"]["callbacks"] == ["prometheus"]
def test_add_callbacks_from_db_config():
"""Test that callbacks are added correctly and duplicates are prevented"""
# Setup
from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
proxy_config = ProxyConfig()
# Reset litellm callbacks before test
litellm.success_callback = []
litellm.failure_callback = []
# Test Case 1: Add new callbacks
config_data = {
"litellm_settings": {
"success_callback": ["langfuse", "custom_callback_api"],
"failure_callback": ["langfuse"],
}
}
proxy_config._add_callbacks_from_db_config(config_data)
# 1 instance of LangfusePromptManagement should exist in litellm.success_callback
num_langfuse_instances = sum(
isinstance(callback, LangfusePromptManagement)
for callback in litellm.success_callback
)
assert num_langfuse_instances == 1
assert len(litellm.success_callback) == 2
assert len(litellm.failure_callback) == 1
# Test Case 2: Try adding duplicate callbacks
proxy_config._add_callbacks_from_db_config(config_data)
# Verify no duplicates were added
assert len(litellm.success_callback) == 2
assert len(litellm.failure_callback) == 1
# Cleanup
litellm.success_callback = []
litellm.failure_callback = []
litellm._known_custom_logger_compatible_callbacks = []
def test_add_callbacks_invalid_input():
"""Test handling of invalid input for callbacks"""
proxy_config = ProxyConfig()
# Reset callbacks
litellm.success_callback = []
litellm.failure_callback = []
# Test Case 1: Invalid callback format
config_data = {
"litellm_settings": {
"success_callback": "invalid_string_format", # Should be a list
"failure_callback": 123, # Should be a list
}
}
proxy_config._add_callbacks_from_db_config(config_data)
# Verify no callbacks were added with invalid input
assert len(litellm.success_callback) == 0
assert len(litellm.failure_callback) == 0
# Test Case 2: Missing litellm_settings
config_data = {}
proxy_config._add_callbacks_from_db_config(config_data)
# Verify no callbacks were added
assert len(litellm.success_callback) == 0
assert len(litellm.failure_callback) == 0
# Cleanup
litellm.success_callback = []
litellm.failure_callback = []