Spaces:
Configuration error
Configuration error
import io | |
import os | |
import sys | |
sys.path.insert(0, os.path.abspath("../..")) | |
import asyncio | |
import gzip | |
import json | |
import logging | |
import time | |
from unittest.mock import AsyncMock, patch | |
import pytest | |
import litellm | |
from litellm import completion | |
from litellm._logging import verbose_logger | |
from litellm.proxy.utils import log_db_metrics, ServiceTypes | |
from datetime import datetime | |
import httpx | |
from prisma.errors import ClientNotConnectedError | |
# Test async function to decorate | |
async def sample_db_function(*args, **kwargs): | |
return "success" | |
async def sample_proxy_function(*args, **kwargs): | |
return "success" | |
async def test_log_db_metrics_success(): | |
# Mock the proxy_logging_obj | |
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: | |
# Setup mock | |
mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock() | |
# Call the decorated function | |
result = await sample_db_function(parent_otel_span="test_span") | |
# Assertions | |
assert result == "success" | |
mock_proxy_logging.service_logging_obj.async_service_success_hook.assert_called_once() | |
call_args = ( | |
mock_proxy_logging.service_logging_obj.async_service_success_hook.call_args[ | |
1 | |
] | |
) | |
assert call_args["service"] == ServiceTypes.DB | |
assert call_args["call_type"] == "sample_db_function" | |
assert call_args["parent_otel_span"] == "test_span" | |
assert isinstance(call_args["duration"], float) | |
assert isinstance(call_args["start_time"], datetime) | |
assert isinstance(call_args["end_time"], datetime) | |
assert "function_name" in call_args["event_metadata"] | |
async def test_log_db_metrics_duration(): | |
# Mock the proxy_logging_obj | |
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: | |
# Setup mock | |
mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock() | |
# Add a delay to the function to test duration | |
async def delayed_function(**kwargs): | |
await asyncio.sleep(1) # 1 second delay | |
return "success" | |
# Call the decorated function | |
start = time.time() | |
result = await delayed_function(parent_otel_span="test_span") | |
end = time.time() | |
# Get the actual duration | |
actual_duration = end - start | |
# Get the logged duration from the mock call | |
call_args = ( | |
mock_proxy_logging.service_logging_obj.async_service_success_hook.call_args[ | |
1 | |
] | |
) | |
logged_duration = call_args["duration"] | |
# Assert the logged duration is approximately equal to actual duration (within 0.1 seconds) | |
assert abs(logged_duration - actual_duration) < 0.1 | |
assert result == "success" | |
async def test_log_db_metrics_failure(): | |
""" | |
should log a failure if a prisma error is raised | |
""" | |
# Mock the proxy_logging_obj | |
from prisma.errors import ClientNotConnectedError | |
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: | |
# Setup mock | |
mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock() | |
# Create a failing function | |
async def failing_function(**kwargs): | |
raise ClientNotConnectedError() | |
# Call the decorated function and expect it to raise | |
with pytest.raises(ClientNotConnectedError) as exc_info: | |
await failing_function(parent_otel_span="test_span") | |
# Assertions | |
assert "Client is not connected to the query engine" in str(exc_info.value) | |
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once() | |
call_args = ( | |
mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[ | |
1 | |
] | |
) | |
assert call_args["service"] == ServiceTypes.DB | |
assert call_args["call_type"] == "failing_function" | |
assert call_args["parent_otel_span"] == "test_span" | |
assert isinstance(call_args["duration"], float) | |
assert isinstance(call_args["start_time"], datetime) | |
assert isinstance(call_args["end_time"], datetime) | |
assert isinstance(call_args["error"], ClientNotConnectedError) | |
async def test_log_db_metrics_failure_error_types(exception, should_log): | |
""" | |
Why Test? | |
Users were seeing that non-DB errors were being logged as DB Service Failures | |
Example a failure to read a value from cache was being logged as a DB Service Failure | |
Parameterized test to verify: | |
- DB-related errors (Prisma, httpx) are logged as service failures | |
- Non-DB errors (ValueError, KeyError, etc.) are not logged | |
""" | |
with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: | |
mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock() | |
async def failing_function(**kwargs): | |
raise exception | |
# Call the function and expect it to raise the exception | |
with pytest.raises(type(exception)): | |
await failing_function(parent_otel_span="test_span") | |
if should_log: | |
# Assert failure was logged for DB-related errors | |
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once() | |
call_args = mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[ | |
1 | |
] | |
assert call_args["service"] == ServiceTypes.DB | |
assert call_args["call_type"] == "failing_function" | |
assert call_args["parent_otel_span"] == "test_span" | |
assert isinstance(call_args["duration"], float) | |
assert isinstance(call_args["start_time"], datetime) | |
assert isinstance(call_args["end_time"], datetime) | |
assert isinstance(call_args["error"], type(exception)) | |
else: | |
# Assert failure was NOT logged for non-DB errors | |
mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called() | |
async def test_dd_log_db_spend_failure_metrics(): | |
from litellm._service_logger import ServiceLogging | |
from litellm.integrations.datadog.datadog import DataDogLogger | |
dd_logger = DataDogLogger() | |
with patch.object(dd_logger, "async_service_failure_hook", new_callable=AsyncMock): | |
service_logging_obj = ServiceLogging() | |
litellm.service_callback = [dd_logger] | |
await service_logging_obj.async_service_failure_hook( | |
service=ServiceTypes.DB, | |
call_type="test_call_type", | |
error="test_error", | |
duration=1.0, | |
) | |