Spaces:
Configuration error
Configuration error
File size: 9,011 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
"""
In-Memory Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import json
import sys
import time
from typing import TYPE_CHECKING, Any, List, Optional
if TYPE_CHECKING:
from litellm.types.caching import RedisPipelineIncrementOperation
from pydantic import BaseModel
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
from .base_cache import BaseCache
class InMemoryCache(BaseCache):
def __init__(
self,
max_size_in_memory: Optional[int] = 200,
default_ttl: Optional[
int
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
):
"""
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
"""
self.max_size_in_memory = (
max_size_in_memory or 200
) # set an upper bound of 200 items in-memory
self.default_ttl = default_ttl or 600
self.max_size_per_item = (
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
) # 1MB = 1024KB
# in-memory cache
self.cache_dict: dict = {}
self.ttl_dict: dict = {}
def check_value_size(self, value: Any):
"""
Check if value size exceeds max_size_per_item (1MB)
Returns True if value size is acceptable, False otherwise
"""
try:
# Fast path for common primitive types that are typically small
if (
isinstance(value, (bool, int, float, str))
and len(str(value))
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
): # Conservative estimate
return True
# Direct size check for bytes objects
if isinstance(value, bytes):
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
# Handle special types without full conversion when possible
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
size = value.__sizeof__() / 1024
return size <= self.max_size_per_item
# Fallback for complex types
if isinstance(value, BaseModel) and hasattr(
value, "model_dump"
): # Pydantic v2
value = value.model_dump()
elif hasattr(value, "isoformat"): # datetime objects
return True # datetime strings are always small
# Only convert to JSON if absolutely necessary
if not isinstance(value, (str, bytes)):
value = json.dumps(value, default=str)
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
except Exception:
return False
def _is_key_expired(self, key: str) -> bool:
"""
Check if a specific key is expired
"""
return key in self.ttl_dict and time.time() > self.ttl_dict[key]
def _remove_key(self, key: str) -> None:
"""
Remove a key from both cache_dict and ttl_dict
"""
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
def evict_cache(self):
"""
Eviction policy:
- check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict
This guarantees the following:
- 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes
- 2. When ttl is set: the item will remain in memory for at least that amount of time
- 3. the size of in-memory cache is bounded
"""
for key in list(self.ttl_dict.keys()):
if self._is_key_expired(key):
self._remove_key(key)
# de-reference the removed item
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
# This can occur when an object is referenced by another object, but the reference is never removed.
def allow_ttl_override(self, key: str) -> bool:
"""
Check if ttl is set for a key
"""
ttl_time = self.ttl_dict.get(key)
if ttl_time is None: # if ttl is not set, allow override
return True
elif float(ttl_time) < time.time(): # if ttl is expired, allow override
return True
else:
return False
def set_cache(self, key, value, **kwargs):
if len(self.cache_dict) >= self.max_size_in_memory:
# only evict when cache is full
self.evict_cache()
if not self.check_value_size(value):
return
self.cache_dict[key] = value
if self.allow_ttl_override(key): # if ttl is not set, set it to default ttl
if "ttl" in kwargs and kwargs["ttl"] is not None:
self.ttl_dict[key] = time.time() + float(kwargs["ttl"])
else:
self.ttl_dict[key] = time.time() + self.default_ttl
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
"""
Add value to set
"""
# get the value
init_value = self.get_cache(key=key) or set()
for val in value:
init_value.add(val)
self.set_cache(key, init_value, ttl=ttl)
return value
def evict_element_if_expired(self, key: str) -> bool:
"""
Returns True if the element is expired and removed from the cache
Returns False if the element is not expired
"""
if self._is_key_expired(key):
self._remove_key(key)
return True
return False
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if self.evict_element_if_expired(key):
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except Exception:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: float, **kwargs) -> float:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
async def async_increment_pipeline(
self, increment_list: List["RedisPipelineIncrementOperation"], **kwargs
) -> Optional[List[float]]:
results = []
for increment in increment_list:
result = await self.async_increment(
increment["key"], increment["increment_value"], **kwargs
)
results.append(result)
return results
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self._remove_key(key)
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in in-memory cache
"""
return self.ttl_dict.get(key, None)
async def async_get_oldest_n_keys(self, n: int) -> List[str]:
"""
Get the oldest n keys in the cache
"""
# sorted ttl dict by ttl
sorted_ttl_dict = sorted(self.ttl_dict.items(), key=lambda x: x[1])
return [key for key, _ in sorted_ttl_dict[:n]]
|