Spaces:
Configuration error
Configuration error
import json | |
import os | |
import sys | |
from datetime import datetime | |
from unittest.mock import AsyncMock | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system-path | |
from typing import Literal | |
import pytest | |
import litellm | |
import asyncio | |
import logging | |
from litellm._logging import verbose_logger | |
from prometheus_client import REGISTRY, CollectorRegistry | |
from litellm.integrations.lago import LagoLogger | |
from litellm.integrations.deepeval import DeepEvalLogger | |
from litellm.integrations.openmeter import OpenMeterLogger | |
from litellm.integrations.braintrust_logging import BraintrustLogger | |
from litellm.integrations.galileo import GalileoObserve | |
from litellm.integrations.langsmith import LangsmithLogger | |
from litellm.integrations.literal_ai import LiteralAILogger | |
from litellm.integrations.prometheus import PrometheusLogger | |
from litellm.integrations.datadog.datadog import DataDogLogger | |
from litellm.integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger | |
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger | |
from litellm.integrations.gcs_pubsub.pub_sub import GcsPubSubLogger | |
from litellm.integrations.opik.opik import OpikLogger | |
from litellm.integrations.opentelemetry import OpenTelemetry | |
from litellm.integrations.mlflow import MlflowLogger | |
from litellm.integrations.argilla import ArgillaLogger | |
from litellm.integrations.deepeval.deepeval import DeepEvalLogger | |
from litellm.integrations.s3_v2 import S3Logger | |
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook | |
from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore | |
from litellm.integrations.langfuse.langfuse_prompt_management import ( | |
LangfusePromptManagement, | |
) | |
from litellm.integrations.azure_storage.azure_storage import AzureBlobStorageLogger | |
from litellm.integrations.agentops import AgentOps | |
from litellm.integrations.humanloop import HumanloopLogger | |
from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler | |
from litellm_enterprise.enterprise_callbacks.generic_api_callback import GenericAPILogger | |
from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import ResendEmailLogger | |
from litellm_enterprise.enterprise_callbacks.send_emails.smtp_email import SMTPEmailLogger | |
from litellm_enterprise.enterprise_callbacks.pagerduty.pagerduty import PagerDutyAlerting | |
from unittest.mock import patch | |
# clear prometheus collectors / registry | |
collectors = list(REGISTRY._collector_to_names.keys()) | |
for collector in collectors: | |
REGISTRY.unregister(collector) | |
###################################### | |
callback_class_str_to_classType = { | |
"lago": LagoLogger, | |
"openmeter": OpenMeterLogger, | |
"braintrust": BraintrustLogger, | |
"galileo": GalileoObserve, | |
"langsmith": LangsmithLogger, | |
"literalai": LiteralAILogger, | |
"prometheus": PrometheusLogger, | |
"datadog": DataDogLogger, | |
"datadog_llm_observability": DataDogLLMObsLogger, | |
"gcs_bucket": GCSBucketLogger, | |
"opik": OpikLogger, | |
"argilla": ArgillaLogger, | |
"opentelemetry": OpenTelemetry, | |
"azure_storage": AzureBlobStorageLogger, | |
"humanloop": HumanloopLogger, | |
# OTEL compatible loggers | |
"logfire": OpenTelemetry, | |
"arize": OpenTelemetry, | |
"arize_phoenix": OpenTelemetry, | |
"langtrace": OpenTelemetry, | |
"mlflow": MlflowLogger, | |
"langfuse": LangfusePromptManagement, | |
"otel": OpenTelemetry, | |
"pagerduty": PagerDutyAlerting, | |
"gcs_pubsub": GcsPubSubLogger, | |
"anthropic_cache_control_hook": AnthropicCacheControlHook, | |
"agentops": AgentOps, | |
"bedrock_vector_store": BedrockVectorStore, | |
"generic_api": GenericAPILogger, | |
"resend_email": ResendEmailLogger, | |
"smtp_email": SMTPEmailLogger, | |
"deepeval": DeepEvalLogger, | |
"s3_v2": S3Logger, | |
} | |
expected_env_vars = { | |
"LAGO_API_KEY": "api_key", | |
"LAGO_API_BASE": "mock_base", | |
"LAGO_API_EVENT_CODE": "mock_event_code", | |
"OPENMETER_API_KEY": "openmeter_api_key", | |
"BRAINTRUST_API_KEY": "braintrust_api_key", | |
"GALILEO_API_KEY": "galileo_api_key", | |
"LITERAL_API_KEY": "literal_api_key", | |
"DD_API_KEY": "datadog_api_key", | |
"DD_SITE": "datadog_site", | |
"GOOGLE_APPLICATION_CREDENTIALS": "gcs_credentials", | |
"OPIK_API_KEY": "opik_api_key", | |
"LANGTRACE_API_KEY": "langtrace_api_key", | |
"LOGFIRE_TOKEN": "logfire_token", | |
"ARIZE_SPACE_KEY": "arize_space_key", | |
"ARIZE_API_KEY": "arize_api_key", | |
"PHOENIX_API_KEY": "phoenix_api_key", | |
"ARGILLA_API_KEY": "argilla_api_key", | |
"PAGERDUTY_API_KEY": "pagerduty_api_key", | |
"GCS_PUBSUB_TOPIC_ID": "gcs_pubsub_topic_id", | |
"GCS_PUBSUB_PROJECT_ID": "gcs_pubsub_project_id", | |
"CONFIDENT_API_KEY": "confident_api_key", | |
"LITELM_ENVIRONMENT": "development", | |
"AWS_BUCKET_NAME": "aws_bucket_name", | |
"AWS_SECRET_ACCESS_KEY": "aws_secret_access_key", | |
"AWS_ACCESS_KEY_ID": "aws_access_key_id", | |
"AWS_REGION": "aws_region", | |
} | |
def reset_all_callbacks(): | |
litellm.callbacks = [] | |
litellm.input_callback = [] | |
litellm.success_callback = [] | |
litellm.failure_callback = [] | |
litellm._async_success_callback = [] | |
litellm._async_failure_callback = [] | |
initial_env_vars = {} | |
def init_env_vars(): | |
for env_var, value in expected_env_vars.items(): | |
if env_var not in os.environ: | |
os.environ[env_var] = value | |
else: | |
initial_env_vars[env_var] = os.environ[env_var] | |
def reset_env_vars(): | |
for env_var, value in initial_env_vars.items(): | |
os.environ[env_var] = value | |
all_callback_required_env_vars = [] | |
async def use_callback_in_llm_call( | |
callback: str, used_in: Literal["callbacks", "success_callback"] | |
): | |
if callback == "dynamic_rate_limiter": | |
# internal CustomLogger class that expects internal_usage_cache passed to it, it always fails when tested in this way | |
return | |
elif callback == "argilla": | |
litellm.argilla_transformation_object = {} | |
elif callback == "openmeter": | |
# it's currently handled in jank way, TODO: fix openmete and then actually run it's test | |
return | |
elif callback == "prometheus": | |
# pytest teardown - clear existing prometheus collectors | |
collectors = list(REGISTRY._collector_to_names.keys()) | |
for collector in collectors: | |
REGISTRY.unregister(collector) | |
# Mock the httpx call for Argilla dataset retrieval | |
if callback == "argilla": | |
import httpx | |
mock_response = httpx.Response( | |
status_code=200, json={"items": [{"id": "mocked_dataset_id"}]} | |
) | |
patch.object( | |
litellm.module_level_client, "get", return_value=mock_response | |
).start() | |
# Mock the httpx call for Argilla dataset retrieval | |
if callback == "argilla": | |
import httpx | |
mock_response = httpx.Response( | |
status_code=200, json={"items": [{"id": "mocked_dataset_id"}]} | |
) | |
patch.object( | |
litellm.module_level_client, "get", return_value=mock_response | |
).start() | |
if used_in == "callbacks": | |
litellm.callbacks = [callback] | |
elif used_in == "success_callback": | |
litellm.success_callback = [callback] | |
for _ in range(5): | |
await litellm.acompletion( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "hi"}], | |
temperature=0.1, | |
mock_response="hello", | |
) | |
await asyncio.sleep(0.5) | |
expected_class = callback_class_str_to_classType[callback] | |
if used_in == "callbacks": | |
assert isinstance(litellm._async_success_callback[0], expected_class) | |
assert isinstance(litellm._async_failure_callback[0], expected_class) | |
assert isinstance(litellm.success_callback[0], expected_class) | |
assert isinstance(litellm.failure_callback[0], expected_class) | |
assert ( | |
len(litellm._async_success_callback) == 1 | |
), f"Got={litellm._async_success_callback}" | |
assert len(litellm._async_failure_callback) == 1 | |
assert len(litellm.success_callback) == 1 | |
assert len(litellm.failure_callback) == 1 | |
assert len(litellm.callbacks) == 1 | |
elif used_in == "success_callback": | |
print(f"litellm.success_callback: {litellm.success_callback}") | |
print(f"litellm._async_success_callback: {litellm._async_success_callback}") | |
assert isinstance(litellm.success_callback[0], expected_class) | |
assert len(litellm.success_callback) == 1 # ["lago", LagoLogger] | |
assert isinstance(litellm._async_success_callback[0], expected_class) | |
assert len(litellm._async_success_callback) == 1 | |
# TODO also assert that it's not set for failure_callback | |
# As of Oct 21 2024, it's currently set | |
# 1st hoping to add test coverage for just setting in success_callback/_async_success_callback | |
if callback == "argilla": | |
patch.stopall() | |
if callback == "argilla": | |
patch.stopall() | |
async def test_init_custom_logger_compatible_class_as_callback(): | |
init_env_vars() | |
# used like litellm.callbacks = ["prometheus"] | |
for callback in litellm._known_custom_logger_compatible_callbacks: | |
print(f"Testing callback: {callback}") | |
reset_all_callbacks() | |
await use_callback_in_llm_call(callback, used_in="callbacks") | |
# used like this litellm.success_callback = ["prometheus"] | |
for callback in litellm._known_custom_logger_compatible_callbacks: | |
print(f"Testing callback: {callback}") | |
reset_all_callbacks() | |
await use_callback_in_llm_call(callback, used_in="success_callback") | |
reset_env_vars() | |
def test_dynamic_logging_global_callback(): | |
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.types.utils import ModelResponse, Choices, Message, Usage | |
cl = CustomLogger() | |
litellm_logging = LiteLLMLoggingObj( | |
model="claude-3-opus-20240229", | |
messages=[{"role": "user", "content": "hi"}], | |
stream=False, | |
call_type="completion", | |
start_time=datetime.now(), | |
litellm_call_id="123", | |
function_id="456", | |
kwargs={ | |
"langfuse_public_key": "my-mock-public-key", | |
"langfuse_secret_key": "my-mock-secret-key", | |
}, | |
dynamic_success_callbacks=["langfuse"], | |
) | |
with patch.object(cl, "log_success_event") as mock_log_success_event: | |
cl.log_success_event = mock_log_success_event | |
litellm.success_callback = [cl] | |
try: | |
litellm_logging.success_handler( | |
result=ModelResponse( | |
id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70", | |
created=1732306261, | |
model="claude-3-opus-20240229", | |
object="chat.completion", | |
system_fingerprint=None, | |
choices=[ | |
Choices( | |
finish_reason="stop", | |
index=0, | |
message=Message( | |
content="hello", | |
role="assistant", | |
tool_calls=None, | |
function_call=None, | |
), | |
) | |
], | |
usage=Usage( | |
completion_tokens=20, | |
prompt_tokens=10, | |
total_tokens=30, | |
completion_tokens_details=None, | |
prompt_tokens_details=None, | |
), | |
), | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
cache_hit=False, | |
) | |
except Exception as e: | |
print(f"Error: {e}") | |
mock_log_success_event.assert_called_once() | |
def test_get_combined_callback_list(): | |
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
_logging = LiteLLMLoggingObj( | |
model="claude-3-opus-20240229", | |
messages=[{"role": "user", "content": "hi"}], | |
stream=False, | |
call_type="completion", | |
start_time=datetime.now(), | |
litellm_call_id="123", | |
function_id="456", | |
) | |
assert "langfuse" in _logging.get_combined_callback_list( | |
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] | |
) | |
assert "lago" in _logging.get_combined_callback_list( | |
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] | |
) | |