Spaces:
Configuration error
Configuration error
# tests/test_budget_endpoints.py | |
import os | |
import sys | |
import types | |
import pytest | |
from unittest.mock import AsyncMock, MagicMock | |
from fastapi.testclient import TestClient | |
import litellm.proxy.proxy_server as ps | |
from litellm.proxy.proxy_server import app | |
from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles, CommonProxyErrors | |
import litellm.proxy.management_endpoints.budget_management_endpoints as bm | |
sys.path.insert( | |
0, os.path.abspath("../../../") | |
) # Adds the parent directory to the system path | |
def client_and_mocks(monkeypatch): | |
# Setup MagicMock Prisma | |
mock_prisma = MagicMock() | |
mock_table = MagicMock() | |
mock_table.create = AsyncMock(side_effect=lambda *, data: data) | |
mock_table.update = AsyncMock(side_effect=lambda *, where, data: {**where, **data}) | |
mock_prisma.db = types.SimpleNamespace( | |
litellm_budgettable = mock_table, | |
litellm_dailyspend = mock_table, | |
) | |
# Monkeypatch Mocked Prisma client into the server module | |
monkeypatch.setattr(ps, "prisma_client", mock_prisma) | |
# override returned auth user | |
fake_user = UserAPIKeyAuth( | |
user_id="test_user", | |
user_role=LitellmUserRoles.INTERNAL_USER, | |
) | |
app.dependency_overrides[ps.user_api_key_auth] = lambda: fake_user | |
client = TestClient(app) | |
yield client, mock_prisma, mock_table | |
# teardown | |
app.dependency_overrides.clear() | |
monkeypatch.setattr(ps, "prisma_client", ps.prisma_client) | |
async def test_new_budget_success(client_and_mocks): | |
client, _, mock_table = client_and_mocks | |
# Call /budget/new endpoint | |
payload = { | |
"budget_id": "budget_123", | |
"max_budget": 42.0, | |
"budget_duration": "30d", | |
} | |
resp = client.post("/budget/new", json=payload) | |
assert resp.status_code == 200, resp.text | |
body = resp.json() | |
assert body["budget_id"] == payload["budget_id"] | |
assert body["max_budget"] == payload["max_budget"] | |
assert body["budget_duration"] == payload["budget_duration"] | |
assert body["created_by"] == "test_user" | |
assert body["updated_by"] == "test_user" | |
mock_table.create.assert_awaited_once() | |
async def test_new_budget_db_not_connected(client_and_mocks, monkeypatch): | |
client, mock_prisma, mock_table = client_and_mocks | |
# override the prisma_client that the handler imports at runtime | |
import litellm.proxy.proxy_server as ps | |
monkeypatch.setattr(ps, "prisma_client", None) | |
# Call /budget/new endpoint | |
resp = client.post("/budget/new", json={"budget_id": "no_db", "max_budget": 1.0}) | |
assert resp.status_code == 500 | |
detail = resp.json()["detail"] | |
assert detail["error"] == CommonProxyErrors.db_not_connected_error.value | |
async def test_update_budget_success(client_and_mocks, monkeypatch): | |
client, mock_prisma, mock_table = client_and_mocks | |
payload = { | |
"budget_id": "budget_456", | |
"max_budget": 99.0, | |
"soft_budget": 50.0, | |
} | |
resp = client.post("/budget/update", json=payload) | |
assert resp.status_code == 200, resp.text | |
body = resp.json() | |
assert body["budget_id"] == payload["budget_id"] | |
assert body["max_budget"] == payload["max_budget"] | |
assert body["soft_budget"] == payload["soft_budget"] | |
assert body["updated_by"] == "test_user" | |
async def test_update_budget_missing_id(client_and_mocks, monkeypatch): | |
client, mock_prisma, mock_table = client_and_mocks | |
payload = {"max_budget": 10.0} | |
resp = client.post("/budget/update", json=payload) | |
assert resp.status_code == 400, resp.text | |
detail = resp.json()["detail"] | |
assert detail["error"] == "budget_id is required" | |
async def test_update_budget_db_not_connected(client_and_mocks, monkeypatch): | |
client, mock_prisma, mock_table = client_and_mocks | |
# override the prisma_client that the handler imports at runtime | |
import litellm.proxy.proxy_server as ps | |
monkeypatch.setattr(ps, "prisma_client", None) | |
payload = {"budget_id": "any", "max_budget": 1.0} | |
resp = client.post("/budget/update", json=payload) | |
assert resp.status_code == 500 | |
detail = resp.json()["detail"] | |
assert detail["error"] == CommonProxyErrors.db_not_connected_error.value | |