Spaces:
Configuration error
Configuration error
File size: 4,052 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import sys
from unittest.mock import MagicMock, call, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.openai.common_utils import BaseOpenAILLM
# Test parameters for different API functions
API_FUNCTION_PARAMS = [
# (function_name, is_async, args)
(
"completion",
False,
{
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
),
(
"completion",
True,
{
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
),
(
"completion",
True,
{
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
"stream": True,
},
),
("embedding", False, {"model": "text-embedding-ada-002", "input": "Hello world"}),
("embedding", True, {"model": "text-embedding-ada-002", "input": "Hello world"}),
(
"image_generation",
False,
{"model": "dall-e-3", "prompt": "A beautiful sunset over mountains"},
),
(
"image_generation",
True,
{"model": "dall-e-3", "prompt": "A beautiful sunset over mountains"},
),
(
"speech",
False,
{
"model": "tts-1",
"input": "Hello, this is a test of text to speech",
"voice": "alloy",
},
),
(
"speech",
True,
{
"model": "tts-1",
"input": "Hello, this is a test of text to speech",
"voice": "alloy",
},
),
("transcription", False, {"model": "whisper-1", "file": MagicMock()}),
("transcription", True, {"model": "whisper-1", "file": MagicMock()}),
]
@pytest.mark.parametrize("function_name,is_async,args", API_FUNCTION_PARAMS)
@pytest.mark.asyncio
async def test_openai_client_reuse(function_name, is_async, args):
"""
Test that multiple API calls reuse the same OpenAI client
"""
litellm.set_verbose = True
# Determine which client class to mock based on whether the test is async
client_path = (
"litellm.llms.openai.openai.AsyncOpenAI"
if is_async
else "litellm.llms.openai.openai.OpenAI"
)
# Create the appropriate patches
with patch(client_path) as mock_client_class, patch.object(
BaseOpenAILLM, "set_cached_openai_client"
) as mock_set_cache, patch.object(
BaseOpenAILLM, "get_cached_openai_client"
) as mock_get_cache:
# Setup the mock to return None first time (cache miss) then a client for subsequent calls
mock_client = MagicMock()
mock_get_cache.side_effect = [None] + [
mock_client
] * 9 # First call returns None, rest return the mock client
# Make 10 API calls
for _ in range(10):
try:
# Call the appropriate function based on parameters
if is_async:
# Add 'a' prefix for async functions
func = getattr(litellm, f"a{function_name}")
await func(**args)
else:
func = getattr(litellm, function_name)
func(**args)
except Exception:
# We expect exceptions since we're mocking the client
pass
# Verify client was created only once
assert (
mock_client_class.call_count == 1
), f"{'Async' if is_async else ''}OpenAI client should be created only once"
# Verify the client was cached
assert mock_set_cache.call_count == 1, "Client should be cached once"
# Verify we tried to get from cache 10 times (once per request)
assert mock_get_cache.call_count == 10, "Should check cache for each request"
|