File size: 4,566 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import asyncio
import json
import os
import sys
from unittest.mock import MagicMock, patch

import pytest
from fastapi import HTTPException, Request, status
from prisma import errors as prisma_errors
from prisma.errors import (
    ClientNotConnectedError,
    DataError,
    ForeignKeyViolationError,
    HTTPClientClosedError,
    MissingRequiredValueError,
    PrismaError,
    RawQueryError,
    RecordNotFoundError,
    TableNotFoundError,
    UniqueViolationError,
)

sys.path.insert(
    0, os.path.abspath("../../..")
)  # Adds the parent directory to the system path

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyErrorTypes, ProxyException
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler


# Test is_database_connection_error method
@pytest.mark.parametrize(
    "prisma_error",
    [
        PrismaError(),
        DataError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
        UniqueViolationError(
            data={"user_facing_error": {"meta": {"table": "test_table"}}}
        ),
        ForeignKeyViolationError(
            data={"user_facing_error": {"meta": {"table": "test_table"}}}
        ),
        MissingRequiredValueError(
            data={"user_facing_error": {"meta": {"table": "test_table"}}}
        ),
        RawQueryError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
        TableNotFoundError(
            data={"user_facing_error": {"meta": {"table": "test_table"}}}
        ),
        RecordNotFoundError(
            data={"user_facing_error": {"meta": {"table": "test_table"}}}
        ),
        HTTPClientClosedError(),
        ClientNotConnectedError(),
    ],
)
def test_is_database_connection_error_prisma_errors(prisma_error):
    """
    Test that all Prisma errors are considered database connection errors
    """
    assert PrismaDBExceptionHandler.is_database_connection_error(prisma_error) == True


def test_is_database_connection_generic_errors():
    """
    Test non-Prisma error cases for database connection checking
    """
    assert (
        PrismaDBExceptionHandler.is_database_connection_error(
            Exception("Regular error")
        )
        == False
    )

    # Test with ProxyException (DB connection)
    db_proxy_exception = ProxyException(
        message="DB Connection Error",
        type=ProxyErrorTypes.no_db_connection,
        param="test-param",
    )
    assert (
        PrismaDBExceptionHandler.is_database_connection_error(db_proxy_exception)
        == True
    )

    # Test with non-DB error
    regular_exception = Exception("Regular error")
    assert (
        PrismaDBExceptionHandler.is_database_connection_error(regular_exception)
        == False
    )


# Test should_allow_request_on_db_unavailable method
@patch(
    "litellm.proxy.proxy_server.general_settings",
    {"allow_requests_on_db_unavailable": True},
)
def test_should_allow_request_on_db_unavailable_true():
    assert PrismaDBExceptionHandler.should_allow_request_on_db_unavailable() == True


@patch(
    "litellm.proxy.proxy_server.general_settings",
    {"allow_requests_on_db_unavailable": False},
)
def test_should_allow_request_on_db_unavailable_false():
    assert PrismaDBExceptionHandler.should_allow_request_on_db_unavailable() == False


@patch(
    "litellm.proxy.proxy_server.general_settings",
    {"allow_requests_on_db_unavailable": True},
)
def test_handle_db_exception_with_connection_error():
    """
    Test that DB connection errors are handled gracefully when allow_requests_on_db_unavailable is True
    """
    db_error = ClientNotConnectedError()
    result = PrismaDBExceptionHandler.handle_db_exception(db_error)
    assert result is None


@patch(
    "litellm.proxy.proxy_server.general_settings",
    {"allow_requests_on_db_unavailable": False},
)
def test_handle_db_exception_raises_error():
    """
    Test that DB connection errors are raised when allow_requests_on_db_unavailable is False
    """
    db_error = ClientNotConnectedError()
    with pytest.raises(ClientNotConnectedError):
        PrismaDBExceptionHandler.handle_db_exception(db_error)


def test_handle_db_exception_with_non_db_error():
    """
    Test that non-DB errors are always raised regardless of allow_requests_on_db_unavailable setting
    """
    regular_error = litellm.BudgetExceededError(
        current_cost=10,
        max_budget=10,
    )
    with pytest.raises(litellm.BudgetExceededError):
        PrismaDBExceptionHandler.handle_db_exception(regular_error)