Spaces:
Sleeping
Sleeping
File size: 15,498 Bytes
469eae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 |
# What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import base64
import json
import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from litellm import Router, verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.proxy._types import CallTypes, LiteLLM_ManagedFileTable, UserAPIKeyAuth
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionFileObject,
CreateFileRequest,
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.types.utils import SpecialEnums
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
from litellm.proxy.utils import PrismaClient as _PrismaClient
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
PrismaClient = _PrismaClient
else:
Span = Any
InternalUsageCache = Any
PrismaClient = Any
class BaseFileEndpoints(ABC):
@abstractmethod
async def afile_retrieve(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
) -> OpenAIFileObject:
pass
@abstractmethod
async def afile_list(
self, custom_llm_provider: str, **data: dict
) -> List[OpenAIFileObject]:
pass
@abstractmethod
async def afile_delete(
self, custom_llm_provider: str, file_id: str, **data: dict
) -> OpenAIFileObject:
pass
class _PROXY_LiteLLMManagedFiles(CustomLogger):
# Class variables or attributes
def __init__(
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
):
self.internal_usage_cache = internal_usage_cache
self.prisma_client = prisma_client
async def store_unified_file_id(
self,
file_id: str,
file_object: OpenAIFileObject,
litellm_parent_otel_span: Optional[Span],
model_mappings: Dict[str, str],
) -> None:
verbose_logger.info(
f"Storing LiteLLM Managed File object with id={file_id} in cache"
)
litellm_managed_file_object = LiteLLM_ManagedFileTable(
unified_file_id=file_id,
file_object=file_object,
model_mappings=model_mappings,
)
await self.internal_usage_cache.async_set_cache(
key=file_id,
value=litellm_managed_file_object.model_dump(),
litellm_parent_otel_span=litellm_parent_otel_span,
)
await self.prisma_client.db.litellm_managedfiletable.create(
data={
"unified_file_id": file_id,
"file_object": file_object.model_dump_json(),
"model_mappings": json.dumps(model_mappings),
}
)
async def get_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> Optional[LiteLLM_ManagedFileTable]:
## CHECK CACHE
result = cast(
Optional[dict],
await self.internal_usage_cache.async_get_cache(
key=file_id,
litellm_parent_otel_span=litellm_parent_otel_span,
),
)
if result:
return LiteLLM_ManagedFileTable(**result)
## CHECK DB
db_object = await self.prisma_client.db.litellm_managedfiletable.find_first(
where={"unified_file_id": file_id}
)
if db_object:
return LiteLLM_ManagedFileTable(**db_object.model_dump())
return None
async def delete_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> OpenAIFileObject:
## get old value
initial_value = await self.prisma_client.db.litellm_managedfiletable.find_first(
where={"unified_file_id": file_id}
)
if initial_value is None:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
## delete old value
await self.internal_usage_cache.async_set_cache(
key=file_id,
value=None,
litellm_parent_otel_span=litellm_parent_otel_span,
)
await self.prisma_client.db.litellm_managedfiletable.delete(
where={"unified_file_id": file_id}
)
return initial_value.file_object
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: Dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Union[Exception, str, Dict, None]:
"""
- Detect litellm_proxy/ file_id
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
"""
if call_type == CallTypes.completion.value:
messages = data.get("messages")
if messages:
file_ids = self.get_file_ids_from_messages(messages)
if file_ids:
model_file_id_mapping = await self.get_model_file_id_mapping(
file_ids, user_api_key_dict.parent_otel_span
)
data["model_file_id_mapping"] = model_file_id_mapping
return data
def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[str]:
"""
Gets file ids from messages
"""
file_ids = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(file_id)
return file_ids
@staticmethod
def _convert_b64_uid_to_unified_uid(b64_uid: str) -> str:
is_base64_unified_file_id = (
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(b64_uid)
)
if is_base64_unified_file_id:
return is_base64_unified_file_id
else:
return b64_uid
@staticmethod
def _is_base64_encoded_unified_file_id(b64_uid: str) -> Union[str, Literal[False]]:
# Add padding back if needed
padded = b64_uid + "=" * (-len(b64_uid) % 4)
# Decode from base64
try:
decoded = base64.urlsafe_b64decode(padded).decode()
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
return decoded
else:
return False
except Exception:
return False
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
if is_base64_unified_file_id:
return is_base64_unified_file_id
else:
return b64_uid
async def get_model_file_id_mapping(
self, file_ids: List[str], litellm_parent_otel_span: Span
) -> dict:
"""
Get model-specific file IDs for a list of proxy file IDs.
Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id
1. Get all the litellm_proxy/ file_ids from the messages
2. For each file_id, search for cache keys matching the pattern file_id:*
3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id
Example:
{
"litellm_proxy/file_id": {
"model_id": "model_file_id"
}
}
"""
file_id_mapping: Dict[str, Dict[str, str]] = {}
litellm_managed_file_ids = []
for file_id in file_ids:
## CHECK IF FILE ID IS MANAGED BY LITELM
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
if is_base64_unified_file_id:
litellm_managed_file_ids.append(file_id)
if litellm_managed_file_ids:
# Get all cache keys matching the pattern file_id:*
for file_id in litellm_managed_file_ids:
# Search for any cache key starting with this file_id
unified_file_object = await self.get_unified_file_id(
file_id, litellm_parent_otel_span
)
if unified_file_object:
file_id_mapping[file_id] = unified_file_object.model_mappings
return file_id_mapping
async def create_file_for_each_model(
self,
llm_router: Optional[Router],
_create_file_request: CreateFileRequest,
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
) -> List[OpenAIFileObject]:
if llm_router is None:
raise Exception("LLM Router not initialized. Ensure models added to proxy.")
responses = []
for model in target_model_names_list:
individual_response = await llm_router.acreate_file(
model=model, **_create_file_request
)
responses.append(individual_response)
return responses
async def acreate_file(
self,
create_file_request: CreateFileRequest,
llm_router: Router,
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
) -> OpenAIFileObject:
responses = await self.create_file_for_each_model(
llm_router=llm_router,
_create_file_request=create_file_request,
target_model_names_list=target_model_names_list,
litellm_parent_otel_span=litellm_parent_otel_span,
)
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
file_objects=responses,
create_file_request=create_file_request,
internal_usage_cache=self.internal_usage_cache,
litellm_parent_otel_span=litellm_parent_otel_span,
)
## STORE MODEL MAPPINGS IN DB
model_mappings: Dict[str, str] = {}
for file_object in responses:
model_id = file_object._hidden_params.get("model_id")
if model_id is None:
verbose_logger.warning(
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
)
continue
file_id = file_object.id
model_mappings[model_id] = file_id
await self.store_unified_file_id(
file_id=response.id,
file_object=response,
litellm_parent_otel_span=litellm_parent_otel_span,
model_mappings=model_mappings,
)
return response
@staticmethod
async def return_unified_file_id(
file_objects: List[OpenAIFileObject],
create_file_request: CreateFileRequest,
internal_usage_cache: InternalUsageCache,
litellm_parent_otel_span: Span,
) -> OpenAIFileObject:
## GET THE FILE TYPE FROM THE CREATE FILE REQUEST
file_data = extract_file_data(create_file_request["file"])
file_type = file_data["content_type"]
unified_file_id = SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format(
file_type, str(uuid.uuid4())
)
# Convert to URL-safe base64 and strip padding
base64_unified_file_id = (
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
)
## CREATE RESPONSE OBJECT
response = OpenAIFileObject(
id=base64_unified_file_id,
object="file",
purpose=create_file_request["purpose"],
created_at=file_objects[0].created_at,
bytes=file_objects[0].bytes,
filename=file_objects[0].filename,
status="uploaded",
)
return response
async def afile_retrieve(
self, file_id: str, litellm_parent_otel_span: Optional[Span]
) -> OpenAIFileObject:
stored_file_object = await self.get_unified_file_id(
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object.file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
async def afile_list(
self,
purpose: Optional[OpenAIFilesPurpose],
litellm_parent_otel_span: Optional[Span],
**data: Dict,
) -> List[OpenAIFileObject]:
return []
async def afile_delete(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> OpenAIFileObject:
file_id = self.convert_b64_uid_to_unified_uid(file_id)
model_file_id_mapping = await self.get_model_file_id_mapping(
[file_id], litellm_parent_otel_span
)
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping:
for model_id, file_id in specific_model_file_id_mapping.items():
await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore
stored_file_object = await self.delete_unified_file_id(
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
async def afile_content(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> str:
"""
Get the content of a file from first model that has it
"""
model_file_id_mapping = await self.get_model_file_id_mapping(
[file_id], litellm_parent_otel_span
)
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping:
exception_dict = {}
for model_id, file_id in specific_model_file_id_mapping.items():
try:
return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore
except Exception as e:
exception_dict[model_id] = str(e)
raise Exception(
f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
)
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|