Spaces:
Configuration error
Configuration error
import os | |
import sys | |
from unittest.mock import MagicMock, call, patch | |
import pytest | |
sys.path.insert( | |
0, os.path.abspath("../../..") | |
) # Adds the parent directory to the system path | |
import litellm | |
from litellm.llms.openai.common_utils import BaseOpenAILLM | |
# Test parameters for different API functions | |
API_FUNCTION_PARAMS = [ | |
# (function_name, is_async, args) | |
( | |
"completion", | |
False, | |
{ | |
"model": "gpt-4o-mini", | |
"messages": [{"role": "user", "content": "Hello"}], | |
"max_tokens": 10, | |
}, | |
), | |
( | |
"completion", | |
True, | |
{ | |
"model": "gpt-4o-mini", | |
"messages": [{"role": "user", "content": "Hello"}], | |
"max_tokens": 10, | |
}, | |
), | |
( | |
"completion", | |
True, | |
{ | |
"model": "gpt-4o-mini", | |
"messages": [{"role": "user", "content": "Hello"}], | |
"max_tokens": 10, | |
"stream": True, | |
}, | |
), | |
("embedding", False, {"model": "text-embedding-ada-002", "input": "Hello world"}), | |
("embedding", True, {"model": "text-embedding-ada-002", "input": "Hello world"}), | |
( | |
"image_generation", | |
False, | |
{"model": "dall-e-3", "prompt": "A beautiful sunset over mountains"}, | |
), | |
( | |
"image_generation", | |
True, | |
{"model": "dall-e-3", "prompt": "A beautiful sunset over mountains"}, | |
), | |
( | |
"speech", | |
False, | |
{ | |
"model": "tts-1", | |
"input": "Hello, this is a test of text to speech", | |
"voice": "alloy", | |
}, | |
), | |
( | |
"speech", | |
True, | |
{ | |
"model": "tts-1", | |
"input": "Hello, this is a test of text to speech", | |
"voice": "alloy", | |
}, | |
), | |
("transcription", False, {"model": "whisper-1", "file": MagicMock()}), | |
("transcription", True, {"model": "whisper-1", "file": MagicMock()}), | |
] | |
async def test_openai_client_reuse(function_name, is_async, args): | |
""" | |
Test that multiple API calls reuse the same OpenAI client | |
""" | |
litellm.set_verbose = True | |
# Determine which client class to mock based on whether the test is async | |
client_path = ( | |
"litellm.llms.openai.openai.AsyncOpenAI" | |
if is_async | |
else "litellm.llms.openai.openai.OpenAI" | |
) | |
# Create the appropriate patches | |
with patch(client_path) as mock_client_class, patch.object( | |
BaseOpenAILLM, "set_cached_openai_client" | |
) as mock_set_cache, patch.object( | |
BaseOpenAILLM, "get_cached_openai_client" | |
) as mock_get_cache: | |
# Setup the mock to return None first time (cache miss) then a client for subsequent calls | |
mock_client = MagicMock() | |
mock_get_cache.side_effect = [None] + [ | |
mock_client | |
] * 9 # First call returns None, rest return the mock client | |
# Make 10 API calls | |
for _ in range(10): | |
try: | |
# Call the appropriate function based on parameters | |
if is_async: | |
# Add 'a' prefix for async functions | |
func = getattr(litellm, f"a{function_name}") | |
await func(**args) | |
else: | |
func = getattr(litellm, function_name) | |
func(**args) | |
except Exception: | |
# We expect exceptions since we're mocking the client | |
pass | |
# Verify client was created only once | |
assert ( | |
mock_client_class.call_count == 1 | |
), f"{'Async' if is_async else ''}OpenAI client should be created only once" | |
# Verify the client was cached | |
assert mock_set_cache.call_count == 1, "Client should be cached once" | |
# Verify we tried to get from cache 10 times (once per request) | |
assert mock_get_cache.call_count == 10, "Should check cache for each request" | |