""" This is a rate limiter implementation based on a similar one by Envoy proxy. This is currently in development and not yet ready for production. """ import os from datetime import datetime from typing import ( TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union, cast, ) from fastapi import HTTPException from litellm import DualCache from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth if TYPE_CHECKING: from opentelemetry.trace import Span as _Span from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache from litellm.types.caching import RedisPipelineIncrementOperation Span = Union[_Span, Any] InternalUsageCache = _InternalUsageCache else: Span = Any InternalUsageCache = Any BATCH_RATE_LIMITER_SCRIPT = """ local results = {} local now = tonumber(ARGV[1]) local window_size = tonumber(ARGV[2]) -- Process each window/counter pair for i = 1, #KEYS, 2 do local window_key = KEYS[i] local counter_key = KEYS[i + 1] local increment_value = 1 -- Check if window exists and is valid local window_start = redis.call('GET', window_key) if not window_start or (now - tonumber(window_start)) >= window_size then -- Reset window and counter redis.call('SET', window_key, tostring(now)) redis.call('SET', counter_key, increment_value) redis.call('EXPIRE', window_key, window_size) redis.call('EXPIRE', counter_key, window_size) table.insert(results, tostring(now)) -- window_start table.insert(results, increment_value) -- counter else local counter = redis.call('INCR', counter_key) table.insert(results, window_start) -- window_start table.insert(results, counter) -- counter end end return results """ class RateLimitDescriptorRateLimitObject(TypedDict, total=False): requests_per_unit: Optional[int] tokens_per_unit: Optional[int] max_parallel_requests: Optional[int] window_size: Optional[int] class RateLimitDescriptor(TypedDict): key: str value: str rate_limit: Optional[RateLimitDescriptorRateLimitObject] class RateLimitStatus(TypedDict): code: str current_limit: int limit_remaining: int rate_limit_type: Literal["requests", "tokens", "max_parallel_requests"] descriptor_key: str class RateLimitResponse(TypedDict): overall_code: str statuses: List[RateLimitStatus] class RateLimitResponseWithDescriptors(TypedDict): descriptors: List[RateLimitDescriptor] response: RateLimitResponse class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): def __init__(self, internal_usage_cache: InternalUsageCache): self.internal_usage_cache = internal_usage_cache if self.internal_usage_cache.dual_cache.redis_cache is not None: self.batch_rate_limiter_script = ( self.internal_usage_cache.dual_cache.redis_cache.async_register_script( BATCH_RATE_LIMITER_SCRIPT ) ) else: self.batch_rate_limiter_script = None self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60)) async def in_memory_cache_sliding_window( self, keys: List[str], now_int: int, window_size: int, ) -> List[Any]: """ Implement sliding window rate limiting logic using in-memory cache operations. This follows the same logic as the Redis Lua script but uses async cache operations. """ results: List[Any] = [] # Process each window/counter pair for i in range(0, len(keys), 2): window_key = keys[i] counter_key = keys[i + 1] increment_value = 1 # Get the window start time window_start = await self.internal_usage_cache.async_get_cache( key=window_key, litellm_parent_otel_span=None, local_only=True, ) # Check if window exists and is valid if window_start is None or (now_int - int(window_start)) >= window_size: # Reset window and counter await self.internal_usage_cache.async_set_cache( key=window_key, value=str(now_int), ttl=window_size, litellm_parent_otel_span=None, local_only=True, ) await self.internal_usage_cache.async_set_cache( key=counter_key, value=increment_value, ttl=window_size, litellm_parent_otel_span=None, local_only=True, ) results.append(str(now_int)) # window_start results.append(increment_value) # counter else: # Increment the counter current_counter = await self.internal_usage_cache.async_get_cache( key=counter_key, litellm_parent_otel_span=None, local_only=True, ) new_counter_value = ( int(current_counter) if current_counter is not None else 0 ) + increment_value await self.internal_usage_cache.async_set_cache( key=counter_key, value=new_counter_value, ttl=window_size, litellm_parent_otel_span=None, local_only=True, ) results.append(window_start) # window_start results.append(new_counter_value) # counter return results def create_rate_limit_keys( self, key: str, value: str, rate_limit_type: Literal["requests", "tokens", "max_parallel_requests"], ) -> str: """ Create the rate limit keys for the given key and value. """ counter_key = f"{{{key}:{value}}}:{rate_limit_type}" return counter_key def is_cache_list_over_limit( self, keys_to_fetch: List[str], cache_values: List[Any], key_metadata: Dict[str, Any], ) -> RateLimitResponse: """ Check if the cache values are over the limit. """ statuses: List[RateLimitStatus] = [] overall_code = "OK" for i in range(0, len(cache_values), 2): item_code = "OK" window_key = keys_to_fetch[i] counter_key = keys_to_fetch[i + 1] counter_value = cache_values[i + 1] requests_limit = key_metadata[window_key]["requests_limit"] max_parallel_requests_limit = key_metadata[window_key][ "max_parallel_requests_limit" ] tokens_limit = key_metadata[window_key]["tokens_limit"] # Determine which limit to use for current_limit and limit_remaining current_limit: Optional[int] = None rate_limit_type: Optional[ Literal["requests", "tokens", "max_parallel_requests"] ] = None if counter_key.endswith(":requests"): current_limit = requests_limit rate_limit_type = "requests" elif counter_key.endswith(":max_parallel_requests"): current_limit = max_parallel_requests_limit rate_limit_type = "max_parallel_requests" elif counter_key.endswith(":tokens"): current_limit = tokens_limit rate_limit_type = "tokens" if current_limit is None or rate_limit_type is None: continue if counter_value is not None and int(counter_value) + 1 > current_limit: overall_code = "OVER_LIMIT" item_code = "OVER_LIMIT" # Only compute limit_remaining if current_limit is not None limit_remaining = ( current_limit - int(counter_value) if counter_value is not None else current_limit ) statuses.append( { "code": item_code, "current_limit": current_limit, "limit_remaining": limit_remaining, "rate_limit_type": rate_limit_type, "descriptor_key": key_metadata[window_key]["descriptor_key"], } ) return RateLimitResponse(overall_code=overall_code, statuses=statuses) async def should_rate_limit( self, descriptors: List[RateLimitDescriptor], parent_otel_span: Optional[Span] = None, read_only: bool = False, ) -> RateLimitResponse: """ Check if any of the rate limit descriptors should be rate limited. Returns a RateLimitResponse with the overall code and status for each descriptor. Uses batch operations for Redis to improve performance. """ now = datetime.now().timestamp() now_int = int(now) # Convert to integer for Redis Lua script # Collect all keys and their metadata upfront keys_to_fetch: List[str] = [] key_metadata = {} # Store metadata for each key for descriptor in descriptors: descriptor_key = descriptor["key"] descriptor_value = descriptor["value"] rate_limit = descriptor.get("rate_limit", {}) or {} requests_limit = rate_limit.get("requests_per_unit") tokens_limit = rate_limit.get("tokens_per_unit") max_parallel_requests_limit = rate_limit.get("max_parallel_requests") window_size = rate_limit.get("window_size") or self.window_size window_key = f"{{{descriptor_key}:{descriptor_value}}}:window" rate_limit_set = False if requests_limit is not None: rpm_key = self.create_rate_limit_keys( descriptor_key, descriptor_value, "requests" ) keys_to_fetch.extend([window_key, rpm_key]) rate_limit_set = True if tokens_limit is not None: tpm_key = self.create_rate_limit_keys( descriptor_key, descriptor_value, "tokens" ) keys_to_fetch.extend([window_key, tpm_key]) rate_limit_set = True if max_parallel_requests_limit is not None: max_parallel_requests_key = self.create_rate_limit_keys( descriptor_key, descriptor_value, "max_parallel_requests" ) keys_to_fetch.extend([window_key, max_parallel_requests_key]) rate_limit_set = True if not rate_limit_set: continue key_metadata[window_key] = { "requests_limit": int(requests_limit) if requests_limit is not None else None, "tokens_limit": int(tokens_limit) if tokens_limit is not None else None, "max_parallel_requests_limit": int(max_parallel_requests_limit) if max_parallel_requests_limit is not None else None, "window_size": int(window_size), "descriptor_key": descriptor_key, } ## CHECK IN-MEMORY CACHE cache_values = await self.internal_usage_cache.async_batch_get_cache( keys=keys_to_fetch, parent_otel_span=parent_otel_span, local_only=True, ) if cache_values is not None: rate_limit_response = self.is_cache_list_over_limit( keys_to_fetch, cache_values, key_metadata ) if rate_limit_response["overall_code"] == "OVER_LIMIT": return rate_limit_response ## IF under limit, check Redis if self.batch_rate_limiter_script is not None: cache_values = await self.batch_rate_limiter_script( keys=keys_to_fetch, args=[now_int, self.window_size], # Use integer timestamp ) # update in-memory cache with new values for i in range(0, len(cache_values), 2): window_key = keys_to_fetch[i] counter_key = keys_to_fetch[i + 1] window_value = cache_values[i] counter_value = cache_values[i + 1] await self.internal_usage_cache.async_set_cache( key=counter_key, value=counter_value, ttl=self.window_size, litellm_parent_otel_span=parent_otel_span, local_only=True, ) await self.internal_usage_cache.async_set_cache( key=window_key, value=window_value, ttl=self.window_size, litellm_parent_otel_span=parent_otel_span, local_only=True, ) else: cache_values = await self.in_memory_cache_sliding_window( keys=keys_to_fetch, now_int=now_int, window_size=self.window_size, ) rate_limit_response = self.is_cache_list_over_limit( keys_to_fetch, cache_values, key_metadata ) return rate_limit_response async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, ): """ Pre-call hook to check rate limits before making the API call. """ from litellm.proxy.auth.auth_utils import ( get_key_model_rpm_limit, get_key_model_tpm_limit, ) verbose_proxy_logger.debug("Inside Rate Limit Pre-Call Hook") # Create rate limit descriptors descriptors = [] # API Key rate limits if user_api_key_dict.api_key: descriptors.append( RateLimitDescriptor( key="api_key", value=user_api_key_dict.api_key, rate_limit={ "requests_per_unit": user_api_key_dict.rpm_limit, "tokens_per_unit": user_api_key_dict.tpm_limit, "max_parallel_requests": user_api_key_dict.max_parallel_requests, "window_size": self.window_size, # 1 minute window }, ) ) # User rate limits if user_api_key_dict.user_id: descriptors.append( RateLimitDescriptor( key="user", value=user_api_key_dict.user_id, rate_limit={ "requests_per_unit": user_api_key_dict.user_rpm_limit, "tokens_per_unit": user_api_key_dict.user_tpm_limit, "window_size": self.window_size, }, ) ) # Team rate limits if user_api_key_dict.team_id: descriptors.append( RateLimitDescriptor( key="team", value=user_api_key_dict.team_id, rate_limit={ "requests_per_unit": user_api_key_dict.team_rpm_limit, "tokens_per_unit": user_api_key_dict.team_tpm_limit, "window_size": self.window_size, }, ) ) # End user rate limits if user_api_key_dict.end_user_id: descriptors.append( RateLimitDescriptor( key="end_user", value=user_api_key_dict.end_user_id, rate_limit={ "requests_per_unit": user_api_key_dict.end_user_rpm_limit, "tokens_per_unit": user_api_key_dict.end_user_tpm_limit, "window_size": self.window_size, }, ) ) # Model rate limits requested_model = data.get("model", None) if requested_model and ( get_key_model_tpm_limit(user_api_key_dict) is not None or get_key_model_rpm_limit(user_api_key_dict) is not None ): _tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) or {} _rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) or {} should_check_rate_limit = False if requested_model in _tpm_limit_for_key_model: should_check_rate_limit = True elif requested_model in _rpm_limit_for_key_model: should_check_rate_limit = True if should_check_rate_limit: model_specific_tpm_limit: Optional[int] = None model_specific_rpm_limit: Optional[int] = None if requested_model in _tpm_limit_for_key_model: model_specific_tpm_limit = _tpm_limit_for_key_model[requested_model] if requested_model in _rpm_limit_for_key_model: model_specific_rpm_limit = _rpm_limit_for_key_model[requested_model] descriptors.append( RateLimitDescriptor( key="model_per_key", value=f"{user_api_key_dict.api_key}:{requested_model}", rate_limit={ "requests_per_unit": model_specific_rpm_limit, "tokens_per_unit": model_specific_tpm_limit, "window_size": self.window_size, }, ) ) # Check rate limits response = await self.should_rate_limit( descriptors=descriptors, parent_otel_span=user_api_key_dict.parent_otel_span, ) if response["overall_code"] == "OVER_LIMIT": # Find which descriptor hit the limit for i, status in enumerate(response["statuses"]): if status["code"] == "OVER_LIMIT": descriptor = descriptors[i] raise HTTPException( status_code=429, detail=f"Rate limit exceeded for {descriptor['key']}: {descriptor['value']}. Remaining: {status['limit_remaining']}", headers={ "retry-after": str(self.window_size) }, # Retry after 1 minute ) else: # add descriptors to request headers data["litellm_proxy_rate_limit_response"] = response def _create_pipeline_operations( self, key: str, value: str, rate_limit_type: Literal["requests", "tokens", "max_parallel_requests"], total_tokens: int, ) -> List["RedisPipelineIncrementOperation"]: """ Create pipeline operations for TPM increments """ from litellm.types.caching import RedisPipelineIncrementOperation pipeline_operations: List[RedisPipelineIncrementOperation] = [] counter_key = self.create_rate_limit_keys( key=key, value=value, rate_limit_type="tokens", ) pipeline_operations.append( RedisPipelineIncrementOperation( key=counter_key, increment_value=total_tokens, ttl=self.window_size, ) ) return pipeline_operations def get_rate_limit_type(self) -> Literal["output", "input", "total"]: from litellm.proxy.proxy_server import general_settings specified_rate_limit_type = general_settings.get( "token_rate_limit_type", "output" ) if not specified_rate_limit_type or specified_rate_limit_type not in [ "output", "input", "total", ]: return "total" # default to total return specified_rate_limit_type async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): """ Update TPM usage on successful API calls by incrementing counters using pipeline """ from litellm.litellm_core_utils.core_helpers import ( _get_parent_otel_span_from_kwargs, ) from litellm.proxy.common_utils.callback_utils import ( get_model_group_from_litellm_kwargs, ) from litellm.types.caching import RedisPipelineIncrementOperation from litellm.types.utils import ModelResponse, Usage rate_limit_type = self.get_rate_limit_type() litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs( kwargs ) try: verbose_proxy_logger.debug( "INSIDE parallel request limiter ASYNC SUCCESS LOGGING" ) # Get metadata from kwargs user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key") user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id" ) user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_team_id" ) user_api_key_end_user_id = kwargs.get("user") or kwargs["litellm_params"][ "metadata" ].get("user_api_key_end_user_id") model_group = get_model_group_from_litellm_kwargs(kwargs) # Get total tokens from response total_tokens = 0 if isinstance(response_obj, ModelResponse): _usage = getattr(response_obj, "usage", None) if _usage and isinstance(_usage, Usage): if rate_limit_type == "output": total_tokens = _usage.completion_tokens elif rate_limit_type == "input": total_tokens = _usage.prompt_tokens elif rate_limit_type == "total": total_tokens = _usage.total_tokens # Create pipeline operations for TPM increments pipeline_operations: List[RedisPipelineIncrementOperation] = [] # API Key TPM if user_api_key: # MAX PARALLEL REQUESTS - only support for API Key, just decrement the counter counter_key = self.create_rate_limit_keys( key="api_key", value=user_api_key, rate_limit_type="max_parallel_requests", ) pipeline_operations.append( RedisPipelineIncrementOperation( key=counter_key, increment_value=-1, ttl=self.window_size, ) ) pipeline_operations.extend( self._create_pipeline_operations( key="api_key", value=user_api_key, rate_limit_type="tokens", total_tokens=total_tokens, ) ) # User TPM if user_api_key_user_id: # TPM pipeline_operations.extend( self._create_pipeline_operations( key="user", value=user_api_key_user_id, rate_limit_type="tokens", total_tokens=total_tokens, ) ) # Team TPM if user_api_key_team_id: pipeline_operations.extend( self._create_pipeline_operations( key="team", value=user_api_key_team_id, rate_limit_type="tokens", total_tokens=total_tokens, ) ) # End User TPM if user_api_key_end_user_id: pipeline_operations.extend( self._create_pipeline_operations( key="end_user", value=user_api_key_end_user_id, rate_limit_type="tokens", total_tokens=total_tokens, ) ) # Model-specific TPM if model_group and user_api_key: pipeline_operations.extend( self._create_pipeline_operations( key="model_per_key", value=f"{user_api_key}:{model_group}", rate_limit_type="tokens", total_tokens=total_tokens, ) ) # Execute all increments in a single pipeline if pipeline_operations: await self.internal_usage_cache.dual_cache.async_increment_cache_pipeline( increment_list=pipeline_operations, litellm_parent_otel_span=litellm_parent_otel_span, ) except Exception as e: verbose_proxy_logger.exception( f"Error in rate limit success event: {str(e)}" ) async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): """ Decrement max parallel requests counter for the API Key """ from litellm.litellm_core_utils.core_helpers import ( _get_parent_otel_span_from_kwargs, ) from litellm.types.caching import RedisPipelineIncrementOperation try: litellm_parent_otel_span: Union[ Span, None ] = _get_parent_otel_span_from_kwargs(kwargs) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key") pipeline_operations: List[RedisPipelineIncrementOperation] = [] if user_api_key: # MAX PARALLEL REQUESTS - only support for API Key, just decrement the counter counter_key = self.create_rate_limit_keys( key="api_key", value=user_api_key, rate_limit_type="max_parallel_requests", ) pipeline_operations.append( RedisPipelineIncrementOperation( key=counter_key, increment_value=-1, ttl=self.window_size, ) ) # Execute all increments in a single pipeline if pipeline_operations: await self.internal_usage_cache.dual_cache.async_increment_cache_pipeline( increment_list=pipeline_operations, litellm_parent_otel_span=litellm_parent_otel_span, ) except Exception as e: verbose_proxy_logger.exception( f"Error in rate limit failure event: {str(e)}" ) async def async_post_call_success_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, response ): """ Post-call hook to update rate limit headers in the response. """ try: from pydantic import BaseModel litellm_proxy_rate_limit_response = cast( Optional[RateLimitResponse], data.get("litellm_proxy_rate_limit_response", None), ) if litellm_proxy_rate_limit_response is not None: # Update response headers if hasattr(response, "_hidden_params"): _hidden_params = getattr(response, "_hidden_params") else: _hidden_params = None if _hidden_params is not None and ( isinstance(_hidden_params, BaseModel) or isinstance(_hidden_params, dict) ): if isinstance(_hidden_params, BaseModel): _hidden_params = _hidden_params.model_dump() _additional_headers = ( _hidden_params.get("additional_headers", {}) or {} ) # Add rate limit headers for status in litellm_proxy_rate_limit_response["statuses"]: prefix = f"x-ratelimit-{status['descriptor_key']}" _additional_headers[ f"{prefix}-remaining-{status['rate_limit_type']}" ] = status["limit_remaining"] _additional_headers[ f"{prefix}-limit-{status['rate_limit_type']}" ] = status["current_limit"] setattr( response, "_hidden_params", {**_hidden_params, "additional_headers": _additional_headers}, ) except Exception as e: verbose_proxy_logger.exception( f"Error in rate limit post-call hook: {str(e)}" )