File size: 3,178 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import asyncio
import json
import os
import sys
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch

sys.path.insert(
    0, os.path.abspath("../../..")
)  # Adds the parent directory to the system path

import pytest
from prisma.errors import ClientNotConnectedError, HTTPClientClosedError, PrismaError

from litellm.proxy._types import ProxyErrorTypes, ProxyException
from litellm.proxy.health_endpoints._health_endpoints import (
    _db_health_readiness_check,
    db_health_cache,
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "prisma_error",
    [
        PrismaError(),
        ClientNotConnectedError(),
        HTTPClientClosedError(),
    ],
)
async def test_db_health_readiness_check_with_prisma_error(prisma_error):
    """
    Test that when prisma_client.health_check() raises a PrismaError and
    allow_requests_on_db_unavailable is True, the function should not raise an error
    and return the cached health status.
    """
    # Mock the prisma client
    mock_prisma_client = MagicMock()
    mock_prisma_client.health_check.side_effect = prisma_error

    # Reset the health cache to a known state
    global db_health_cache
    db_health_cache = {
        "status": "unknown",
        "last_updated": datetime.now() - timedelta(minutes=5),
    }

    # Patch the imports and general_settings
    with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch(
        "litellm.proxy.proxy_server.general_settings",
        {"allow_requests_on_db_unavailable": True},
    ):
        # Call the function
        result = await _db_health_readiness_check()

        # Verify that the function called health_check
        mock_prisma_client.health_check.assert_called_once()

        # Verify that the function returned the cache
        assert result is not None
        assert result["status"] == "unknown"  # Should retain the status from the cache


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "prisma_error",
    [
        PrismaError(),
        ClientNotConnectedError(),
        HTTPClientClosedError(),
    ],
)
async def test_db_health_readiness_check_with_error_and_flag_off(prisma_error):
    """
    Test that when prisma_client.health_check() raises a DB error but
    allow_requests_on_db_unavailable is False, the exception should be raised.
    """
    # Mock the prisma client
    mock_prisma_client = MagicMock()
    mock_prisma_client.health_check.side_effect = prisma_error

    # Reset the health cache
    global db_health_cache
    db_health_cache = {
        "status": "unknown",
        "last_updated": datetime.now() - timedelta(minutes=5),
    }

    # Patch the imports and general_settings where the flag is False
    with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch(
        "litellm.proxy.proxy_server.general_settings",
        {"allow_requests_on_db_unavailable": False},
    ):
        # The function should raise the exception
        with pytest.raises(Exception) as excinfo:
            await _db_health_readiness_check()

        # Verify that the raised exception is the same
        assert excinfo.value == prisma_error