# What is this? ## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id import asyncio import base64 import json import uuid from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast from fastapi import HTTPException from litellm import Router, verbose_logger from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data from litellm.llms.base_llm.files.transformation import BaseFileEndpoints from litellm.proxy._types import ( CallTypes, LiteLLM_ManagedFileTable, LiteLLM_ManagedObjectTable, UserAPIKeyAuth, ) from litellm.proxy.openai_files_endpoints.common_utils import ( _is_base64_encoded_unified_file_id, convert_b64_uid_to_unified_uid, ) from litellm.types.llms.openai import ( AllMessageValues, AsyncCursorPage, ChatCompletionFileObject, CreateFileRequest, FileObject, OpenAIFileObject, OpenAIFilesPurpose, ) from litellm.types.utils import ( LiteLLMBatch, LiteLLMFineTuningJob, LLMResponseTypes, SpecialEnums, ) if TYPE_CHECKING: from opentelemetry.trace import Span as _Span from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache from litellm.proxy.utils import PrismaClient as _PrismaClient Span = Union[_Span, Any] InternalUsageCache = _InternalUsageCache PrismaClient = _PrismaClient else: Span = Any InternalUsageCache = Any PrismaClient = Any class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): # Class variables or attributes def __init__( self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient ): self.internal_usage_cache = internal_usage_cache self.prisma_client = prisma_client async def store_unified_file_id( self, file_id: str, file_object: OpenAIFileObject, litellm_parent_otel_span: Optional[Span], model_mappings: Dict[str, str], user_api_key_dict: UserAPIKeyAuth, ) -> None: verbose_logger.info( f"Storing LiteLLM Managed File object with id={file_id} in cache" ) litellm_managed_file_object = LiteLLM_ManagedFileTable( unified_file_id=file_id, file_object=file_object, model_mappings=model_mappings, flat_model_file_ids=list(model_mappings.values()), created_by=user_api_key_dict.user_id, updated_by=user_api_key_dict.user_id, ) await self.internal_usage_cache.async_set_cache( key=file_id, value=litellm_managed_file_object.model_dump(), litellm_parent_otel_span=litellm_parent_otel_span, ) await self.prisma_client.db.litellm_managedfiletable.create( data={ "unified_file_id": file_id, "file_object": file_object.model_dump_json(), "model_mappings": json.dumps(model_mappings), "flat_model_file_ids": list(model_mappings.values()), "created_by": user_api_key_dict.user_id, "updated_by": user_api_key_dict.user_id, } ) async def store_unified_object_id( self, unified_object_id: str, file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob], litellm_parent_otel_span: Optional[Span], model_object_id: str, file_purpose: Literal["batch", "fine-tune"], user_api_key_dict: UserAPIKeyAuth, ) -> None: verbose_logger.info( f"Storing LiteLLM Managed {file_purpose} object with id={unified_object_id} in cache" ) litellm_managed_object = LiteLLM_ManagedObjectTable( unified_object_id=unified_object_id, model_object_id=model_object_id, file_purpose=file_purpose, file_object=file_object, ) await self.internal_usage_cache.async_set_cache( key=unified_object_id, value=litellm_managed_object.model_dump(), litellm_parent_otel_span=litellm_parent_otel_span, ) await self.prisma_client.db.litellm_managedobjecttable.create( data={ "unified_object_id": unified_object_id, "file_object": file_object.model_dump_json(), "model_object_id": model_object_id, "file_purpose": file_purpose, "created_by": user_api_key_dict.user_id, "updated_by": user_api_key_dict.user_id, } ) async def get_unified_file_id( self, file_id: str, litellm_parent_otel_span: Optional[Span] = None ) -> Optional[LiteLLM_ManagedFileTable]: ## CHECK CACHE result = cast( Optional[dict], await self.internal_usage_cache.async_get_cache( key=file_id, litellm_parent_otel_span=litellm_parent_otel_span, ), ) if result: return LiteLLM_ManagedFileTable(**result) ## CHECK DB db_object = await self.prisma_client.db.litellm_managedfiletable.find_first( where={"unified_file_id": file_id} ) if db_object: return LiteLLM_ManagedFileTable(**db_object.model_dump()) return None async def delete_unified_file_id( self, file_id: str, litellm_parent_otel_span: Optional[Span] = None ) -> OpenAIFileObject: ## get old value initial_value = await self.prisma_client.db.litellm_managedfiletable.find_first( where={"unified_file_id": file_id} ) if initial_value is None: raise Exception(f"LiteLLM Managed File object with id={file_id} not found") ## delete old value await self.internal_usage_cache.async_set_cache( key=file_id, value=None, litellm_parent_otel_span=litellm_parent_otel_span, ) await self.prisma_client.db.litellm_managedfiletable.delete( where={"unified_file_id": file_id} ) return initial_value.file_object async def can_user_call_unified_file_id( self, unified_file_id: str, user_api_key_dict: UserAPIKeyAuth ) -> bool: ## check if the user has access to the unified file id user_id = user_api_key_dict.user_id managed_file = await self.prisma_client.db.litellm_managedfiletable.find_first( where={"unified_file_id": unified_file_id} ) if managed_file: return managed_file.created_by == user_id return False async def can_user_call_unified_object_id( self, unified_object_id: str, user_api_key_dict: UserAPIKeyAuth ) -> bool: ## check if the user has access to the unified object id ## check if the user has access to the unified object id user_id = user_api_key_dict.user_id managed_object = ( await self.prisma_client.db.litellm_managedobjecttable.find_first( where={"unified_object_id": unified_object_id} ) ) if managed_object: return managed_object.created_by == user_id return False async def get_user_created_file_ids( self, user_api_key_dict: UserAPIKeyAuth, model_object_ids: List[str] ) -> List[OpenAIFileObject]: """ Get all file ids created by the user for a list of model object ids Returns: - List of OpenAIFileObject's """ file_ids = await self.prisma_client.db.litellm_managedfiletable.find_many( where={ "created_by": user_api_key_dict.user_id, "flat_model_file_ids": {"hasSome": model_object_ids}, } ) return [OpenAIFileObject(**file_object.file_object) for file_object in file_ids] async def check_managed_file_id_access( self, data: Dict, user_api_key_dict: UserAPIKeyAuth ) -> bool: retrieve_file_id = cast(Optional[str], data.get("file_id")) potential_file_id = ( _is_base64_encoded_unified_file_id(retrieve_file_id) if retrieve_file_id else False ) if potential_file_id and retrieve_file_id: if await self.can_user_call_unified_file_id( retrieve_file_id, user_api_key_dict ): return True else: raise HTTPException( status_code=403, detail=f"User {user_api_key_dict.user_id} does not have access to the file {retrieve_file_id}", ) return False async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: Dict, call_type: Literal[ "completion", "text_completion", "embeddings", "image_generation", "moderation", "audio_transcription", "pass_through_endpoint", "rerank", "acreate_batch", "aretrieve_batch", "acreate_file", "afile_list", "afile_delete", "afile_content", "acreate_fine_tuning_job", "aretrieve_fine_tuning_job", "alist_fine_tuning_jobs", "acancel_fine_tuning_job", ], ) -> Union[Exception, str, Dict, None]: """ - Detect litellm_proxy/ file_id - add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}} """ ### HANDLE FILE ACCESS ### - ensure user has access to the file if ( call_type == CallTypes.afile_content.value or call_type == CallTypes.afile_delete.value ): await self.check_managed_file_id_access(data, user_api_key_dict) ### HANDLE TRANSFORMATIONS ### if call_type == CallTypes.completion.value: messages = data.get("messages") if messages: file_ids = self.get_file_ids_from_messages(messages) if file_ids: model_file_id_mapping = await self.get_model_file_id_mapping( file_ids, user_api_key_dict.parent_otel_span ) data["model_file_id_mapping"] = model_file_id_mapping elif call_type == CallTypes.afile_content.value: retrieve_file_id = cast(Optional[str], data.get("file_id")) potential_file_id = ( _is_base64_encoded_unified_file_id(retrieve_file_id) if retrieve_file_id else False ) if potential_file_id: model_id = self.get_model_id_from_unified_file_id(potential_file_id) if model_id: data["model"] = model_id data["file_id"] = self.get_output_file_id_from_unified_file_id( potential_file_id ) elif call_type == CallTypes.acreate_batch.value: input_file_id = cast(Optional[str], data.get("input_file_id")) if input_file_id: model_file_id_mapping = await self.get_model_file_id_mapping( [input_file_id], user_api_key_dict.parent_otel_span ) data["model_file_id_mapping"] = model_file_id_mapping elif ( call_type == CallTypes.aretrieve_batch.value or call_type == CallTypes.acancel_fine_tuning_job.value or call_type == CallTypes.aretrieve_fine_tuning_job.value ): accessor_key: Optional[str] = None retrieve_object_id: Optional[str] = None if call_type == CallTypes.aretrieve_batch.value: accessor_key = "batch_id" elif ( call_type == CallTypes.acancel_fine_tuning_job.value or call_type == CallTypes.aretrieve_fine_tuning_job.value ): accessor_key = "fine_tuning_job_id" if accessor_key: retrieve_object_id = cast(Optional[str], data.get(accessor_key)) potential_llm_object_id = ( _is_base64_encoded_unified_file_id(retrieve_object_id) if retrieve_object_id else False ) if potential_llm_object_id and retrieve_object_id: ## VALIDATE USER HAS ACCESS TO THE OBJECT ## if not await self.can_user_call_unified_object_id( retrieve_object_id, user_api_key_dict ): raise HTTPException( status_code=403, detail=f"User {user_api_key_dict.user_id} does not have access to the object {retrieve_object_id}", ) ## for managed batch id - get the model id potential_model_id = self.get_model_id_from_unified_batch_id( potential_llm_object_id ) if potential_model_id is None: raise Exception( f"LiteLLM Managed {accessor_key} with id={retrieve_object_id} is invalid - does not contain encoded model_id." ) data["model"] = potential_model_id data[accessor_key] = self.get_batch_id_from_unified_batch_id( potential_llm_object_id ) elif call_type == CallTypes.acreate_fine_tuning_job.value: input_file_id = cast(Optional[str], data.get("training_file")) if input_file_id: model_file_id_mapping = await self.get_model_file_id_mapping( [input_file_id], user_api_key_dict.parent_otel_span ) return data async def async_pre_call_deployment_hook( self, kwargs: Dict[str, Any], call_type: Optional[CallTypes] ) -> Optional[dict]: """ Allow modifying the request just before it's sent to the deployment. """ accessor_key: Optional[str] = None if call_type and call_type == CallTypes.acreate_batch: accessor_key = "input_file_id" elif call_type and call_type == CallTypes.acreate_fine_tuning_job: accessor_key = "training_file" else: return kwargs if accessor_key: input_file_id = cast(Optional[str], kwargs.get(accessor_key)) model_file_id_mapping = cast( Optional[Dict[str, Dict[str, str]]], kwargs.get("model_file_id_mapping") ) model_id = cast(Optional[str], kwargs.get("model_info", {}).get("id", None)) mapped_file_id: Optional[str] = None if input_file_id and model_file_id_mapping and model_id: mapped_file_id = model_file_id_mapping.get(input_file_id, {}).get( model_id, None ) if mapped_file_id: kwargs[accessor_key] = mapped_file_id return kwargs def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[str]: """ Gets file ids from messages """ file_ids = [] for message in messages: if message.get("role") == "user": content = message.get("content") if content: if isinstance(content, str): continue for c in content: if c["type"] == "file": file_object = cast(ChatCompletionFileObject, c) file_object_file_field = file_object["file"] file_id = file_object_file_field.get("file_id") if file_id: file_ids.append(file_id) return file_ids async def get_model_file_id_mapping( self, file_ids: List[str], litellm_parent_otel_span: Span ) -> dict: """ Get model-specific file IDs for a list of proxy file IDs. Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id 1. Get all the litellm_proxy/ file_ids from the messages 2. For each file_id, search for cache keys matching the pattern file_id:* 3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id Example: { "litellm_proxy/file_id": { "model_id": "model_file_id" } } """ file_id_mapping: Dict[str, Dict[str, str]] = {} litellm_managed_file_ids = [] for file_id in file_ids: ## CHECK IF FILE ID IS MANAGED BY LITELM is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id) if is_base64_unified_file_id: litellm_managed_file_ids.append(file_id) if litellm_managed_file_ids: # Get all cache keys matching the pattern file_id:* for file_id in litellm_managed_file_ids: # Search for any cache key starting with this file_id unified_file_object = await self.get_unified_file_id( file_id, litellm_parent_otel_span ) if unified_file_object: file_id_mapping[file_id] = unified_file_object.model_mappings return file_id_mapping async def create_file_for_each_model( self, llm_router: Optional[Router], _create_file_request: CreateFileRequest, target_model_names_list: List[str], litellm_parent_otel_span: Span, ) -> List[OpenAIFileObject]: if llm_router is None: raise Exception("LLM Router not initialized. Ensure models added to proxy.") responses = [] for model in target_model_names_list: individual_response = await llm_router.acreate_file( model=model, **_create_file_request ) responses.append(individual_response) return responses async def acreate_file( self, create_file_request: CreateFileRequest, llm_router: Router, target_model_names_list: List[str], litellm_parent_otel_span: Span, user_api_key_dict: UserAPIKeyAuth, ) -> OpenAIFileObject: responses = await self.create_file_for_each_model( llm_router=llm_router, _create_file_request=create_file_request, target_model_names_list=target_model_names_list, litellm_parent_otel_span=litellm_parent_otel_span, ) response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id( file_objects=responses, create_file_request=create_file_request, internal_usage_cache=self.internal_usage_cache, litellm_parent_otel_span=litellm_parent_otel_span, target_model_names_list=target_model_names_list, ) ## STORE MODEL MAPPINGS IN DB model_mappings: Dict[str, str] = {} for file_object in responses: model_id = file_object._hidden_params.get("model_id") if model_id is None: verbose_logger.warning( f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None" ) continue file_id = file_object.id model_mappings[model_id] = file_id await self.store_unified_file_id( file_id=response.id, file_object=response, litellm_parent_otel_span=litellm_parent_otel_span, model_mappings=model_mappings, user_api_key_dict=user_api_key_dict, ) return response @staticmethod async def return_unified_file_id( file_objects: List[OpenAIFileObject], create_file_request: CreateFileRequest, internal_usage_cache: InternalUsageCache, litellm_parent_otel_span: Span, target_model_names_list: List[str], ) -> OpenAIFileObject: ## GET THE FILE TYPE FROM THE CREATE FILE REQUEST file_data = extract_file_data(create_file_request["file"]) file_type = file_data["content_type"] output_file_id = file_objects[0].id model_id = file_objects[0]._hidden_params.get("model_id") unified_file_id = SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format( file_type, str(uuid.uuid4()), ",".join(target_model_names_list), output_file_id, model_id, ) # Convert to URL-safe base64 and strip padding base64_unified_file_id = ( base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=") ) ## CREATE RESPONSE OBJECT response = OpenAIFileObject( id=base64_unified_file_id, object="file", purpose=create_file_request["purpose"], created_at=file_objects[0].created_at, bytes=file_objects[0].bytes, filename=file_objects[0].filename, status="uploaded", ) return response def get_unified_generic_response_id( self, model_id: str, generic_response_id: str ) -> str: unified_generic_response_id = ( SpecialEnums.LITELLM_MANAGED_GENERIC_RESPONSE_COMPLETE_STR.value.format( model_id, generic_response_id ) ) return ( base64.urlsafe_b64encode(unified_generic_response_id.encode()) .decode() .rstrip("=") ) def get_unified_batch_id(self, batch_id: str, model_id: str) -> str: unified_batch_id = SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format( model_id, batch_id ) return base64.urlsafe_b64encode(unified_batch_id.encode()).decode().rstrip("=") def get_unified_output_file_id( self, output_file_id: str, model_id: str, model_name: str ) -> str: unified_output_file_id = ( SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format( "application/json", str(uuid.uuid4()), model_name, output_file_id, model_id, ) ) return ( base64.urlsafe_b64encode(unified_output_file_id.encode()) .decode() .rstrip("=") ) def get_model_id_from_unified_file_id(self, file_id: str) -> str: return file_id.split("llm_output_file_model_id,")[1].split(";")[0] def get_output_file_id_from_unified_file_id(self, file_id: str) -> str: return file_id.split("llm_output_file_id,")[1].split(";")[0] def get_model_id_from_unified_batch_id(self, file_id: str) -> Optional[str]: """ Get the model_id from the file_id Expected format: litellm_proxy;model_id:{};llm_batch_id:{};llm_output_file_id:{} """ ## use regex to get the model_id from the file_id try: return file_id.split("model_id:")[1].split(";")[0] except Exception: return None def get_batch_id_from_unified_batch_id(self, file_id: str) -> str: ## use regex to get the batch_id from the file_id if "llm_batch_id" in file_id: return file_id.split("llm_batch_id:")[1].split(",")[0] else: return file_id.split("generic_response_id:")[1].split(",")[0] async def async_post_call_success_hook( self, data: Dict, user_api_key_dict: UserAPIKeyAuth, response: LLMResponseTypes ) -> Any: if isinstance(response, LiteLLMBatch): ## Check if unified_file_id is in the response unified_file_id = response._hidden_params.get( "unified_file_id" ) # managed file id unified_batch_id = response._hidden_params.get( "unified_batch_id" ) # managed batch id model_id = cast(Optional[str], response._hidden_params.get("model_id")) model_name = cast(Optional[str], response._hidden_params.get("model_name")) original_response_id = response.id if (unified_batch_id or unified_file_id) and model_id: response.id = self.get_unified_batch_id( batch_id=response.id, model_id=model_id ) if ( response.output_file_id and model_name and model_id ): # return a file id with the model_id and output_file_id response.output_file_id = self.get_unified_output_file_id( output_file_id=response.output_file_id, model_id=model_id, model_name=model_name, ) asyncio.create_task( self.store_unified_object_id( unified_object_id=response.id, file_object=response, litellm_parent_otel_span=user_api_key_dict.parent_otel_span, model_object_id=original_response_id, file_purpose="batch", user_api_key_dict=user_api_key_dict, ) ) elif isinstance(response, LiteLLMFineTuningJob): ## Check if unified_file_id is in the response unified_file_id = response._hidden_params.get( "unified_file_id" ) # managed file id unified_finetuning_job_id = response._hidden_params.get( "unified_finetuning_job_id" ) # managed finetuning job id model_id = cast(Optional[str], response._hidden_params.get("model_id")) model_name = cast(Optional[str], response._hidden_params.get("model_name")) original_response_id = response.id if (unified_file_id or unified_finetuning_job_id) and model_id: response.id = self.get_unified_generic_response_id( model_id=model_id, generic_response_id=response.id ) asyncio.create_task( self.store_unified_object_id( unified_object_id=response.id, file_object=response, litellm_parent_otel_span=user_api_key_dict.parent_otel_span, model_object_id=original_response_id, file_purpose="fine-tune", user_api_key_dict=user_api_key_dict, ) ) elif isinstance(response, AsyncCursorPage): """ For listing files, filter for the ones created by the user """ ## check if file object if hasattr(response, "data") and isinstance(response.data, list): if all( isinstance(file_object, FileObject) for file_object in response.data ): ## Get all file id's ## Check which file id's were created by the user ## Filter the response to only include the files created by the user ## Return the filtered response file_ids = [ file_object.id for file_object in cast(List[FileObject], response.data) # type: ignore ] user_created_file_ids = await self.get_user_created_file_ids( user_api_key_dict, file_ids ) ## Filter the response to only include the files created by the user response.data = user_created_file_ids # type: ignore return response return response return response async def afile_retrieve( self, file_id: str, litellm_parent_otel_span: Optional[Span] ) -> OpenAIFileObject: stored_file_object = await self.get_unified_file_id( file_id, litellm_parent_otel_span ) if stored_file_object: return stored_file_object.file_object else: raise Exception(f"LiteLLM Managed File object with id={file_id} not found") async def afile_list( self, purpose: Optional[OpenAIFilesPurpose], litellm_parent_otel_span: Optional[Span], **data: Dict, ) -> List[OpenAIFileObject]: """Handled in files_endpoints.py""" return [] async def afile_delete( self, file_id: str, litellm_parent_otel_span: Optional[Span], llm_router: Router, **data: Dict, ) -> OpenAIFileObject: file_id = convert_b64_uid_to_unified_uid(file_id) model_file_id_mapping = await self.get_model_file_id_mapping( [file_id], litellm_parent_otel_span ) specific_model_file_id_mapping = model_file_id_mapping.get(file_id) if specific_model_file_id_mapping: for model_id, file_id in specific_model_file_id_mapping.items(): await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore stored_file_object = await self.delete_unified_file_id( file_id, litellm_parent_otel_span ) if stored_file_object: return stored_file_object else: raise Exception(f"LiteLLM Managed File object with id={file_id} not found") async def afile_content( self, file_id: str, litellm_parent_otel_span: Optional[Span], llm_router: Router, **data: Dict, ) -> str: """ Get the content of a file from first model that has it """ model_file_id_mapping = await self.get_model_file_id_mapping( [file_id], litellm_parent_otel_span ) specific_model_file_id_mapping = model_file_id_mapping.get(file_id) if specific_model_file_id_mapping: exception_dict = {} for model_id, file_id in specific_model_file_id_mapping.items(): try: return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore except Exception as e: exception_dict[model_id] = str(e) raise Exception( f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}" ) else: raise Exception(f"LiteLLM Managed File object with id={file_id} not found")