File size: 6,307 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import base64
import time
from io import BytesIO
from typing import Any, List, Mapping, Optional, Tuple, Union
from aiohttp import ClientResponse
from httpx import Headers, Response
from litellm.llms.base_llm.chat.transformation import (
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIImageVariationOptionalParams,
)
from litellm.types.utils import (
FileTypes,
HttpHandlerRequestFields,
ImageObject,
ImageResponse,
)
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ..common_utils import TopazException
class TopazImageVariationConfig(BaseImageVariationConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageVariationOptionalParams]:
return ["response_format", "size"]
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
)
return {
# "Content-Type": "multipart/form-data",
"Accept": "image/jpeg",
"X-API-Key": api_key,
}
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
api_base = api_base or "https://api.topazlabs.com"
return f"{api_base}/image/v1/enhance"
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "response_format":
optional_params["output_format"] = v
elif k == "size":
split_v = v.split("x")
assert len(split_v) == 2, "size must be in the format of widthxheight"
optional_params["output_width"] = split_v[0]
optional_params["output_height"] = split_v[1]
return optional_params
def prepare_file_tuple(
self,
file_data: FileTypes,
) -> Tuple[str, Optional[FileTypes], str, Mapping[str, str]]:
"""
Convert various file input formats to a consistent tuple format for HTTPX
Returns: (filename, file_content, content_type, headers)
"""
# Default values
filename = "image.png"
content: Optional[FileTypes] = None
content_type = "image/png"
headers: Mapping[str, str] = {}
if isinstance(file_data, (bytes, BytesIO)):
# Case 1: Just file content
content = file_data
elif isinstance(file_data, tuple):
if len(file_data) == 2:
# Case 2: (filename, content)
filename = file_data[0] or filename
content = file_data[1]
elif len(file_data) == 3:
# Case 3: (filename, content, content_type)
filename = file_data[0] or filename
content = file_data[1]
content_type = file_data[2] or content_type
elif len(file_data) == 4:
# Case 4: (filename, content, content_type, headers)
filename = file_data[0] or filename
content = file_data[1]
content_type = file_data[2] or content_type
headers = file_data[3]
return (filename, content, content_type, headers)
def transform_request_image_variation(
self,
model: Optional[str],
image: FileTypes,
optional_params: dict,
headers: dict,
) -> HttpHandlerRequestFields:
request_params = HttpHandlerRequestFields(
files={"image": self.prepare_file_tuple(image)},
data=optional_params,
)
return request_params
def _common_transform_response_image_variation(
self,
image_content: bytes,
response_ms: float,
) -> ImageResponse:
# Convert to base64
base64_image = base64.b64encode(image_content).decode("utf-8")
return ImageResponse(
created=int(time.time()),
data=[
ImageObject(
b64_json=base64_image,
url=None,
revised_prompt=None,
)
],
response_ms=response_ms,
)
async def async_transform_response_image_variation(
self,
model: Optional[str],
raw_response: ClientResponse,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
image: FileTypes,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
) -> ImageResponse:
image_content = await raw_response.read()
response_ms = logging_obj.get_response_ms()
return self._common_transform_response_image_variation(
image_content, response_ms
)
def transform_response_image_variation(
self,
model: Optional[str],
raw_response: Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
image: FileTypes,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
) -> ImageResponse:
image_content = raw_response.content
response_ms = (
raw_response.elapsed.total_seconds() * 1000
) # Convert to milliseconds
return self._common_transform_response_image_variation(
image_content, response_ms
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return TopazException(
status_code=status_code,
message=error_message,
headers=headers,
)
|