Spaces:
Configuration error
Configuration error
# What is this? | |
## Unit test that rejected requests are also logged as failures | |
# What is this? | |
## This tests the llm guard integration | |
import asyncio | |
import os | |
import random | |
# What is this? | |
## Unit test for presidio pii masking | |
import sys | |
import time | |
import traceback | |
from datetime import datetime | |
from dotenv import load_dotenv | |
load_dotenv() | |
import os | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system path | |
from typing import Literal | |
import pytest | |
from fastapi import Request, Response | |
from starlette.datastructures import URL | |
import litellm | |
from litellm import Router, mock_completion | |
from litellm.caching.caching import DualCache | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.proxy._types import UserAPIKeyAuth | |
from litellm_enterprise.enterprise_callbacks.secret_detection import ( | |
_ENTERPRISE_SecretDetection, | |
) | |
from litellm.proxy.proxy_server import ( | |
Depends, | |
HTTPException, | |
chat_completion, | |
completion, | |
embeddings, | |
) | |
from litellm.proxy.utils import ProxyLogging, hash_token | |
from litellm.router import Router | |
class testLogger(CustomLogger): | |
def __init__(self): | |
self.reaches_sync_failure_event = False | |
self.reaches_async_failure_event = False | |
async def async_pre_call_hook( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
cache: DualCache, | |
data: dict, | |
call_type: Literal[ | |
"completion", | |
"text_completion", | |
"embeddings", | |
"image_generation", | |
"moderation", | |
"audio_transcription", | |
"pass_through_endpoint", | |
"rerank", | |
], | |
): | |
raise HTTPException( | |
status_code=429, detail={"error": "Max parallel request limit reached"} | |
) | |
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
self.reaches_async_failure_event = True | |
def log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
self.reaches_sync_failure_event = True | |
router = Router( | |
model_list=[ | |
{ | |
"model_name": "fake-model", | |
"litellm_params": { | |
"model": "openai/fake", | |
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/", | |
"api_key": "sk-12345", | |
}, | |
} | |
] | |
) | |
async def test_chat_completion_request_with_redaction(route, body): | |
""" | |
IMPORTANT Enterprise Test - Do not delete it: | |
Makes a /chat/completions request on LiteLLM Proxy | |
Ensures that the secret is redacted EVEN on the callback | |
""" | |
from litellm.proxy import proxy_server | |
setattr(proxy_server, "llm_router", router) | |
_test_logger = testLogger() | |
litellm.callbacks = [_test_logger] | |
litellm.set_verbose = True | |
# Prepare the query string | |
query_params = "param1=value1¶m2=value2" | |
# Create the Request object with query parameters | |
request = Request( | |
scope={ | |
"type": "http", | |
"method": "POST", | |
"headers": [(b"content-type", b"application/json")], | |
"query_string": query_params.encode(), | |
} | |
) | |
request._url = URL(url=route) | |
async def return_body(): | |
import json | |
return json.dumps(body).encode() | |
request.body = return_body | |
try: | |
if route == "/v1/chat/completions": | |
response = await chat_completion( | |
request=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
api_key="sk-12345", token="hashed_sk-12345", rpm_limit=0 | |
), | |
fastapi_response=Response(), | |
) | |
elif route == "/v1/completions": | |
response = await completion( | |
request=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
api_key="sk-12345", token="hashed_sk-12345", rpm_limit=0 | |
), | |
fastapi_response=Response(), | |
) | |
elif route == "/v1/embeddings": | |
response = await embeddings( | |
request=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
api_key="sk-12345", token="hashed_sk-12345", rpm_limit=0 | |
), | |
fastapi_response=Response(), | |
) | |
except Exception: | |
pass | |
await asyncio.sleep(3) | |
assert _test_logger.reaches_async_failure_event is True | |
assert _test_logger.reaches_sync_failure_event is True | |