File size: 7,353 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import io
import os
import sys


sys.path.insert(0, os.path.abspath("../.."))

import asyncio
import gzip
import json
import logging
import time
from unittest.mock import AsyncMock, patch

import pytest

import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.proxy.utils import log_db_metrics, ServiceTypes
from datetime import datetime
import httpx
from prisma.errors import ClientNotConnectedError


# Test async function to decorate
@log_db_metrics
async def sample_db_function(*args, **kwargs):
    return "success"


@log_db_metrics
async def sample_proxy_function(*args, **kwargs):
    return "success"


@pytest.mark.asyncio
async def test_log_db_metrics_success():
    # Mock the proxy_logging_obj
    with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
        # Setup mock
        mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock()

        # Call the decorated function
        result = await sample_db_function(parent_otel_span="test_span")

        # Assertions
        assert result == "success"
        mock_proxy_logging.service_logging_obj.async_service_success_hook.assert_called_once()
        call_args = (
            mock_proxy_logging.service_logging_obj.async_service_success_hook.call_args[
                1
            ]
        )

        assert call_args["service"] == ServiceTypes.DB
        assert call_args["call_type"] == "sample_db_function"
        assert call_args["parent_otel_span"] == "test_span"
        assert isinstance(call_args["duration"], float)
        assert isinstance(call_args["start_time"], datetime)
        assert isinstance(call_args["end_time"], datetime)
        assert "function_name" in call_args["event_metadata"]


@pytest.mark.asyncio
async def test_log_db_metrics_duration():
    # Mock the proxy_logging_obj
    with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
        # Setup mock
        mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock()

        # Add a delay to the function to test duration
        @log_db_metrics
        async def delayed_function(**kwargs):
            await asyncio.sleep(1)  # 1 second delay
            return "success"

        # Call the decorated function
        start = time.time()
        result = await delayed_function(parent_otel_span="test_span")
        end = time.time()

        # Get the actual duration
        actual_duration = end - start

        # Get the logged duration from the mock call
        call_args = (
            mock_proxy_logging.service_logging_obj.async_service_success_hook.call_args[
                1
            ]
        )
        logged_duration = call_args["duration"]

        # Assert the logged duration is approximately equal to actual duration (within 0.1 seconds)
        assert abs(logged_duration - actual_duration) < 0.1
        assert result == "success"


@pytest.mark.asyncio
async def test_log_db_metrics_failure():
    """
    should log a failure if a prisma error is raised
    """
    # Mock the proxy_logging_obj
    from prisma.errors import ClientNotConnectedError

    with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
        # Setup mock
        mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock()

        # Create a failing function
        @log_db_metrics
        async def failing_function(**kwargs):
            raise ClientNotConnectedError()

        # Call the decorated function and expect it to raise
        with pytest.raises(ClientNotConnectedError) as exc_info:
            await failing_function(parent_otel_span="test_span")

        # Assertions
        assert "Client is not connected to the query engine" in str(exc_info.value)
        mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once()
        call_args = (
            mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[
                1
            ]
        )

        assert call_args["service"] == ServiceTypes.DB
        assert call_args["call_type"] == "failing_function"
        assert call_args["parent_otel_span"] == "test_span"
        assert isinstance(call_args["duration"], float)
        assert isinstance(call_args["start_time"], datetime)
        assert isinstance(call_args["end_time"], datetime)
        assert isinstance(call_args["error"], ClientNotConnectedError)


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "exception,should_log",
    [
        (ValueError("Generic error"), False),
        (KeyError("Missing key"), False),
        (TypeError("Type error"), False),
        (httpx.ConnectError("Failed to connect"), True),
        (httpx.TimeoutException("Request timed out"), True),
        (ClientNotConnectedError(), True),  # Prisma error
    ],
)
async def test_log_db_metrics_failure_error_types(exception, should_log):
    """
    Why Test?
    Users were seeing that non-DB errors were being logged as DB Service Failures
    Example a failure to read a value from cache was being logged as a DB Service Failure


    Parameterized test to verify:
    - DB-related errors (Prisma, httpx) are logged as service failures
    - Non-DB errors (ValueError, KeyError, etc.) are not logged
    """
    with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging:
        mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock()

        @log_db_metrics
        async def failing_function(**kwargs):
            raise exception

        # Call the function and expect it to raise the exception
        with pytest.raises(type(exception)):
            await failing_function(parent_otel_span="test_span")

        if should_log:
            # Assert failure was logged for DB-related errors
            mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once()
            call_args = mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[
                1
            ]
            assert call_args["service"] == ServiceTypes.DB
            assert call_args["call_type"] == "failing_function"
            assert call_args["parent_otel_span"] == "test_span"
            assert isinstance(call_args["duration"], float)
            assert isinstance(call_args["start_time"], datetime)
            assert isinstance(call_args["end_time"], datetime)
            assert isinstance(call_args["error"], type(exception))
        else:
            # Assert failure was NOT logged for non-DB errors
            mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called()


@pytest.mark.asyncio
async def test_dd_log_db_spend_failure_metrics():
    from litellm._service_logger import ServiceLogging
    from litellm.integrations.datadog.datadog import DataDogLogger

    dd_logger = DataDogLogger()
    with patch.object(dd_logger, "async_service_failure_hook", new_callable=AsyncMock):
        service_logging_obj = ServiceLogging()

        litellm.service_callback = [dd_logger]

        await service_logging_obj.async_service_failure_hook(
            service=ServiceTypes.DB,
            call_type="test_call_type",
            error="test_error",
            duration=1.0,
        )