Spaces:
Configuration error
Configuration error
#### What this tests #### | |
# Unit tests for JWT-Auth | |
import asyncio | |
import os | |
import random | |
import sys | |
import time | |
import traceback | |
import uuid | |
from dotenv import load_dotenv | |
load_dotenv() | |
import os | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system path | |
from datetime import datetime, timedelta | |
from unittest.mock import AsyncMock, MagicMock, patch | |
import pytest | |
from fastapi import Request, HTTPException | |
from fastapi.routing import APIRoute | |
from fastapi.responses import Response | |
import litellm | |
from litellm.caching.caching import DualCache | |
from litellm.proxy._types import ( | |
LiteLLM_JWTAuth, | |
LiteLLM_UserTable, | |
LiteLLMRoutes, | |
JWTAuthBuilderResult, | |
) | |
from litellm.proxy.auth.handle_jwt import JWTHandler, JWTAuthManager | |
from litellm.proxy.management_endpoints.team_endpoints import new_team | |
from litellm.proxy.proxy_server import chat_completion | |
from typing import Literal | |
public_key = { | |
"kty": "RSA", | |
"e": "AQAB", | |
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ", | |
"alg": "RS256", | |
} | |
def test_load_config_with_custom_role_names(): | |
config = { | |
"general_settings": { | |
"litellm_proxy_roles": {"admin_jwt_scope": "litellm-proxy-admin"} | |
} | |
} | |
proxy_roles = LiteLLM_JWTAuth( | |
**config.get("general_settings", {}).get("litellm_proxy_roles", {}) | |
) | |
print(f"proxy_roles: {proxy_roles}") | |
assert proxy_roles.admin_jwt_scope == "litellm-proxy-admin" | |
# test_load_config_with_custom_role_names() | |
async def test_token_single_public_key(monkeypatch): | |
import jwt | |
jwt_handler = JWTHandler() | |
backend_keys = { | |
"keys": [ | |
{ | |
"kty": "RSA", | |
"use": "sig", | |
"e": "AQAB", | |
"n": "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ", | |
"alg": "RS256", | |
} | |
] | |
} | |
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") | |
# set cache | |
cache = DualCache() | |
await cache.async_set_cache( | |
key="litellm_jwt_auth_keys_https://example.com/public-key", | |
value=backend_keys["keys"], | |
) | |
jwt_handler.user_api_key_cache = cache | |
public_key = await jwt_handler.get_public_key(kid=None) | |
assert public_key is not None | |
assert isinstance(public_key, dict) | |
assert ( | |
public_key["n"] | |
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ" | |
) | |
async def test_valid_invalid_token(audience, monkeypatch): | |
""" | |
Tests | |
- valid token | |
- invalid token | |
""" | |
import json | |
import jwt | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.hazmat.primitives import serialization | |
from cryptography.hazmat.primitives.asymmetric import rsa | |
os.environ.pop("JWT_AUDIENCE", None) | |
if audience: | |
os.environ["JWT_AUDIENCE"] = audience | |
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") | |
# Generate a private / public key pair using RSA algorithm | |
key = rsa.generate_private_key( | |
public_exponent=65537, key_size=2048, backend=default_backend() | |
) | |
# Get private key in PEM format | |
private_key = key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.PKCS8, | |
encryption_algorithm=serialization.NoEncryption(), | |
) | |
# Get public key in PEM format | |
public_key = key.public_key().public_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PublicFormat.SubjectPublicKeyInfo, | |
) | |
public_key_obj = serialization.load_pem_public_key( | |
public_key, backend=default_backend() | |
) | |
# Convert RSA public key object to JWK (JSON Web Key) | |
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) | |
assert isinstance(public_jwk, dict) | |
# set cache | |
cache = DualCache() | |
await cache.async_set_cache( | |
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] | |
) | |
jwt_handler = JWTHandler() | |
jwt_handler.user_api_key_cache = cache | |
# VALID TOKEN | |
## GENERATE A TOKEN | |
# Assuming the current time is in UTC | |
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp()) | |
payload = { | |
"sub": "user123", | |
"exp": expiration_time, # set the token to expire in 10 minutes | |
"scope": "litellm-proxy-admin", | |
"aud": audience, | |
} | |
# Generate the JWT token | |
# But before, you should convert bytes to string | |
private_key_str = private_key.decode("utf-8") | |
token = jwt.encode(payload, private_key_str, algorithm="RS256") | |
## VERIFY IT WORKS | |
# verify token | |
response = await jwt_handler.auth_jwt(token=token) | |
assert response is not None | |
assert isinstance(response, dict) | |
print(f"response: {response}") | |
# INVALID TOKEN | |
## GENERATE A TOKEN | |
# Assuming the current time is in UTC | |
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp()) | |
payload = { | |
"sub": "user123", | |
"exp": expiration_time, # set the token to expire in 10 minutes | |
"scope": "litellm-NO-SCOPE", | |
"aud": audience, | |
} | |
# Generate the JWT token | |
# But before, you should convert bytes to string | |
private_key_str = private_key.decode("utf-8") | |
token = jwt.encode(payload, private_key_str, algorithm="RS256") | |
## VERIFY IT WORKS | |
# verify token | |
try: | |
response = await jwt_handler.auth_jwt(token=token) | |
except Exception as e: | |
pytest.fail(f"An exception occurred - {str(e)}") | |
def prisma_client(): | |
import litellm | |
from litellm.proxy.proxy_cli import append_query_params | |
from litellm.proxy.utils import PrismaClient, ProxyLogging | |
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) | |
### add connection pool + pool timeout args | |
params = {"connection_limit": 100, "pool_timeout": 60} | |
database_url = os.getenv("DATABASE_URL") | |
modified_url = append_query_params(database_url, params) | |
os.environ["DATABASE_URL"] = modified_url | |
# Assuming PrismaClient is a class that needs to be instantiated | |
prisma_client = PrismaClient( | |
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj | |
) | |
return prisma_client | |
def team_token_tuple(): | |
import json | |
import uuid | |
import jwt | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.hazmat.primitives import serialization | |
from cryptography.hazmat.primitives.asymmetric import rsa | |
from fastapi import Request | |
from starlette.datastructures import URL | |
import litellm | |
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth | |
from litellm.proxy.proxy_server import user_api_key_auth | |
# Generate a private / public key pair using RSA algorithm | |
key = rsa.generate_private_key( | |
public_exponent=65537, key_size=2048, backend=default_backend() | |
) | |
# Get private key in PEM format | |
private_key = key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.PKCS8, | |
encryption_algorithm=serialization.NoEncryption(), | |
) | |
# Get public key in PEM format | |
public_key = key.public_key().public_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PublicFormat.SubjectPublicKeyInfo, | |
) | |
public_key_obj = serialization.load_pem_public_key( | |
public_key, backend=default_backend() | |
) | |
# Convert RSA public key object to JWK (JSON Web Key) | |
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) | |
# VALID TOKEN | |
## GENERATE A TOKEN | |
# Assuming the current time is in UTC | |
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp()) | |
team_id = f"team123_{uuid.uuid4()}" | |
payload = { | |
"sub": "user123", | |
"exp": expiration_time, # set the token to expire in 10 minutes | |
"scope": "litellm_team", | |
"client_id": team_id, | |
"aud": None, | |
} | |
# Generate the JWT token | |
# But before, you should convert bytes to string | |
private_key_str = private_key.decode("utf-8") | |
## team token | |
token = jwt.encode(payload, private_key_str, algorithm="RS256") | |
return team_id, token, public_jwk | |
async def test_team_token_output(prisma_client, audience, monkeypatch): | |
import json | |
import uuid | |
import jwt | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.hazmat.primitives import serialization | |
from cryptography.hazmat.primitives.asymmetric import rsa | |
from fastapi import Request | |
from starlette.datastructures import URL | |
import litellm | |
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth | |
from litellm.proxy.proxy_server import user_api_key_auth | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
os.environ.pop("JWT_AUDIENCE", None) | |
if audience: | |
os.environ["JWT_AUDIENCE"] = audience | |
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") | |
# Generate a private / public key pair using RSA algorithm | |
key = rsa.generate_private_key( | |
public_exponent=65537, key_size=2048, backend=default_backend() | |
) | |
# Get private key in PEM format | |
private_key = key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.PKCS8, | |
encryption_algorithm=serialization.NoEncryption(), | |
) | |
# Get public key in PEM format | |
public_key = key.public_key().public_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PublicFormat.SubjectPublicKeyInfo, | |
) | |
public_key_obj = serialization.load_pem_public_key( | |
public_key, backend=default_backend() | |
) | |
# Convert RSA public key object to JWK (JSON Web Key) | |
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) | |
assert isinstance(public_jwk, dict) | |
# set cache | |
cache = DualCache() | |
await cache.async_set_cache( | |
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] | |
) | |
jwt_handler = JWTHandler() | |
jwt_handler.user_api_key_cache = cache | |
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id") | |
# VALID TOKEN | |
## GENERATE A TOKEN | |
# Assuming the current time is in UTC | |
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp()) | |
team_id = f"team123_{uuid.uuid4()}" | |
payload = { | |
"sub": "user123", | |
"exp": expiration_time, # set the token to expire in 10 minutes | |
"scope": "litellm_team", | |
"client_id": team_id, | |
"aud": audience, | |
} | |
# Generate the JWT token | |
# But before, you should convert bytes to string | |
private_key_str = private_key.decode("utf-8") | |
## team token | |
token = jwt.encode(payload, private_key_str, algorithm="RS256") | |
## admin token | |
payload = { | |
"sub": "user123", | |
"exp": expiration_time, # set the token to expire in 10 minutes | |
"scope": "litellm_proxy_admin", | |
"aud": audience, | |
} | |
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") | |
## VERIFY IT WORKS | |
# verify token | |
response = await jwt_handler.auth_jwt(token=token) | |
## RUN IT THROUGH USER API KEY AUTH | |
""" | |
- 1. Initial call should fail -> team doesn't exist | |
- 2. Create team via admin token | |
- 3. 2nd call w/ same team -> call should succeed -> assert UserAPIKeyAuth object correctly formatted | |
""" | |
bearer_token = "Bearer " + token | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
## 1. INITIAL TEAM CALL - should fail | |
# use generated key to auth in | |
setattr( | |
litellm.proxy.proxy_server, | |
"general_settings", | |
{ | |
"enable_jwt_auth": True, | |
}, | |
) | |
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) | |
try: | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
pytest.fail("Team doesn't exist. This should fail") | |
except Exception as e: | |
pass | |
## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed | |
try: | |
bearer_token = "Bearer " + admin_token | |
request._url = URL(url="/team/new") | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
await new_team( | |
data=NewTeamRequest( | |
team_id=team_id, | |
tpm_limit=100, | |
rpm_limit=99, | |
models=["gpt-3.5-turbo", "gpt-4"], | |
), | |
user_api_key_dict=result, | |
http_request=Request(scope={"type": "http"}), | |
) | |
except Exception as e: | |
pytest.fail(f"This should not fail - {str(e)}") | |
## 3. 2nd CALL W/ TEAM TOKEN - should succeed | |
bearer_token = "Bearer " + token | |
request._url = URL(url="/chat/completions") | |
try: | |
team_result: UserAPIKeyAuth = await user_api_key_auth( | |
request=request, api_key=bearer_token | |
) | |
except Exception as e: | |
pytest.fail(f"Team exists. This should not fail - {e}") | |
## 4. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py) | |
assert team_result.team_tpm_limit == 100 | |
assert team_result.team_rpm_limit == 99 | |
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] | |
async def aaaatest_user_token_output( | |
prisma_client, audience, team_id_set, default_team_id, user_id_upsert, monkeypatch | |
): | |
import uuid | |
args = locals() | |
print(f"received args - {args}") | |
if default_team_id: | |
default_team_id = "team_id_12344_{}".format(uuid.uuid4()) | |
""" | |
- If user required, check if it exists | |
- fail initial request (when user doesn't exist) | |
- create user | |
- retry -> it should pass now | |
""" | |
import json | |
import uuid | |
import jwt | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.hazmat.primitives import serialization | |
from cryptography.hazmat.primitives.asymmetric import rsa | |
from fastapi import Request | |
from starlette.datastructures import URL | |
import litellm | |
from litellm.proxy._types import NewTeamRequest, NewUserRequest, UserAPIKeyAuth | |
from litellm.proxy.management_endpoints.internal_user_endpoints import ( | |
new_user, | |
user_info, | |
) | |
from litellm.proxy.proxy_server import user_api_key_auth | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
os.environ.pop("JWT_AUDIENCE", None) | |
if audience: | |
os.environ["JWT_AUDIENCE"] = audience | |
# Generate a private / public key pair using RSA algorithm | |
key = rsa.generate_private_key( | |
public_exponent=65537, key_size=2048, backend=default_backend() | |
) | |
# Get private key in PEM format | |
private_key = key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.PKCS8, | |
encryption_algorithm=serialization.NoEncryption(), | |
) | |
# Get public key in PEM format | |
public_key = key.public_key().public_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PublicFormat.SubjectPublicKeyInfo, | |
) | |
public_key_obj = serialization.load_pem_public_key( | |
public_key, backend=default_backend() | |
) | |
# Convert RSA public key object to JWK (JSON Web Key) | |
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) | |
assert isinstance(public_jwk, dict) | |
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") | |
# set cache | |
cache = DualCache() | |
await cache.async_set_cache( | |
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] | |
) | |
jwt_handler = JWTHandler() | |
jwt_handler.user_api_key_cache = cache | |
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() | |
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub" | |
jwt_handler.litellm_jwtauth.team_id_default = default_team_id | |
jwt_handler.litellm_jwtauth.user_id_upsert = user_id_upsert | |
if team_id_set: | |
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id" | |
# VALID TOKEN | |
## GENERATE A TOKEN | |
# Assuming the current time is in UTC | |
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp()) | |
team_id = f"team123_{uuid.uuid4()}" | |
user_id = f"user123_{uuid.uuid4()}" | |
payload = { | |
"sub": user_id, | |
"exp": expiration_time, # set the token to expire in 10 minutes | |
"scope": "litellm_team", | |
"client_id": team_id, | |
"aud": audience, | |
} | |
# Generate the JWT token | |
# But before, you should convert bytes to string | |
private_key_str = private_key.decode("utf-8") | |
## team token | |
token = jwt.encode(payload, private_key_str, algorithm="RS256") | |
## admin token | |
payload = { | |
"sub": user_id, | |
"exp": expiration_time, # set the token to expire in 10 minutes | |
"scope": "litellm_proxy_admin", | |
"aud": audience, | |
} | |
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") | |
## VERIFY IT WORKS | |
# verify token | |
response = await jwt_handler.auth_jwt(token=token) | |
## RUN IT THROUGH USER API KEY AUTH | |
""" | |
- 1. Initial call should fail -> team doesn't exist | |
- 2. Create team via admin token | |
- 3. 2nd call w/ same team -> call should fail -> user doesn't exist | |
- 4. Create user via admin token | |
- 5. 3rd call w/ same team, same user -> call should succeed | |
- 6. assert user api key auth format | |
""" | |
bearer_token = "Bearer " + token | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
## 1. INITIAL TEAM CALL - should fail | |
# use generated key to auth in | |
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True}) | |
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) | |
try: | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
pytest.fail("Team doesn't exist. This should fail") | |
except Exception as e: | |
pass | |
## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed | |
try: | |
bearer_token = "Bearer " + admin_token | |
request._url = URL(url="/team/new") | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
await new_team( | |
data=NewTeamRequest( | |
team_id=team_id, | |
tpm_limit=100, | |
rpm_limit=99, | |
models=["gpt-3.5-turbo", "gpt-4"], | |
), | |
user_api_key_dict=result, | |
http_request=Request(scope={"type": "http"}), | |
) | |
if default_team_id: | |
await new_team( | |
data=NewTeamRequest( | |
team_id=default_team_id, | |
tpm_limit=100, | |
rpm_limit=99, | |
models=["gpt-3.5-turbo", "gpt-4"], | |
), | |
user_api_key_dict=result, | |
http_request=Request(scope={"type": "http"}), | |
) | |
except Exception as e: | |
pytest.fail(f"This should not fail - {str(e)}") | |
## 3. 2nd CALL W/ TEAM TOKEN - should fail | |
bearer_token = "Bearer " + token | |
request._url = URL(url="/chat/completions") | |
try: | |
team_result: UserAPIKeyAuth = await user_api_key_auth( | |
request=request, api_key=bearer_token | |
) | |
if user_id_upsert == False: | |
pytest.fail(f"User doesn't exist. this should fail") | |
except Exception as e: | |
pass | |
## 4. Create user | |
if user_id_upsert: | |
## check if user already exists | |
try: | |
bearer_token = "Bearer " + admin_token | |
request._url = URL(url="/team/new") | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
await user_info(request=request, user_id=user_id) | |
except Exception as e: | |
pytest.fail(f"This should not fail - {str(e)}") | |
else: | |
try: | |
bearer_token = "Bearer " + admin_token | |
request._url = URL(url="/team/new") | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
await new_user( | |
data=NewUserRequest( | |
user_id=user_id, | |
), | |
) | |
except Exception as e: | |
pytest.fail(f"This should not fail - {str(e)}") | |
## 5. 3rd call w/ same team, same user -> call should succeed | |
bearer_token = "Bearer " + token | |
request._url = URL(url="/chat/completions") | |
try: | |
team_result: UserAPIKeyAuth = await user_api_key_auth( | |
request=request, api_key=bearer_token | |
) | |
except Exception as e: | |
pytest.fail(f"Team exists. This should not fail - {e}") | |
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking) | |
if team_id_set or default_team_id is not None: | |
assert team_result.team_tpm_limit == 100 | |
assert team_result.team_rpm_limit == 99 | |
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"] | |
assert team_result.user_id == user_id | |
async def test_allowed_routes_admin( | |
prisma_client, audience, admin_allowed_routes, monkeypatch | |
): | |
""" | |
Add a check to make sure jwt proxy admin scope can access all allowed admin routes | |
- iterate through allowed endpoints | |
- check if admin passes user_api_key_auth for them | |
""" | |
import json | |
import uuid | |
import jwt | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.hazmat.primitives import serialization | |
from cryptography.hazmat.primitives.asymmetric import rsa | |
from fastapi import Request | |
from starlette.datastructures import URL | |
import litellm | |
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth | |
from litellm.proxy.proxy_server import user_api_key_auth | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") | |
os.environ.pop("JWT_AUDIENCE", None) | |
if audience: | |
os.environ["JWT_AUDIENCE"] = audience | |
# Generate a private / public key pair using RSA algorithm | |
key = rsa.generate_private_key( | |
public_exponent=65537, key_size=2048, backend=default_backend() | |
) | |
# Get private key in PEM format | |
private_key = key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.PKCS8, | |
encryption_algorithm=serialization.NoEncryption(), | |
) | |
# Get public key in PEM format | |
public_key = key.public_key().public_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PublicFormat.SubjectPublicKeyInfo, | |
) | |
public_key_obj = serialization.load_pem_public_key( | |
public_key, backend=default_backend() | |
) | |
# Convert RSA public key object to JWK (JSON Web Key) | |
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) | |
assert isinstance(public_jwk, dict) | |
# set cache | |
cache = DualCache() | |
await cache.async_set_cache( | |
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] | |
) | |
jwt_handler = JWTHandler() | |
jwt_handler.user_api_key_cache = cache | |
if admin_allowed_routes: | |
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( | |
team_id_jwt_field="client_id", admin_allowed_routes=admin_allowed_routes | |
) | |
else: | |
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id") | |
# VALID TOKEN | |
## GENERATE A TOKEN | |
# Assuming the current time is in UTC | |
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp()) | |
# Generate the JWT token | |
# But before, you should convert bytes to string | |
private_key_str = private_key.decode("utf-8") | |
## admin token | |
payload = { | |
"sub": "user123", | |
"exp": expiration_time, # set the token to expire in 10 minutes | |
"scope": "litellm_proxy_admin", | |
"aud": audience, | |
} | |
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256") | |
# verify token | |
print(f"admin_token: {admin_token}") | |
response = await jwt_handler.auth_jwt(token=admin_token) | |
## RUN IT THROUGH USER API KEY AUTH | |
""" | |
- 1. Initial call should fail -> team doesn't exist | |
- 2. Create team via admin token | |
- 3. 2nd call w/ same team -> call should succeed -> assert UserAPIKeyAuth object correctly formatted | |
""" | |
bearer_token = "Bearer " + admin_token | |
pseudo_routes = jwt_handler.litellm_jwtauth.admin_allowed_routes | |
actual_routes = [] | |
for route in pseudo_routes: | |
if route in LiteLLMRoutes.__members__: | |
actual_routes.extend(LiteLLMRoutes[route].value) | |
for route in actual_routes: | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url=route) | |
## 1. INITIAL TEAM CALL - should fail | |
# use generated key to auth in | |
setattr( | |
litellm.proxy.proxy_server, | |
"general_settings", | |
{ | |
"enable_jwt_auth": True, | |
}, | |
) | |
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) | |
try: | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
except Exception as e: | |
raise e | |
import pytest | |
async def test_team_cache_update_called(): | |
import litellm | |
from litellm.proxy.proxy_server import user_api_key_cache | |
# Use setattr to replace the method on the user_api_key_cache object | |
cache = DualCache() | |
setattr( | |
litellm.proxy.proxy_server, | |
"user_api_key_cache", | |
cache, | |
) | |
with patch.object(cache, "async_get_cache", new=AsyncMock()) as mock_call_cache: | |
cache.async_get_cache = mock_call_cache | |
# Call the function under test | |
await litellm.proxy.proxy_server.update_cache( | |
token=None, | |
user_id=None, | |
end_user_id=None, | |
team_id="1234", | |
response_cost=20, | |
parent_otel_span=None, | |
) # type: ignore | |
await asyncio.sleep(3) | |
mock_call_cache.assert_awaited_once() | |
def public_jwt_key(): | |
import json | |
import jwt | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.hazmat.primitives import serialization | |
from cryptography.hazmat.primitives.asymmetric import rsa | |
# Generate a private / public key pair using RSA algorithm | |
key = rsa.generate_private_key( | |
public_exponent=65537, key_size=2048, backend=default_backend() | |
) | |
# Get private key in PEM format | |
private_key = key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.PKCS8, | |
encryption_algorithm=serialization.NoEncryption(), | |
) | |
# Get public key in PEM format | |
public_key = key.public_key().public_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PublicFormat.SubjectPublicKeyInfo, | |
) | |
public_key_obj = serialization.load_pem_public_key( | |
public_key, backend=default_backend() | |
) | |
# Convert RSA public key object to JWK (JSON Web Key) | |
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) | |
return {"private_key": private_key, "public_jwk": public_jwk} | |
def mock_user_object(*args, **kwargs): | |
print("Args: {}".format(args)) | |
print("kwargs: {}".format(kwargs)) | |
assert kwargs["user_id_upsert"] is True | |
async def test_allow_access_by_email( | |
public_jwt_key, user_email, should_work, monkeypatch | |
): | |
""" | |
Allow anyone with an `@xyz.com` email make a request to the proxy. | |
Relevant issue: https://github.com/BerriAI/litellm/issues/5605 | |
""" | |
import jwt | |
from starlette.datastructures import URL | |
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth | |
from litellm.proxy.proxy_server import user_api_key_auth | |
public_jwk = public_jwt_key["public_jwk"] | |
private_key = public_jwt_key["private_key"] | |
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") | |
# set cache | |
cache = DualCache() | |
await cache.async_set_cache( | |
key="litellm_jwt_auth_keys_https://example.com/public-key", value=[public_jwk] | |
) | |
jwt_handler = JWTHandler() | |
jwt_handler.user_api_key_cache = cache | |
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( | |
user_email_jwt_field="email", | |
user_allowed_email_domain="berri.ai", | |
user_id_upsert=True, | |
) | |
# VALID TOKEN | |
## GENERATE A TOKEN | |
# Assuming the current time is in UTC | |
expiration_time = int((datetime.now() + timedelta(minutes=10)).timestamp()) | |
team_id = f"team123_{uuid.uuid4()}" | |
payload = { | |
"sub": "user123", | |
"exp": expiration_time, # set the token to expire in 10 minutes | |
"scope": "litellm_team", | |
"client_id": team_id, | |
"aud": "litellm-proxy", | |
"email": user_email, | |
} | |
# Generate the JWT token | |
# But before, you should convert bytes to string | |
private_key_str = private_key.decode("utf-8") | |
## team token | |
token = jwt.encode(payload, private_key_str, algorithm="RS256") | |
## VERIFY IT WORKS | |
# Expect the call to succeed | |
response = await jwt_handler.auth_jwt(token=token) | |
assert response is not None # Adjust this based on your actual response check | |
## RUN IT THROUGH USER API KEY AUTH | |
bearer_token = "Bearer " + token | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
## 1. INITIAL TEAM CALL - should fail | |
# use generated key to auth in | |
setattr( | |
litellm.proxy.proxy_server, | |
"general_settings", | |
{ | |
"enable_jwt_auth": True, | |
}, | |
) | |
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) | |
setattr(litellm.proxy.proxy_server, "prisma_client", {}) | |
# AsyncMock( | |
# return_value=LiteLLM_UserTable( | |
# spend=0, user_id=user_email, max_budget=None, user_email=user_email | |
# ) | |
# ), | |
with patch.object( | |
litellm.proxy.auth.handle_jwt, | |
"get_user_object", | |
side_effect=mock_user_object, | |
) as mock_client: | |
if should_work: | |
# Expect the call to succeed | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
assert result is not None # Adjust this based on your actual response check | |
else: | |
# Expect the call to fail | |
with pytest.raises( | |
Exception | |
): # Replace with the actual exception raised on failure | |
resp = await user_api_key_auth(request=request, api_key=bearer_token) | |
print(resp) | |
def test_get_public_key_from_jwk_url(): | |
import litellm | |
from litellm.proxy.auth.handle_jwt import JWTHandler | |
jwt_handler = JWTHandler() | |
jwk_response = [ | |
{ | |
"kty": "RSA", | |
"alg": "RS256", | |
"kid": "RaPJB8QVptWHjHcoHkVlUWO4f0D3BtcY6iSDXgGVBgk", | |
"use": "sig", | |
"e": "AQAB", | |
"n": "zgLDu57gLpkzzIkKrTKQVyjK8X40hvu6X_JOeFjmYmI0r3bh7FTOmre5rTEkDOL-1xvQguZAx4hjKmCzBU5Kz84FbsGiqM0ug19df4kwdTS6XOM6YEKUZrbaw4P7xTPsbZj7W2G_kxWNm3Xaxq6UKFdUF7n9snnBKKD6iUA-cE6HfsYmt9OhYZJfy44dbAbuanFmAsWw97SHrPFL3ueh3Ixt19KgpF4iSsXNg3YvoesdFM8psmivgePyyHA8k7pK1Yq7rNQX1Q9nzhvP-F7ocFbP52KYPlaSTu30YwPTVTFKYpDNmHT1fZ7LXZZNLrP_7-NSY76HS2ozSpzjsGVelQ", | |
} | |
] | |
public_key = jwt_handler.parse_keys( | |
keys=jwk_response, | |
kid="RaPJB8QVptWHjHcoHkVlUWO4f0D3BtcY6iSDXgGVBgk", | |
) | |
assert public_key is not None | |
assert public_key == jwk_response[0] | |
async def test_end_user_jwt_auth(monkeypatch): | |
import litellm | |
from litellm.proxy.auth.handle_jwt import JWTHandler | |
from litellm.caching import DualCache | |
from litellm.proxy._types import LiteLLM_JWTAuth | |
from litellm.proxy.proxy_server import user_api_key_auth | |
import json | |
monkeypatch.delenv("JWT_AUDIENCE", None) | |
monkeypatch.setenv("JWT_PUBLIC_KEY_URL", "https://example.com/public-key") | |
jwt_handler = JWTHandler() | |
litellm_jwtauth = LiteLLM_JWTAuth( | |
end_user_id_jwt_field="sub", | |
) | |
cache = DualCache() | |
keys = [ | |
{ | |
"kid": "d-1733370597545", | |
"alg": "RS256", | |
"kty": "RSA", | |
"use": "sig", | |
"n": "j5Ik60FJSUIPMVdMSU8vhYvyPPH7rUsUllNI0BfBlIkgEYFk2mg4KX1XDQc6mcKFjbq9k_7TSkHWKnfPhNkkb0MdmZLKbwTrmte2k8xWDxp-rSmZpIJwC1zuPDx5joLHBgIb09-K2cPL2rvwzP75WtOr_QLXBerHAbXx8cOdI7mrSRWJ9iXbKv_pLDnZHnGNld75tztb8nCtgrywkF010jGi1xxaT8UKsTvK-QkIBkYI6m6WR9LMnG2OZm-ExuvNPUenfYUsqnynPF4SMNZwyQqJfavSLKI8uMzB2s9pcbx5HfQwIOYhMlgBHjhdDn2IUSnXSJqSsN6RQO18M2rgPQ", | |
"e": "AQAB", | |
}, | |
{ | |
"kid": "s-f836dd32-ef71-426a-8804-946a7f230bc9", | |
"alg": "RS256", | |
"kty": "RSA", | |
"use": "sig", | |
"n": "2A5-ZA18YKn7M4OtxsfXBc3Z7n2WyHTxbK4GEBlmD9T9TDr4sbJaI4oHfTvzsAC3H2r2YkASzrCISXMXQJjLHoeLgDVcKs8qTdLj7K5FNT9fA0kU9ayUjSGrqkz57SG7oNf9Wp__Qa-H-bs6Z8_CEfBy0JA9QSHUfrdOXp4vCB_qLn6DE0DJH9ELAq_0nktVQk_oxlvXlGtVZSZe31mNNgiD__RJMogf-SIFcYOkMLVGTTEBYiCk1mHxXS6oJZaVSWiBgHzu5wkra5AfQLUVelQaupT5H81hFPmiceEApf_2DacnqqRV4-Nl8sjhJtuTXiprVS2Z5r2pOMz_kVGNgw", | |
"e": "AQAB", | |
}, | |
] | |
cache.set_cache( | |
key="litellm_jwt_auth_keys_https://example.com/public-key", | |
value=keys, | |
) | |
jwt_handler.update_environment( | |
prisma_client=None, | |
user_api_key_cache=cache, | |
litellm_jwtauth=litellm_jwtauth, | |
leeway=100000000000000, | |
) | |
token = "eyJraWQiOiJkLTE3MzMzNzA1OTc1NDUiLCJ0eXAiOiJKV1QiLCJ2ZXJzaW9uIjoiNCIsImFsZyI6IlJTMjU2In0.eyJpYXQiOjE3MzM1MDcyNzcsImV4cCI6MTczMzUwNzg3Nywic3ViIjoiODFiM2U1MmEtNjdhNi00ZWZiLTk2NDUtNzA1MjdlMTAxNDc5IiwidElkIjoicHVibGljIiwic2Vzc2lvbkhhbmRsZSI6Ijg4M2Y4YWFmLWUwOTEtNGE1Ny04YTJhLTRiMjcwMmZhZjMzYyIsInJlZnJlc2hUb2tlbkhhc2gxIjoiNDVhNDRhYjlmOTMwMGQyOTY4ZjkxODZkYWQ0YTcwY2QwNjk2YzBiNTBmZmUxZmQ4ZTM2YzU1NGU0MWE4ODU0YiIsInBhcmVudFJlZnJlc2hUb2tlbkhhc2gxIjpudWxsLCJhbnRpQ3NyZlRva2VuIjpudWxsLCJpc3MiOiJodHRwOi8vbG9jYWxob3N0OjMwMDEvYXV0aC9zdCIsImxlZ2FjeV9jb21wYW55X2lkIjoxNTI0OTM5LCJsZWdhY3lfaWQiOjM5NzAyNzk1LCJzY29wZSI6WyJza2lsbF91cCJdLCJzdC1ldiI6eyJ0IjoxNzMzNTA3Mjc4NzAwLCJ2IjpmYWxzZX19.XlYrT6dRIjaZKkJtdr7C_UuxajFRbNpA9BnIsny3rxiPVyS8rhIBwxW12tZwgttRywmXrXK-msowFhWU4XdL5Qfe4lwZb2HTbDeGiQPvQTlOjWWYMhgCoKdPtjCQsAcW45rg7aQ0p42JFQPoAQa8AnGfxXpgx2vSR7njiZ3ZZyHerDdKQHyIGSFVOxoK0TgR-hxBVY__Wjg8UTKgKSz9KU_uwnPgpe2DeYmP-LTK2oeoygsVRmbldY_GrrcRe3nqYcUfFkxSs0FSsoSv35jIxiptXfCjhEB1Y5eaJhHEjlYlP2rw98JysYxjO2rZbAdUpL3itPeo3T2uh1NZr_lArw" | |
response = await jwt_handler.auth_jwt(token=token) | |
assert response is not None | |
end_user_id = jwt_handler.get_end_user_id( | |
token=response, | |
default_value=None, | |
) | |
assert end_user_id is not None | |
## CHECK USER API KEY AUTH ## | |
from starlette.datastructures import URL | |
bearer_token = "Bearer " + token | |
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion) | |
request = Request( | |
{ | |
"type": "http", | |
"route": api_route, | |
"path": "/chat/completions", | |
"headers": [(b"authorization", f"Bearer {bearer_token}".encode("latin-1"))], | |
"method": "POST", | |
} | |
) | |
async def return_body(): | |
body_dict = { | |
"model": "openai/gpt-3.5-turbo", | |
"messages": [{"role": "user", "content": "Hello, how are you?"}], | |
} | |
# Serialize the dictionary to JSON and encode it to bytes | |
return json.dumps(body_dict).encode("utf-8") | |
request.body = return_body | |
## 1. INITIAL TEAM CALL - should fail | |
# use generated key to auth in | |
setattr( | |
litellm.proxy.proxy_server, | |
"general_settings", | |
{"enable_jwt_auth": True, "pass_through_all_models": True}, | |
) | |
setattr( | |
litellm.proxy.proxy_server, | |
"llm_router", | |
MagicMock(), | |
) | |
setattr(litellm.proxy.proxy_server, "prisma_client", {}) | |
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) | |
from litellm.proxy.proxy_server import cost_tracking | |
cost_tracking() | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
assert ( | |
result.end_user_id == "81b3e52a-67a6-4efb-9645-70527e101479" | |
) # jwt token decoded sub value | |
temp_response = Response() | |
from litellm.proxy.hooks.proxy_track_cost_callback import ( | |
_should_track_cost_callback, | |
) | |
with patch.object( | |
litellm.proxy.hooks.proxy_track_cost_callback, "_should_track_cost_callback" | |
) as mock_client: | |
resp = await chat_completion( | |
request=request, | |
fastapi_response=temp_response, | |
model="gpt-4o", | |
user_api_key_dict=result, | |
) | |
assert resp is not None | |
await asyncio.sleep(1) | |
mock_client.assert_called_once() | |
mock_client.call_args.kwargs[ | |
"end_user_id" | |
] == "81b3e52a-67a6-4efb-9645-70527e101479" | |
def test_can_rbac_role_call_route(): | |
from litellm.proxy.auth.handle_jwt import JWTAuthManager | |
from litellm.proxy._types import RoleBasedPermissions | |
from litellm.proxy._types import LitellmUserRoles | |
with pytest.raises(HTTPException): | |
JWTAuthManager.can_rbac_role_call_route( | |
rbac_role=LitellmUserRoles.TEAM, | |
general_settings={ | |
"role_permissions": [ | |
RoleBasedPermissions( | |
role=LitellmUserRoles.TEAM, routes=["/v1/chat/completions"] | |
) | |
] | |
}, | |
route="/v1/embeddings", | |
) | |
def test_check_scope_based_access(requested_model, should_work): | |
from litellm.proxy.auth.handle_jwt import JWTAuthManager | |
from litellm.proxy._types import ScopeMapping | |
args = { | |
"scope_mappings": [ | |
ScopeMapping( | |
models=["anthropic-claude"], | |
routes=["/v1/chat/completions"], | |
scope="litellm.api.consumer", | |
), | |
ScopeMapping( | |
models=["gpt-3.5-turbo-testing"], | |
routes=None, | |
scope="litellm.api.gpt_3_5_turbo", | |
), | |
], | |
"scopes": [ | |
"profile", | |
"groups-scope", | |
"email", | |
"litellm.api.gpt_3_5_turbo", | |
"litellm.api.consumer", | |
], | |
"request_data": { | |
"model": requested_model, | |
"messages": [{"role": "user", "content": "Hey, how's it going 1234?"}], | |
}, | |
"general_settings": { | |
"enable_jwt_auth": True, | |
"litellm_jwtauth": { | |
"team_id_jwt_field": "client_id", | |
"team_id_upsert": True, | |
"scope_mappings": [ | |
{ | |
"scope": "litellm.api.consumer", | |
"models": ["anthropic-claude"], | |
"routes": ["/v1/chat/completions"], | |
}, | |
{ | |
"scope": "litellm.api.gpt_3_5_turbo", | |
"models": ["gpt-3.5-turbo-testing"], | |
}, | |
], | |
"enforce_scope_based_access": True, | |
"enforce_rbac": True, | |
}, | |
}, | |
} | |
if should_work: | |
JWTAuthManager.check_scope_based_access(**args) | |
else: | |
with pytest.raises(HTTPException): | |
JWTAuthManager.check_scope_based_access(**args) | |
async def test_custom_validate_called(): | |
# Setup | |
mock_custom_validate = MagicMock(return_value=True) | |
jwt_handler = MagicMock() | |
jwt_handler.litellm_jwtauth = MagicMock( | |
custom_validate=mock_custom_validate, allowed_routes=["/chat/completions"] | |
) | |
jwt_handler.auth_jwt = AsyncMock(return_value={"sub": "test_user"}) | |
try: | |
await JWTAuthManager.auth_builder( | |
api_key="test", | |
jwt_handler=jwt_handler, | |
request_data={}, | |
general_settings={}, | |
route="/chat/completions", | |
prisma_client=None, | |
user_api_key_cache=MagicMock(), | |
parent_otel_span=None, | |
proxy_logging_obj=MagicMock(), | |
) | |
except Exception: | |
pass | |
# Assert custom_validate was called with the jwt token | |
mock_custom_validate.assert_called_once_with({"sub": "test_user"}) | |