File size: 6,807 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import io
import os
import sys


sys.path.insert(0, os.path.abspath("../.."))

import asyncio
import litellm
import gzip
import json
import logging
import time
from unittest.mock import AsyncMock, patch

import pytest

import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.integrations.gcs_pubsub.pub_sub import *
from datetime import datetime, timedelta
from litellm.types.utils import (
    StandardLoggingPayload,
    StandardLoggingModelInformation,
    StandardLoggingMetadata,
    StandardLoggingHiddenParams,
)

verbose_logger.setLevel(logging.DEBUG)
from litellm_enterprise.enterprise_callbacks.generic_api_callback import GenericAPILogger



@pytest.mark.asyncio
async def test_generic_api_callback():
    """
    Test the GenericAPILogger callback with a standard logging payload.
    This test mocks the HTTP client and validates that the logger properly
    formats and sends the expected payload.
    """

    # Create a mock for the async_httpx_client's post method
    mock_post = AsyncMock()
    mock_post.return_value.status_code = 200
    mock_post.return_value.text = "OK"

    # Set up an endpoint for testing
    test_endpoint = "https://example.com/api/logs"
    test_headers = {"Authorization": "Bearer test_token"}
    os.environ["GENERIC_LOGGER_ENDPOINT"] = test_endpoint

    # Initialize the GenericAPILogger and set the mock
    generic_logger = GenericAPILogger(
        endpoint=test_endpoint,
        headers=test_headers,
        flush_interval=1
    )
    generic_logger.async_httpx_client.post = mock_post
    litellm.callbacks = [generic_logger]

    # Make the completion call
    response = await litellm.acompletion(
        model="gpt-4o",
        messages=[{"role": "user", "content": "Hello, world!"}],
        mock_response="hi",
        user="test_user",
    )

    # Wait for async flush
    await asyncio.sleep(3)

    # Assert httpx post was called
    mock_post.assert_called_once()

    # Get the actual request body from the mock
    actual_url = mock_post.call_args[1]["url"]
    print("##########\n")
    print("logs were flushed to URL", actual_url, "with the following headers", mock_post.call_args[1]["headers"])
    assert actual_url == test_endpoint, f"Expected URL {test_endpoint}, got {actual_url}"

    # Validate headers
    assert mock_post.call_args[1]["headers"]["Content-Type"] == "application/json", "Content-Type should be application/json"
    
    # For the GenericAPILogger, it sends the payload directly as JSON in the data field
    json_data = mock_post.call_args[1]["data"]
    # Parse the JSON string

    print("##########\n")
    print("json_data", json_data)
    actual_request = json.loads(json_data)
    
    # The payload is a list of StandardLoggingPayload objects in the log queue
    assert isinstance(actual_request, list), "Request body should be a list"
    assert len(actual_request) > 0, "Request body list should not be empty"
    
    # Validate the first payload item
    payload_item: StandardLoggingPayload = StandardLoggingPayload(**actual_request[0])
    print("##########\n")
    print(json.dumps(payload_item, indent=4))
    print("##########\n")


    # Basic assertions for standard logging payload
    assert payload_item["response_cost"] > 0, "Response cost should be greater than 0"
    assert payload_item["model"] == "gpt-4o", "Model should be gpt-4o"
    assert payload_item["model_parameters"]["user"] == "test_user", "User should be test_user"
    assert payload_item["model"] == "gpt-4o", "Model should be gpt-4o"
    assert payload_item["messages"] == [{"role": "user", "content": "Hello, world!"}], "Messages should be the same"
    assert payload_item["response"]["choices"][0]["message"]["content"] == "hi", "Response should be hi"
    
    


@pytest.mark.asyncio
async def test_generic_api_callback_multiple_logs():
    """
    Test the GenericAPILogger callback with multiple chat completions
    """
    # Create a mock for the async_httpx_client's post method
    mock_post = AsyncMock()
    mock_post.return_value.status_code = 200
    mock_post.return_value.text = "OK"

    # Set up an endpoint for testing
    test_endpoint = "https://example.com/api/logs"
    test_headers = {"Authorization": "Bearer test_token"}
    os.environ["GENERIC_LOGGER_ENDPOINT"] = test_endpoint

    # Initialize the GenericAPILogger and set the mock
    generic_logger = GenericAPILogger(
        endpoint=test_endpoint,
        headers=test_headers,
        flush_interval=5
    )
    generic_logger.async_httpx_client.post = mock_post
    litellm.callbacks = [generic_logger]

    # Make the completion call
    for _ in range(10):
        response = await litellm.acompletion(
            model="gpt-4o",
            messages=[{"role": "user", "content": "Hello, world!"}],
            mock_response="hi",
            user="test_user",
        )

    # Wait for async flush
    await asyncio.sleep(6)

    # Assert httpx post was called
    mock_post.assert_called_once()

    # Get the actual request body from the mock
    actual_url = mock_post.call_args[1]["url"]
    print("##########\n")
    print("logs were flushed to URL", actual_url, "with the following headers", mock_post.call_args[1]["headers"])
    assert actual_url == test_endpoint, f"Expected URL {test_endpoint}, got {actual_url}"
    
    # For the GenericAPILogger, it sends the payload directly as JSON in the data field
    json_data = mock_post.call_args[1]["data"]
    # Parse the JSON string

    print("##########\n")
    print("json_data", json_data)
    actual_request = json.loads(json_data)
    
    # The payload is a list of StandardLoggingPayload objects in the log queue
    assert isinstance(actual_request, list), "Request body should be a list"
    assert len(actual_request) > 0, "Request body list should not be empty"
    assert len(actual_request) == 10, "Request body list should be 10 items, since we made 10 calls"
    
    # Validate all payload items
    for payload_item in actual_request:
        payload_item: StandardLoggingPayload = StandardLoggingPayload(**payload_item)
        print("##########\n")
        print(json.dumps(payload_item, indent=4))
        print("##########\n")

        assert payload_item["response_cost"] > 0, "Response cost should be greater than 0"
        assert payload_item["model"] == "gpt-4o", "Model should be gpt-4o"
        assert payload_item["model_parameters"]["user"] == "test_user", "User should be test_user"
        assert payload_item["model"] == "gpt-4o", "Model should be gpt-4o"
        assert payload_item["messages"] == [{"role": "user", "content": "Hello, world!"}], "Messages should be the same"
        assert payload_item["response"]["choices"][0]["message"]["content"] == "hi", "Response should be hi"