Spaces:
Configuration error
Configuration error
### What this tests #### | |
## This test asserts the type of data passed into each method of the custom callback handler | |
import asyncio | |
import inspect | |
import os | |
import sys | |
import time | |
import traceback | |
import uuid | |
from datetime import datetime | |
import pytest | |
from pydantic import BaseModel | |
sys.path.insert(0, os.path.abspath("../..")) | |
from typing import List, Literal, Optional, Union | |
from unittest.mock import AsyncMock, MagicMock, patch | |
import litellm | |
from litellm import Cache, completion, embedding | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.types.utils import LiteLLMCommonStrings | |
# Test Scenarios (test across completion, streaming, embedding) | |
## 1: Pre-API-Call | |
## 2: Post-API-Call | |
## 3: On LiteLLM Call success | |
## 4: On LiteLLM Call failure | |
## 5. Caching | |
# Test models | |
## 1. OpenAI | |
## 2. Azure OpenAI | |
## 3. Non-OpenAI/Azure - e.g. Bedrock | |
# Test interfaces | |
## 1. litellm.completion() + litellm.embeddings() | |
## refer to test_custom_callback_input_router.py for the router + proxy tests | |
class CompletionCustomHandler( | |
CustomLogger | |
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class | |
""" | |
The set of expected inputs to a custom handler for a | |
""" | |
# Class variables or attributes | |
def __init__(self): | |
self.errors = [] | |
self.states: List[ | |
Literal[ | |
"sync_pre_api_call", | |
"async_pre_api_call", | |
"post_api_call", | |
"sync_stream", | |
"async_stream", | |
"sync_success", | |
"async_success", | |
"sync_failure", | |
"async_failure", | |
] | |
] = [] | |
def log_pre_api_call(self, model, messages, kwargs): | |
try: | |
self.states.append("sync_pre_api_call") | |
## MODEL | |
assert isinstance(model, str) | |
## MESSAGES | |
assert isinstance(messages, list) | |
## KWARGS | |
assert isinstance(kwargs["model"], str) | |
assert isinstance(kwargs["messages"], list) | |
assert isinstance(kwargs["optional_params"], dict) | |
assert isinstance(kwargs["litellm_params"], dict) | |
assert isinstance(kwargs["start_time"], (datetime, type(None))) | |
assert isinstance(kwargs["stream"], bool) | |
assert isinstance(kwargs["user"], (str, type(None))) | |
### METADATA | |
metadata_value = kwargs["litellm_params"].get("metadata") | |
assert metadata_value is None or isinstance(metadata_value, dict) | |
if metadata_value is not None: | |
if litellm.turn_off_message_logging is True: | |
assert ( | |
metadata_value["raw_request"] | |
is LiteLLMCommonStrings.redacted_by_litellm.value | |
) | |
else: | |
assert "raw_request" not in metadata_value or isinstance( | |
metadata_value["raw_request"], str | |
) | |
except Exception: | |
print(f"Assertion Error: {traceback.format_exc()}") | |
self.errors.append(traceback.format_exc()) | |
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): | |
try: | |
self.states.append("post_api_call") | |
## START TIME | |
assert isinstance(start_time, datetime) | |
## END TIME | |
assert end_time == None | |
## RESPONSE OBJECT | |
assert response_obj == None | |
## KWARGS | |
assert isinstance(kwargs["model"], str) | |
assert isinstance(kwargs["messages"], list) | |
assert isinstance(kwargs["optional_params"], dict) | |
assert isinstance(kwargs["litellm_params"], dict) | |
assert isinstance(kwargs["start_time"], (datetime, type(None))) | |
assert isinstance(kwargs["stream"], bool) | |
assert isinstance(kwargs["user"], (str, type(None))) | |
assert isinstance(kwargs["input"], (list, dict, str)) | |
assert isinstance(kwargs["api_key"], (str, type(None))) | |
assert ( | |
isinstance( | |
kwargs["original_response"], | |
(str, litellm.CustomStreamWrapper, BaseModel), | |
) | |
or inspect.iscoroutine(kwargs["original_response"]) | |
or inspect.isasyncgen(kwargs["original_response"]) | |
) | |
assert isinstance(kwargs["additional_args"], (dict, type(None))) | |
assert isinstance(kwargs["log_event_type"], str) | |
except Exception: | |
print(f"Assertion Error: {traceback.format_exc()}") | |
self.errors.append(traceback.format_exc()) | |
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): | |
try: | |
self.states.append("async_stream") | |
## START TIME | |
assert isinstance(start_time, datetime) | |
## END TIME | |
assert isinstance(end_time, datetime) | |
## RESPONSE OBJECT | |
assert isinstance(response_obj, litellm.ModelResponse) | |
## KWARGS | |
assert isinstance(kwargs["model"], str) | |
assert isinstance(kwargs["messages"], list) and isinstance( | |
kwargs["messages"][0], dict | |
) | |
assert isinstance(kwargs["optional_params"], dict) | |
assert isinstance(kwargs["litellm_params"], dict) | |
assert isinstance(kwargs["start_time"], (datetime, type(None))) | |
assert isinstance(kwargs["stream"], bool) | |
assert isinstance(kwargs["user"], (str, type(None))) | |
assert ( | |
isinstance(kwargs["input"], list) | |
and isinstance(kwargs["input"][0], dict) | |
) or isinstance(kwargs["input"], (dict, str)) | |
assert isinstance(kwargs["api_key"], (str, type(None))) | |
assert ( | |
isinstance( | |
kwargs["original_response"], (str, litellm.CustomStreamWrapper) | |
) | |
or inspect.isasyncgen(kwargs["original_response"]) | |
or inspect.iscoroutine(kwargs["original_response"]) | |
) | |
assert isinstance(kwargs["additional_args"], (dict, type(None))) | |
assert isinstance(kwargs["log_event_type"], str) | |
except Exception: | |
print(f"Assertion Error: {traceback.format_exc()}") | |
self.errors.append(traceback.format_exc()) | |
def log_success_event(self, kwargs, response_obj, start_time, end_time): | |
try: | |
print(f"\n\nkwargs={kwargs}\n\n") | |
print( | |
json.dumps(kwargs, default=str) | |
) # this is a test to confirm no circular references are in the logging object | |
self.states.append("sync_success") | |
## START TIME | |
assert isinstance(start_time, datetime) | |
## END TIME | |
assert isinstance(end_time, datetime) | |
## RESPONSE OBJECT | |
assert isinstance( | |
response_obj, | |
( | |
litellm.ModelResponse, | |
litellm.EmbeddingResponse, | |
litellm.ImageResponse, | |
), | |
) | |
## KWARGS | |
assert isinstance(kwargs["model"], str) | |
assert isinstance(kwargs["messages"], list) and isinstance( | |
kwargs["messages"][0], dict | |
) | |
assert isinstance(kwargs["optional_params"], dict) | |
assert isinstance(kwargs["litellm_params"], dict) | |
assert isinstance(kwargs["litellm_params"]["api_base"], str) | |
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) | |
assert isinstance(kwargs["start_time"], (datetime, type(None))) | |
assert isinstance(kwargs["stream"], bool) | |
assert isinstance(kwargs["user"], (str, type(None))) | |
assert ( | |
isinstance(kwargs["input"], list) | |
and ( | |
isinstance(kwargs["input"][0], dict) | |
or isinstance(kwargs["input"][0], str) | |
) | |
) or isinstance(kwargs["input"], (dict, str)) | |
assert isinstance(kwargs["api_key"], (str, type(None))) | |
assert isinstance( | |
kwargs["original_response"], | |
(str, litellm.CustomStreamWrapper, BaseModel), | |
), "Original Response={}. Allowed types=[str, litellm.CustomStreamWrapper, BaseModel]".format( | |
kwargs["original_response"] | |
) | |
assert isinstance(kwargs["additional_args"], (dict, type(None))) | |
assert isinstance(kwargs["log_event_type"], str) | |
assert isinstance(kwargs["response_cost"], (float, type(None))) | |
except Exception: | |
print(f"Assertion Error: {traceback.format_exc()}") | |
self.errors.append(traceback.format_exc()) | |
def log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
try: | |
print(f"kwargs: {kwargs}") | |
self.states.append("sync_failure") | |
## START TIME | |
assert isinstance(start_time, datetime) | |
## END TIME | |
assert isinstance(end_time, datetime) | |
## RESPONSE OBJECT | |
assert response_obj == None | |
## KWARGS | |
assert isinstance(kwargs["model"], str) | |
assert isinstance(kwargs["messages"], list) and isinstance( | |
kwargs["messages"][0], dict | |
) | |
assert isinstance(kwargs["optional_params"], dict) | |
assert isinstance(kwargs["litellm_params"], dict) | |
assert isinstance(kwargs["litellm_params"]["metadata"], Optional[dict]) | |
assert isinstance(kwargs["start_time"], (datetime, type(None))) | |
assert isinstance(kwargs["stream"], bool) | |
assert isinstance(kwargs["user"], (str, type(None))) | |
assert ( | |
isinstance(kwargs["input"], list) | |
and isinstance(kwargs["input"][0], dict) | |
) or isinstance(kwargs["input"], (dict, str)) | |
assert isinstance(kwargs["api_key"], (str, type(None))) | |
assert ( | |
isinstance( | |
kwargs["original_response"], (str, litellm.CustomStreamWrapper) | |
) | |
or kwargs["original_response"] == None | |
) | |
assert isinstance(kwargs["additional_args"], (dict, type(None))) | |
assert isinstance(kwargs["log_event_type"], str) | |
except Exception: | |
print(f"Assertion Error: {traceback.format_exc()}") | |
self.errors.append(traceback.format_exc()) | |
async def async_log_pre_api_call(self, model, messages, kwargs): | |
try: | |
self.states.append("async_pre_api_call") | |
## MODEL | |
assert isinstance(model, str) | |
## MESSAGES | |
assert isinstance(messages, list) and isinstance(messages[0], dict) | |
## KWARGS | |
assert isinstance(kwargs["model"], str) | |
assert isinstance(kwargs["messages"], list) and isinstance( | |
kwargs["messages"][0], dict | |
) | |
assert isinstance(kwargs["optional_params"], dict) | |
assert isinstance(kwargs["litellm_params"], dict) | |
assert isinstance(kwargs["start_time"], (datetime, type(None))) | |
assert isinstance(kwargs["stream"], bool) | |
assert isinstance(kwargs["user"], (str, type(None))) | |
except Exception as e: | |
print(f"Assertion Error: {traceback.format_exc()}") | |
self.errors.append(traceback.format_exc()) | |
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
try: | |
print( | |
"in async_log_success_event", kwargs, response_obj, start_time, end_time | |
) | |
self.states.append("async_success") | |
## START TIME | |
assert isinstance(start_time, datetime) | |
## END TIME | |
assert isinstance(end_time, datetime) | |
## RESPONSE OBJECT | |
assert isinstance( | |
response_obj, | |
( | |
litellm.ModelResponse, | |
litellm.EmbeddingResponse, | |
litellm.TextCompletionResponse, | |
), | |
) | |
## KWARGS | |
assert isinstance(kwargs["model"], str) | |
assert isinstance(kwargs["messages"], list) | |
assert isinstance(kwargs["optional_params"], dict) | |
assert isinstance(kwargs["litellm_params"], dict) | |
assert isinstance(kwargs["litellm_params"]["api_base"], str) | |
assert isinstance(kwargs["start_time"], (datetime, type(None))) | |
assert isinstance(kwargs["stream"], bool) | |
assert isinstance(kwargs["completion_start_time"], datetime) | |
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) | |
assert isinstance(kwargs["user"], (str, type(None))) | |
assert isinstance(kwargs["input"], (list, dict, str)) | |
assert isinstance(kwargs["api_key"], (str, type(None))) | |
assert ( | |
isinstance( | |
kwargs["original_response"], (str, litellm.CustomStreamWrapper) | |
) | |
or inspect.isasyncgen(kwargs["original_response"]) | |
or inspect.iscoroutine(kwargs["original_response"]) | |
) | |
assert isinstance(kwargs["additional_args"], (dict, type(None))) | |
assert isinstance(kwargs["log_event_type"], str) | |
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) | |
assert isinstance(kwargs["response_cost"], (float, type(None))) | |
except Exception: | |
print(f"Assertion Error: {traceback.format_exc()}") | |
self.errors.append(traceback.format_exc()) | |
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
try: | |
self.states.append("async_failure") | |
## START TIME | |
assert isinstance(start_time, datetime) | |
## END TIME | |
assert isinstance(end_time, datetime) | |
## RESPONSE OBJECT | |
assert response_obj == None | |
## KWARGS | |
assert isinstance(kwargs["model"], str) | |
assert isinstance(kwargs["messages"], list) | |
assert isinstance(kwargs["optional_params"], dict) | |
assert isinstance(kwargs["litellm_params"], dict) | |
assert isinstance(kwargs["start_time"], (datetime, type(None))) | |
assert isinstance(kwargs["stream"], bool) | |
assert isinstance(kwargs["user"], (str, type(None))) | |
assert isinstance(kwargs["input"], (list, str, dict)) | |
assert isinstance(kwargs["api_key"], (str, type(None))) | |
assert ( | |
isinstance( | |
kwargs["original_response"], (str, litellm.CustomStreamWrapper) | |
) | |
or inspect.isasyncgen(kwargs["original_response"]) | |
or inspect.iscoroutine(kwargs["original_response"]) | |
or kwargs["original_response"] == None | |
) | |
assert isinstance(kwargs["additional_args"], (dict, type(None))) | |
assert isinstance(kwargs["log_event_type"], str) | |
except Exception: | |
print(f"Assertion Error: {traceback.format_exc()}") | |
self.errors.append(traceback.format_exc()) | |