|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import litellm |
|
import time, logging |
|
import json, traceback, ast |
|
from typing import Optional |
|
|
|
def get_prompt(*args, **kwargs): |
|
|
|
if len(args) > 1: |
|
messages = args[1] |
|
prompt = " ".join(message["content"] for message in messages) |
|
return prompt |
|
if "messages" in kwargs: |
|
messages = kwargs["messages"] |
|
prompt = " ".join(message["content"] for message in messages) |
|
return prompt |
|
return None |
|
|
|
def print_verbose(print_statement): |
|
if litellm.set_verbose: |
|
print(print_statement) |
|
|
|
class BaseCache: |
|
def set_cache(self, key, value, **kwargs): |
|
raise NotImplementedError |
|
|
|
def get_cache(self, key, **kwargs): |
|
raise NotImplementedError |
|
|
|
|
|
class InMemoryCache(BaseCache): |
|
def __init__(self): |
|
|
|
self.cache_dict = {} |
|
self.ttl_dict = {} |
|
|
|
def set_cache(self, key, value, **kwargs): |
|
self.cache_dict[key] = value |
|
if "ttl" in kwargs: |
|
self.ttl_dict[key] = time.time() + kwargs["ttl"] |
|
|
|
def get_cache(self, key, **kwargs): |
|
if key in self.cache_dict: |
|
if key in self.ttl_dict: |
|
if time.time() > self.ttl_dict[key]: |
|
self.cache_dict.pop(key, None) |
|
return None |
|
original_cached_response = self.cache_dict[key] |
|
try: |
|
cached_response = json.loads(original_cached_response) |
|
except: |
|
cached_response = original_cached_response |
|
if isinstance(cached_response, dict): |
|
cached_response['cache'] = True |
|
return cached_response |
|
return None |
|
|
|
def flush_cache(self): |
|
self.cache_dict.clear() |
|
self.ttl_dict.clear() |
|
|
|
|
|
class RedisCache(BaseCache): |
|
def __init__(self, host, port, password, **kwargs): |
|
import redis |
|
|
|
self.redis_client = redis.Redis(host=host, port=port, password=password, **kwargs) |
|
|
|
def set_cache(self, key, value, **kwargs): |
|
ttl = kwargs.get("ttl", None) |
|
try: |
|
self.redis_client.set(name=key, value=str(value), ex=ttl) |
|
except Exception as e: |
|
|
|
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) |
|
|
|
def get_cache(self, key, **kwargs): |
|
try: |
|
|
|
cached_response = self.redis_client.get(key) |
|
if cached_response != None: |
|
|
|
cached_response = cached_response.decode("utf-8") |
|
try: |
|
cached_response = json.loads(cached_response) |
|
except: |
|
cached_response = ast.literal_eval(cached_response) |
|
if isinstance(cached_response, dict): |
|
cached_response['cache'] = True |
|
return cached_response |
|
except Exception as e: |
|
|
|
traceback.print_exc() |
|
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) |
|
|
|
def flush_cache(self): |
|
self.redis_client.flushall() |
|
|
|
class DualCache(BaseCache): |
|
""" |
|
This updates both Redis and an in-memory cache simultaneously. |
|
When data is updated or inserted, it is written to both the in-memory cache + Redis. |
|
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data. |
|
""" |
|
def __init__(self, in_memory_cache: Optional[InMemoryCache] =None, redis_cache: Optional[RedisCache] =None) -> None: |
|
super().__init__() |
|
|
|
self.in_memory_cache = in_memory_cache or InMemoryCache() |
|
|
|
self.redis_cache = redis_cache |
|
|
|
def set_cache(self, key, value, **kwargs): |
|
|
|
try: |
|
print_verbose(f"set cache: key: {key}; value: {value}") |
|
if self.in_memory_cache is not None: |
|
self.in_memory_cache.set_cache(key, value, **kwargs) |
|
|
|
if self.redis_cache is not None: |
|
self.redis_cache.set_cache(key, value, **kwargs) |
|
except Exception as e: |
|
print_verbose(e) |
|
|
|
def get_cache(self, key, **kwargs): |
|
|
|
try: |
|
print_verbose(f"get cache: cache key: {key}") |
|
result = None |
|
if self.in_memory_cache is not None: |
|
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs) |
|
|
|
if in_memory_result is not None: |
|
result = in_memory_result |
|
|
|
if self.redis_cache is not None: |
|
|
|
redis_result = self.redis_cache.get_cache(key, **kwargs) |
|
|
|
if redis_result is not None: |
|
|
|
self.in_memory_cache.set_cache(key, redis_result, **kwargs) |
|
|
|
result = redis_result |
|
|
|
print_verbose(f"get cache: cache result: {result}") |
|
return result |
|
except Exception as e: |
|
traceback.print_exc() |
|
|
|
def flush_cache(self): |
|
if self.in_memory_cache is not None: |
|
self.in_memory_cache.flush_cache() |
|
if self.redis_cache is not None: |
|
self.redis_cache.flush_cache() |
|
|
|
|
|
class Cache: |
|
def __init__( |
|
self, |
|
type="local", |
|
host=None, |
|
port=None, |
|
password=None, |
|
**kwargs |
|
): |
|
""" |
|
Initializes the cache based on the given type. |
|
|
|
Args: |
|
type (str, optional): The type of cache to initialize. Defaults to "local". |
|
host (str, optional): The host address for the Redis cache. Required if type is "redis". |
|
port (int, optional): The port number for the Redis cache. Required if type is "redis". |
|
password (str, optional): The password for the Redis cache. Required if type is "redis". |
|
**kwargs: Additional keyword arguments for redis.Redis() cache |
|
|
|
Raises: |
|
ValueError: If an invalid cache type is provided. |
|
|
|
Returns: |
|
None |
|
""" |
|
if type == "redis": |
|
self.cache = RedisCache(host, port, password, **kwargs) |
|
if type == "local": |
|
self.cache = InMemoryCache() |
|
if "cache" not in litellm.input_callback: |
|
litellm.input_callback.append("cache") |
|
if "cache" not in litellm.success_callback: |
|
litellm.success_callback.append("cache") |
|
|
|
def get_cache_key(self, *args, **kwargs): |
|
""" |
|
Get the cache key for the given arguments. |
|
|
|
Args: |
|
*args: args to litellm.completion() or embedding() |
|
**kwargs: kwargs to litellm.completion() or embedding() |
|
|
|
Returns: |
|
str: The cache key generated from the arguments, or None if no cache key could be generated. |
|
""" |
|
cache_key ="" |
|
for param in kwargs: |
|
|
|
if param in set(["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]): |
|
|
|
if param == "model" and kwargs.get("metadata", None) is not None and kwargs["metadata"].get("model_group", None) is not None: |
|
param_value = kwargs["metadata"].get("model_group", None) |
|
else: |
|
param_value = kwargs[param] |
|
cache_key+= f"{str(param)}: {str(param_value)}" |
|
return cache_key |
|
|
|
def generate_streaming_content(self, content): |
|
chunk_size = 5 |
|
for i in range(0, len(content), chunk_size): |
|
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]} |
|
time.sleep(0.02) |
|
|
|
def get_cache(self, *args, **kwargs): |
|
""" |
|
Retrieves the cached result for the given arguments. |
|
|
|
Args: |
|
*args: args to litellm.completion() or embedding() |
|
**kwargs: kwargs to litellm.completion() or embedding() |
|
|
|
Returns: |
|
The cached result if it exists, otherwise None. |
|
""" |
|
try: |
|
if "cache_key" in kwargs: |
|
cache_key = kwargs["cache_key"] |
|
else: |
|
cache_key = self.get_cache_key(*args, **kwargs) |
|
if cache_key is not None: |
|
cached_result = self.cache.get_cache(cache_key) |
|
if cached_result != None and 'stream' in kwargs and kwargs['stream'] == True: |
|
|
|
return self.generate_streaming_content(cached_result["choices"][0]['message']['content']) |
|
return cached_result |
|
except Exception as e: |
|
logging.debug(f"An exception occurred: {traceback.format_exc()}") |
|
return None |
|
|
|
def add_cache(self, result, *args, **kwargs): |
|
""" |
|
Adds a result to the cache. |
|
|
|
Args: |
|
*args: args to litellm.completion() or embedding() |
|
**kwargs: kwargs to litellm.completion() or embedding() |
|
|
|
Returns: |
|
None |
|
""" |
|
try: |
|
if "cache_key" in kwargs: |
|
cache_key = kwargs["cache_key"] |
|
else: |
|
cache_key = self.get_cache_key(*args, **kwargs) |
|
if cache_key is not None: |
|
if isinstance(result, litellm.ModelResponse): |
|
result = result.model_dump_json() |
|
self.cache.set_cache(cache_key, result, **kwargs) |
|
except Exception as e: |
|
pass |
|
|