DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
import asyncio
import traceback
from typing import List
import orjson
from fastapi import APIRouter, Depends, File, HTTPException, Request, Response, status
from fastapi.responses import ORJSONResponse
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.route_llm_request import route_request
router = APIRouter()
import io
from fastapi import UploadFile
async def uploadfile_to_bytesio(upload: UploadFile) -> io.BytesIO:
"""
Read a FastAPI UploadFile into a BytesIO and set .name so OpenAI SDK
infers filename/content-type correctly.
"""
data = await upload.read()
buffer = io.BytesIO(data)
buffer.name = upload.filename
return buffer
async def batch_to_bytesio(
uploads: Optional[List[UploadFile]],
) -> Optional[List[io.BytesIO]]:
"""
Convert a list of UploadFiles to a list of BytesIO buffers, or None.
"""
if not uploads:
return None
return [await uploadfile_to_bytesio(u) for u in uploads]
@router.post(
"/v1/images/generations",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["images"],
)
@router.post(
"/images/generations",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["images"],
)
@router.post(
"/openai/deployments/{model:path}/images/generations",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["images"],
) # azure compatible endpoint
async def image_generation(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
model: Optional[str] = None,
):
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
user_model,
version,
)
data = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
data["model"] = (
model
or general_settings.get("image_generation_model", None) # server default
or user_model # model name passed via cli args
or data.get("model", None) # default passed in http request
)
if user_model:
data["model"] = user_model
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation"
)
## ROUTE TO CORRECT ENDPOINT ##
llm_call = await route_request(
data=data,
route_type="aimage_generation",
llm_router=llm_router,
user_model=user_model,
)
response = await llm_call
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
fastapi_response.headers.update(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
hidden_params=hidden_params,
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.image_generation(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
openai_code=getattr(e, "code", None),
code=getattr(e, "status_code", 500),
)
@router.post(
"/v1/images/edits",
dependencies=[Depends(user_api_key_auth)],
tags=["images"],
)
@router.post(
"/images/edits",
dependencies=[Depends(user_api_key_auth)],
tags=["images"],
)
@router.post(
"/openai/deployments/{model:path}/images/edits",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["images"],
) # azure compatible endpoint
async def image_edit_api(
request: Request,
fastapi_response: Response,
image: List[UploadFile] = File(...),
mask: Optional[List[UploadFile]] = File(None),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
model: Optional[str] = None,
):
"""
Follows the OpenAI Images API spec: https://platform.openai.com/docs/api-reference/images/create
```bash
curl -s -D >(grep -i x-request-id >&2) \
-o >(jq -r '.data[0].b64_json' | base64 --decode > gift-basket.png) \
-X POST "http://localhost:4000/v1/images/edits" \
-H "Authorization: Bearer sk-1234" \
-F "model=gpt-image-1" \
-F "image[][email protected]" \
-F 'prompt=Create a studio ghibli image of this'
```
"""
from litellm.proxy.proxy_server import (
_read_request_body,
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
#########################################################
# Read request body and convert UploadFiles to BytesIO
#########################################################
data = await _read_request_body(request=request)
image_files = await batch_to_bytesio(image)
mask_files = await batch_to_bytesio(mask)
if image_files:
data["image"] = image_files
if mask_files:
data["mask"] = mask_files
data["model"] = (
model
or general_settings.get("image_generation_model", None) # server default
or user_model # model name passed via cli args
or data.get("model", None) # default passed in http request
)
#########################################################
# Process request
#########################################################
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="aimage_edit",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)