Spaces:
Configuration error
Configuration error
import asyncio | |
import traceback | |
from typing import List | |
import orjson | |
from fastapi import APIRouter, Depends, File, HTTPException, Request, Response, status | |
from fastapi.responses import ORJSONResponse | |
import litellm | |
from litellm._logging import verbose_proxy_logger | |
from litellm.proxy._types import * | |
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth | |
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing | |
from litellm.proxy.route_llm_request import route_request | |
router = APIRouter() | |
import io | |
from fastapi import UploadFile | |
async def uploadfile_to_bytesio(upload: UploadFile) -> io.BytesIO: | |
""" | |
Read a FastAPI UploadFile into a BytesIO and set .name so OpenAI SDK | |
infers filename/content-type correctly. | |
""" | |
data = await upload.read() | |
buffer = io.BytesIO(data) | |
buffer.name = upload.filename | |
return buffer | |
async def batch_to_bytesio( | |
uploads: Optional[List[UploadFile]], | |
) -> Optional[List[io.BytesIO]]: | |
""" | |
Convert a list of UploadFiles to a list of BytesIO buffers, or None. | |
""" | |
if not uploads: | |
return None | |
return [await uploadfile_to_bytesio(u) for u in uploads] | |
# azure compatible endpoint | |
async def image_generation( | |
request: Request, | |
fastapi_response: Response, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
model: Optional[str] = None, | |
): | |
from litellm.proxy.proxy_server import ( | |
add_litellm_data_to_request, | |
general_settings, | |
llm_router, | |
proxy_config, | |
proxy_logging_obj, | |
user_model, | |
version, | |
) | |
data = {} | |
try: | |
# Use orjson to parse JSON data, orjson speeds up requests significantly | |
body = await request.body() | |
data = orjson.loads(body) | |
# Include original request and headers in the data | |
data = await add_litellm_data_to_request( | |
data=data, | |
request=request, | |
general_settings=general_settings, | |
user_api_key_dict=user_api_key_dict, | |
version=version, | |
proxy_config=proxy_config, | |
) | |
data["model"] = ( | |
model | |
or general_settings.get("image_generation_model", None) # server default | |
or user_model # model name passed via cli args | |
or data.get("model", None) # default passed in http request | |
) | |
if user_model: | |
data["model"] = user_model | |
### MODEL ALIAS MAPPING ### | |
# check if model name in model alias map | |
# get the actual model name | |
if data["model"] in litellm.model_alias_map: | |
data["model"] = litellm.model_alias_map[data["model"]] | |
### CALL HOOKS ### - modify incoming data / reject request before calling the model | |
data = await proxy_logging_obj.pre_call_hook( | |
user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation" | |
) | |
## ROUTE TO CORRECT ENDPOINT ## | |
llm_call = await route_request( | |
data=data, | |
route_type="aimage_generation", | |
llm_router=llm_router, | |
user_model=user_model, | |
) | |
response = await llm_call | |
### ALERTING ### | |
asyncio.create_task( | |
proxy_logging_obj.update_request_status( | |
litellm_call_id=data.get("litellm_call_id", ""), status="success" | |
) | |
) | |
### RESPONSE HEADERS ### | |
hidden_params = getattr(response, "_hidden_params", {}) or {} | |
model_id = hidden_params.get("model_id", None) or "" | |
cache_key = hidden_params.get("cache_key", None) or "" | |
api_base = hidden_params.get("api_base", None) or "" | |
response_cost = hidden_params.get("response_cost", None) or "" | |
litellm_call_id = hidden_params.get("litellm_call_id", None) or "" | |
fastapi_response.headers.update( | |
ProxyBaseLLMRequestProcessing.get_custom_headers( | |
user_api_key_dict=user_api_key_dict, | |
model_id=model_id, | |
cache_key=cache_key, | |
api_base=api_base, | |
version=version, | |
response_cost=response_cost, | |
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), | |
call_id=litellm_call_id, | |
request_data=data, | |
hidden_params=hidden_params, | |
) | |
) | |
return response | |
except Exception as e: | |
await proxy_logging_obj.post_call_failure_hook( | |
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data | |
) | |
verbose_proxy_logger.error( | |
"litellm.proxy.proxy_server.image_generation(): Exception occured - {}".format( | |
str(e) | |
) | |
) | |
verbose_proxy_logger.debug(traceback.format_exc()) | |
if isinstance(e, HTTPException): | |
raise ProxyException( | |
message=getattr(e, "message", str(e)), | |
type=getattr(e, "type", "None"), | |
param=getattr(e, "param", "None"), | |
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), | |
) | |
else: | |
error_msg = f"{str(e)}" | |
raise ProxyException( | |
message=getattr(e, "message", error_msg), | |
type=getattr(e, "type", "None"), | |
param=getattr(e, "param", "None"), | |
openai_code=getattr(e, "code", None), | |
code=getattr(e, "status_code", 500), | |
) | |
# azure compatible endpoint | |
async def image_edit_api( | |
request: Request, | |
fastapi_response: Response, | |
image: List[UploadFile] = File(...), | |
mask: Optional[List[UploadFile]] = File(None), | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
model: Optional[str] = None, | |
): | |
""" | |
Follows the OpenAI Images API spec: https://platform.openai.com/docs/api-reference/images/create | |
```bash | |
curl -s -D >(grep -i x-request-id >&2) \ | |
-o >(jq -r '.data[0].b64_json' | base64 --decode > gift-basket.png) \ | |
-X POST "http://localhost:4000/v1/images/edits" \ | |
-H "Authorization: Bearer sk-1234" \ | |
-F "model=gpt-image-1" \ | |
-F "image[][email protected]" \ | |
-F 'prompt=Create a studio ghibli image of this' | |
``` | |
""" | |
from litellm.proxy.proxy_server import ( | |
_read_request_body, | |
general_settings, | |
llm_router, | |
proxy_config, | |
proxy_logging_obj, | |
select_data_generator, | |
user_api_base, | |
user_max_tokens, | |
user_model, | |
user_request_timeout, | |
user_temperature, | |
version, | |
) | |
######################################################### | |
# Read request body and convert UploadFiles to BytesIO | |
######################################################### | |
data = await _read_request_body(request=request) | |
image_files = await batch_to_bytesio(image) | |
mask_files = await batch_to_bytesio(mask) | |
if image_files: | |
data["image"] = image_files | |
if mask_files: | |
data["mask"] = mask_files | |
data["model"] = ( | |
model | |
or general_settings.get("image_generation_model", None) # server default | |
or user_model # model name passed via cli args | |
or data.get("model", None) # default passed in http request | |
) | |
######################################################### | |
# Process request | |
######################################################### | |
processor = ProxyBaseLLMRequestProcessing(data=data) | |
try: | |
return await processor.base_process_llm_request( | |
request=request, | |
fastapi_response=fastapi_response, | |
user_api_key_dict=user_api_key_dict, | |
route_type="aimage_edit", | |
proxy_logging_obj=proxy_logging_obj, | |
llm_router=llm_router, | |
general_settings=general_settings, | |
proxy_config=proxy_config, | |
select_data_generator=select_data_generator, | |
model=None, | |
user_model=user_model, | |
user_temperature=user_temperature, | |
user_request_timeout=user_request_timeout, | |
user_max_tokens=user_max_tokens, | |
user_api_base=user_api_base, | |
version=version, | |
) | |
except Exception as e: | |
raise await processor._handle_llm_api_exception( | |
e=e, | |
user_api_key_dict=user_api_key_dict, | |
proxy_logging_obj=proxy_logging_obj, | |
version=version, | |
) | |