import asyncio import traceback from typing import List import orjson from fastapi import APIRouter, Depends, File, HTTPException, Request, Response, status from fastapi.responses import ORJSONResponse import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing from litellm.proxy.route_llm_request import route_request router = APIRouter() import io from fastapi import UploadFile async def uploadfile_to_bytesio(upload: UploadFile) -> io.BytesIO: """ Read a FastAPI UploadFile into a BytesIO and set .name so OpenAI SDK infers filename/content-type correctly. """ data = await upload.read() buffer = io.BytesIO(data) buffer.name = upload.filename return buffer async def batch_to_bytesio( uploads: Optional[List[UploadFile]], ) -> Optional[List[io.BytesIO]]: """ Convert a list of UploadFiles to a list of BytesIO buffers, or None. """ if not uploads: return None return [await uploadfile_to_bytesio(u) for u in uploads] @router.post( "/v1/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["images"], ) @router.post( "/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["images"], ) @router.post( "/openai/deployments/{model:path}/images/generations", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["images"], ) # azure compatible endpoint async def image_generation( request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), model: Optional[str] = None, ): from litellm.proxy.proxy_server import ( add_litellm_data_to_request, general_settings, llm_router, proxy_config, proxy_logging_obj, user_model, version, ) data = {} try: # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() data = orjson.loads(body) # Include original request and headers in the data data = await add_litellm_data_to_request( data=data, request=request, general_settings=general_settings, user_api_key_dict=user_api_key_dict, version=version, proxy_config=proxy_config, ) data["model"] = ( model or general_settings.get("image_generation_model", None) # server default or user_model # model name passed via cli args or data.get("model", None) # default passed in http request ) if user_model: data["model"] = user_model ### MODEL ALIAS MAPPING ### # check if model name in model alias map # get the actual model name if data["model"] in litellm.model_alias_map: data["model"] = litellm.model_alias_map[data["model"]] ### CALL HOOKS ### - modify incoming data / reject request before calling the model data = await proxy_logging_obj.pre_call_hook( user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation" ) ## ROUTE TO CORRECT ENDPOINT ## llm_call = await route_request( data=data, route_type="aimage_generation", llm_router=llm_router, user_model=user_model, ) response = await llm_call ### ALERTING ### asyncio.create_task( proxy_logging_obj.update_request_status( litellm_call_id=data.get("litellm_call_id", ""), status="success" ) ) ### RESPONSE HEADERS ### hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" response_cost = hidden_params.get("response_cost", None) or "" litellm_call_id = hidden_params.get("litellm_call_id", None) or "" fastapi_response.headers.update( ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=user_api_key_dict, model_id=model_id, cache_key=cache_key, api_base=api_base, version=version, response_cost=response_cost, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), call_id=litellm_call_id, request_data=data, hidden_params=hidden_params, ) ) return response except Exception as e: await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) verbose_proxy_logger.error( "litellm.proxy.proxy_server.image_generation(): Exception occured - {}".format( str(e) ) ) verbose_proxy_logger.debug(traceback.format_exc()) if isinstance(e, HTTPException): raise ProxyException( message=getattr(e, "message", str(e)), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), ) else: error_msg = f"{str(e)}" raise ProxyException( message=getattr(e, "message", error_msg), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), openai_code=getattr(e, "code", None), code=getattr(e, "status_code", 500), ) @router.post( "/v1/images/edits", dependencies=[Depends(user_api_key_auth)], tags=["images"], ) @router.post( "/images/edits", dependencies=[Depends(user_api_key_auth)], tags=["images"], ) @router.post( "/openai/deployments/{model:path}/images/edits", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse, tags=["images"], ) # azure compatible endpoint async def image_edit_api( request: Request, fastapi_response: Response, image: List[UploadFile] = File(...), mask: Optional[List[UploadFile]] = File(None), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), model: Optional[str] = None, ): """ Follows the OpenAI Images API spec: https://platform.openai.com/docs/api-reference/images/create ```bash curl -s -D >(grep -i x-request-id >&2) \ -o >(jq -r '.data[0].b64_json' | base64 --decode > gift-basket.png) \ -X POST "http://localhost:4000/v1/images/edits" \ -H "Authorization: Bearer sk-1234" \ -F "model=gpt-image-1" \ -F "image[]=@soap.png" \ -F 'prompt=Create a studio ghibli image of this' ``` """ from litellm.proxy.proxy_server import ( _read_request_body, general_settings, llm_router, proxy_config, proxy_logging_obj, select_data_generator, user_api_base, user_max_tokens, user_model, user_request_timeout, user_temperature, version, ) ######################################################### # Read request body and convert UploadFiles to BytesIO ######################################################### data = await _read_request_body(request=request) image_files = await batch_to_bytesio(image) mask_files = await batch_to_bytesio(mask) if image_files: data["image"] = image_files if mask_files: data["mask"] = mask_files data["model"] = ( model or general_settings.get("image_generation_model", None) # server default or user_model # model name passed via cli args or data.get("model", None) # default passed in http request ) ######################################################### # Process request ######################################################### processor = ProxyBaseLLMRequestProcessing(data=data) try: return await processor.base_process_llm_request( request=request, fastapi_response=fastapi_response, user_api_key_dict=user_api_key_dict, route_type="aimage_edit", proxy_logging_obj=proxy_logging_obj, llm_router=llm_router, general_settings=general_settings, proxy_config=proxy_config, select_data_generator=select_data_generator, model=None, user_model=user_model, user_temperature=user_temperature, user_request_timeout=user_request_timeout, user_max_tokens=user_max_tokens, user_api_base=user_api_base, version=version, ) except Exception as e: raise await processor._handle_llm_api_exception( e=e, user_api_key_dict=user_api_key_dict, proxy_logging_obj=proxy_logging_obj, version=version, )