|
import base64 |
|
import time |
|
from io import BytesIO |
|
from typing import Any, List, Mapping, Optional, Tuple, Union |
|
|
|
from aiohttp import ClientResponse |
|
from httpx import Headers, Response |
|
|
|
from litellm.llms.base_llm.chat.transformation import ( |
|
BaseLLMException, |
|
LiteLLMLoggingObj, |
|
) |
|
from litellm.types.llms.openai import ( |
|
AllMessageValues, |
|
OpenAIImageVariationOptionalParams, |
|
) |
|
from litellm.types.utils import ( |
|
FileTypes, |
|
HttpHandlerRequestFields, |
|
ImageObject, |
|
ImageResponse, |
|
) |
|
|
|
from ...base_llm.image_variations.transformation import BaseImageVariationConfig |
|
from ..common_utils import TopazException |
|
|
|
|
|
class TopazImageVariationConfig(BaseImageVariationConfig): |
|
def get_supported_openai_params( |
|
self, model: str |
|
) -> List[OpenAIImageVariationOptionalParams]: |
|
return ["response_format", "size"] |
|
|
|
def validate_environment( |
|
self, |
|
headers: dict, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
) -> dict: |
|
if api_key is None: |
|
raise ValueError( |
|
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`" |
|
) |
|
return { |
|
|
|
"Accept": "image/jpeg", |
|
"X-API-Key": api_key, |
|
} |
|
|
|
def get_complete_url( |
|
self, |
|
api_base: Optional[str], |
|
model: str, |
|
optional_params: dict, |
|
stream: Optional[bool] = None, |
|
) -> str: |
|
api_base = api_base or "https://api.topazlabs.com" |
|
return f"{api_base}/image/v1/enhance" |
|
|
|
def map_openai_params( |
|
self, |
|
non_default_params: dict, |
|
optional_params: dict, |
|
model: str, |
|
drop_params: bool, |
|
) -> dict: |
|
for k, v in non_default_params.items(): |
|
if k == "response_format": |
|
optional_params["output_format"] = v |
|
elif k == "size": |
|
split_v = v.split("x") |
|
assert len(split_v) == 2, "size must be in the format of widthxheight" |
|
optional_params["output_width"] = split_v[0] |
|
optional_params["output_height"] = split_v[1] |
|
return optional_params |
|
|
|
def prepare_file_tuple( |
|
self, |
|
file_data: FileTypes, |
|
) -> Tuple[str, Optional[FileTypes], str, Mapping[str, str]]: |
|
""" |
|
Convert various file input formats to a consistent tuple format for HTTPX |
|
Returns: (filename, file_content, content_type, headers) |
|
""" |
|
|
|
filename = "image.png" |
|
content: Optional[FileTypes] = None |
|
content_type = "image/png" |
|
headers: Mapping[str, str] = {} |
|
|
|
if isinstance(file_data, (bytes, BytesIO)): |
|
|
|
content = file_data |
|
elif isinstance(file_data, tuple): |
|
if len(file_data) == 2: |
|
|
|
filename = file_data[0] or filename |
|
content = file_data[1] |
|
elif len(file_data) == 3: |
|
|
|
filename = file_data[0] or filename |
|
content = file_data[1] |
|
content_type = file_data[2] or content_type |
|
elif len(file_data) == 4: |
|
|
|
filename = file_data[0] or filename |
|
content = file_data[1] |
|
content_type = file_data[2] or content_type |
|
headers = file_data[3] |
|
|
|
return (filename, content, content_type, headers) |
|
|
|
def transform_request_image_variation( |
|
self, |
|
model: Optional[str], |
|
image: FileTypes, |
|
optional_params: dict, |
|
headers: dict, |
|
) -> HttpHandlerRequestFields: |
|
|
|
request_params = HttpHandlerRequestFields( |
|
files={"image": self.prepare_file_tuple(image)}, |
|
data=optional_params, |
|
) |
|
|
|
return request_params |
|
|
|
def _common_transform_response_image_variation( |
|
self, |
|
image_content: bytes, |
|
response_ms: float, |
|
) -> ImageResponse: |
|
|
|
|
|
base64_image = base64.b64encode(image_content).decode("utf-8") |
|
|
|
return ImageResponse( |
|
created=int(time.time()), |
|
data=[ |
|
ImageObject( |
|
b64_json=base64_image, |
|
url=None, |
|
revised_prompt=None, |
|
) |
|
], |
|
response_ms=response_ms, |
|
) |
|
|
|
async def async_transform_response_image_variation( |
|
self, |
|
model: Optional[str], |
|
raw_response: ClientResponse, |
|
model_response: ImageResponse, |
|
logging_obj: LiteLLMLoggingObj, |
|
request_data: dict, |
|
image: FileTypes, |
|
optional_params: dict, |
|
litellm_params: dict, |
|
encoding: Any, |
|
api_key: Optional[str] = None, |
|
) -> ImageResponse: |
|
image_content = await raw_response.read() |
|
|
|
response_ms = logging_obj.get_response_ms() |
|
|
|
return self._common_transform_response_image_variation( |
|
image_content, response_ms |
|
) |
|
|
|
def transform_response_image_variation( |
|
self, |
|
model: Optional[str], |
|
raw_response: Response, |
|
model_response: ImageResponse, |
|
logging_obj: LiteLLMLoggingObj, |
|
request_data: dict, |
|
image: FileTypes, |
|
optional_params: dict, |
|
litellm_params: dict, |
|
encoding: Any, |
|
api_key: Optional[str] = None, |
|
) -> ImageResponse: |
|
image_content = raw_response.content |
|
|
|
response_ms = ( |
|
raw_response.elapsed.total_seconds() * 1000 |
|
) |
|
|
|
return self._common_transform_response_image_variation( |
|
image_content, response_ms |
|
) |
|
|
|
def get_error_class( |
|
self, error_message: str, status_code: int, headers: Union[dict, Headers] |
|
) -> BaseLLMException: |
|
return TopazException( |
|
status_code=status_code, |
|
message=error_message, |
|
headers=headers, |
|
) |
|
|