File size: 5,736 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
"""
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
"""
import hashlib
import json
from typing import TYPE_CHECKING, Any, List, Optional, TypedDict
from litellm.caching.caching import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router
litellm_router = Router
Span = _Span
else:
Span = Any
litellm_router = Any
class PromptCachingCacheValue(TypedDict):
model_id: str
class PromptCachingCache:
def __init__(self, cache: DualCache):
self.cache = cache
self.in_memory_cache = InMemoryCache()
@staticmethod
def serialize_object(obj: Any) -> Any:
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
if hasattr(obj, "dict"):
# If the object is a Pydantic model, use its `dict()` method
return obj.dict()
elif isinstance(obj, dict):
# If the object is a dictionary, serialize it with sorted keys
return json.dumps(
obj, sort_keys=True, separators=(",", ":")
) # Standardize serialization
elif isinstance(obj, list):
# Serialize lists by ensuring each element is handled properly
return [PromptCachingCache.serialize_object(item) for item in obj]
elif isinstance(obj, (int, float, bool)):
return obj # Keep primitive types as-is
return str(obj)
@staticmethod
def get_prompt_caching_cache_key(
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[str]:
if messages is None and tools is None:
return None
# Use serialize_object for consistent and stable serialization
data_to_hash = {}
if messages is not None:
serialized_messages = PromptCachingCache.serialize_object(messages)
data_to_hash["messages"] = serialized_messages
if tools is not None:
serialized_tools = PromptCachingCache.serialize_object(tools)
data_to_hash["tools"] = serialized_tools
# Combine serialized data into a single string
data_to_hash_str = json.dumps(
data_to_hash,
sort_keys=True,
separators=(",", ":"),
)
# Create a hash of the serialized data for a stable cache key
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
return f"deployment:{hashed_data}:prompt_caching"
def add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
self.cache.set_cache(
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
)
return None
async def async_add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
await self.cache.async_set_cache(
cache_key,
PromptCachingCacheValue(model_id=model_id),
ttl=300, # store for 5 minutes
)
return None
async def async_get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
"""
if messages is not none
- check full messages
- check messages[:-1]
- check messages[:-2]
- check messages[:-3]
use self.cache.async_batch_get_cache(keys=potential_cache_keys])
"""
if messages is None and tools is None:
return None
# Generate potential cache keys by slicing messages
potential_cache_keys = []
if messages is not None:
full_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
messages, tools
)
potential_cache_keys.append(full_cache_key)
# Check progressively shorter message slices
for i in range(1, min(4, len(messages))):
partial_messages = messages[:-i]
partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
partial_messages, tools
)
potential_cache_keys.append(partial_cache_key)
# Perform batch cache lookup
cache_results = await self.cache.async_batch_get_cache(
keys=potential_cache_keys
)
if cache_results is None:
return None
# Return the first non-None cache result
for result in cache_results:
if result is not None:
return result
return None
def get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
return self.cache.get_cache(cache_key)
|