# tests/test_budget_endpoints.py import os import sys import types import pytest from unittest.mock import AsyncMock, MagicMock from fastapi.testclient import TestClient import litellm.proxy.proxy_server as ps from litellm.proxy.proxy_server import app from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles, CommonProxyErrors import litellm.proxy.management_endpoints.budget_management_endpoints as bm sys.path.insert( 0, os.path.abspath("../../../") ) # Adds the parent directory to the system path @pytest.fixture def client_and_mocks(monkeypatch): # Setup MagicMock Prisma mock_prisma = MagicMock() mock_table = MagicMock() mock_table.create = AsyncMock(side_effect=lambda *, data: data) mock_table.update = AsyncMock(side_effect=lambda *, where, data: {**where, **data}) mock_prisma.db = types.SimpleNamespace( litellm_budgettable = mock_table, litellm_dailyspend = mock_table, ) # Monkeypatch Mocked Prisma client into the server module monkeypatch.setattr(ps, "prisma_client", mock_prisma) # override returned auth user fake_user = UserAPIKeyAuth( user_id="test_user", user_role=LitellmUserRoles.INTERNAL_USER, ) app.dependency_overrides[ps.user_api_key_auth] = lambda: fake_user client = TestClient(app) yield client, mock_prisma, mock_table # teardown app.dependency_overrides.clear() monkeypatch.setattr(ps, "prisma_client", ps.prisma_client) @pytest.mark.asyncio async def test_new_budget_success(client_and_mocks): client, _, mock_table = client_and_mocks # Call /budget/new endpoint payload = { "budget_id": "budget_123", "max_budget": 42.0, "budget_duration": "30d", } resp = client.post("/budget/new", json=payload) assert resp.status_code == 200, resp.text body = resp.json() assert body["budget_id"] == payload["budget_id"] assert body["max_budget"] == payload["max_budget"] assert body["budget_duration"] == payload["budget_duration"] assert body["created_by"] == "test_user" assert body["updated_by"] == "test_user" mock_table.create.assert_awaited_once() @pytest.mark.asyncio async def test_new_budget_db_not_connected(client_and_mocks, monkeypatch): client, mock_prisma, mock_table = client_and_mocks # override the prisma_client that the handler imports at runtime import litellm.proxy.proxy_server as ps monkeypatch.setattr(ps, "prisma_client", None) # Call /budget/new endpoint resp = client.post("/budget/new", json={"budget_id": "no_db", "max_budget": 1.0}) assert resp.status_code == 500 detail = resp.json()["detail"] assert detail["error"] == CommonProxyErrors.db_not_connected_error.value @pytest.mark.asyncio async def test_update_budget_success(client_and_mocks, monkeypatch): client, mock_prisma, mock_table = client_and_mocks payload = { "budget_id": "budget_456", "max_budget": 99.0, "soft_budget": 50.0, } resp = client.post("/budget/update", json=payload) assert resp.status_code == 200, resp.text body = resp.json() assert body["budget_id"] == payload["budget_id"] assert body["max_budget"] == payload["max_budget"] assert body["soft_budget"] == payload["soft_budget"] assert body["updated_by"] == "test_user" @pytest.mark.asyncio async def test_update_budget_missing_id(client_and_mocks, monkeypatch): client, mock_prisma, mock_table = client_and_mocks payload = {"max_budget": 10.0} resp = client.post("/budget/update", json=payload) assert resp.status_code == 400, resp.text detail = resp.json()["detail"] assert detail["error"] == "budget_id is required" @pytest.mark.asyncio async def test_update_budget_db_not_connected(client_and_mocks, monkeypatch): client, mock_prisma, mock_table = client_and_mocks # override the prisma_client that the handler imports at runtime import litellm.proxy.proxy_server as ps monkeypatch.setattr(ps, "prisma_client", None) payload = {"budget_id": "any", "max_budget": 1.0} resp = client.post("/budget/update", json=payload) assert resp.status_code == 500 detail = resp.json()["detail"] assert detail["error"] == CommonProxyErrors.db_not_connected_error.value