|
""" |
|
Transformation logic for context caching. |
|
|
|
Why separate file? Make it easy to see how transformation works |
|
""" |
|
|
|
from typing import List, Tuple |
|
|
|
from litellm.types.llms.openai import AllMessageValues |
|
from litellm.types.llms.vertex_ai import CachedContentRequestBody |
|
from litellm.utils import is_cached_message |
|
|
|
from ..common_utils import get_supports_system_message |
|
from ..gemini.transformation import ( |
|
_gemini_convert_messages_with_history, |
|
_transform_system_message, |
|
) |
|
|
|
|
|
def get_first_continuous_block_idx( |
|
filtered_messages: List[Tuple[int, AllMessageValues]] |
|
) -> int: |
|
""" |
|
Find the array index that ends the first continuous sequence of message blocks. |
|
|
|
Args: |
|
filtered_messages: List of tuples containing (index, message) pairs |
|
|
|
Returns: |
|
int: The array index where the first continuous sequence ends |
|
""" |
|
if not filtered_messages: |
|
return -1 |
|
|
|
if len(filtered_messages) == 1: |
|
return 0 |
|
|
|
current_value = filtered_messages[0][0] |
|
|
|
|
|
for i in range(1, len(filtered_messages)): |
|
if filtered_messages[i][0] != current_value + 1: |
|
return i - 1 |
|
current_value = filtered_messages[i][0] |
|
|
|
|
|
return len(filtered_messages) - 1 |
|
|
|
|
|
def separate_cached_messages( |
|
messages: List[AllMessageValues], |
|
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]: |
|
""" |
|
Returns separated cached and non-cached messages. |
|
|
|
Args: |
|
messages: List of messages to be separated. |
|
|
|
Returns: |
|
Tuple containing: |
|
- cached_messages: List of cached messages. |
|
- non_cached_messages: List of non-cached messages. |
|
""" |
|
cached_messages: List[AllMessageValues] = [] |
|
non_cached_messages: List[AllMessageValues] = [] |
|
|
|
|
|
filtered_messages: List[Tuple[int, AllMessageValues]] = [] |
|
for idx, message in enumerate(messages): |
|
if is_cached_message(message=message): |
|
filtered_messages.append((idx, message)) |
|
|
|
|
|
last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages) |
|
|
|
if filtered_messages and last_continuous_block_idx is not None: |
|
first_cached_idx = filtered_messages[0][0] |
|
last_cached_idx = filtered_messages[last_continuous_block_idx][0] |
|
|
|
cached_messages = messages[first_cached_idx : last_cached_idx + 1] |
|
non_cached_messages = ( |
|
messages[:first_cached_idx] + messages[last_cached_idx + 1 :] |
|
) |
|
else: |
|
non_cached_messages = messages |
|
|
|
return cached_messages, non_cached_messages |
|
|
|
|
|
def transform_openai_messages_to_gemini_context_caching( |
|
model: str, messages: List[AllMessageValues], cache_key: str |
|
) -> CachedContentRequestBody: |
|
supports_system_message = get_supports_system_message( |
|
model=model, custom_llm_provider="gemini" |
|
) |
|
|
|
transformed_system_messages, new_messages = _transform_system_message( |
|
supports_system_message=supports_system_message, messages=messages |
|
) |
|
|
|
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages) |
|
data = CachedContentRequestBody( |
|
contents=transformed_messages, |
|
model="models/{}".format(model), |
|
displayName=cache_key, |
|
) |
|
if transformed_system_messages is not None: |
|
data["system_instruction"] = transformed_system_messages |
|
|
|
return data |
|
|