# What is this? ## Common auth checks between jwt + key based auth """ Got Valid Token from Cache, DB Run checks for: 1. If user can call model 2. If user is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ import asyncio import re import time import traceback from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast from fastapi import status from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, RBAC_ROLES, CallInfo, LiteLLM_EndUserTable, LiteLLM_JWTAuth, LiteLLM_OrganizationMembershipTable, LiteLLM_OrganizationTable, LiteLLM_TeamTable, LiteLLM_TeamTableCachedObj, LiteLLM_UserTable, LiteLLMRoutes, LitellmUserRoles, ProxyErrorTypes, ProxyException, RoleBasedPermissions, UserAPIKeyAuth, ) from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.route_llm_request import route_request from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics from litellm.router import Router from litellm.types.services import ServiceTypes from .auth_checks_organization import organization_role_based_access_check if TYPE_CHECKING: from opentelemetry.trace import Span as _Span Span = _Span else: Span = Any last_db_access_time = LimitedSizeOrderedDict(max_size=100) db_cache_expiry = 5 # refresh every 5s all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value async def common_checks( request_body: dict, team_object: Optional[LiteLLM_TeamTable], user_object: Optional[LiteLLM_UserTable], end_user_object: Optional[LiteLLM_EndUserTable], global_proxy_spend: Optional[float], general_settings: dict, route: str, llm_router: Optional[Router], proxy_logging_obj: ProxyLogging, valid_token: Optional[UserAPIKeyAuth], ) -> bool: """ Common checks across jwt + key-based auth. 1. If team is blocked 2. If team can call model 3. If team is in budget 4. If user passed in (JWT or key.user_id) - is in budget 5. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget 6. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget 8. [OPTIONAL] If guardrails modified - is request allowed to change this 9. Check if request body is safe 10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks """ _model = request_body.get("model", None) # 1. If team is blocked if team_object is not None and team_object.blocked is True: raise Exception( f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin." ) # 2. If team can call model _team_model_access_check( team_object=team_object, model=_model, llm_router=llm_router, ) ## 2.1 If user can call model (if personal key) if team_object is None and user_object is not None: await can_user_call_model( model=_model, llm_router=llm_router, user_object=user_object, ) # 3. If team is in budget await _team_max_budget_check( team_object=team_object, proxy_logging_obj=proxy_logging_obj, valid_token=valid_token, ) # 4. If user is in budget ## 4.1 check personal budget, if personal key if ( (team_object is None or team_object.team_id is None) and user_object is not None and user_object.max_budget is not None ): user_budget = user_object.max_budget if user_budget < user_object.spend: raise litellm.BudgetExceededError( current_cost=user_object.spend, max_budget=user_budget, message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}", ) ## 4.2 check team member budget, if team key # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget if end_user_object is not None and end_user_object.litellm_budget_table is not None: end_user_budget = end_user_object.litellm_budget_table.max_budget if end_user_budget is not None and end_user_object.spend > end_user_budget: raise litellm.BudgetExceededError( current_cost=end_user_object.spend, max_budget=end_user_budget, message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}", ) # 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints if ( general_settings.get("enforce_user_param", None) is not None and general_settings["enforce_user_param"] is True ): if RouteChecks.is_llm_api_route(route=route) and "user" not in request_body: raise Exception( f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" ) # 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget if ( litellm.max_budget > 0 and global_proxy_spend is not None # only run global budget checks for OpenAI routes # Reason - the Admin UI should continue working if the proxy crosses it's global budget and RouteChecks.is_llm_api_route(route=route) and route != "/v1/models" and route != "/models" ): if global_proxy_spend > litellm.max_budget: raise litellm.BudgetExceededError( current_cost=global_proxy_spend, max_budget=litellm.max_budget ) _request_metadata: dict = request_body.get("metadata", {}) or {} if _request_metadata.get("guardrails"): # check if team allowed to modify guardrails from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails can_modify: bool = can_modify_guardrails(team_object) if can_modify is False: from fastapi import HTTPException raise HTTPException( status_code=403, detail={ "error": "Your team does not have permission to modify guardrails." }, ) # 10 [OPTIONAL] Organization RBAC checks organization_role_based_access_check( user_object=user_object, route=route, request_body=request_body ) return True def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool: """ Return if a user is allowed to access route. Helper function for `allowed_routes_check`. Parameters: - user_route: str - the route the user is trying to call - allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user. """ for allowed_route in allowed_routes: if ( allowed_route in LiteLLMRoutes.__members__ and user_route in LiteLLMRoutes[allowed_route].value ): return True elif allowed_route == user_route: return True return False def allowed_routes_check( user_role: Literal[ LitellmUserRoles.PROXY_ADMIN, LitellmUserRoles.TEAM, LitellmUserRoles.INTERNAL_USER, ], user_route: str, litellm_proxy_roles: LiteLLM_JWTAuth, ) -> bool: """ Check if user -> not admin - allowed to access these routes """ if user_role == LitellmUserRoles.PROXY_ADMIN: is_allowed = _allowed_routes_check( user_route=user_route, allowed_routes=litellm_proxy_roles.admin_allowed_routes, ) return is_allowed elif user_role == LitellmUserRoles.TEAM: if litellm_proxy_roles.team_allowed_routes is None: """ By default allow a team to call openai + info routes """ is_allowed = _allowed_routes_check( user_route=user_route, allowed_routes=["openai_routes", "info_routes"] ) return is_allowed elif litellm_proxy_roles.team_allowed_routes is not None: is_allowed = _allowed_routes_check( user_route=user_route, allowed_routes=litellm_proxy_roles.team_allowed_routes, ) return is_allowed return False def allowed_route_check_inside_route( user_api_key_dict: UserAPIKeyAuth, requested_user_id: Optional[str], ) -> bool: ret_val = True if ( user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY ): ret_val = False if requested_user_id is not None and user_api_key_dict.user_id is not None: if user_api_key_dict.user_id == requested_user_id: ret_val = True return ret_val def get_actual_routes(allowed_routes: list) -> list: actual_routes: list = [] for route_name in allowed_routes: try: route_value = LiteLLMRoutes[route_name].value if isinstance(route_value, set): actual_routes.extend(list(route_value)) else: actual_routes.extend(route_value) except KeyError: actual_routes.append(route_name) return actual_routes @log_db_metrics async def get_end_user_object( end_user_id: Optional[str], prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, parent_otel_span: Optional[Span] = None, proxy_logging_obj: Optional[ProxyLogging] = None, ) -> Optional[LiteLLM_EndUserTable]: """ Returns end user object, if in db. Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user). """ if prisma_client is None: raise Exception("No db connected") if end_user_id is None: return None _key = "end_user_id:{}".format(end_user_id) def check_in_budget(end_user_obj: LiteLLM_EndUserTable): if end_user_obj.litellm_budget_table is None: return end_user_budget = end_user_obj.litellm_budget_table.max_budget if end_user_budget is not None and end_user_obj.spend > end_user_budget: raise litellm.BudgetExceededError( current_cost=end_user_obj.spend, max_budget=end_user_budget ) # check if in cache cached_user_obj = await user_api_key_cache.async_get_cache(key=_key) if cached_user_obj is not None: if isinstance(cached_user_obj, dict): return_obj = LiteLLM_EndUserTable(**cached_user_obj) check_in_budget(end_user_obj=return_obj) return return_obj elif isinstance(cached_user_obj, LiteLLM_EndUserTable): return_obj = cached_user_obj check_in_budget(end_user_obj=return_obj) return return_obj # else, check db try: response = await prisma_client.db.litellm_endusertable.find_unique( where={"user_id": end_user_id}, include={"litellm_budget_table": True}, ) if response is None: raise Exception # save the end-user object to cache await user_api_key_cache.async_set_cache( key="end_user_id:{}".format(end_user_id), value=response ) _response = LiteLLM_EndUserTable(**response.dict()) check_in_budget(end_user_obj=_response) return _response except Exception as e: # if end-user not in db if isinstance(e, litellm.BudgetExceededError): raise e return None def model_in_access_group( model: str, team_models: Optional[List[str]], llm_router: Optional[Router] ) -> bool: from collections import defaultdict if team_models is None: return True if model in team_models: return True access_groups: dict[str, list[str]] = defaultdict(list) if llm_router: access_groups = llm_router.get_model_access_groups(model_name=model) if len(access_groups) > 0: # check if token contains any model access groups for idx, m in enumerate( team_models ): # loop token models, if any of them are an access group add the access group if m in access_groups: return True # Filter out models that are access_groups filtered_models = [m for m in team_models if m not in access_groups] if model in filtered_models: return True return False def _should_check_db( key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int ) -> bool: """ Prevent calling db repeatedly for items that don't exist in the db. """ current_time = time.time() # if key doesn't exist in last_db_access_time -> check db if key not in last_db_access_time: return True elif ( last_db_access_time[key][0] is not None ): # check db for non-null values (for refresh operations) return True elif last_db_access_time[key][0] is None: if current_time - last_db_access_time[key] >= db_cache_expiry: return True return False def _update_last_db_access_time( key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict ): last_db_access_time[key] = (value, time.time()) def _get_role_based_permissions( rbac_role: RBAC_ROLES, general_settings: dict, key: Literal["models", "routes"], ) -> Optional[List[str]]: """ Get the role based permissions from the general settings. """ role_based_permissions = cast( Optional[List[RoleBasedPermissions]], general_settings.get("role_permissions", []), ) if role_based_permissions is None: return None for role_based_permission in role_based_permissions: if role_based_permission.role == rbac_role: return getattr(role_based_permission, key) return None def get_role_based_models( rbac_role: RBAC_ROLES, general_settings: dict, ) -> Optional[List[str]]: """ Get the models allowed for a user role. Used by JWT Auth. """ return _get_role_based_permissions( rbac_role=rbac_role, general_settings=general_settings, key="models", ) def get_role_based_routes( rbac_role: RBAC_ROLES, general_settings: dict, ) -> Optional[List[str]]: """ Get the routes allowed for a user role. """ return _get_role_based_permissions( rbac_role=rbac_role, general_settings=general_settings, key="routes", ) async def _get_fuzzy_user_object( prisma_client: PrismaClient, sso_user_id: Optional[str] = None, user_email: Optional[str] = None, ) -> Optional[LiteLLM_UserTable]: """ Checks if sso user is in db. Called when user id match is not found in db. - Check if sso_user_id is user_id in db - Check if sso_user_id is sso_user_id in db - Check if user_email is user_email in db - If not, create new user with user_email and sso_user_id and user_id = sso_user_id """ response = None if sso_user_id is not None: response = await prisma_client.db.litellm_usertable.find_unique( where={"sso_user_id": sso_user_id}, include={"organization_memberships": True}, ) if response is None and user_email is not None: response = await prisma_client.db.litellm_usertable.find_first( where={"user_email": user_email}, include={"organization_memberships": True}, ) if response is not None and sso_user_id is not None: # update sso_user_id asyncio.create_task( # background task to update user with sso id prisma_client.db.litellm_usertable.update( where={"user_id": response.user_id}, data={"sso_user_id": sso_user_id}, ) ) return response @log_db_metrics async def get_user_object( user_id: Optional[str], prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, user_id_upsert: bool, parent_otel_span: Optional[Span] = None, proxy_logging_obj: Optional[ProxyLogging] = None, sso_user_id: Optional[str] = None, user_email: Optional[str] = None, ) -> Optional[LiteLLM_UserTable]: """ - Check if user id in proxy User Table - if valid, return LiteLLM_UserTable object with defined limits - if not, then raise an error """ if user_id is None: return None # check if in cache cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if cached_user_obj is not None: if isinstance(cached_user_obj, dict): return LiteLLM_UserTable(**cached_user_obj) elif isinstance(cached_user_obj, LiteLLM_UserTable): return cached_user_obj # else, check db if prisma_client is None: raise Exception("No db connected") try: db_access_time_key = "user_id:{}".format(user_id) should_check_db = _should_check_db( key=db_access_time_key, last_db_access_time=last_db_access_time, db_cache_expiry=db_cache_expiry, ) if should_check_db: response = await prisma_client.db.litellm_usertable.find_unique( where={"user_id": user_id}, include={"organization_memberships": True} ) if response is None: response = await _get_fuzzy_user_object( prisma_client=prisma_client, sso_user_id=sso_user_id, user_email=user_email, ) else: response = None if response is None: if user_id_upsert: response = await prisma_client.db.litellm_usertable.create( data={"user_id": user_id}, include={"organization_memberships": True}, ) else: raise Exception if ( response.organization_memberships is not None and len(response.organization_memberships) > 0 ): # dump each organization membership to type LiteLLM_OrganizationMembershipTable _dumped_memberships = [ LiteLLM_OrganizationMembershipTable(**membership.model_dump()) for membership in response.organization_memberships if membership is not None ] response.organization_memberships = _dumped_memberships _response = LiteLLM_UserTable(**dict(response)) response_dict = _response.model_dump() # save the user object to cache await user_api_key_cache.async_set_cache(key=user_id, value=response_dict) # save to db access time _update_last_db_access_time( key=db_access_time_key, value=response_dict, last_db_access_time=last_db_access_time, ) return _response except Exception as e: # if user not in db raise ValueError( f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call. Got error - {e}" ) async def _cache_management_object( key: str, value: BaseModel, user_api_key_cache: DualCache, proxy_logging_obj: Optional[ProxyLogging], ): await user_api_key_cache.async_set_cache(key=key, value=value) async def _cache_team_object( team_id: str, team_table: LiteLLM_TeamTableCachedObj, user_api_key_cache: DualCache, proxy_logging_obj: Optional[ProxyLogging], ): key = "team_id:{}".format(team_id) ## CACHE REFRESH TIME! team_table.last_refreshed_at = time.time() await _cache_management_object( key=key, value=team_table, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, ) async def _cache_key_object( hashed_token: str, user_api_key_obj: UserAPIKeyAuth, user_api_key_cache: DualCache, proxy_logging_obj: Optional[ProxyLogging], ): key = hashed_token ## CACHE REFRESH TIME user_api_key_obj.last_refreshed_at = time.time() await _cache_management_object( key=key, value=user_api_key_obj, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, ) async def _delete_cache_key_object( hashed_token: str, user_api_key_cache: DualCache, proxy_logging_obj: Optional[ProxyLogging], ): key = hashed_token user_api_key_cache.delete_cache(key=key) ## UPDATE REDIS CACHE ## if proxy_logging_obj is not None: await proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache( key=key ) @log_db_metrics async def _get_team_db_check(team_id: str, prisma_client: PrismaClient): return await prisma_client.db.litellm_teamtable.find_unique( where={"team_id": team_id} ) async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient): return await prisma_client.db.litellm_teamtable.find_unique( where={"team_id": team_id} ) async def _get_team_object_from_user_api_key_cache( team_id: str, prisma_client: PrismaClient, user_api_key_cache: DualCache, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int, proxy_logging_obj: Optional[ProxyLogging], key: str, ) -> LiteLLM_TeamTableCachedObj: db_access_time_key = key should_check_db = _should_check_db( key=db_access_time_key, last_db_access_time=last_db_access_time, db_cache_expiry=db_cache_expiry, ) if should_check_db: response = await _get_team_db_check( team_id=team_id, prisma_client=prisma_client ) else: response = None if response is None: raise Exception _response = LiteLLM_TeamTableCachedObj(**response.dict()) # save the team object to cache await _cache_team_object( team_id=team_id, team_table=_response, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, ) # save to db access time # save to db access time _update_last_db_access_time( key=db_access_time_key, value=_response, last_db_access_time=last_db_access_time, ) return _response async def _get_team_object_from_cache( key: str, proxy_logging_obj: Optional[ProxyLogging], user_api_key_cache: DualCache, parent_otel_span: Optional[Span], ) -> Optional[LiteLLM_TeamTableCachedObj]: cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None ## CHECK REDIS CACHE ## if ( proxy_logging_obj is not None and proxy_logging_obj.internal_usage_cache.dual_cache ): cached_team_obj = ( await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache( key=key, parent_otel_span=parent_otel_span ) ) if cached_team_obj is None: cached_team_obj = await user_api_key_cache.async_get_cache(key=key) if cached_team_obj is not None: if isinstance(cached_team_obj, dict): return LiteLLM_TeamTableCachedObj(**cached_team_obj) elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj): return cached_team_obj return None async def get_team_object( team_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, parent_otel_span: Optional[Span] = None, proxy_logging_obj: Optional[ProxyLogging] = None, check_cache_only: Optional[bool] = None, check_db_only: Optional[bool] = None, ) -> LiteLLM_TeamTableCachedObj: """ - Check if team id in proxy Team Table - if valid, return LiteLLM_TeamTable object with defined limits - if not, then raise an error Raises: - Exception: If team doesn't exist in db or cache """ if prisma_client is None: raise Exception( "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" ) # check if in cache key = "team_id:{}".format(team_id) if not check_db_only: cached_team_obj = await _get_team_object_from_cache( key=key, proxy_logging_obj=proxy_logging_obj, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, ) if cached_team_obj is not None: return cached_team_obj if check_cache_only: raise Exception( f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}." ) # else, check db try: return await _get_team_object_from_user_api_key_cache( team_id=team_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, last_db_access_time=last_db_access_time, db_cache_expiry=db_cache_expiry, key=key, ) except Exception: raise Exception( f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." ) @log_db_metrics async def get_key_object( hashed_token: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, parent_otel_span: Optional[Span] = None, proxy_logging_obj: Optional[ProxyLogging] = None, check_cache_only: Optional[bool] = None, ) -> UserAPIKeyAuth: """ - Check if team id in proxy Team Table - if valid, return LiteLLM_TeamTable object with defined limits - if not, then raise an error """ if prisma_client is None: raise Exception( "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" ) # check if in cache key = hashed_token cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache( key=key ) if cached_key_obj is not None: if isinstance(cached_key_obj, dict): return UserAPIKeyAuth(**cached_key_obj) elif isinstance(cached_key_obj, UserAPIKeyAuth): return cached_key_obj if check_cache_only: raise Exception( f"Key doesn't exist in cache + check_cache_only=True. key={key}." ) # else, check db try: _valid_token: Optional[BaseModel] = await prisma_client.get_data( token=hashed_token, table_name="combined_view", parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) if _valid_token is None: raise Exception _response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True)) # save the key object to cache await _cache_key_object( hashed_token=hashed_token, user_api_key_obj=_response, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, ) return _response except DB_CONNECTION_ERROR_TYPES as e: return await _handle_failed_db_connection_for_get_key_object(e=e) except Exception: traceback.print_exc() raise Exception( f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call." ) async def _handle_failed_db_connection_for_get_key_object( e: Exception, ) -> UserAPIKeyAuth: """ Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB Use this if you don't want failed DB queries to block LLM API reqiests Returns: - UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True Raises: - Orignal Exception in all other cases """ from litellm.proxy.proxy_server import ( general_settings, litellm_proxy_admin_name, proxy_logging_obj, ) # If this flag is on, requests failing to connect to the DB will be allowed if general_settings.get("allow_requests_on_db_unavailable", False) is True: # log this as a DB failure on prometheus proxy_logging_obj.service_logging_obj.service_failure_hook( service=ServiceTypes.DB, call_type="get_key_object", error=e, duration=0.0, ) return UserAPIKeyAuth( key_name="failed-to-connect-to-db", token="failed-to-connect-to-db", user_id=litellm_proxy_admin_name, ) else: # raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus raise e @log_db_metrics async def get_org_object( org_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, parent_otel_span: Optional[Span] = None, proxy_logging_obj: Optional[ProxyLogging] = None, ) -> Optional[LiteLLM_OrganizationTable]: """ - Check if org id in proxy Org Table - if valid, return LiteLLM_OrganizationTable object - if not, then raise an error """ if prisma_client is None: raise Exception( "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" ) # check if in cache cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id)) if cached_org_obj is not None: if isinstance(cached_org_obj, dict): return LiteLLM_OrganizationTable(**cached_org_obj) elif isinstance(cached_org_obj, LiteLLM_OrganizationTable): return cached_org_obj # else, check db try: response = await prisma_client.db.litellm_organizationtable.find_unique( where={"organization_id": org_id} ) if response is None: raise Exception return response except Exception: raise Exception( f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call." ) async def _can_object_call_model( model: str, llm_router: Optional[Router], models: List[str], ) -> Literal[True]: """ Checks if token can call a given model Returns: - True: if token allowed to call model Raises: - Exception: If token not allowed to call model """ if model in litellm.model_alias_map: model = litellm.model_alias_map[model] ## check if model in allowed model names from collections import defaultdict access_groups: Dict[str, List[str]] = defaultdict(list) if llm_router: access_groups = llm_router.get_model_access_groups(model_name=model) if ( len(access_groups) > 0 and llm_router is not None ): # check if token contains any model access groups for idx, m in enumerate( models ): # loop token models, if any of them are an access group add the access group if m in access_groups: return True # Filter out models that are access_groups filtered_models = [m for m in models if m not in access_groups] verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") if _model_matches_any_wildcard_pattern_in_list( model=model, allowed_model_list=filtered_models ): return True all_model_access: bool = False if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models: all_model_access = True if model is not None and model not in filtered_models and all_model_access is False: raise ProxyException( message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}", type=ProxyErrorTypes.key_model_access_denied, param="model", code=status.HTTP_401_UNAUTHORIZED, ) verbose_proxy_logger.debug( f"filtered allowed_models: {filtered_models}; models: {models}" ) return True async def can_key_call_model( model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth, llm_router: Optional[litellm.Router], ) -> Literal[True]: """ Checks if token can call a given model Returns: - True: if token allowed to call model Raises: - Exception: If token not allowed to call model """ return await _can_object_call_model( model=model, llm_router=llm_router, models=valid_token.models, ) async def can_user_call_model( model: str, llm_router: Optional[Router], user_object: Optional[LiteLLM_UserTable], ) -> Literal[True]: if user_object is None: return True return await _can_object_call_model( model=model, llm_router=llm_router, models=user_object.models, ) async def is_valid_fallback_model( model: str, llm_router: Optional[Router], user_model: Optional[str], ) -> Literal[True]: """ Try to route the fallback model request. Validate if it can't be routed. Help catch invalid fallback models. """ await route_request( data={ "model": model, "messages": [{"role": "user", "content": "Who was Alexander?"}], }, llm_router=llm_router, user_model=user_model, route_type="acompletion", # route type shouldn't affect the fallback model check ) return True async def _virtual_key_max_budget_check( valid_token: UserAPIKeyAuth, proxy_logging_obj: ProxyLogging, user_obj: Optional[LiteLLM_UserTable] = None, ): """ Raises: BudgetExceededError if the token is over it's max budget. Triggers a budget alert if the token is over it's max budget. """ if valid_token.spend is not None and valid_token.max_budget is not None: #################################### # collect information for alerting # #################################### user_email = None # Check if the token has any user id information if user_obj is not None: user_email = user_obj.user_email call_info = CallInfo( token=valid_token.token, spend=valid_token.spend, max_budget=valid_token.max_budget, user_id=valid_token.user_id, team_id=valid_token.team_id, user_email=user_email, key_alias=valid_token.key_alias, ) asyncio.create_task( proxy_logging_obj.budget_alerts( type="token_budget", user_info=call_info, ) ) #################################### # collect information for alerting # #################################### if valid_token.spend >= valid_token.max_budget: raise litellm.BudgetExceededError( current_cost=valid_token.spend, max_budget=valid_token.max_budget, ) async def _virtual_key_soft_budget_check( valid_token: UserAPIKeyAuth, proxy_logging_obj: ProxyLogging, ): """ Triggers a budget alert if the token is over it's soft budget. """ if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget: verbose_proxy_logger.debug( "Crossed Soft Budget for token %s, spend %s, soft_budget %s", valid_token.token, valid_token.spend, valid_token.soft_budget, ) call_info = CallInfo( token=valid_token.token, spend=valid_token.spend, max_budget=valid_token.max_budget, soft_budget=valid_token.soft_budget, user_id=valid_token.user_id, team_id=valid_token.team_id, team_alias=valid_token.team_alias, user_email=None, key_alias=valid_token.key_alias, ) asyncio.create_task( proxy_logging_obj.budget_alerts( type="soft_budget", user_info=call_info, ) ) async def _team_max_budget_check( team_object: Optional[LiteLLM_TeamTable], valid_token: Optional[UserAPIKeyAuth], proxy_logging_obj: ProxyLogging, ): """ Check if the team is over it's max budget. Raises: BudgetExceededError if the team is over it's max budget. Triggers a budget alert if the team is over it's max budget. """ if ( team_object is not None and team_object.max_budget is not None and team_object.spend is not None and team_object.spend > team_object.max_budget ): if valid_token: call_info = CallInfo( token=valid_token.token, spend=team_object.spend, max_budget=team_object.max_budget, user_id=valid_token.user_id, team_id=valid_token.team_id, team_alias=valid_token.team_alias, ) asyncio.create_task( proxy_logging_obj.budget_alerts( type="team_budget", user_info=call_info, ) ) raise litellm.BudgetExceededError( current_cost=team_object.spend, max_budget=team_object.max_budget, message=f"Budget has been exceeded! Team={team_object.team_id} Current cost: {team_object.spend}, Max budget: {team_object.max_budget}", ) def _team_model_access_check( model: Optional[str], team_object: Optional[LiteLLM_TeamTable], llm_router: Optional[Router], ): """ Access check for team models Raises: Exception if the team is not allowed to call the`model` """ if ( model is not None and team_object is not None and team_object.models is not None and len(team_object.models) > 0 and model not in team_object.models ): # this means the team has access to all models on the proxy if "all-proxy-models" in team_object.models or "*" in team_object.models: # this means the team has access to all models on the proxy pass # check if the team model is an access_group elif ( model_in_access_group( model=model, team_models=team_object.models, llm_router=llm_router ) is True ): pass elif model and "*" in model: pass elif _model_matches_any_wildcard_pattern_in_list( model=model, allowed_model_list=team_object.models ): pass else: raise ProxyException( message=f"Team not allowed to access model. Team={team_object.team_id}, Model={model}. Allowed team models = {team_object.models}", type=ProxyErrorTypes.team_model_access_denied, param="model", code=status.HTTP_401_UNAUTHORIZED, ) def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool: """ Check if a model matches an allowed pattern. Handles exact matches and wildcard patterns. Args: model (str): The model to check (e.g., "bedrock/anthropic.claude-3-5-sonnet-20240620") allowed_model_pattern (str): The allowed pattern (e.g., "bedrock/*", "*", "openai/*") Returns: bool: True if model matches the pattern, False otherwise """ if "*" in allowed_model_pattern: pattern = f"^{allowed_model_pattern.replace('*', '.*')}$" return bool(re.match(pattern, model)) return False def _model_matches_any_wildcard_pattern_in_list( model: str, allowed_model_list: list ) -> bool: """ Returns True if a model matches any wildcard pattern in a list. eg. - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns True - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True - model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False """ if any( _is_wildcard_pattern(allowed_model_pattern) and is_model_allowed_by_pattern( model=model, allowed_model_pattern=allowed_model_pattern ) for allowed_model_pattern in allowed_model_list ): return True if any( _is_wildcard_pattern(allowed_model_pattern) and _model_custom_llm_provider_matches_wildcard_pattern( model=model, allowed_model_pattern=allowed_model_pattern ) for allowed_model_pattern in allowed_model_list ): return True return False def _model_custom_llm_provider_matches_wildcard_pattern( model: str, allowed_model_pattern: str ) -> bool: """ Returns True for this scenario: - `model=gpt-4o` - `allowed_model_pattern=openai/*` or - `model=claude-3-5-sonnet-20240620` - `allowed_model_pattern=anthropic/*` """ try: model, custom_llm_provider, _, _ = get_llm_provider(model=model) except Exception: return False return is_model_allowed_by_pattern( model=f"{custom_llm_provider}/{model}", allowed_model_pattern=allowed_model_pattern, ) def _is_wildcard_pattern(allowed_model_pattern: str) -> bool: """ Returns True if the pattern is a wildcard pattern. Checks if `*` is in the pattern. """ return "*" in allowed_model_pattern