Spaces:
Configuration error
Configuration error
File size: 9,856 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
import asyncio
import os
import sys
from unittest.mock import Mock
from litellm.proxy.utils import _get_redoc_url, _get_docs_url
import pytest
from fastapi import Request
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from unittest.mock import MagicMock, patch, AsyncMock
import httpx
from litellm.proxy.utils import update_spend, DB_CONNECTION_ERROR_TYPES
class MockPrismaClient:
def __init__(self):
# Create AsyncMock for db operations
self.db = AsyncMock()
self.db.litellm_spendlogs = AsyncMock()
self.db.litellm_spendlogs.create_many = AsyncMock()
# Initialize transaction lists
self.spend_log_transactions = []
self.daily_user_spend_transactions = {}
def jsonify_object(self, obj):
return obj
def add_spend_log_transaction_to_daily_user_transaction(self, payload):
# Mock implementation
pass
def create_mock_proxy_logging():
print("creating mock proxy logging")
proxy_logging_obj = MagicMock()
proxy_logging_obj.failure_handler = AsyncMock()
proxy_logging_obj.db_spend_update_writer = AsyncMock()
proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler = AsyncMock()
print("returning proxy logging obj")
return proxy_logging_obj
@pytest.mark.asyncio
@pytest.mark.parametrize(
"error_type",
[
httpx.ConnectError("Failed to connect"),
httpx.ReadError("Failed to read response"),
httpx.ReadTimeout("Request timed out"),
],
)
async def test_update_spend_logs_connection_errors(error_type):
"""Test retry mechanism for different connection error types"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = create_mock_proxy_logging()
# Create AsyncMock for db_spend_update_writer
proxy_logging_obj.db_spend_update_writer = AsyncMock()
proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler = AsyncMock()
# Add test spend logs
prisma_client.spend_log_transactions = [
{"id": "1", "spend": 10},
{"id": "2", "spend": 20},
]
# Mock the database to fail with connection error twice then succeed
create_many_mock = AsyncMock()
create_many_mock.side_effect = [
error_type, # First attempt fails
error_type, # Second attempt fails
error_type, # Third attempt fails
None, # Fourth attempt succeeds
]
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
# Execute
await update_spend(prisma_client, None, proxy_logging_obj)
# Verify
assert create_many_mock.call_count == 4 # Should have tried 3 times
assert (
len(prisma_client.spend_log_transactions) == 0
) # Should have cleared after success
@pytest.mark.asyncio
@pytest.mark.parametrize(
"error_type",
[
httpx.ConnectError("Failed to connect"),
httpx.ReadError("Failed to read response"),
httpx.ReadTimeout("Request timed out"),
],
)
async def test_update_spend_logs_max_retries_exceeded(error_type):
"""Test that each connection error type properly fails after max retries"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = create_mock_proxy_logging()
# Add test spend logs
prisma_client.spend_log_transactions = [
{"id": "1", "spend": 10},
{"id": "2", "spend": 20},
]
# Mock the database to always fail
create_many_mock = AsyncMock(side_effect=error_type)
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
# Execute and verify it raises after max retries
with pytest.raises(type(error_type)) as exc_info:
await update_spend(prisma_client, None, proxy_logging_obj)
# Verify error message matches
assert str(exc_info.value) == str(error_type)
# Verify retry attempts (initial try + 4 retries)
assert create_many_mock.call_count == 4
await asyncio.sleep(2)
# Verify failure handler was called
assert proxy_logging_obj.failure_handler.call_count == 1
@pytest.mark.asyncio
async def test_update_spend_logs_non_connection_error():
"""Test handling of non-connection related errors"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = create_mock_proxy_logging()
# Add test spend logs
prisma_client.spend_log_transactions = [
{"id": "1", "spend": 10},
{"id": "2", "spend": 20},
]
# Mock a different type of error (not connection-related)
unexpected_error = ValueError("Unexpected database error")
create_many_mock = AsyncMock(side_effect=unexpected_error)
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
# Execute and verify it raises immediately without retrying
with pytest.raises(ValueError) as exc_info:
await update_spend(prisma_client, None, proxy_logging_obj)
# Verify error message
assert str(exc_info.value) == "Unexpected database error"
# Verify only tried once (no retries for non-connection errors)
assert create_many_mock.call_count == 1
# Verify failure handler was called
assert proxy_logging_obj.failure_handler.called
@pytest.mark.asyncio
async def test_update_spend_logs_exponential_backoff():
"""Test that exponential backoff is working correctly"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = create_mock_proxy_logging()
# Add test spend logs
prisma_client.spend_log_transactions = [{"id": "1", "spend": 10}]
# Track sleep times
sleep_times = []
# Mock asyncio.sleep to track delay times
async def mock_sleep(seconds):
sleep_times.append(seconds)
# Mock the database to fail with connection errors
create_many_mock = AsyncMock(
side_effect=[
httpx.ConnectError("Failed to connect"), # First attempt
httpx.ConnectError("Failed to connect"), # Second attempt
None, # Third attempt succeeds
]
)
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
# Apply mocks
with patch("asyncio.sleep", mock_sleep):
await update_spend(prisma_client, None, proxy_logging_obj)
# Verify exponential backoff
assert len(sleep_times) == 2 # Should have slept twice
assert sleep_times[0] == 1 # First retry after 2^0 seconds
assert sleep_times[1] == 2 # Second retry after 2^1 seconds
@pytest.mark.asyncio
async def test_update_spend_logs_multiple_batches_success():
"""
Test successful processing of multiple batches of spend logs
Code sets batch size to 100. This test creates 150 logs, so it should make 2 batches.
"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = create_mock_proxy_logging()
# Create 150 test spend logs (1.5x BATCH_SIZE)
prisma_client.spend_log_transactions = [
{"id": str(i), "spend": 10} for i in range(150)
]
create_many_mock = AsyncMock(return_value=None)
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
# Execute
await update_spend(prisma_client, None, proxy_logging_obj)
# Verify
assert create_many_mock.call_count == 2 # Should have made 2 batch calls
# Get the actual data from each batch call
first_batch = create_many_mock.call_args_list[0][1]["data"]
second_batch = create_many_mock.call_args_list[1][1]["data"]
# Verify batch sizes
assert len(first_batch) == 100
assert len(second_batch) == 50
# Verify exact IDs in each batch
expected_first_batch_ids = {str(i) for i in range(100)}
expected_second_batch_ids = {str(i) for i in range(100, 150)}
actual_first_batch_ids = {item["id"] for item in first_batch}
actual_second_batch_ids = {item["id"] for item in second_batch}
assert actual_first_batch_ids == expected_first_batch_ids
assert actual_second_batch_ids == expected_second_batch_ids
# Verify all logs were processed
assert len(prisma_client.spend_log_transactions) == 0
@pytest.mark.asyncio
async def test_update_spend_logs_multiple_batches_with_failure():
"""
Test processing of multiple batches where one batch fails.
Creates 400 logs (4 batches) with one batch failing but eventually succeeding after retry.
"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = create_mock_proxy_logging()
# Create 400 test spend logs (4x BATCH_SIZE)
prisma_client.spend_log_transactions = [
{"id": str(i), "spend": 10} for i in range(400)
]
# Mock to fail on second batch first attempt, then succeed
call_count = 0
async def create_many_side_effect(**kwargs):
nonlocal call_count
call_count += 1
# Fail on the second batch's first attempt
if call_count == 2:
raise httpx.ConnectError("Failed to connect")
return None
create_many_mock = AsyncMock(side_effect=create_many_side_effect)
prisma_client.db.litellm_spendlogs.create_many = create_many_mock
# Execute
await update_spend(prisma_client, None, proxy_logging_obj)
# Verify
assert create_many_mock.call_count == 6 # 4 batches + 2 retries for failed batch
# Verify all batches were processed
all_processed_logs = []
for call in create_many_mock.call_args_list:
all_processed_logs.extend(call[1]["data"])
# Verify all IDs were processed
processed_ids = {item["id"] for item in all_processed_logs}
# these should have ids 0-399
print("all processed ids", sorted(processed_ids, key=int))
expected_ids = {str(i) for i in range(400)}
assert processed_ids == expected_ids
# Verify all logs were cleared from transactions
assert len(prisma_client.spend_log_transactions) == 0
|