Raju2024's picture
Upload 1072 files
e3278e4 verified
import base64
import time
from io import BytesIO
from typing import Any, List, Mapping, Optional, Tuple, Union
from aiohttp import ClientResponse
from httpx import Headers, Response
from litellm.llms.base_llm.chat.transformation import (
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIImageVariationOptionalParams,
)
from litellm.types.utils import (
FileTypes,
HttpHandlerRequestFields,
ImageObject,
ImageResponse,
)
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ..common_utils import TopazException
class TopazImageVariationConfig(BaseImageVariationConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageVariationOptionalParams]:
return ["response_format", "size"]
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
)
return {
# "Content-Type": "multipart/form-data",
"Accept": "image/jpeg",
"X-API-Key": api_key,
}
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
api_base = api_base or "https://api.topazlabs.com"
return f"{api_base}/image/v1/enhance"
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "response_format":
optional_params["output_format"] = v
elif k == "size":
split_v = v.split("x")
assert len(split_v) == 2, "size must be in the format of widthxheight"
optional_params["output_width"] = split_v[0]
optional_params["output_height"] = split_v[1]
return optional_params
def prepare_file_tuple(
self,
file_data: FileTypes,
) -> Tuple[str, Optional[FileTypes], str, Mapping[str, str]]:
"""
Convert various file input formats to a consistent tuple format for HTTPX
Returns: (filename, file_content, content_type, headers)
"""
# Default values
filename = "image.png"
content: Optional[FileTypes] = None
content_type = "image/png"
headers: Mapping[str, str] = {}
if isinstance(file_data, (bytes, BytesIO)):
# Case 1: Just file content
content = file_data
elif isinstance(file_data, tuple):
if len(file_data) == 2:
# Case 2: (filename, content)
filename = file_data[0] or filename
content = file_data[1]
elif len(file_data) == 3:
# Case 3: (filename, content, content_type)
filename = file_data[0] or filename
content = file_data[1]
content_type = file_data[2] or content_type
elif len(file_data) == 4:
# Case 4: (filename, content, content_type, headers)
filename = file_data[0] or filename
content = file_data[1]
content_type = file_data[2] or content_type
headers = file_data[3]
return (filename, content, content_type, headers)
def transform_request_image_variation(
self,
model: Optional[str],
image: FileTypes,
optional_params: dict,
headers: dict,
) -> HttpHandlerRequestFields:
request_params = HttpHandlerRequestFields(
files={"image": self.prepare_file_tuple(image)},
data=optional_params,
)
return request_params
def _common_transform_response_image_variation(
self,
image_content: bytes,
response_ms: float,
) -> ImageResponse:
# Convert to base64
base64_image = base64.b64encode(image_content).decode("utf-8")
return ImageResponse(
created=int(time.time()),
data=[
ImageObject(
b64_json=base64_image,
url=None,
revised_prompt=None,
)
],
response_ms=response_ms,
)
async def async_transform_response_image_variation(
self,
model: Optional[str],
raw_response: ClientResponse,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
image: FileTypes,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
) -> ImageResponse:
image_content = await raw_response.read()
response_ms = logging_obj.get_response_ms()
return self._common_transform_response_image_variation(
image_content, response_ms
)
def transform_response_image_variation(
self,
model: Optional[str],
raw_response: Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
image: FileTypes,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
) -> ImageResponse:
image_content = raw_response.content
response_ms = (
raw_response.elapsed.total_seconds() * 1000
) # Convert to milliseconds
return self._common_transform_response_image_variation(
image_content, response_ms
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return TopazException(
status_code=status_code,
message=error_message,
headers=headers,
)