File size: 11,321 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import os
import sys
from unittest.mock import MagicMock, patch

import pytest

sys.path.insert(
    0, os.path.abspath("../../..")
)  # Adds the parent directory to the system-path

from litellm.proxy.proxy_cli import ProxyInitializationHelpers


class TestProxyInitializationHelpers:
    @patch("importlib.metadata.version")
    @patch("click.echo")
    def test_echo_litellm_version(self, mock_echo, mock_version):
        # Setup
        mock_version.return_value = "1.0.0"

        # Execute
        ProxyInitializationHelpers._echo_litellm_version()

        # Assert
        mock_version.assert_called_once_with("litellm")
        mock_echo.assert_called_once_with("\nLiteLLM: Current Version = 1.0.0\n")

    @patch("httpx.get")
    @patch("builtins.print")
    @patch("json.dumps")
    def test_run_health_check(self, mock_dumps, mock_print, mock_get):
        # Setup
        mock_response = MagicMock()
        mock_response.json.return_value = {"status": "healthy"}
        mock_get.return_value = mock_response
        mock_dumps.return_value = '{"status": "healthy"}'

        # Execute
        ProxyInitializationHelpers._run_health_check("localhost", 8000)

        # Assert
        mock_get.assert_called_once_with(url="http://localhost:8000/health")
        mock_response.json.assert_called_once()
        mock_dumps.assert_called_once_with({"status": "healthy"}, indent=4)

    @patch("openai.OpenAI")
    @patch("click.echo")
    @patch("builtins.print")
    def test_run_test_chat_completion(self, mock_print, mock_echo, mock_openai):
        # Setup
        mock_client = MagicMock()
        mock_openai.return_value = mock_client

        mock_response = MagicMock()
        mock_client.chat.completions.create.return_value = mock_response

        mock_stream_response = MagicMock()
        mock_stream_response.__iter__.return_value = [MagicMock(), MagicMock()]
        mock_client.chat.completions.create.side_effect = [
            mock_response,
            mock_stream_response,
        ]

        # Execute
        with pytest.raises(ValueError, match="Invalid test value"):
            ProxyInitializationHelpers._run_test_chat_completion(
                "localhost", 8000, "gpt-3.5-turbo", True
            )

        # Test with valid string test value
        ProxyInitializationHelpers._run_test_chat_completion(
            "localhost", 8000, "gpt-3.5-turbo", "http://test-url"
        )

        # Assert
        mock_openai.assert_called_once_with(
            api_key="My API Key", base_url="http://test-url"
        )
        mock_client.chat.completions.create.assert_called()

    def test_get_default_unvicorn_init_args(self):
        # Test without log_config
        args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
            "localhost", 8000
        )
        assert args["app"] == "litellm.proxy.proxy_server:app"
        assert args["host"] == "localhost"
        assert args["port"] == 8000

        # Test with log_config
        args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
            "localhost", 8000, "log_config.json"
        )
        assert args["log_config"] == "log_config.json"

        # Test with json_logs=True
        with patch("litellm.json_logs", True):
            args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
                "localhost", 8000
            )
            assert args["log_config"] is None

        # Test with keepalive_timeout
        args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
            "localhost", 8000, None, 60
        )
        assert args["timeout_keep_alive"] == 60

        # Test with both log_config and keepalive_timeout
        args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
            "localhost", 8000, "log_config.json", 120
        )
        assert args["log_config"] == "log_config.json"
        assert args["timeout_keep_alive"] == 120

    @patch("asyncio.run")
    @patch("builtins.print")
    def test_init_hypercorn_server(self, mock_print, mock_asyncio_run):
        # Setup
        mock_app = MagicMock()

        # Execute
        ProxyInitializationHelpers._init_hypercorn_server(
            mock_app, "localhost", 8000, None, None
        )

        # Assert
        mock_asyncio_run.assert_called_once()

        # Test with SSL
        ProxyInitializationHelpers._init_hypercorn_server(
            mock_app, "localhost", 8000, "cert.pem", "key.pem"
        )

    @patch("subprocess.Popen")
    def test_run_ollama_serve(self, mock_popen):
        # Execute
        ProxyInitializationHelpers._run_ollama_serve()

        # Assert
        mock_popen.assert_called_once()

        # Test exception handling
        mock_popen.side_effect = Exception("Test exception")
        ProxyInitializationHelpers._run_ollama_serve()  # Should not raise

    @patch("socket.socket")
    def test_is_port_in_use(self, mock_socket):
        # Setup for port in use
        mock_socket_instance = MagicMock()
        mock_socket_instance.connect_ex.return_value = 0
        mock_socket.return_value.__enter__.return_value = mock_socket_instance

        # Execute and Assert
        assert ProxyInitializationHelpers._is_port_in_use(8000) is True

        # Setup for port not in use
        mock_socket_instance.connect_ex.return_value = 1

        # Execute and Assert
        assert ProxyInitializationHelpers._is_port_in_use(8000) is False

    def test_get_loop_type(self):
        # Test on Windows
        with patch("sys.platform", "win32"):
            assert ProxyInitializationHelpers._get_loop_type() is None

        # Test on Linux
        with patch("sys.platform", "linux"):
            assert ProxyInitializationHelpers._get_loop_type() == "uvloop"

    @patch.dict(os.environ, {}, clear=True)
    def test_database_url_construction_with_special_characters(self):
        # Setup environment variables with special characters that need escaping
        test_env = {
            "DATABASE_HOST": "localhost:5432",
            "DATABASE_USERNAME": "user@with+special",
            "DATABASE_PASSWORD": "pass&word!@#$%",
            "DATABASE_NAME": "db_name/test",
        }

        with patch.dict(os.environ, test_env):
            # Call the relevant function - we'll need to extract the database URL construction logic
            # This is simulating what happens in the run_server function when database_url is None
            import urllib.parse

            from litellm.proxy.proxy_cli import append_query_params

            database_host = os.environ["DATABASE_HOST"]
            database_username = os.environ["DATABASE_USERNAME"]
            database_password = os.environ["DATABASE_PASSWORD"]
            database_name = os.environ["DATABASE_NAME"]

            # Test the URL encoding part
            database_username_enc = urllib.parse.quote_plus(database_username)
            database_password_enc = urllib.parse.quote_plus(database_password)
            database_name_enc = urllib.parse.quote_plus(database_name)

            # Construct DATABASE_URL from the provided variables
            database_url = f"postgresql://{database_username_enc}:{database_password_enc}@{database_host}/{database_name_enc}"

            # Assert the correct URL was constructed with properly escaped characters
            expected_url = "postgresql://user%40with%2Bspecial:pass%26word%21%40%23%24%25@localhost:5432/db_name%2Ftest"
            assert database_url == expected_url

            # Test appending query parameters
            params = {"connection_limit": 10, "pool_timeout": 60}
            modified_url = append_query_params(database_url, params)
            assert "connection_limit=10" in modified_url
            assert "pool_timeout=60" in modified_url

    @patch("uvicorn.run")
    @patch("builtins.print")
    def test_skip_server_startup(self, mock_print, mock_uvicorn_run):
        """Test that the skip_server_startup flag prevents server startup when True"""
        from click.testing import CliRunner

        from litellm.proxy.proxy_cli import run_server

        runner = CliRunner()

        mock_app = MagicMock()
        mock_proxy_config = MagicMock()
        mock_key_mgmt = MagicMock()
        mock_save_worker_config = MagicMock()

        with patch.dict(
            "sys.modules",
            {
                "proxy_server": MagicMock(
                    app=mock_app,
                    ProxyConfig=mock_proxy_config,
                    KeyManagementSettings=mock_key_mgmt,
                    save_worker_config=mock_save_worker_config,
                )
            },
        ), patch(
            "litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
        ) as mock_get_args:
            mock_get_args.return_value = {
                "app": "litellm.proxy.proxy_server:app",
                "host": "localhost",
                "port": 8000,
            }

            result = runner.invoke(run_server, ["--local", "--skip_server_startup"])

            assert result.exit_code == 0
            mock_uvicorn_run.assert_not_called()
            mock_print.assert_any_call(
                "LiteLLM: Setup complete. Skipping server startup as requested."
            )

            mock_uvicorn_run.reset_mock()
            mock_print.reset_mock()

            result = runner.invoke(run_server, ["--local"])

            assert result.exit_code == 0
            mock_uvicorn_run.assert_called_once()

    @patch("uvicorn.run")
    @patch("builtins.print")
    def test_keepalive_timeout_flag(self, mock_print, mock_uvicorn_run):
        """Test that the keepalive_timeout flag is properly passed to uvicorn"""
        from click.testing import CliRunner

        from litellm.proxy.proxy_cli import run_server

        runner = CliRunner()

        mock_app = MagicMock()
        mock_proxy_config = MagicMock()
        mock_key_mgmt = MagicMock()
        mock_save_worker_config = MagicMock()

        with patch.dict(
            "sys.modules",
            {
                "proxy_server": MagicMock(
                    app=mock_app,
                    ProxyConfig=mock_proxy_config,
                    KeyManagementSettings=mock_key_mgmt,
                    save_worker_config=mock_save_worker_config,
                )
            },
        ), patch(
            "litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
        ) as mock_get_args:
            mock_get_args.return_value = {
                "app": "litellm.proxy.proxy_server:app",
                "host": "localhost",
                "port": 8000,
                "timeout_keep_alive": 30,
            }

            result = runner.invoke(run_server, ["--local", "--keepalive_timeout", "30"])

            assert result.exit_code == 0
            mock_get_args.assert_called_once_with(
                host="0.0.0.0",
                port=4000,
                log_config=None,
                keepalive_timeout=30,
            )
            mock_uvicorn_run.assert_called_once()
            
            # Check that the uvicorn.run was called with the timeout_keep_alive parameter
            call_args = mock_uvicorn_run.call_args
            assert call_args[1]["timeout_keep_alive"] == 30