Raju2024's picture
Upload 1072 files
e3278e4 verified
"""
Transformation logic for context caching.
Why separate file? Make it easy to see how transformation works
"""
from typing import List, Tuple
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import CachedContentRequestBody
from litellm.utils import is_cached_message
from ..common_utils import get_supports_system_message
from ..gemini.transformation import (
_gemini_convert_messages_with_history,
_transform_system_message,
)
def get_first_continuous_block_idx(
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
) -> int:
"""
Find the array index that ends the first continuous sequence of message blocks.
Args:
filtered_messages: List of tuples containing (index, message) pairs
Returns:
int: The array index where the first continuous sequence ends
"""
if not filtered_messages:
return -1
if len(filtered_messages) == 1:
return 0
current_value = filtered_messages[0][0]
# Search forward through the array indices
for i in range(1, len(filtered_messages)):
if filtered_messages[i][0] != current_value + 1:
return i - 1
current_value = filtered_messages[i][0]
# If we made it through the whole list, return the last index
return len(filtered_messages) - 1
def separate_cached_messages(
messages: List[AllMessageValues],
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
"""
Returns separated cached and non-cached messages.
Args:
messages: List of messages to be separated.
Returns:
Tuple containing:
- cached_messages: List of cached messages.
- non_cached_messages: List of non-cached messages.
"""
cached_messages: List[AllMessageValues] = []
non_cached_messages: List[AllMessageValues] = []
# Extract cached messages and their indices
filtered_messages: List[Tuple[int, AllMessageValues]] = []
for idx, message in enumerate(messages):
if is_cached_message(message=message):
filtered_messages.append((idx, message))
# Validate only one block of continuous cached messages
last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
# Separate messages based on the block of cached messages
if filtered_messages and last_continuous_block_idx is not None:
first_cached_idx = filtered_messages[0][0]
last_cached_idx = filtered_messages[last_continuous_block_idx][0]
cached_messages = messages[first_cached_idx : last_cached_idx + 1]
non_cached_messages = (
messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
)
else:
non_cached_messages = messages
return cached_messages, non_cached_messages
def transform_openai_messages_to_gemini_context_caching(
model: str, messages: List[AllMessageValues], cache_key: str
) -> CachedContentRequestBody:
supports_system_message = get_supports_system_message(
model=model, custom_llm_provider="gemini"
)
transformed_system_messages, new_messages = _transform_system_message(
supports_system_message=supports_system_message, messages=messages
)
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
data = CachedContentRequestBody(
contents=transformed_messages,
model="models/{}".format(model),
displayName=cache_key,
)
if transformed_system_messages is not None:
data["system_instruction"] = transformed_system_messages
return data