Spaces:
Configuration error
Configuration error
File size: 5,810 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import json
import os
import sys
from datetime import datetime
from typing import Dict, List, Optional
from unittest.mock import AsyncMock
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from fastapi import HTTPException
from litellm.proxy.guardrails.guardrail_endpoints import (
get_guardrail_info,
list_guardrails_v2,
)
from litellm.proxy.guardrails.guardrail_registry import (
IN_MEMORY_GUARDRAIL_HANDLER,
InMemoryGuardrailHandler,
)
from litellm.types.guardrails import (
GuardrailInfoLiteLLMParamsResponse,
GuardrailInfoResponse,
)
# Mock data for testing
MOCK_DB_GUARDRAIL = {
"guardrail_id": "test-db-guardrail",
"guardrail_name": "Test DB Guardrail",
"litellm_params": {
"guardrail": "test.guardrail",
"mode": "pre_call",
},
"guardrail_info": {"description": "Test guardrail from DB"},
"created_at": datetime.now(),
"updated_at": datetime.now(),
}
MOCK_CONFIG_GUARDRAIL = {
"guardrail_id": "test-config-guardrail",
"guardrail_name": "Test Config Guardrail",
"litellm_params": {
"guardrail": "custom_guardrail.myCustomGuardrail",
"mode": "during_call",
},
"guardrail_info": {"description": "Test guardrail from config"},
}
@pytest.fixture
def mock_prisma_client(mocker):
"""Mock Prisma client for testing"""
mock_client = mocker.Mock()
# Create async mocks for the database methods
mock_client.db = mocker.Mock()
mock_client.db.litellm_guardrailstable = mocker.Mock()
mock_client.db.litellm_guardrailstable.find_many = AsyncMock(
return_value=[MOCK_DB_GUARDRAIL]
)
mock_client.db.litellm_guardrailstable.find_unique = AsyncMock(
return_value=MOCK_DB_GUARDRAIL
)
return mock_client
@pytest.fixture
def mock_in_memory_handler(mocker):
"""Mock InMemoryGuardrailHandler for testing"""
mock_handler = mocker.Mock(spec=InMemoryGuardrailHandler)
mock_handler.list_in_memory_guardrails.return_value = [MOCK_CONFIG_GUARDRAIL]
mock_handler.get_guardrail_by_id.return_value = MOCK_CONFIG_GUARDRAIL
return mock_handler
@pytest.mark.asyncio
async def test_list_guardrails_v2_with_db_and_config(
mocker, mock_prisma_client, mock_in_memory_handler
):
"""Test listing guardrails from both DB and config"""
# Mock the prisma client
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
# Mock the in-memory handler
mocker.patch(
"litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER",
mock_in_memory_handler,
)
response = await list_guardrails_v2()
assert len(response.guardrails) == 2
# Check DB guardrail
db_guardrail = next(
g for g in response.guardrails if g.guardrail_id == "test-db-guardrail"
)
assert db_guardrail.guardrail_name == "Test DB Guardrail"
assert db_guardrail.guardrail_definition_location == "db"
assert isinstance(db_guardrail.litellm_params, GuardrailInfoLiteLLMParamsResponse)
# Check config guardrail
config_guardrail = next(
g for g in response.guardrails if g.guardrail_id == "test-config-guardrail"
)
assert config_guardrail.guardrail_name == "Test Config Guardrail"
assert config_guardrail.guardrail_definition_location == "config"
assert isinstance(
config_guardrail.litellm_params, GuardrailInfoLiteLLMParamsResponse
)
@pytest.mark.asyncio
async def test_get_guardrail_info_from_db(mocker, mock_prisma_client):
"""Test getting guardrail info from DB"""
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
response = await get_guardrail_info("test-db-guardrail")
assert response.guardrail_id == "test-db-guardrail"
assert response.guardrail_name == "Test DB Guardrail"
assert isinstance(response.litellm_params, GuardrailInfoLiteLLMParamsResponse)
assert response.guardrail_info == {"description": "Test guardrail from DB"}
@pytest.mark.asyncio
async def test_get_guardrail_info_from_config(
mocker, mock_prisma_client, mock_in_memory_handler
):
"""Test getting guardrail info from config when not found in DB"""
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch(
"litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER",
mock_in_memory_handler,
)
# Mock DB to return None
mock_prisma_client.db.litellm_guardrailstable.find_unique = AsyncMock(
return_value=None
)
response = await get_guardrail_info("test-config-guardrail")
assert response.guardrail_id == "test-config-guardrail"
assert response.guardrail_name == "Test Config Guardrail"
assert isinstance(response.litellm_params, GuardrailInfoLiteLLMParamsResponse)
assert response.guardrail_info == {"description": "Test guardrail from config"}
@pytest.mark.asyncio
async def test_get_guardrail_info_not_found(
mocker, mock_prisma_client, mock_in_memory_handler
):
"""Test getting guardrail info when not found in either DB or config"""
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch(
"litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER",
mock_in_memory_handler,
)
# Mock both DB and in-memory handler to return None
mock_prisma_client.db.litellm_guardrailstable.find_unique = AsyncMock(
return_value=None
)
mock_in_memory_handler.get_guardrail_by_id.return_value = None
with pytest.raises(HTTPException) as exc_info:
await get_guardrail_info("non-existent-guardrail")
assert exc_info.value.status_code == 404
assert "not found" in str(exc_info.value.detail)
|