DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
import pytest
import requests
from litellm.proxy.client.chat import ChatClient
from litellm.proxy.client.exceptions import UnauthorizedError
@pytest.fixture
def base_url():
return "http://localhost:8000"
@pytest.fixture
def api_key():
return "test-api-key"
@pytest.fixture
def client(base_url, api_key):
return ChatClient(base_url=base_url, api_key=api_key)
@pytest.fixture
def sample_messages():
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Name 3 countries"},
]
def test_client_initialization(base_url, api_key):
"""Test that the ChatClient is properly initialized"""
client = ChatClient(base_url=base_url, api_key=api_key)
assert client._base_url == base_url
assert client._api_key == api_key
def test_client_initialization_strips_trailing_slash():
"""Test that the client properly strips trailing slashes from base_url during initialization"""
base_url = "http://localhost:8000/////"
client = ChatClient(base_url=base_url)
assert client._base_url == "http://localhost:8000"
def test_client_without_api_key(base_url):
"""Test that the client works without an API key"""
client = ChatClient(base_url=base_url)
assert client._api_key is None
def test_completions_request_creation(client, base_url, api_key, sample_messages):
"""Test that completions creates a request with correct URL, headers, and body"""
request = client.completions(
model="gpt-4",
messages=sample_messages,
temperature=0.7,
max_tokens=100,
return_request=True,
)
# Check request method and URL
assert request.method == "POST"
assert request.url == f"{base_url}/chat/completions"
# Check headers
assert request.headers["Content-Type"] == "application/json"
assert request.headers["Authorization"] == f"Bearer {api_key}"
# Check request body
assert request.json == {
"model": "gpt-4",
"messages": sample_messages,
"temperature": 0.7,
"max_tokens": 100,
}
def test_completions_minimal_request(client, sample_messages):
"""Test that completions works with only required parameters"""
request = client.completions(
model="gpt-4", messages=sample_messages, return_request=True
)
# Check request body has only required fields
assert request.json == {"model": "gpt-4", "messages": sample_messages}
def test_completions_all_parameters(client, sample_messages):
"""Test that completions accepts all optional parameters"""
request = client.completions(
model="gpt-4",
messages=sample_messages,
temperature=0.7,
top_p=0.9,
n=2,
max_tokens=100,
presence_penalty=0.5,
frequency_penalty=-0.5,
user="test-user",
return_request=True,
)
# Check all parameters are included in request body
assert request.json == {
"model": "gpt-4",
"messages": sample_messages,
"temperature": 0.7,
"top_p": 0.9,
"n": 2,
"max_tokens": 100,
"presence_penalty": 0.5,
"frequency_penalty": -0.5,
"user": "test-user",
}
def test_completions_mock_response(client, sample_messages, requests_mock):
"""Test completions with a mocked successful response"""
mock_response = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677858242,
"model": "gpt-4",
"usage": {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20},
"choices": [
{
"message": {
"role": "assistant",
"content": "Hello! How can I help you today?",
},
"finish_reason": "stop",
"index": 0,
}
],
}
# Mock the POST request
requests_mock.post(f"{client._base_url}/chat/completions", json=mock_response)
response = client.completions(model="gpt-4", messages=sample_messages)
assert response == mock_response
assert (
response["choices"][0]["message"]["content"]
== "Hello! How can I help you today?"
)
def test_completions_unauthorized_error(client, sample_messages, requests_mock):
"""Test that completions raises UnauthorizedError for 401 responses"""
# Mock a 401 response
requests_mock.post(
f"{client._base_url}/chat/completions",
status_code=401,
json={"error": "Unauthorized"},
)
with pytest.raises(UnauthorizedError):
client.completions(model="gpt-4", messages=sample_messages)
def test_completions_other_errors(client, sample_messages, requests_mock):
"""Test that completions raises HTTPError for other error responses"""
# Mock a 500 response
requests_mock.post(
f"{client._base_url}/chat/completions",
status_code=500,
json={"error": "Internal Server Error"},
)
with pytest.raises(requests.exceptions.HTTPError) as exc_info:
client.completions(model="gpt-4", messages=sample_messages)
assert exc_info.value.response.status_code == 500