Spaces:
Sleeping
Sleeping
File size: 7,144 Bytes
469eae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import json
from typing import Any, Dict, List, Optional
import orjson
from fastapi import Request, UploadFile, status
from litellm._logging import verbose_proxy_logger
from litellm.types.router import Deployment
async def _read_request_body(request: Optional[Request]) -> Dict:
"""
Safely read the request body and parse it as JSON.
Parameters:
- request: The request object to read the body from
Returns:
- dict: Parsed request data as a dictionary or an empty dictionary if parsing fails
"""
try:
if request is None:
return {}
# Check if we already read and parsed the body
_cached_request_body: Optional[dict] = _safe_get_request_parsed_body(
request=request
)
if _cached_request_body is not None:
return _cached_request_body
_request_headers: dict = _safe_get_request_headers(request=request)
content_type = _request_headers.get("content-type", "")
if "form" in content_type:
parsed_body = dict(await request.form())
else:
# Read the request body
body = await request.body()
# Return empty dict if body is empty or None
if not body:
parsed_body = {}
else:
try:
parsed_body = orjson.loads(body)
except orjson.JSONDecodeError:
# Fall back to the standard json module which is more forgiving
# First decode bytes to string if needed
body_str = body.decode("utf-8") if isinstance(body, bytes) else body
# Replace invalid surrogate pairs
import re
# This regex finds incomplete surrogate pairs
body_str = re.sub(
r"[\uD800-\uDBFF](?![\uDC00-\uDFFF])", "", body_str
)
# This regex finds low surrogates without high surrogates
body_str = re.sub(
r"(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]", "", body_str
)
parsed_body = json.loads(body_str)
# Cache the parsed result
_safe_set_request_parsed_body(request=request, parsed_body=parsed_body)
return parsed_body
except (json.JSONDecodeError, orjson.JSONDecodeError):
verbose_proxy_logger.exception("Invalid JSON payload received.")
return {}
except Exception as e:
# Catch unexpected errors to avoid crashes
verbose_proxy_logger.exception(
"Unexpected error reading request body - {}".format(e)
)
return {}
def _safe_get_request_parsed_body(request: Optional[Request]) -> Optional[dict]:
if request is None:
return None
if (
hasattr(request, "scope")
and "parsed_body" in request.scope
and isinstance(request.scope["parsed_body"], tuple)
):
accepted_keys, parsed_body = request.scope["parsed_body"]
return {key: parsed_body[key] for key in accepted_keys}
return None
def _safe_set_request_parsed_body(
request: Optional[Request],
parsed_body: dict,
) -> None:
try:
if request is None:
return
request.scope["parsed_body"] = (tuple(parsed_body.keys()), parsed_body)
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error setting request parsed body - {}".format(e)
)
def _safe_get_request_headers(request: Optional[Request]) -> dict:
"""
[Non-Blocking] Safely get the request headers
"""
try:
if request is None:
return {}
return dict(request.headers)
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error reading request headers - {}".format(e)
)
return {}
def check_file_size_under_limit(
request_data: dict,
file: UploadFile,
router_model_names: List[str],
) -> bool:
"""
Check if any files passed in request are under max_file_size_mb
Returns True -> when file size is under max_file_size_mb limit
Raises ProxyException -> when file size is over max_file_size_mb limit or not a premium_user
"""
from litellm.proxy.proxy_server import (
CommonProxyErrors,
ProxyException,
llm_router,
premium_user,
)
file_contents_size = file.size or 0
file_content_size_in_mb = file_contents_size / (1024 * 1024)
if "metadata" not in request_data:
request_data["metadata"] = {}
request_data["metadata"]["file_size_in_mb"] = file_content_size_in_mb
max_file_size_mb = None
if llm_router is not None and request_data["model"] in router_model_names:
try:
deployment: Optional[Deployment] = (
llm_router.get_deployment_by_model_group_name(
model_group_name=request_data["model"]
)
)
if (
deployment
and deployment.litellm_params is not None
and deployment.litellm_params.max_file_size_mb is not None
):
max_file_size_mb = deployment.litellm_params.max_file_size_mb
except Exception as e:
verbose_proxy_logger.error(
"Got error when checking file size: %s", (str(e))
)
if max_file_size_mb is not None:
verbose_proxy_logger.debug(
"Checking file size, file content size=%s, max_file_size_mb=%s",
file_content_size_in_mb,
max_file_size_mb,
)
if not premium_user:
raise ProxyException(
message=f"Tried setting max_file_size_mb for /audio/transcriptions. {CommonProxyErrors.not_premium_user.value}",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)
if file_content_size_in_mb > max_file_size_mb:
raise ProxyException(
message=f"File size is too large. Please check your file size. Passed file size: {file_content_size_in_mb} MB. Max file size: {max_file_size_mb} MB",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)
return True
async def get_form_data(request: Request) -> Dict[str, Any]:
"""
Read form data from request
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
"""
form = await request.form()
form_data = dict(form)
parsed_form_data: dict[str, Any] = {}
for key, value in form_data.items():
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
if key.endswith("[]"):
clean_key = key[:-2]
parsed_form_data.setdefault(clean_key, []).append(value)
else:
parsed_form_data[key] = value
return parsed_form_data
|