DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
# tests/test_budget_endpoints.py
import os
import sys
import types
import pytest
from unittest.mock import AsyncMock, MagicMock
from fastapi.testclient import TestClient
import litellm.proxy.proxy_server as ps
from litellm.proxy.proxy_server import app
from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles, CommonProxyErrors
import litellm.proxy.management_endpoints.budget_management_endpoints as bm
sys.path.insert(
0, os.path.abspath("../../../")
) # Adds the parent directory to the system path
@pytest.fixture
def client_and_mocks(monkeypatch):
# Setup MagicMock Prisma
mock_prisma = MagicMock()
mock_table = MagicMock()
mock_table.create = AsyncMock(side_effect=lambda *, data: data)
mock_table.update = AsyncMock(side_effect=lambda *, where, data: {**where, **data})
mock_prisma.db = types.SimpleNamespace(
litellm_budgettable = mock_table,
litellm_dailyspend = mock_table,
)
# Monkeypatch Mocked Prisma client into the server module
monkeypatch.setattr(ps, "prisma_client", mock_prisma)
# override returned auth user
fake_user = UserAPIKeyAuth(
user_id="test_user",
user_role=LitellmUserRoles.INTERNAL_USER,
)
app.dependency_overrides[ps.user_api_key_auth] = lambda: fake_user
client = TestClient(app)
yield client, mock_prisma, mock_table
# teardown
app.dependency_overrides.clear()
monkeypatch.setattr(ps, "prisma_client", ps.prisma_client)
@pytest.mark.asyncio
async def test_new_budget_success(client_and_mocks):
client, _, mock_table = client_and_mocks
# Call /budget/new endpoint
payload = {
"budget_id": "budget_123",
"max_budget": 42.0,
"budget_duration": "30d",
}
resp = client.post("/budget/new", json=payload)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["budget_id"] == payload["budget_id"]
assert body["max_budget"] == payload["max_budget"]
assert body["budget_duration"] == payload["budget_duration"]
assert body["created_by"] == "test_user"
assert body["updated_by"] == "test_user"
mock_table.create.assert_awaited_once()
@pytest.mark.asyncio
async def test_new_budget_db_not_connected(client_and_mocks, monkeypatch):
client, mock_prisma, mock_table = client_and_mocks
# override the prisma_client that the handler imports at runtime
import litellm.proxy.proxy_server as ps
monkeypatch.setattr(ps, "prisma_client", None)
# Call /budget/new endpoint
resp = client.post("/budget/new", json={"budget_id": "no_db", "max_budget": 1.0})
assert resp.status_code == 500
detail = resp.json()["detail"]
assert detail["error"] == CommonProxyErrors.db_not_connected_error.value
@pytest.mark.asyncio
async def test_update_budget_success(client_and_mocks, monkeypatch):
client, mock_prisma, mock_table = client_and_mocks
payload = {
"budget_id": "budget_456",
"max_budget": 99.0,
"soft_budget": 50.0,
}
resp = client.post("/budget/update", json=payload)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["budget_id"] == payload["budget_id"]
assert body["max_budget"] == payload["max_budget"]
assert body["soft_budget"] == payload["soft_budget"]
assert body["updated_by"] == "test_user"
@pytest.mark.asyncio
async def test_update_budget_missing_id(client_and_mocks, monkeypatch):
client, mock_prisma, mock_table = client_and_mocks
payload = {"max_budget": 10.0}
resp = client.post("/budget/update", json=payload)
assert resp.status_code == 400, resp.text
detail = resp.json()["detail"]
assert detail["error"] == "budget_id is required"
@pytest.mark.asyncio
async def test_update_budget_db_not_connected(client_and_mocks, monkeypatch):
client, mock_prisma, mock_table = client_and_mocks
# override the prisma_client that the handler imports at runtime
import litellm.proxy.proxy_server as ps
monkeypatch.setattr(ps, "prisma_client", None)
payload = {"budget_id": "any", "max_budget": 1.0}
resp = client.post("/budget/update", json=payload)
assert resp.status_code == 500
detail = resp.json()["detail"]
assert detail["error"] == CommonProxyErrors.db_not_connected_error.value