File size: 5,266 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import pytest
import requests

from litellm.proxy.client import Client, ModelGroupsManagementClient
from litellm.proxy.client.exceptions import UnauthorizedError


@pytest.fixture
def base_url():
    return "http://localhost:8000"


@pytest.fixture
def api_key():
    return "test-api-key"


@pytest.fixture
def client(base_url, api_key):
    return ModelGroupsManagementClient(base_url=base_url, api_key=api_key)


def test_info_request_creation(client, base_url, api_key):
    """Test that info creates a request with correct URL and headers when return_request=True"""
    request = client.info(return_request=True)

    # Check request method
    assert request.method == "GET"

    # Check URL construction
    expected_url = f"{base_url}/model_group/info"
    assert request.url == expected_url

    # Check authorization header
    assert "Authorization" in request.headers
    assert request.headers["Authorization"] == f"Bearer {api_key}"


def test_info_request_no_auth(base_url):
    """Test that info creates a request without auth header when no api_key is provided"""
    client = ModelGroupsManagementClient(base_url=base_url)  # No API key
    request = client.info(return_request=True)

    # Check URL is still correct
    assert request.url == f"{base_url}/model_group/info"

    # Check that there's no authorization header
    assert "Authorization" not in request.headers


@pytest.mark.parametrize(
    "base_url,expected",
    [
        ("http://localhost:8000", "http://localhost:8000/model_group/info"),
        (
            "http://localhost:8000/",
            "http://localhost:8000/model_group/info",
        ),  # With trailing slash
        ("https://api.example.com", "https://api.example.com/model_group/info"),
        ("http://127.0.0.1:3000", "http://127.0.0.1:3000/model_group/info"),
    ],
)
def test_info_url_variants(base_url, expected):
    """Test that info handles different base URL formats correctly"""
    client = ModelGroupsManagementClient(base_url=base_url)
    request = client.info(return_request=True)
    assert request.url == expected


def test_info_with_mock_response(client, requests_mock):
    """Test the full info execution with a mocked response"""
    mock_data = {
        "data": [
            {
                "model_group_name": "gpt4-group",
                "models": ["gpt-4", "gpt-4-32k"],
                "litellm_params": {"timeout": 30, "max_retries": 3},
            },
            {
                "model_group_name": "azure-group",
                "models": ["azure-gpt-4", "azure-gpt-35"],
                "litellm_params": {
                    "api_base": "https://azure-endpoint.com",
                    "api_version": "2023-05-15",
                },
            },
        ]
    }
    requests_mock.get(f"{client._base_url}/model_group/info", json=mock_data)

    response = client.info()
    assert response == mock_data["data"]
    assert len(response) == 2
    assert response[0]["model_group_name"] == "gpt4-group"
    assert response[1]["model_group_name"] == "azure-group"


def test_info_unauthorized_error(client, requests_mock):
    """Test that info raises UnauthorizedError for 401 responses"""
    requests_mock.get(
        f"{client._base_url}/model_group/info",
        status_code=401,
        json={"error": "Invalid API key"},
    )

    with pytest.raises(UnauthorizedError) as exc_info:
        client.info()
    assert exc_info.value.orig_exception.response.status_code == 401


def test_info_other_errors(client, requests_mock):
    """Test that info raises normal HTTPError for non-401 errors"""
    requests_mock.get(
        f"{client._base_url}/model_group/info",
        status_code=500,
        json={"error": "Internal Server Error"},
    )

    with pytest.raises(requests.exceptions.HTTPError) as exc_info:
        client.info()
    assert exc_info.value.response.status_code == 500


@pytest.mark.parametrize(
    "api_key",
    [
        "",  # Empty string
        None,  # None value
    ],
)
def test_info_invalid_api_keys(base_url, api_key):
    """Test that the client handles invalid API keys appropriately"""
    client = ModelGroupsManagementClient(base_url=base_url, api_key=api_key)
    request = client.info(return_request=True)
    assert "Authorization" not in request.headers


def test_client_initialization_strips_trailing_slash():
    """Test that the client properly strips trailing slashes from base_url during initialization"""
    client = ModelGroupsManagementClient(base_url="http://localhost:8000/////")
    assert client._base_url == "http://localhost:8000"


def test_client_initialization(base_url, api_key):
    """Test that the Client properly initializes the model_groups client"""
    client = Client(base_url=base_url, api_key=api_key)

    # Check that model_groups client is properly initialized
    assert isinstance(client.model_groups, ModelGroupsManagementClient)
    assert client.model_groups._base_url == base_url
    assert client.model_groups._api_key == api_key


def test_client_initialization_without_api_key(base_url):
    """Test that the client works without an API key"""
    client = Client(base_url=base_url)

    assert client._api_key is None
    assert client.model_groups._api_key is None