File size: 8,199 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import os
import sys
import traceback
from unittest import mock
import pytest

from dotenv import load_dotenv

import litellm.proxy
import litellm.proxy.proxy_server

load_dotenv()
import io
import os

# this file is to test litellm/proxy

sys.path.insert(
    0, os.path.abspath("../..")
)  # Adds the parent directory to the system path
import asyncio
import logging

from litellm.proxy.proxy_server import ProxyConfig

INVALID_FILES = ["config_with_missing_include.yaml"]


@pytest.mark.asyncio
async def test_basic_reading_configs_from_files():
    """
    Test that the config is read correctly from the files in the example_config_yaml folder
    """
    proxy_config_instance = ProxyConfig()
    current_path = os.path.dirname(os.path.abspath(__file__))
    example_config_yaml_path = os.path.join(current_path, "example_config_yaml")

    # get all the files from example_config_yaml
    files = os.listdir(example_config_yaml_path)
    print(files)

    for file in files:
        if file in INVALID_FILES:  # these are intentionally invalid files
            continue
        print("reading file=", file)
        config_path = os.path.join(example_config_yaml_path, file)
        config = await proxy_config_instance.get_config(config_file_path=config_path)
        print(config)


@pytest.mark.asyncio
async def test_read_config_from_bad_file_path():
    """
    Raise an exception if the file path is not valid
    """
    proxy_config_instance = ProxyConfig()
    config_path = "non-existent-file.yaml"
    with pytest.raises(Exception):
        config = await proxy_config_instance.get_config(config_file_path=config_path)


@pytest.mark.asyncio
async def test_read_config_file_with_os_environ_vars():
    """
    Ensures os.environ variables are read correctly from config.yaml
    Following vars are set as os.environ variables in the config.yaml file
    - DEFAULT_USER_ROLE
    - AWS_ACCESS_KEY_ID
    - AWS_SECRET_ACCESS_KEY
    - AZURE_GPT_4O
    - FIREWORKS
    """

    _env_vars_for_testing = {
        "DEFAULT_USER_ROLE": "admin",
        "AWS_ACCESS_KEY_ID": "1234567890",
        "AWS_SECRET_ACCESS_KEY": "1234567890",
        "AZURE_GPT_4O": "1234567890",
        "FIREWORKS": "1234567890",
    }

    _old_env_vars = {}
    for key, value in _env_vars_for_testing.items():
        if key in os.environ:
            _old_env_vars[key] = os.environ.get(key)
        os.environ[key] = value

    # Read config
    proxy_config_instance = ProxyConfig()
    current_path = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(
        current_path, "example_config_yaml", "config_with_env_vars.yaml"
    )
    config = await proxy_config_instance.get_config(config_file_path=config_path)
    print(config)

    # Add assertions
    assert (
        config["litellm_settings"]["default_internal_user_params"]["user_role"]
        == "admin"
    )
    assert (
        config["litellm_settings"]["s3_callback_params"]["s3_aws_access_key_id"]
        == "1234567890"
    )
    assert (
        config["litellm_settings"]["s3_callback_params"]["s3_aws_secret_access_key"]
        == "1234567890"
    )

    for model in config["model_list"]:
        if "azure" in model["litellm_params"]["model"]:
            assert model["litellm_params"]["api_key"] == "1234567890"
        elif "fireworks" in model["litellm_params"]["model"]:
            assert model["litellm_params"]["api_key"] == "1234567890"

    # cleanup
    for key, value in _env_vars_for_testing.items():
        if key in _old_env_vars:
            os.environ[key] = _old_env_vars[key]
        else:
            del os.environ[key]


@pytest.mark.asyncio
async def test_basic_include_directive():
    """
    Test that the include directive correctly loads and merges configs
    """
    proxy_config_instance = ProxyConfig()
    current_path = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(
        current_path, "example_config_yaml", "config_with_include.yaml"
    )

    config = await proxy_config_instance.get_config(config_file_path=config_path)

    # Verify the included model list was merged
    assert len(config["model_list"]) > 0
    assert any(
        model["model_name"] == "included-model" for model in config["model_list"]
    )

    # Verify original config settings remain
    assert config["litellm_settings"]["callbacks"] == ["prometheus"]


@pytest.mark.asyncio
async def test_missing_include_file():
    """
    Test that a missing included file raises FileNotFoundError
    """
    proxy_config_instance = ProxyConfig()
    current_path = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(
        current_path, "example_config_yaml", "config_with_missing_include.yaml"
    )

    with pytest.raises(FileNotFoundError):
        await proxy_config_instance.get_config(config_file_path=config_path)


@pytest.mark.asyncio
async def test_multiple_includes():
    """
    Test that multiple files in the include list are all processed correctly
    """
    proxy_config_instance = ProxyConfig()
    current_path = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(
        current_path, "example_config_yaml", "config_with_multiple_includes.yaml"
    )

    config = await proxy_config_instance.get_config(config_file_path=config_path)

    # Verify models from both included files are present
    assert len(config["model_list"]) == 2
    assert any(
        model["model_name"] == "included-model-1" for model in config["model_list"]
    )
    assert any(
        model["model_name"] == "included-model-2" for model in config["model_list"]
    )

    # Verify original config settings remain
    assert config["litellm_settings"]["callbacks"] == ["prometheus"]


def test_add_callbacks_from_db_config():
    """Test that callbacks are added correctly and duplicates are prevented"""
    # Setup
    from litellm.integrations.langfuse.langfuse_prompt_management import (
        LangfusePromptManagement,
    )

    proxy_config = ProxyConfig()

    # Reset litellm callbacks before test
    litellm.success_callback = []
    litellm.failure_callback = []

    # Test Case 1: Add new callbacks
    config_data = {
        "litellm_settings": {
            "success_callback": ["langfuse", "custom_callback_api"],
            "failure_callback": ["langfuse"],
        }
    }

    proxy_config._add_callbacks_from_db_config(config_data)

    # 1 instance of LangfusePromptManagement should exist in litellm.success_callback
    num_langfuse_instances = sum(
        isinstance(callback, LangfusePromptManagement)
        for callback in litellm.success_callback
    )
    assert num_langfuse_instances == 1
    assert len(litellm.success_callback) == 2
    assert len(litellm.failure_callback) == 1

    # Test Case 2: Try adding duplicate callbacks
    proxy_config._add_callbacks_from_db_config(config_data)

    # Verify no duplicates were added
    assert len(litellm.success_callback) == 2
    assert len(litellm.failure_callback) == 1

    # Cleanup
    litellm.success_callback = []
    litellm.failure_callback = []
    litellm._known_custom_logger_compatible_callbacks = []


def test_add_callbacks_invalid_input():
    """Test handling of invalid input for callbacks"""
    proxy_config = ProxyConfig()

    # Reset callbacks
    litellm.success_callback = []
    litellm.failure_callback = []

    # Test Case 1: Invalid callback format
    config_data = {
        "litellm_settings": {
            "success_callback": "invalid_string_format",  # Should be a list
            "failure_callback": 123,  # Should be a list
        }
    }

    proxy_config._add_callbacks_from_db_config(config_data)

    # Verify no callbacks were added with invalid input
    assert len(litellm.success_callback) == 0
    assert len(litellm.failure_callback) == 0

    # Test Case 2: Missing litellm_settings
    config_data = {}
    proxy_config._add_callbacks_from_db_config(config_data)

    # Verify no callbacks were added
    assert len(litellm.success_callback) == 0
    assert len(litellm.failure_callback) == 0

    # Cleanup
    litellm.success_callback = []
    litellm.failure_callback = []