|
import copy |
|
import json |
|
import os |
|
from typing import TYPE_CHECKING, Any, Optional, Union |
|
|
|
import httpx |
|
from pydantic import BaseModel |
|
|
|
import litellm |
|
from litellm._logging import verbose_logger |
|
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
_get_httpx_client, |
|
get_async_httpx_client, |
|
) |
|
from litellm.types.utils import ImageResponse |
|
|
|
from ..base_aws_llm import BaseAWSLLM |
|
from ..common_utils import BedrockError |
|
|
|
if TYPE_CHECKING: |
|
from botocore.awsrequest import AWSPreparedRequest |
|
else: |
|
AWSPreparedRequest = Any |
|
|
|
|
|
class BedrockImagePreparedRequest(BaseModel): |
|
""" |
|
Internal/Helper class for preparing the request for bedrock image generation |
|
""" |
|
|
|
endpoint_url: str |
|
prepped: AWSPreparedRequest |
|
body: bytes |
|
data: dict |
|
|
|
|
|
class BedrockImageGeneration(BaseAWSLLM): |
|
""" |
|
Bedrock Image Generation handler |
|
""" |
|
|
|
def image_generation( |
|
self, |
|
model: str, |
|
prompt: str, |
|
model_response: ImageResponse, |
|
optional_params: dict, |
|
logging_obj: LitellmLogging, |
|
timeout: Optional[Union[float, httpx.Timeout]], |
|
aimg_generation: bool = False, |
|
api_base: Optional[str] = None, |
|
extra_headers: Optional[dict] = None, |
|
): |
|
prepared_request = self._prepare_request( |
|
model=model, |
|
optional_params=optional_params, |
|
api_base=api_base, |
|
extra_headers=extra_headers, |
|
logging_obj=logging_obj, |
|
prompt=prompt, |
|
) |
|
|
|
if aimg_generation is True: |
|
return self.async_image_generation( |
|
prepared_request=prepared_request, |
|
timeout=timeout, |
|
model=model, |
|
logging_obj=logging_obj, |
|
prompt=prompt, |
|
model_response=model_response, |
|
) |
|
|
|
client = _get_httpx_client() |
|
try: |
|
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as err: |
|
error_code = err.response.status_code |
|
raise BedrockError(status_code=error_code, message=err.response.text) |
|
except httpx.TimeoutException: |
|
raise BedrockError(status_code=408, message="Timeout error occurred.") |
|
|
|
model_response = self._transform_response_dict_to_openai_response( |
|
model_response=model_response, |
|
model=model, |
|
logging_obj=logging_obj, |
|
prompt=prompt, |
|
response=response, |
|
data=prepared_request.data, |
|
) |
|
return model_response |
|
|
|
async def async_image_generation( |
|
self, |
|
prepared_request: BedrockImagePreparedRequest, |
|
timeout: Optional[Union[float, httpx.Timeout]], |
|
model: str, |
|
logging_obj: LitellmLogging, |
|
prompt: str, |
|
model_response: ImageResponse, |
|
) -> ImageResponse: |
|
""" |
|
Asynchronous handler for bedrock image generation |
|
|
|
Awaits the response from the bedrock image generation endpoint |
|
""" |
|
async_client = get_async_httpx_client( |
|
llm_provider=litellm.LlmProviders.BEDROCK, |
|
params={"timeout": timeout}, |
|
) |
|
|
|
try: |
|
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as err: |
|
error_code = err.response.status_code |
|
raise BedrockError(status_code=error_code, message=err.response.text) |
|
except httpx.TimeoutException: |
|
raise BedrockError(status_code=408, message="Timeout error occurred.") |
|
|
|
|
|
model_response = self._transform_response_dict_to_openai_response( |
|
model=model, |
|
logging_obj=logging_obj, |
|
prompt=prompt, |
|
response=response, |
|
data=prepared_request.data, |
|
model_response=model_response, |
|
) |
|
return model_response |
|
|
|
def _prepare_request( |
|
self, |
|
model: str, |
|
optional_params: dict, |
|
api_base: Optional[str], |
|
extra_headers: Optional[dict], |
|
logging_obj: LitellmLogging, |
|
prompt: str, |
|
) -> BedrockImagePreparedRequest: |
|
""" |
|
Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API |
|
|
|
Args: |
|
model (str): The model to use for the image generation |
|
optional_params (dict): The optional parameters for the image generation |
|
api_base (Optional[str]): The base URL for the Bedrock API |
|
extra_headers (Optional[dict]): The extra headers to include in the request |
|
logging_obj (LitellmLogging): The logging object to use for logging |
|
prompt (str): The prompt to use for the image generation |
|
Returns: |
|
BedrockImagePreparedRequest: The prepared request object |
|
|
|
The BedrockImagePreparedRequest contains: |
|
endpoint_url (str): The endpoint URL for the Bedrock Image Generation API |
|
prepped (httpx.Request): The prepared request object |
|
body (bytes): The request body |
|
""" |
|
try: |
|
from botocore.auth import SigV4Auth |
|
from botocore.awsrequest import AWSRequest |
|
except ImportError: |
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") |
|
boto3_credentials_info = self._get_boto_credentials_from_optional_params( |
|
optional_params |
|
) |
|
|
|
|
|
modelId = model |
|
_, proxy_endpoint_url = self.get_runtime_endpoint( |
|
api_base=api_base, |
|
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint, |
|
aws_region_name=boto3_credentials_info.aws_region_name, |
|
) |
|
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke" |
|
sigv4 = SigV4Auth( |
|
boto3_credentials_info.credentials, |
|
"bedrock", |
|
boto3_credentials_info.aws_region_name, |
|
) |
|
|
|
data = self._get_request_body( |
|
model=model, prompt=prompt, optional_params=optional_params |
|
) |
|
|
|
|
|
body = json.dumps(data).encode("utf-8") |
|
|
|
headers = {"Content-Type": "application/json"} |
|
if extra_headers is not None: |
|
headers = {"Content-Type": "application/json", **extra_headers} |
|
request = AWSRequest( |
|
method="POST", url=proxy_endpoint_url, data=body, headers=headers |
|
) |
|
sigv4.add_auth(request) |
|
if ( |
|
extra_headers is not None and "Authorization" in extra_headers |
|
): |
|
request.headers["Authorization"] = extra_headers["Authorization"] |
|
prepped = request.prepare() |
|
|
|
|
|
logging_obj.pre_call( |
|
input=prompt, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": data, |
|
"api_base": proxy_endpoint_url, |
|
"headers": prepped.headers, |
|
}, |
|
) |
|
return BedrockImagePreparedRequest( |
|
endpoint_url=proxy_endpoint_url, |
|
prepped=prepped, |
|
body=body, |
|
data=data, |
|
) |
|
|
|
def _get_request_body( |
|
self, |
|
model: str, |
|
prompt: str, |
|
optional_params: dict, |
|
) -> dict: |
|
""" |
|
Get the request body for the Bedrock Image Generation API |
|
|
|
Checks the model/provider and transforms the request body accordingly |
|
|
|
Returns: |
|
dict: The request body to use for the Bedrock Image Generation API |
|
""" |
|
provider = model.split(".")[0] |
|
inference_params = copy.deepcopy(optional_params) |
|
inference_params.pop( |
|
"user", None |
|
) |
|
data = {} |
|
if provider == "stability": |
|
if litellm.AmazonStability3Config._is_stability_3_model(model): |
|
request_body = litellm.AmazonStability3Config.transform_request_body( |
|
prompt=prompt, optional_params=optional_params |
|
) |
|
return dict(request_body) |
|
else: |
|
prompt = prompt.replace(os.linesep, " ") |
|
|
|
config = litellm.AmazonStabilityConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
data = { |
|
"text_prompts": [{"text": prompt, "weight": 1}], |
|
**inference_params, |
|
} |
|
else: |
|
raise BedrockError( |
|
status_code=422, message=f"Unsupported model={model}, passed in" |
|
) |
|
return data |
|
|
|
def _transform_response_dict_to_openai_response( |
|
self, |
|
model_response: ImageResponse, |
|
model: str, |
|
logging_obj: LitellmLogging, |
|
prompt: str, |
|
response: httpx.Response, |
|
data: dict, |
|
) -> ImageResponse: |
|
""" |
|
Transforms the Image Generation response from Bedrock to OpenAI format |
|
""" |
|
|
|
|
|
if logging_obj is not None: |
|
logging_obj.post_call( |
|
input=prompt, |
|
api_key="", |
|
original_response=response.text, |
|
additional_args={"complete_input_dict": data}, |
|
) |
|
verbose_logger.debug("raw model_response: %s", response.text) |
|
response_dict = response.json() |
|
if response_dict is None: |
|
raise ValueError("Error in response object format, got None") |
|
|
|
config_class = ( |
|
litellm.AmazonStability3Config |
|
if litellm.AmazonStability3Config._is_stability_3_model(model=model) |
|
else litellm.AmazonStabilityConfig |
|
) |
|
config_class.transform_response_dict_to_openai_response( |
|
model_response=model_response, |
|
response_dict=response_dict, |
|
) |
|
|
|
return model_response |
|
|