DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
import asyncio
import contextvars
from functools import partial
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
import httpx
import litellm
from litellm import Logging, client, exception_type, get_litellm_params
from litellm.constants import DEFAULT_IMAGE_ENDPOINT_MODEL
from litellm.constants import request_timeout as DEFAULT_REQUEST_TIMEOUT
from litellm.exceptions import LiteLLMUnknownProvider
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.mock_functions import mock_image_generation
from litellm.llms.base_llm import BaseImageEditConfig, BaseImageGenerationConfig
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_llm import CustomLLM
#################### Initialize provider clients ####################
from litellm.main import (
azure_chat_completions,
base_llm_aiohttp_handler,
base_llm_http_handler,
bedrock_image_generation,
openai_chat_completions,
openai_image_variations,
vertex_image_generation,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.images.main import ImageEditOptionalRequestParams
from litellm.types.llms.openai import ImageGenerationRequestQuality
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import (
LITELLM_IMAGE_VARIATION_PROVIDERS,
FileTypes,
LlmProviders,
all_litellm_params,
)
from litellm.utils import (
ImageResponse,
ProviderConfigManager,
get_llm_provider,
get_optional_params_image_gen,
)
from .utils import ImageEditRequestUtils
##### Image Generation #######################
@client
async def aimage_generation(*args, **kwargs) -> ImageResponse:
"""
Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.
Parameters:
- `args` (tuple): Positional arguments to be passed to the `image_generation` function.
- `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function.
Returns:
- `response` (Any): The response returned by the `image_generation` function.
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO Image Generation ###
kwargs["aimg_generation"] = True
custom_llm_provider = None
try:
# Use a partial function to pass your keyword arguments
func = partial(image_generation, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, ImageResponse
): ## CACHING SCENARIO
if isinstance(init_response, dict):
init_response = ImageResponse(**init_response)
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@client
def image_generation( # noqa: PLR0915
prompt: str,
model: Optional[str] = None,
n: Optional[int] = None,
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
response_format: Optional[str] = None,
size: Optional[str] = None,
style: Optional[str] = None,
user: Optional[str] = None,
timeout=600, # default to 10 minutes
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
custom_llm_provider=None,
**kwargs,
) -> ImageResponse:
"""
Maps the https://api.openai.com/v1/images/generations endpoint.
Currently supports just Azure + OpenAI.
"""
try:
args = locals()
aimg_generation = kwargs.get("aimg_generation", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
proxy_server_request = kwargs.get("proxy_server_request", None)
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
client = kwargs.get("client", None)
extra_headers = kwargs.get("extra_headers", None)
headers: dict = kwargs.get("headers", None) or {}
base_model = kwargs.get("base_model", None)
if extra_headers is not None:
headers.update(extra_headers)
model_response: ImageResponse = litellm.utils.ImageResponse()
dynamic_api_key: Optional[str] = None
if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, # type: ignore
custom_llm_provider=custom_llm_provider,
api_base=api_base,
)
else:
model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai
model_response._hidden_params["model"] = model
openai_params = [
"user",
"request_timeout",
"api_base",
"api_version",
"api_key",
"deployment_id",
"organization",
"base_url",
"default_headers",
"timeout",
"max_retries",
"n",
"quality",
"size",
"style",
]
litellm_params = all_litellm_params
default_params = openai_params + litellm_params
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
image_generation_config: Optional[BaseImageGenerationConfig] = None
if (
custom_llm_provider is not None
and custom_llm_provider in LlmProviders._member_map_.values()
):
image_generation_config = (
ProviderConfigManager.get_provider_image_generation_config(
model=base_model or model,
provider=LlmProviders(custom_llm_provider),
)
)
optional_params = get_optional_params_image_gen(
model=base_model or model,
n=n,
quality=quality,
response_format=response_format,
size=size,
style=style,
user=user,
custom_llm_provider=custom_llm_provider,
provider_config=image_generation_config,
**non_default_params,
)
litellm_params_dict = get_litellm_params(**kwargs)
logging: Logging = litellm_logging_obj
logging.update_environment_variables(
model=model,
user=user,
optional_params=optional_params,
litellm_params={
"timeout": timeout,
"azure": False,
"litellm_call_id": litellm_call_id,
"logger_fn": logger_fn,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
},
custom_llm_provider=custom_llm_provider,
)
if "custom_llm_provider" not in logging.model_call_details:
logging.model_call_details["custom_llm_provider"] = custom_llm_provider
if mock_response is not None:
return mock_image_generation(model=model, mock_response=mock_response)
if custom_llm_provider == "azure":
# azure configs
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
api_version = (
api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
azure_ad_token = optional_params.pop(
"azure_ad_token", None
) or get_secret_str("AZURE_AD_TOKEN")
default_headers = {
"Content-Type": "application/json;",
"api-key": api_key,
}
for k, v in default_headers.items():
if k not in headers:
headers[k] = v
model_response = azure_chat_completions.image_generation(
model=model,
prompt=prompt,
timeout=timeout,
api_key=api_key,
api_base=api_base,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
api_version=api_version,
aimg_generation=aimg_generation,
client=client,
headers=headers,
litellm_params=litellm_params_dict,
)
elif (
custom_llm_provider == "openai"
or custom_llm_provider in litellm.openai_compatible_providers
):
model_response = openai_chat_completions.image_generation(
model=model,
prompt=prompt,
timeout=timeout,
api_key=api_key or dynamic_api_key,
api_base=api_base,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
aimg_generation=aimg_generation,
client=client,
)
elif custom_llm_provider == "bedrock":
if model is None:
raise Exception("Model needs to be set for bedrock")
model_response = bedrock_image_generation.image_generation( # type: ignore
model=model,
prompt=prompt,
timeout=timeout,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
aimg_generation=aimg_generation,
client=client,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret_str("VERTEXAI_CREDENTIALS")
)
api_base = (
api_base
or litellm.api_base
or get_secret_str("VERTEXAI_API_BASE")
or get_secret_str("VERTEX_API_BASE")
)
model_response = vertex_image_generation.image_generation(
model=model,
prompt=prompt,
timeout=timeout,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aimg_generation=aimg_generation,
api_base=api_base,
client=client,
)
elif (
custom_llm_provider in litellm._custom_providers
): # Assume custom LLM provider
# Get the Custom Handler
custom_handler: Optional[CustomLLM] = None
for item in litellm.custom_provider_map:
if item["provider"] == custom_llm_provider:
custom_handler = item["custom_handler"]
if custom_handler is None:
raise LiteLLMUnknownProvider(
model=model, custom_llm_provider=custom_llm_provider
)
## ROUTE LLM CALL ##
if aimg_generation is True:
async_custom_client: Optional[AsyncHTTPHandler] = None
if client is not None and isinstance(client, AsyncHTTPHandler):
async_custom_client = client
## CALL FUNCTION
model_response = custom_handler.aimage_generation( # type: ignore
model=model,
prompt=prompt,
api_key=api_key,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
logging_obj=litellm_logging_obj,
timeout=timeout,
client=async_custom_client,
)
else:
custom_client: Optional[HTTPHandler] = None
if client is not None and isinstance(client, HTTPHandler):
custom_client = client
## CALL FUNCTION
model_response = custom_handler.image_generation(
model=model,
prompt=prompt,
api_key=api_key,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
logging_obj=litellm_logging_obj,
timeout=timeout,
client=custom_client,
)
return model_response
except Exception as e:
## Map to OpenAI Exception
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=locals(),
extra_kwargs=kwargs,
)
@client
async def aimage_variation(*args, **kwargs) -> ImageResponse:
"""
Asynchronously calls the `image_variation` function with the given arguments and keyword arguments.
Parameters:
- `args` (tuple): Positional arguments to be passed to the `image_variation` function.
- `kwargs` (dict): Keyword arguments to be passed to the `image_variation` function.
Returns:
- `response` (Any): The response returned by the `image_variation` function.
"""
loop = asyncio.get_event_loop()
model = kwargs.get("model", None)
custom_llm_provider = kwargs.get("custom_llm_provider", None)
### PASS ARGS TO Image Generation ###
kwargs["async_call"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(image_variation, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
if custom_llm_provider is None and model is not None:
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, ImageResponse
): ## CACHING SCENARIO
if isinstance(init_response, dict):
init_response = ImageResponse(**init_response)
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@client
def image_variation(
image: FileTypes,
model: str = "dall-e-2", # set to dall-e-2 by default - like OpenAI.
n: int = 1,
response_format: Literal["url", "b64_json"] = "url",
size: Optional[str] = None,
user: Optional[str] = None,
**kwargs,
) -> ImageResponse:
# get non-default params
client = kwargs.get("client", None)
# get logging object
litellm_logging_obj = cast(LiteLLMLoggingObj, kwargs.get("litellm_logging_obj"))
# get the litellm params
litellm_params = get_litellm_params(**kwargs)
# get the custom llm provider
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model,
custom_llm_provider=litellm_params.get("custom_llm_provider", None),
api_base=litellm_params.get("api_base", None),
api_key=litellm_params.get("api_key", None),
)
# route to the correct provider w/ the params
try:
llm_provider = LlmProviders(custom_llm_provider)
image_variation_provider = LITELLM_IMAGE_VARIATION_PROVIDERS(llm_provider)
except ValueError:
raise ValueError(
f"Invalid image variation provider: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
)
model_response = ImageResponse()
response: Optional[ImageResponse] = None
provider_config = ProviderConfigManager.get_provider_model_info(
model=model or "", # openai defaults to dall-e-2
provider=llm_provider,
)
if provider_config is None:
raise ValueError(
f"image variation provider has no known model info config - required for getting api keys, etc.: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
)
api_key = provider_config.get_api_key(litellm_params.get("api_key", None))
api_base = provider_config.get_api_base(litellm_params.get("api_base", None))
if image_variation_provider == LITELLM_IMAGE_VARIATION_PROVIDERS.OPENAI:
if api_key is None:
raise ValueError("API key is required for OpenAI image variations")
if api_base is None:
raise ValueError("API base is required for OpenAI image variations")
response = openai_image_variations.image_variations(
model_response=model_response,
api_key=api_key,
api_base=api_base,
model=model,
image=image,
timeout=litellm_params.get("timeout", None),
custom_llm_provider=custom_llm_provider,
logging_obj=litellm_logging_obj,
optional_params={},
litellm_params=litellm_params,
)
elif image_variation_provider == LITELLM_IMAGE_VARIATION_PROVIDERS.TOPAZ:
if api_key is None:
raise ValueError("API key is required for Topaz image variations")
if api_base is None:
raise ValueError("API base is required for Topaz image variations")
response = base_llm_aiohttp_handler.image_variations(
model_response=model_response,
api_key=api_key,
api_base=api_base,
model=model,
image=image,
timeout=litellm_params.get("timeout", None) or DEFAULT_REQUEST_TIMEOUT,
custom_llm_provider=custom_llm_provider,
logging_obj=litellm_logging_obj,
optional_params={},
litellm_params=litellm_params,
client=client,
)
# return the response
if response is None:
raise ValueError(
f"Invalid image variation provider: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
)
return response
@client
def image_edit(
image: FileTypes,
prompt: str,
model: Optional[str] = None,
mask: Optional[str] = None,
n: Optional[int] = None,
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
response_format: Optional[str] = None,
size: Optional[str] = None,
user: Optional[str] = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse]]:
"""
Maps the image edit functionality, similar to OpenAI's images/edits endpoint.
"""
local_vars = locals()
try:
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
_is_async = kwargs.pop("async_call", False) is True
# get llm provider logic
litellm_params = GenericLiteLLMParams(**kwargs)
model, custom_llm_provider, _, _ = get_llm_provider(
model=model or DEFAULT_IMAGE_ENDPOINT_MODEL,
custom_llm_provider=custom_llm_provider,
)
# get provider config
image_edit_provider_config: Optional[
BaseImageEditConfig
] = ProviderConfigManager.get_provider_image_edit_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
if image_edit_provider_config is None:
raise ValueError(f"image edit is not supported for {custom_llm_provider}")
local_vars.update(kwargs)
# Get ImageEditOptionalRequestParams with only valid parameters
image_edit_optional_params: ImageEditOptionalRequestParams = (
ImageEditRequestUtils.get_requested_image_edit_optional_param(local_vars)
)
# Get optional parameters for the responses API
image_edit_request_params: Dict = (
ImageEditRequestUtils.get_optional_params_image_edit(
model=model,
image_edit_provider_config=image_edit_provider_config,
image_edit_optional_params=image_edit_optional_params,
)
)
# Pre Call logging
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params=dict(image_edit_request_params),
litellm_params={
"litellm_call_id": litellm_call_id,
**image_edit_request_params,
},
custom_llm_provider=custom_llm_provider,
)
# Call the handler with _is_async flag instead of directly calling the async handler
return base_llm_http_handler.image_edit_handler(
model=model,
image=image,
prompt=prompt,
image_edit_provider_config=image_edit_provider_config,
image_edit_optional_request_params=image_edit_request_params,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
logging_obj=litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or DEFAULT_REQUEST_TIMEOUT,
_is_async=_is_async,
client=kwargs.get("client"),
)
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
async def aimage_edit(
image: FileTypes,
model: str,
prompt: str,
mask: Optional[str] = None,
n: Optional[int] = None,
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
response_format: Optional[str] = None,
size: Optional[str] = None,
user: Optional[str] = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> ImageResponse:
"""
Asynchronously calls the `image_edit` function with the given arguments and keyword arguments.
Parameters:
- `args` (tuple): Positional arguments to be passed to the `image_edit` function.
- `kwargs` (dict): Keyword arguments to be passed to the `image_edit` function.
Returns:
- `response` (Any): The response returned by the `image_edit` function.
"""
local_vars = locals()
try:
loop = asyncio.get_event_loop()
kwargs["async_call"] = True
# get custom llm provider so we can use this for mapping exceptions
if custom_llm_provider is None:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, api_base=local_vars.get("base_url", None)
)
func = partial(
image_edit,
image=image,
prompt=prompt,
mask=mask,
model=model,
n=n,
quality=quality,
response_format=response_format,
size=size,
user=user,
timeout=timeout,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)