Spaces:
Configuration error
Configuration error
File size: 6,807 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import litellm
import gzip
import json
import logging
import time
from unittest.mock import AsyncMock, patch
import pytest
import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.integrations.gcs_pubsub.pub_sub import *
from datetime import datetime, timedelta
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingModelInformation,
StandardLoggingMetadata,
StandardLoggingHiddenParams,
)
verbose_logger.setLevel(logging.DEBUG)
from litellm_enterprise.enterprise_callbacks.generic_api_callback import GenericAPILogger
@pytest.mark.asyncio
async def test_generic_api_callback():
"""
Test the GenericAPILogger callback with a standard logging payload.
This test mocks the HTTP client and validates that the logger properly
formats and sends the expected payload.
"""
# Create a mock for the async_httpx_client's post method
mock_post = AsyncMock()
mock_post.return_value.status_code = 200
mock_post.return_value.text = "OK"
# Set up an endpoint for testing
test_endpoint = "https://example.com/api/logs"
test_headers = {"Authorization": "Bearer test_token"}
os.environ["GENERIC_LOGGER_ENDPOINT"] = test_endpoint
# Initialize the GenericAPILogger and set the mock
generic_logger = GenericAPILogger(
endpoint=test_endpoint,
headers=test_headers,
flush_interval=1
)
generic_logger.async_httpx_client.post = mock_post
litellm.callbacks = [generic_logger]
# Make the completion call
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="hi",
user="test_user",
)
# Wait for async flush
await asyncio.sleep(3)
# Assert httpx post was called
mock_post.assert_called_once()
# Get the actual request body from the mock
actual_url = mock_post.call_args[1]["url"]
print("##########\n")
print("logs were flushed to URL", actual_url, "with the following headers", mock_post.call_args[1]["headers"])
assert actual_url == test_endpoint, f"Expected URL {test_endpoint}, got {actual_url}"
# Validate headers
assert mock_post.call_args[1]["headers"]["Content-Type"] == "application/json", "Content-Type should be application/json"
# For the GenericAPILogger, it sends the payload directly as JSON in the data field
json_data = mock_post.call_args[1]["data"]
# Parse the JSON string
print("##########\n")
print("json_data", json_data)
actual_request = json.loads(json_data)
# The payload is a list of StandardLoggingPayload objects in the log queue
assert isinstance(actual_request, list), "Request body should be a list"
assert len(actual_request) > 0, "Request body list should not be empty"
# Validate the first payload item
payload_item: StandardLoggingPayload = StandardLoggingPayload(**actual_request[0])
print("##########\n")
print(json.dumps(payload_item, indent=4))
print("##########\n")
# Basic assertions for standard logging payload
assert payload_item["response_cost"] > 0, "Response cost should be greater than 0"
assert payload_item["model"] == "gpt-4o", "Model should be gpt-4o"
assert payload_item["model_parameters"]["user"] == "test_user", "User should be test_user"
assert payload_item["model"] == "gpt-4o", "Model should be gpt-4o"
assert payload_item["messages"] == [{"role": "user", "content": "Hello, world!"}], "Messages should be the same"
assert payload_item["response"]["choices"][0]["message"]["content"] == "hi", "Response should be hi"
@pytest.mark.asyncio
async def test_generic_api_callback_multiple_logs():
"""
Test the GenericAPILogger callback with multiple chat completions
"""
# Create a mock for the async_httpx_client's post method
mock_post = AsyncMock()
mock_post.return_value.status_code = 200
mock_post.return_value.text = "OK"
# Set up an endpoint for testing
test_endpoint = "https://example.com/api/logs"
test_headers = {"Authorization": "Bearer test_token"}
os.environ["GENERIC_LOGGER_ENDPOINT"] = test_endpoint
# Initialize the GenericAPILogger and set the mock
generic_logger = GenericAPILogger(
endpoint=test_endpoint,
headers=test_headers,
flush_interval=5
)
generic_logger.async_httpx_client.post = mock_post
litellm.callbacks = [generic_logger]
# Make the completion call
for _ in range(10):
response = await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="hi",
user="test_user",
)
# Wait for async flush
await asyncio.sleep(6)
# Assert httpx post was called
mock_post.assert_called_once()
# Get the actual request body from the mock
actual_url = mock_post.call_args[1]["url"]
print("##########\n")
print("logs were flushed to URL", actual_url, "with the following headers", mock_post.call_args[1]["headers"])
assert actual_url == test_endpoint, f"Expected URL {test_endpoint}, got {actual_url}"
# For the GenericAPILogger, it sends the payload directly as JSON in the data field
json_data = mock_post.call_args[1]["data"]
# Parse the JSON string
print("##########\n")
print("json_data", json_data)
actual_request = json.loads(json_data)
# The payload is a list of StandardLoggingPayload objects in the log queue
assert isinstance(actual_request, list), "Request body should be a list"
assert len(actual_request) > 0, "Request body list should not be empty"
assert len(actual_request) == 10, "Request body list should be 10 items, since we made 10 calls"
# Validate all payload items
for payload_item in actual_request:
payload_item: StandardLoggingPayload = StandardLoggingPayload(**payload_item)
print("##########\n")
print(json.dumps(payload_item, indent=4))
print("##########\n")
assert payload_item["response_cost"] > 0, "Response cost should be greater than 0"
assert payload_item["model"] == "gpt-4o", "Model should be gpt-4o"
assert payload_item["model_parameters"]["user"] == "test_user", "User should be test_user"
assert payload_item["model"] == "gpt-4o", "Model should be gpt-4o"
assert payload_item["messages"] == [{"role": "user", "content": "Hello, world!"}], "Messages should be the same"
assert payload_item["response"]["choices"][0]["message"]["content"] == "hi", "Response should be hi"
|