import asyncio import contextvars from functools import partial from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast import httpx import litellm from litellm import Logging, client, exception_type, get_litellm_params from litellm.constants import DEFAULT_IMAGE_ENDPOINT_MODEL from litellm.constants import request_timeout as DEFAULT_REQUEST_TIMEOUT from litellm.exceptions import LiteLLMUnknownProvider from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.mock_functions import mock_image_generation from litellm.llms.base_llm import BaseImageEditConfig, BaseImageGenerationConfig from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_llm import CustomLLM #################### Initialize provider clients #################### from litellm.main import ( azure_chat_completions, base_llm_aiohttp_handler, base_llm_http_handler, bedrock_image_generation, openai_chat_completions, openai_image_variations, vertex_image_generation, ) from litellm.secret_managers.main import get_secret_str from litellm.types.images.main import ImageEditOptionalRequestParams from litellm.types.llms.openai import ImageGenerationRequestQuality from litellm.types.router import GenericLiteLLMParams from litellm.types.utils import ( LITELLM_IMAGE_VARIATION_PROVIDERS, FileTypes, LlmProviders, all_litellm_params, ) from litellm.utils import ( ImageResponse, ProviderConfigManager, get_llm_provider, get_optional_params_image_gen, ) from .utils import ImageEditRequestUtils ##### Image Generation ####################### @client async def aimage_generation(*args, **kwargs) -> ImageResponse: """ Asynchronously calls the `image_generation` function with the given arguments and keyword arguments. Parameters: - `args` (tuple): Positional arguments to be passed to the `image_generation` function. - `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function. Returns: - `response` (Any): The response returned by the `image_generation` function. """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] ### PASS ARGS TO Image Generation ### kwargs["aimg_generation"] = True custom_llm_provider = None try: # Use a partial function to pass your keyword arguments func = partial(image_generation, *args, **kwargs) # Add the context to the function ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) _, custom_llm_provider, _, _ = get_llm_provider( model=model, api_base=kwargs.get("api_base", None) ) # Await normally init_response = await loop.run_in_executor(None, func_with_context) if isinstance(init_response, dict) or isinstance( init_response, ImageResponse ): ## CACHING SCENARIO if isinstance(init_response, dict): init_response = ImageResponse(**init_response) response = init_response elif asyncio.iscoroutine(init_response): response = await init_response # type: ignore else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) return response except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, extra_kwargs=kwargs, ) @client def image_generation( # noqa: PLR0915 prompt: str, model: Optional[str] = None, n: Optional[int] = None, quality: Optional[Union[str, ImageGenerationRequestQuality]] = None, response_format: Optional[str] = None, size: Optional[str] = None, style: Optional[str] = None, user: Optional[str] = None, timeout=600, # default to 10 minutes api_key: Optional[str] = None, api_base: Optional[str] = None, api_version: Optional[str] = None, custom_llm_provider=None, **kwargs, ) -> ImageResponse: """ Maps the https://api.openai.com/v1/images/generations endpoint. Currently supports just Azure + OpenAI. """ try: args = locals() aimg_generation = kwargs.get("aimg_generation", False) litellm_call_id = kwargs.get("litellm_call_id", None) logger_fn = kwargs.get("logger_fn", None) mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore proxy_server_request = kwargs.get("proxy_server_request", None) azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore client = kwargs.get("client", None) extra_headers = kwargs.get("extra_headers", None) headers: dict = kwargs.get("headers", None) or {} base_model = kwargs.get("base_model", None) if extra_headers is not None: headers.update(extra_headers) model_response: ImageResponse = litellm.utils.ImageResponse() dynamic_api_key: Optional[str] = None if model is not None or custom_llm_provider is not None: model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( model=model, # type: ignore custom_llm_provider=custom_llm_provider, api_base=api_base, ) else: model = "dall-e-2" custom_llm_provider = "openai" # default to dall-e-2 on openai model_response._hidden_params["model"] = model openai_params = [ "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style", ] litellm_params = all_litellm_params default_params = openai_params + litellm_params non_default_params = { k: v for k, v in kwargs.items() if k not in default_params } # model-specific params - pass them straight to the model/provider image_generation_config: Optional[BaseImageGenerationConfig] = None if ( custom_llm_provider is not None and custom_llm_provider in LlmProviders._member_map_.values() ): image_generation_config = ( ProviderConfigManager.get_provider_image_generation_config( model=base_model or model, provider=LlmProviders(custom_llm_provider), ) ) optional_params = get_optional_params_image_gen( model=base_model or model, n=n, quality=quality, response_format=response_format, size=size, style=style, user=user, custom_llm_provider=custom_llm_provider, provider_config=image_generation_config, **non_default_params, ) litellm_params_dict = get_litellm_params(**kwargs) logging: Logging = litellm_logging_obj logging.update_environment_variables( model=model, user=user, optional_params=optional_params, litellm_params={ "timeout": timeout, "azure": False, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "preset_cache_key": None, "stream_response": {}, }, custom_llm_provider=custom_llm_provider, ) if "custom_llm_provider" not in logging.model_call_details: logging.model_call_details["custom_llm_provider"] = custom_llm_provider if mock_response is not None: return mock_image_generation(model=model, mock_response=mock_response) if custom_llm_provider == "azure": # azure configs api_type = get_secret_str("AZURE_API_TYPE") or "azure" api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") api_version = ( api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION") ) api_key = ( api_key or litellm.api_key or litellm.azure_key or get_secret_str("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_API_KEY") ) azure_ad_token = optional_params.pop( "azure_ad_token", None ) or get_secret_str("AZURE_AD_TOKEN") default_headers = { "Content-Type": "application/json;", "api-key": api_key, } for k, v in default_headers.items(): if k not in headers: headers[k] = v model_response = azure_chat_completions.image_generation( model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, azure_ad_token=azure_ad_token, azure_ad_token_provider=azure_ad_token_provider, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response=model_response, api_version=api_version, aimg_generation=aimg_generation, client=client, headers=headers, litellm_params=litellm_params_dict, ) elif ( custom_llm_provider == "openai" or custom_llm_provider in litellm.openai_compatible_providers ): model_response = openai_chat_completions.image_generation( model=model, prompt=prompt, timeout=timeout, api_key=api_key or dynamic_api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response=model_response, aimg_generation=aimg_generation, client=client, ) elif custom_llm_provider == "bedrock": if model is None: raise Exception("Model needs to be set for bedrock") model_response = bedrock_image_generation.image_generation( # type: ignore model=model, prompt=prompt, timeout=timeout, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response=model_response, aimg_generation=aimg_generation, client=client, ) elif custom_llm_provider == "vertex_ai": vertex_ai_project = ( optional_params.pop("vertex_project", None) or optional_params.pop("vertex_ai_project", None) or litellm.vertex_project or get_secret_str("VERTEXAI_PROJECT") ) vertex_ai_location = ( optional_params.pop("vertex_location", None) or optional_params.pop("vertex_ai_location", None) or litellm.vertex_location or get_secret_str("VERTEXAI_LOCATION") ) vertex_credentials = ( optional_params.pop("vertex_credentials", None) or optional_params.pop("vertex_ai_credentials", None) or get_secret_str("VERTEXAI_CREDENTIALS") ) api_base = ( api_base or litellm.api_base or get_secret_str("VERTEXAI_API_BASE") or get_secret_str("VERTEX_API_BASE") ) model_response = vertex_image_generation.image_generation( model=model, prompt=prompt, timeout=timeout, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response=model_response, vertex_project=vertex_ai_project, vertex_location=vertex_ai_location, vertex_credentials=vertex_credentials, aimg_generation=aimg_generation, api_base=api_base, client=client, ) elif ( custom_llm_provider in litellm._custom_providers ): # Assume custom LLM provider # Get the Custom Handler custom_handler: Optional[CustomLLM] = None for item in litellm.custom_provider_map: if item["provider"] == custom_llm_provider: custom_handler = item["custom_handler"] if custom_handler is None: raise LiteLLMUnknownProvider( model=model, custom_llm_provider=custom_llm_provider ) ## ROUTE LLM CALL ## if aimg_generation is True: async_custom_client: Optional[AsyncHTTPHandler] = None if client is not None and isinstance(client, AsyncHTTPHandler): async_custom_client = client ## CALL FUNCTION model_response = custom_handler.aimage_generation( # type: ignore model=model, prompt=prompt, api_key=api_key, api_base=api_base, model_response=model_response, optional_params=optional_params, logging_obj=litellm_logging_obj, timeout=timeout, client=async_custom_client, ) else: custom_client: Optional[HTTPHandler] = None if client is not None and isinstance(client, HTTPHandler): custom_client = client ## CALL FUNCTION model_response = custom_handler.image_generation( model=model, prompt=prompt, api_key=api_key, api_base=api_base, model_response=model_response, optional_params=optional_params, logging_obj=litellm_logging_obj, timeout=timeout, client=custom_client, ) return model_response except Exception as e: ## Map to OpenAI Exception raise exception_type( model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=locals(), extra_kwargs=kwargs, ) @client async def aimage_variation(*args, **kwargs) -> ImageResponse: """ Asynchronously calls the `image_variation` function with the given arguments and keyword arguments. Parameters: - `args` (tuple): Positional arguments to be passed to the `image_variation` function. - `kwargs` (dict): Keyword arguments to be passed to the `image_variation` function. Returns: - `response` (Any): The response returned by the `image_variation` function. """ loop = asyncio.get_event_loop() model = kwargs.get("model", None) custom_llm_provider = kwargs.get("custom_llm_provider", None) ### PASS ARGS TO Image Generation ### kwargs["async_call"] = True try: # Use a partial function to pass your keyword arguments func = partial(image_variation, *args, **kwargs) # Add the context to the function ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) if custom_llm_provider is None and model is not None: _, custom_llm_provider, _, _ = get_llm_provider( model=model, api_base=kwargs.get("api_base", None) ) # Await normally init_response = await loop.run_in_executor(None, func_with_context) if isinstance(init_response, dict) or isinstance( init_response, ImageResponse ): ## CACHING SCENARIO if isinstance(init_response, dict): init_response = ImageResponse(**init_response) response = init_response elif asyncio.iscoroutine(init_response): response = await init_response # type: ignore else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) return response except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, extra_kwargs=kwargs, ) @client def image_variation( image: FileTypes, model: str = "dall-e-2", # set to dall-e-2 by default - like OpenAI. n: int = 1, response_format: Literal["url", "b64_json"] = "url", size: Optional[str] = None, user: Optional[str] = None, **kwargs, ) -> ImageResponse: # get non-default params client = kwargs.get("client", None) # get logging object litellm_logging_obj = cast(LiteLLMLoggingObj, kwargs.get("litellm_logging_obj")) # get the litellm params litellm_params = get_litellm_params(**kwargs) # get the custom llm provider model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( model=model, custom_llm_provider=litellm_params.get("custom_llm_provider", None), api_base=litellm_params.get("api_base", None), api_key=litellm_params.get("api_key", None), ) # route to the correct provider w/ the params try: llm_provider = LlmProviders(custom_llm_provider) image_variation_provider = LITELLM_IMAGE_VARIATION_PROVIDERS(llm_provider) except ValueError: raise ValueError( f"Invalid image variation provider: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}" ) model_response = ImageResponse() response: Optional[ImageResponse] = None provider_config = ProviderConfigManager.get_provider_model_info( model=model or "", # openai defaults to dall-e-2 provider=llm_provider, ) if provider_config is None: raise ValueError( f"image variation provider has no known model info config - required for getting api keys, etc.: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}" ) api_key = provider_config.get_api_key(litellm_params.get("api_key", None)) api_base = provider_config.get_api_base(litellm_params.get("api_base", None)) if image_variation_provider == LITELLM_IMAGE_VARIATION_PROVIDERS.OPENAI: if api_key is None: raise ValueError("API key is required for OpenAI image variations") if api_base is None: raise ValueError("API base is required for OpenAI image variations") response = openai_image_variations.image_variations( model_response=model_response, api_key=api_key, api_base=api_base, model=model, image=image, timeout=litellm_params.get("timeout", None), custom_llm_provider=custom_llm_provider, logging_obj=litellm_logging_obj, optional_params={}, litellm_params=litellm_params, ) elif image_variation_provider == LITELLM_IMAGE_VARIATION_PROVIDERS.TOPAZ: if api_key is None: raise ValueError("API key is required for Topaz image variations") if api_base is None: raise ValueError("API base is required for Topaz image variations") response = base_llm_aiohttp_handler.image_variations( model_response=model_response, api_key=api_key, api_base=api_base, model=model, image=image, timeout=litellm_params.get("timeout", None) or DEFAULT_REQUEST_TIMEOUT, custom_llm_provider=custom_llm_provider, logging_obj=litellm_logging_obj, optional_params={}, litellm_params=litellm_params, client=client, ) # return the response if response is None: raise ValueError( f"Invalid image variation provider: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}" ) return response @client def image_edit( image: FileTypes, prompt: str, model: Optional[str] = None, mask: Optional[str] = None, n: Optional[int] = None, quality: Optional[Union[str, ImageGenerationRequestQuality]] = None, response_format: Optional[str] = None, size: Optional[str] = None, user: Optional[str] = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Optional[Dict[str, Any]] = None, extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None, # LiteLLM specific params, custom_llm_provider: Optional[str] = None, **kwargs, ) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse]]: """ Maps the image edit functionality, similar to OpenAI's images/edits endpoint. """ local_vars = locals() try: litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) _is_async = kwargs.pop("async_call", False) is True # get llm provider logic litellm_params = GenericLiteLLMParams(**kwargs) model, custom_llm_provider, _, _ = get_llm_provider( model=model or DEFAULT_IMAGE_ENDPOINT_MODEL, custom_llm_provider=custom_llm_provider, ) # get provider config image_edit_provider_config: Optional[ BaseImageEditConfig ] = ProviderConfigManager.get_provider_image_edit_config( model=model, provider=litellm.LlmProviders(custom_llm_provider), ) if image_edit_provider_config is None: raise ValueError(f"image edit is not supported for {custom_llm_provider}") local_vars.update(kwargs) # Get ImageEditOptionalRequestParams with only valid parameters image_edit_optional_params: ImageEditOptionalRequestParams = ( ImageEditRequestUtils.get_requested_image_edit_optional_param(local_vars) ) # Get optional parameters for the responses API image_edit_request_params: Dict = ( ImageEditRequestUtils.get_optional_params_image_edit( model=model, image_edit_provider_config=image_edit_provider_config, image_edit_optional_params=image_edit_optional_params, ) ) # Pre Call logging litellm_logging_obj.update_environment_variables( model=model, user=user, optional_params=dict(image_edit_request_params), litellm_params={ "litellm_call_id": litellm_call_id, **image_edit_request_params, }, custom_llm_provider=custom_llm_provider, ) # Call the handler with _is_async flag instead of directly calling the async handler return base_llm_http_handler.image_edit_handler( model=model, image=image, prompt=prompt, image_edit_provider_config=image_edit_provider_config, image_edit_optional_request_params=image_edit_request_params, custom_llm_provider=custom_llm_provider, litellm_params=litellm_params, logging_obj=litellm_logging_obj, extra_headers=extra_headers, extra_body=extra_body, timeout=timeout or DEFAULT_REQUEST_TIMEOUT, _is_async=_is_async, client=kwargs.get("client"), ) except Exception as e: raise litellm.exception_type( model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=local_vars, extra_kwargs=kwargs, ) @client async def aimage_edit( image: FileTypes, model: str, prompt: str, mask: Optional[str] = None, n: Optional[int] = None, quality: Optional[Union[str, ImageGenerationRequestQuality]] = None, response_format: Optional[str] = None, size: Optional[str] = None, user: Optional[str] = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Optional[Dict[str, Any]] = None, extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None, # LiteLLM specific params, custom_llm_provider: Optional[str] = None, **kwargs, ) -> ImageResponse: """ Asynchronously calls the `image_edit` function with the given arguments and keyword arguments. Parameters: - `args` (tuple): Positional arguments to be passed to the `image_edit` function. - `kwargs` (dict): Keyword arguments to be passed to the `image_edit` function. Returns: - `response` (Any): The response returned by the `image_edit` function. """ local_vars = locals() try: loop = asyncio.get_event_loop() kwargs["async_call"] = True # get custom llm provider so we can use this for mapping exceptions if custom_llm_provider is None: _, custom_llm_provider, _, _ = litellm.get_llm_provider( model=model, api_base=local_vars.get("base_url", None) ) func = partial( image_edit, image=image, prompt=prompt, mask=mask, model=model, n=n, quality=quality, response_format=response_format, size=size, user=user, timeout=timeout, custom_llm_provider=custom_llm_provider, **kwargs, ) ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) init_response = await loop.run_in_executor(None, func_with_context) if asyncio.iscoroutine(init_response): response = await init_response else: response = init_response return response except Exception as e: raise litellm.exception_type( model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=local_vars, extra_kwargs=kwargs, )