Spaces:
Configuration error
Configuration error
# What is this? | |
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id | |
import asyncio | |
import base64 | |
import json | |
import uuid | |
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast | |
from fastapi import HTTPException | |
from litellm import Router, verbose_logger | |
from litellm.caching.caching import DualCache | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data | |
from litellm.llms.base_llm.files.transformation import BaseFileEndpoints | |
from litellm.proxy._types import ( | |
CallTypes, | |
LiteLLM_ManagedFileTable, | |
LiteLLM_ManagedObjectTable, | |
UserAPIKeyAuth, | |
) | |
from litellm.proxy.openai_files_endpoints.common_utils import ( | |
_is_base64_encoded_unified_file_id, | |
convert_b64_uid_to_unified_uid, | |
) | |
from litellm.types.llms.openai import ( | |
AllMessageValues, | |
AsyncCursorPage, | |
ChatCompletionFileObject, | |
CreateFileRequest, | |
FileObject, | |
OpenAIFileObject, | |
OpenAIFilesPurpose, | |
) | |
from litellm.types.utils import ( | |
LiteLLMBatch, | |
LiteLLMFineTuningJob, | |
LLMResponseTypes, | |
SpecialEnums, | |
) | |
if TYPE_CHECKING: | |
from opentelemetry.trace import Span as _Span | |
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache | |
from litellm.proxy.utils import PrismaClient as _PrismaClient | |
Span = Union[_Span, Any] | |
InternalUsageCache = _InternalUsageCache | |
PrismaClient = _PrismaClient | |
else: | |
Span = Any | |
InternalUsageCache = Any | |
PrismaClient = Any | |
class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): | |
# Class variables or attributes | |
def __init__( | |
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient | |
): | |
self.internal_usage_cache = internal_usage_cache | |
self.prisma_client = prisma_client | |
async def store_unified_file_id( | |
self, | |
file_id: str, | |
file_object: OpenAIFileObject, | |
litellm_parent_otel_span: Optional[Span], | |
model_mappings: Dict[str, str], | |
user_api_key_dict: UserAPIKeyAuth, | |
) -> None: | |
verbose_logger.info( | |
f"Storing LiteLLM Managed File object with id={file_id} in cache" | |
) | |
litellm_managed_file_object = LiteLLM_ManagedFileTable( | |
unified_file_id=file_id, | |
file_object=file_object, | |
model_mappings=model_mappings, | |
flat_model_file_ids=list(model_mappings.values()), | |
created_by=user_api_key_dict.user_id, | |
updated_by=user_api_key_dict.user_id, | |
) | |
await self.internal_usage_cache.async_set_cache( | |
key=file_id, | |
value=litellm_managed_file_object.model_dump(), | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) | |
await self.prisma_client.db.litellm_managedfiletable.create( | |
data={ | |
"unified_file_id": file_id, | |
"file_object": file_object.model_dump_json(), | |
"model_mappings": json.dumps(model_mappings), | |
"flat_model_file_ids": list(model_mappings.values()), | |
"created_by": user_api_key_dict.user_id, | |
"updated_by": user_api_key_dict.user_id, | |
} | |
) | |
async def store_unified_object_id( | |
self, | |
unified_object_id: str, | |
file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob], | |
litellm_parent_otel_span: Optional[Span], | |
model_object_id: str, | |
file_purpose: Literal["batch", "fine-tune"], | |
user_api_key_dict: UserAPIKeyAuth, | |
) -> None: | |
verbose_logger.info( | |
f"Storing LiteLLM Managed {file_purpose} object with id={unified_object_id} in cache" | |
) | |
litellm_managed_object = LiteLLM_ManagedObjectTable( | |
unified_object_id=unified_object_id, | |
model_object_id=model_object_id, | |
file_purpose=file_purpose, | |
file_object=file_object, | |
) | |
await self.internal_usage_cache.async_set_cache( | |
key=unified_object_id, | |
value=litellm_managed_object.model_dump(), | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) | |
await self.prisma_client.db.litellm_managedobjecttable.create( | |
data={ | |
"unified_object_id": unified_object_id, | |
"file_object": file_object.model_dump_json(), | |
"model_object_id": model_object_id, | |
"file_purpose": file_purpose, | |
"created_by": user_api_key_dict.user_id, | |
"updated_by": user_api_key_dict.user_id, | |
} | |
) | |
async def get_unified_file_id( | |
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None | |
) -> Optional[LiteLLM_ManagedFileTable]: | |
## CHECK CACHE | |
result = cast( | |
Optional[dict], | |
await self.internal_usage_cache.async_get_cache( | |
key=file_id, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
), | |
) | |
if result: | |
return LiteLLM_ManagedFileTable(**result) | |
## CHECK DB | |
db_object = await self.prisma_client.db.litellm_managedfiletable.find_first( | |
where={"unified_file_id": file_id} | |
) | |
if db_object: | |
return LiteLLM_ManagedFileTable(**db_object.model_dump()) | |
return None | |
async def delete_unified_file_id( | |
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None | |
) -> OpenAIFileObject: | |
## get old value | |
initial_value = await self.prisma_client.db.litellm_managedfiletable.find_first( | |
where={"unified_file_id": file_id} | |
) | |
if initial_value is None: | |
raise Exception(f"LiteLLM Managed File object with id={file_id} not found") | |
## delete old value | |
await self.internal_usage_cache.async_set_cache( | |
key=file_id, | |
value=None, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) | |
await self.prisma_client.db.litellm_managedfiletable.delete( | |
where={"unified_file_id": file_id} | |
) | |
return initial_value.file_object | |
async def can_user_call_unified_file_id( | |
self, unified_file_id: str, user_api_key_dict: UserAPIKeyAuth | |
) -> bool: | |
## check if the user has access to the unified file id | |
user_id = user_api_key_dict.user_id | |
managed_file = await self.prisma_client.db.litellm_managedfiletable.find_first( | |
where={"unified_file_id": unified_file_id} | |
) | |
if managed_file: | |
return managed_file.created_by == user_id | |
return False | |
async def can_user_call_unified_object_id( | |
self, unified_object_id: str, user_api_key_dict: UserAPIKeyAuth | |
) -> bool: | |
## check if the user has access to the unified object id | |
## check if the user has access to the unified object id | |
user_id = user_api_key_dict.user_id | |
managed_object = ( | |
await self.prisma_client.db.litellm_managedobjecttable.find_first( | |
where={"unified_object_id": unified_object_id} | |
) | |
) | |
if managed_object: | |
return managed_object.created_by == user_id | |
return False | |
async def get_user_created_file_ids( | |
self, user_api_key_dict: UserAPIKeyAuth, model_object_ids: List[str] | |
) -> List[OpenAIFileObject]: | |
""" | |
Get all file ids created by the user for a list of model object ids | |
Returns: | |
- List of OpenAIFileObject's | |
""" | |
file_ids = await self.prisma_client.db.litellm_managedfiletable.find_many( | |
where={ | |
"created_by": user_api_key_dict.user_id, | |
"flat_model_file_ids": {"hasSome": model_object_ids}, | |
} | |
) | |
return [OpenAIFileObject(**file_object.file_object) for file_object in file_ids] | |
async def check_managed_file_id_access( | |
self, data: Dict, user_api_key_dict: UserAPIKeyAuth | |
) -> bool: | |
retrieve_file_id = cast(Optional[str], data.get("file_id")) | |
potential_file_id = ( | |
_is_base64_encoded_unified_file_id(retrieve_file_id) | |
if retrieve_file_id | |
else False | |
) | |
if potential_file_id and retrieve_file_id: | |
if await self.can_user_call_unified_file_id( | |
retrieve_file_id, user_api_key_dict | |
): | |
return True | |
else: | |
raise HTTPException( | |
status_code=403, | |
detail=f"User {user_api_key_dict.user_id} does not have access to the file {retrieve_file_id}", | |
) | |
return False | |
async def async_pre_call_hook( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
cache: DualCache, | |
data: Dict, | |
call_type: Literal[ | |
"completion", | |
"text_completion", | |
"embeddings", | |
"image_generation", | |
"moderation", | |
"audio_transcription", | |
"pass_through_endpoint", | |
"rerank", | |
"acreate_batch", | |
"aretrieve_batch", | |
"acreate_file", | |
"afile_list", | |
"afile_delete", | |
"afile_content", | |
"acreate_fine_tuning_job", | |
"aretrieve_fine_tuning_job", | |
"alist_fine_tuning_jobs", | |
"acancel_fine_tuning_job", | |
], | |
) -> Union[Exception, str, Dict, None]: | |
""" | |
- Detect litellm_proxy/ file_id | |
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}} | |
""" | |
### HANDLE FILE ACCESS ### - ensure user has access to the file | |
if ( | |
call_type == CallTypes.afile_content.value | |
or call_type == CallTypes.afile_delete.value | |
): | |
await self.check_managed_file_id_access(data, user_api_key_dict) | |
### HANDLE TRANSFORMATIONS ### | |
if call_type == CallTypes.completion.value: | |
messages = data.get("messages") | |
if messages: | |
file_ids = self.get_file_ids_from_messages(messages) | |
if file_ids: | |
model_file_id_mapping = await self.get_model_file_id_mapping( | |
file_ids, user_api_key_dict.parent_otel_span | |
) | |
data["model_file_id_mapping"] = model_file_id_mapping | |
elif call_type == CallTypes.afile_content.value: | |
retrieve_file_id = cast(Optional[str], data.get("file_id")) | |
potential_file_id = ( | |
_is_base64_encoded_unified_file_id(retrieve_file_id) | |
if retrieve_file_id | |
else False | |
) | |
if potential_file_id: | |
model_id = self.get_model_id_from_unified_file_id(potential_file_id) | |
if model_id: | |
data["model"] = model_id | |
data["file_id"] = self.get_output_file_id_from_unified_file_id( | |
potential_file_id | |
) | |
elif call_type == CallTypes.acreate_batch.value: | |
input_file_id = cast(Optional[str], data.get("input_file_id")) | |
if input_file_id: | |
model_file_id_mapping = await self.get_model_file_id_mapping( | |
[input_file_id], user_api_key_dict.parent_otel_span | |
) | |
data["model_file_id_mapping"] = model_file_id_mapping | |
elif ( | |
call_type == CallTypes.aretrieve_batch.value | |
or call_type == CallTypes.acancel_fine_tuning_job.value | |
or call_type == CallTypes.aretrieve_fine_tuning_job.value | |
): | |
accessor_key: Optional[str] = None | |
retrieve_object_id: Optional[str] = None | |
if call_type == CallTypes.aretrieve_batch.value: | |
accessor_key = "batch_id" | |
elif ( | |
call_type == CallTypes.acancel_fine_tuning_job.value | |
or call_type == CallTypes.aretrieve_fine_tuning_job.value | |
): | |
accessor_key = "fine_tuning_job_id" | |
if accessor_key: | |
retrieve_object_id = cast(Optional[str], data.get(accessor_key)) | |
potential_llm_object_id = ( | |
_is_base64_encoded_unified_file_id(retrieve_object_id) | |
if retrieve_object_id | |
else False | |
) | |
if potential_llm_object_id and retrieve_object_id: | |
## VALIDATE USER HAS ACCESS TO THE OBJECT ## | |
if not await self.can_user_call_unified_object_id( | |
retrieve_object_id, user_api_key_dict | |
): | |
raise HTTPException( | |
status_code=403, | |
detail=f"User {user_api_key_dict.user_id} does not have access to the object {retrieve_object_id}", | |
) | |
## for managed batch id - get the model id | |
potential_model_id = self.get_model_id_from_unified_batch_id( | |
potential_llm_object_id | |
) | |
if potential_model_id is None: | |
raise Exception( | |
f"LiteLLM Managed {accessor_key} with id={retrieve_object_id} is invalid - does not contain encoded model_id." | |
) | |
data["model"] = potential_model_id | |
data[accessor_key] = self.get_batch_id_from_unified_batch_id( | |
potential_llm_object_id | |
) | |
elif call_type == CallTypes.acreate_fine_tuning_job.value: | |
input_file_id = cast(Optional[str], data.get("training_file")) | |
if input_file_id: | |
model_file_id_mapping = await self.get_model_file_id_mapping( | |
[input_file_id], user_api_key_dict.parent_otel_span | |
) | |
return data | |
async def async_pre_call_deployment_hook( | |
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes] | |
) -> Optional[dict]: | |
""" | |
Allow modifying the request just before it's sent to the deployment. | |
""" | |
accessor_key: Optional[str] = None | |
if call_type and call_type == CallTypes.acreate_batch: | |
accessor_key = "input_file_id" | |
elif call_type and call_type == CallTypes.acreate_fine_tuning_job: | |
accessor_key = "training_file" | |
else: | |
return kwargs | |
if accessor_key: | |
input_file_id = cast(Optional[str], kwargs.get(accessor_key)) | |
model_file_id_mapping = cast( | |
Optional[Dict[str, Dict[str, str]]], kwargs.get("model_file_id_mapping") | |
) | |
model_id = cast(Optional[str], kwargs.get("model_info", {}).get("id", None)) | |
mapped_file_id: Optional[str] = None | |
if input_file_id and model_file_id_mapping and model_id: | |
mapped_file_id = model_file_id_mapping.get(input_file_id, {}).get( | |
model_id, None | |
) | |
if mapped_file_id: | |
kwargs[accessor_key] = mapped_file_id | |
return kwargs | |
def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[str]: | |
""" | |
Gets file ids from messages | |
""" | |
file_ids = [] | |
for message in messages: | |
if message.get("role") == "user": | |
content = message.get("content") | |
if content: | |
if isinstance(content, str): | |
continue | |
for c in content: | |
if c["type"] == "file": | |
file_object = cast(ChatCompletionFileObject, c) | |
file_object_file_field = file_object["file"] | |
file_id = file_object_file_field.get("file_id") | |
if file_id: | |
file_ids.append(file_id) | |
return file_ids | |
async def get_model_file_id_mapping( | |
self, file_ids: List[str], litellm_parent_otel_span: Span | |
) -> dict: | |
""" | |
Get model-specific file IDs for a list of proxy file IDs. | |
Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id | |
1. Get all the litellm_proxy/ file_ids from the messages | |
2. For each file_id, search for cache keys matching the pattern file_id:* | |
3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id | |
Example: | |
{ | |
"litellm_proxy/file_id": { | |
"model_id": "model_file_id" | |
} | |
} | |
""" | |
file_id_mapping: Dict[str, Dict[str, str]] = {} | |
litellm_managed_file_ids = [] | |
for file_id in file_ids: | |
## CHECK IF FILE ID IS MANAGED BY LITELM | |
is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id) | |
if is_base64_unified_file_id: | |
litellm_managed_file_ids.append(file_id) | |
if litellm_managed_file_ids: | |
# Get all cache keys matching the pattern file_id:* | |
for file_id in litellm_managed_file_ids: | |
# Search for any cache key starting with this file_id | |
unified_file_object = await self.get_unified_file_id( | |
file_id, litellm_parent_otel_span | |
) | |
if unified_file_object: | |
file_id_mapping[file_id] = unified_file_object.model_mappings | |
return file_id_mapping | |
async def create_file_for_each_model( | |
self, | |
llm_router: Optional[Router], | |
_create_file_request: CreateFileRequest, | |
target_model_names_list: List[str], | |
litellm_parent_otel_span: Span, | |
) -> List[OpenAIFileObject]: | |
if llm_router is None: | |
raise Exception("LLM Router not initialized. Ensure models added to proxy.") | |
responses = [] | |
for model in target_model_names_list: | |
individual_response = await llm_router.acreate_file( | |
model=model, **_create_file_request | |
) | |
responses.append(individual_response) | |
return responses | |
async def acreate_file( | |
self, | |
create_file_request: CreateFileRequest, | |
llm_router: Router, | |
target_model_names_list: List[str], | |
litellm_parent_otel_span: Span, | |
user_api_key_dict: UserAPIKeyAuth, | |
) -> OpenAIFileObject: | |
responses = await self.create_file_for_each_model( | |
llm_router=llm_router, | |
_create_file_request=create_file_request, | |
target_model_names_list=target_model_names_list, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) | |
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id( | |
file_objects=responses, | |
create_file_request=create_file_request, | |
internal_usage_cache=self.internal_usage_cache, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
target_model_names_list=target_model_names_list, | |
) | |
## STORE MODEL MAPPINGS IN DB | |
model_mappings: Dict[str, str] = {} | |
for file_object in responses: | |
model_id = file_object._hidden_params.get("model_id") | |
if model_id is None: | |
verbose_logger.warning( | |
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None" | |
) | |
continue | |
file_id = file_object.id | |
model_mappings[model_id] = file_id | |
await self.store_unified_file_id( | |
file_id=response.id, | |
file_object=response, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
model_mappings=model_mappings, | |
user_api_key_dict=user_api_key_dict, | |
) | |
return response | |
async def return_unified_file_id( | |
file_objects: List[OpenAIFileObject], | |
create_file_request: CreateFileRequest, | |
internal_usage_cache: InternalUsageCache, | |
litellm_parent_otel_span: Span, | |
target_model_names_list: List[str], | |
) -> OpenAIFileObject: | |
## GET THE FILE TYPE FROM THE CREATE FILE REQUEST | |
file_data = extract_file_data(create_file_request["file"]) | |
file_type = file_data["content_type"] | |
output_file_id = file_objects[0].id | |
model_id = file_objects[0]._hidden_params.get("model_id") | |
unified_file_id = SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format( | |
file_type, | |
str(uuid.uuid4()), | |
",".join(target_model_names_list), | |
output_file_id, | |
model_id, | |
) | |
# Convert to URL-safe base64 and strip padding | |
base64_unified_file_id = ( | |
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=") | |
) | |
## CREATE RESPONSE OBJECT | |
response = OpenAIFileObject( | |
id=base64_unified_file_id, | |
object="file", | |
purpose=create_file_request["purpose"], | |
created_at=file_objects[0].created_at, | |
bytes=file_objects[0].bytes, | |
filename=file_objects[0].filename, | |
status="uploaded", | |
) | |
return response | |
def get_unified_generic_response_id( | |
self, model_id: str, generic_response_id: str | |
) -> str: | |
unified_generic_response_id = ( | |
SpecialEnums.LITELLM_MANAGED_GENERIC_RESPONSE_COMPLETE_STR.value.format( | |
model_id, generic_response_id | |
) | |
) | |
return ( | |
base64.urlsafe_b64encode(unified_generic_response_id.encode()) | |
.decode() | |
.rstrip("=") | |
) | |
def get_unified_batch_id(self, batch_id: str, model_id: str) -> str: | |
unified_batch_id = SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format( | |
model_id, batch_id | |
) | |
return base64.urlsafe_b64encode(unified_batch_id.encode()).decode().rstrip("=") | |
def get_unified_output_file_id( | |
self, output_file_id: str, model_id: str, model_name: str | |
) -> str: | |
unified_output_file_id = ( | |
SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format( | |
"application/json", | |
str(uuid.uuid4()), | |
model_name, | |
output_file_id, | |
model_id, | |
) | |
) | |
return ( | |
base64.urlsafe_b64encode(unified_output_file_id.encode()) | |
.decode() | |
.rstrip("=") | |
) | |
def get_model_id_from_unified_file_id(self, file_id: str) -> str: | |
return file_id.split("llm_output_file_model_id,")[1].split(";")[0] | |
def get_output_file_id_from_unified_file_id(self, file_id: str) -> str: | |
return file_id.split("llm_output_file_id,")[1].split(";")[0] | |
def get_model_id_from_unified_batch_id(self, file_id: str) -> Optional[str]: | |
""" | |
Get the model_id from the file_id | |
Expected format: litellm_proxy;model_id:{};llm_batch_id:{};llm_output_file_id:{} | |
""" | |
## use regex to get the model_id from the file_id | |
try: | |
return file_id.split("model_id:")[1].split(";")[0] | |
except Exception: | |
return None | |
def get_batch_id_from_unified_batch_id(self, file_id: str) -> str: | |
## use regex to get the batch_id from the file_id | |
if "llm_batch_id" in file_id: | |
return file_id.split("llm_batch_id:")[1].split(",")[0] | |
else: | |
return file_id.split("generic_response_id:")[1].split(",")[0] | |
async def async_post_call_success_hook( | |
self, data: Dict, user_api_key_dict: UserAPIKeyAuth, response: LLMResponseTypes | |
) -> Any: | |
if isinstance(response, LiteLLMBatch): | |
## Check if unified_file_id is in the response | |
unified_file_id = response._hidden_params.get( | |
"unified_file_id" | |
) # managed file id | |
unified_batch_id = response._hidden_params.get( | |
"unified_batch_id" | |
) # managed batch id | |
model_id = cast(Optional[str], response._hidden_params.get("model_id")) | |
model_name = cast(Optional[str], response._hidden_params.get("model_name")) | |
original_response_id = response.id | |
if (unified_batch_id or unified_file_id) and model_id: | |
response.id = self.get_unified_batch_id( | |
batch_id=response.id, model_id=model_id | |
) | |
if ( | |
response.output_file_id and model_name and model_id | |
): # return a file id with the model_id and output_file_id | |
response.output_file_id = self.get_unified_output_file_id( | |
output_file_id=response.output_file_id, | |
model_id=model_id, | |
model_name=model_name, | |
) | |
asyncio.create_task( | |
self.store_unified_object_id( | |
unified_object_id=response.id, | |
file_object=response, | |
litellm_parent_otel_span=user_api_key_dict.parent_otel_span, | |
model_object_id=original_response_id, | |
file_purpose="batch", | |
user_api_key_dict=user_api_key_dict, | |
) | |
) | |
elif isinstance(response, LiteLLMFineTuningJob): | |
## Check if unified_file_id is in the response | |
unified_file_id = response._hidden_params.get( | |
"unified_file_id" | |
) # managed file id | |
unified_finetuning_job_id = response._hidden_params.get( | |
"unified_finetuning_job_id" | |
) # managed finetuning job id | |
model_id = cast(Optional[str], response._hidden_params.get("model_id")) | |
model_name = cast(Optional[str], response._hidden_params.get("model_name")) | |
original_response_id = response.id | |
if (unified_file_id or unified_finetuning_job_id) and model_id: | |
response.id = self.get_unified_generic_response_id( | |
model_id=model_id, generic_response_id=response.id | |
) | |
asyncio.create_task( | |
self.store_unified_object_id( | |
unified_object_id=response.id, | |
file_object=response, | |
litellm_parent_otel_span=user_api_key_dict.parent_otel_span, | |
model_object_id=original_response_id, | |
file_purpose="fine-tune", | |
user_api_key_dict=user_api_key_dict, | |
) | |
) | |
elif isinstance(response, AsyncCursorPage): | |
""" | |
For listing files, filter for the ones created by the user | |
""" | |
## check if file object | |
if hasattr(response, "data") and isinstance(response.data, list): | |
if all( | |
isinstance(file_object, FileObject) for file_object in response.data | |
): | |
## Get all file id's | |
## Check which file id's were created by the user | |
## Filter the response to only include the files created by the user | |
## Return the filtered response | |
file_ids = [ | |
file_object.id | |
for file_object in cast(List[FileObject], response.data) # type: ignore | |
] | |
user_created_file_ids = await self.get_user_created_file_ids( | |
user_api_key_dict, file_ids | |
) | |
## Filter the response to only include the files created by the user | |
response.data = user_created_file_ids # type: ignore | |
return response | |
return response | |
return response | |
async def afile_retrieve( | |
self, file_id: str, litellm_parent_otel_span: Optional[Span] | |
) -> OpenAIFileObject: | |
stored_file_object = await self.get_unified_file_id( | |
file_id, litellm_parent_otel_span | |
) | |
if stored_file_object: | |
return stored_file_object.file_object | |
else: | |
raise Exception(f"LiteLLM Managed File object with id={file_id} not found") | |
async def afile_list( | |
self, | |
purpose: Optional[OpenAIFilesPurpose], | |
litellm_parent_otel_span: Optional[Span], | |
**data: Dict, | |
) -> List[OpenAIFileObject]: | |
"""Handled in files_endpoints.py""" | |
return [] | |
async def afile_delete( | |
self, | |
file_id: str, | |
litellm_parent_otel_span: Optional[Span], | |
llm_router: Router, | |
**data: Dict, | |
) -> OpenAIFileObject: | |
file_id = convert_b64_uid_to_unified_uid(file_id) | |
model_file_id_mapping = await self.get_model_file_id_mapping( | |
[file_id], litellm_parent_otel_span | |
) | |
specific_model_file_id_mapping = model_file_id_mapping.get(file_id) | |
if specific_model_file_id_mapping: | |
for model_id, file_id in specific_model_file_id_mapping.items(): | |
await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore | |
stored_file_object = await self.delete_unified_file_id( | |
file_id, litellm_parent_otel_span | |
) | |
if stored_file_object: | |
return stored_file_object | |
else: | |
raise Exception(f"LiteLLM Managed File object with id={file_id} not found") | |
async def afile_content( | |
self, | |
file_id: str, | |
litellm_parent_otel_span: Optional[Span], | |
llm_router: Router, | |
**data: Dict, | |
) -> str: | |
""" | |
Get the content of a file from first model that has it | |
""" | |
model_file_id_mapping = await self.get_model_file_id_mapping( | |
[file_id], litellm_parent_otel_span | |
) | |
specific_model_file_id_mapping = model_file_id_mapping.get(file_id) | |
if specific_model_file_id_mapping: | |
exception_dict = {} | |
for model_id, file_id in specific_model_file_id_mapping.items(): | |
try: | |
return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore | |
except Exception as e: | |
exception_dict[model_id] = str(e) | |
raise Exception( | |
f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}" | |
) | |
else: | |
raise Exception(f"LiteLLM Managed File object with id={file_id} not found") | |