test3 / tests /test_litellm /proxy /guardrails /test_guardrail_endpoints.py
DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
import json
import os
import sys
from datetime import datetime
from typing import Dict, List, Optional
from unittest.mock import AsyncMock
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from fastapi import HTTPException
from litellm.proxy.guardrails.guardrail_endpoints import (
get_guardrail_info,
list_guardrails_v2,
)
from litellm.proxy.guardrails.guardrail_registry import (
IN_MEMORY_GUARDRAIL_HANDLER,
InMemoryGuardrailHandler,
)
from litellm.types.guardrails import (
GuardrailInfoLiteLLMParamsResponse,
GuardrailInfoResponse,
)
# Mock data for testing
MOCK_DB_GUARDRAIL = {
"guardrail_id": "test-db-guardrail",
"guardrail_name": "Test DB Guardrail",
"litellm_params": {
"guardrail": "test.guardrail",
"mode": "pre_call",
},
"guardrail_info": {"description": "Test guardrail from DB"},
"created_at": datetime.now(),
"updated_at": datetime.now(),
}
MOCK_CONFIG_GUARDRAIL = {
"guardrail_id": "test-config-guardrail",
"guardrail_name": "Test Config Guardrail",
"litellm_params": {
"guardrail": "custom_guardrail.myCustomGuardrail",
"mode": "during_call",
},
"guardrail_info": {"description": "Test guardrail from config"},
}
@pytest.fixture
def mock_prisma_client(mocker):
"""Mock Prisma client for testing"""
mock_client = mocker.Mock()
# Create async mocks for the database methods
mock_client.db = mocker.Mock()
mock_client.db.litellm_guardrailstable = mocker.Mock()
mock_client.db.litellm_guardrailstable.find_many = AsyncMock(
return_value=[MOCK_DB_GUARDRAIL]
)
mock_client.db.litellm_guardrailstable.find_unique = AsyncMock(
return_value=MOCK_DB_GUARDRAIL
)
return mock_client
@pytest.fixture
def mock_in_memory_handler(mocker):
"""Mock InMemoryGuardrailHandler for testing"""
mock_handler = mocker.Mock(spec=InMemoryGuardrailHandler)
mock_handler.list_in_memory_guardrails.return_value = [MOCK_CONFIG_GUARDRAIL]
mock_handler.get_guardrail_by_id.return_value = MOCK_CONFIG_GUARDRAIL
return mock_handler
@pytest.mark.asyncio
async def test_list_guardrails_v2_with_db_and_config(
mocker, mock_prisma_client, mock_in_memory_handler
):
"""Test listing guardrails from both DB and config"""
# Mock the prisma client
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
# Mock the in-memory handler
mocker.patch(
"litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER",
mock_in_memory_handler,
)
response = await list_guardrails_v2()
assert len(response.guardrails) == 2
# Check DB guardrail
db_guardrail = next(
g for g in response.guardrails if g.guardrail_id == "test-db-guardrail"
)
assert db_guardrail.guardrail_name == "Test DB Guardrail"
assert db_guardrail.guardrail_definition_location == "db"
assert isinstance(db_guardrail.litellm_params, GuardrailInfoLiteLLMParamsResponse)
# Check config guardrail
config_guardrail = next(
g for g in response.guardrails if g.guardrail_id == "test-config-guardrail"
)
assert config_guardrail.guardrail_name == "Test Config Guardrail"
assert config_guardrail.guardrail_definition_location == "config"
assert isinstance(
config_guardrail.litellm_params, GuardrailInfoLiteLLMParamsResponse
)
@pytest.mark.asyncio
async def test_get_guardrail_info_from_db(mocker, mock_prisma_client):
"""Test getting guardrail info from DB"""
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
response = await get_guardrail_info("test-db-guardrail")
assert response.guardrail_id == "test-db-guardrail"
assert response.guardrail_name == "Test DB Guardrail"
assert isinstance(response.litellm_params, GuardrailInfoLiteLLMParamsResponse)
assert response.guardrail_info == {"description": "Test guardrail from DB"}
@pytest.mark.asyncio
async def test_get_guardrail_info_from_config(
mocker, mock_prisma_client, mock_in_memory_handler
):
"""Test getting guardrail info from config when not found in DB"""
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch(
"litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER",
mock_in_memory_handler,
)
# Mock DB to return None
mock_prisma_client.db.litellm_guardrailstable.find_unique = AsyncMock(
return_value=None
)
response = await get_guardrail_info("test-config-guardrail")
assert response.guardrail_id == "test-config-guardrail"
assert response.guardrail_name == "Test Config Guardrail"
assert isinstance(response.litellm_params, GuardrailInfoLiteLLMParamsResponse)
assert response.guardrail_info == {"description": "Test guardrail from config"}
@pytest.mark.asyncio
async def test_get_guardrail_info_not_found(
mocker, mock_prisma_client, mock_in_memory_handler
):
"""Test getting guardrail info when not found in either DB or config"""
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocker.patch(
"litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER",
mock_in_memory_handler,
)
# Mock both DB and in-memory handler to return None
mock_prisma_client.db.litellm_guardrailstable.find_unique = AsyncMock(
return_value=None
)
mock_in_memory_handler.get_guardrail_by_id.return_value = None
with pytest.raises(HTTPException) as exc_info:
await get_guardrail_info("non-existent-guardrail")
assert exc_info.value.status_code == 404
assert "not found" in str(exc_info.value.detail)