|
import json |
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
_get_httpx_client, |
|
get_async_httpx_client, |
|
) |
|
from litellm.types.llms.bedrock import BedrockPreparedRequest |
|
from litellm.types.rerank import RerankRequest |
|
from litellm.types.utils import RerankResponse |
|
|
|
from ..base_aws_llm import BaseAWSLLM |
|
from ..common_utils import BedrockError |
|
from .transformation import BedrockRerankConfig |
|
|
|
if TYPE_CHECKING: |
|
from botocore.awsrequest import AWSPreparedRequest |
|
else: |
|
AWSPreparedRequest = Any |
|
|
|
|
|
class BedrockRerankHandler(BaseAWSLLM): |
|
async def arerank( |
|
self, |
|
prepared_request: BedrockPreparedRequest, |
|
): |
|
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK) |
|
try: |
|
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as err: |
|
error_code = err.response.status_code |
|
raise BedrockError(status_code=error_code, message=err.response.text) |
|
except httpx.TimeoutException: |
|
raise BedrockError(status_code=408, message="Timeout error occurred.") |
|
|
|
return BedrockRerankConfig()._transform_response(response.json()) |
|
|
|
def rerank( |
|
self, |
|
model: str, |
|
query: str, |
|
documents: List[Union[str, Dict[str, Any]]], |
|
optional_params: dict, |
|
logging_obj: LitellmLogging, |
|
top_n: Optional[int] = None, |
|
rank_fields: Optional[List[str]] = None, |
|
return_documents: Optional[bool] = True, |
|
max_chunks_per_doc: Optional[int] = None, |
|
_is_async: Optional[bool] = False, |
|
api_base: Optional[str] = None, |
|
extra_headers: Optional[dict] = None, |
|
) -> RerankResponse: |
|
request_data = RerankRequest( |
|
model=model, |
|
query=query, |
|
documents=documents, |
|
top_n=top_n, |
|
rank_fields=rank_fields, |
|
return_documents=return_documents, |
|
) |
|
data = BedrockRerankConfig()._transform_request(request_data) |
|
|
|
prepared_request = self._prepare_request( |
|
optional_params=optional_params, |
|
api_base=api_base, |
|
extra_headers=extra_headers, |
|
data=cast(dict, data), |
|
) |
|
|
|
logging_obj.pre_call( |
|
input=data, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": data, |
|
"api_base": prepared_request["endpoint_url"], |
|
"headers": prepared_request["prepped"].headers, |
|
}, |
|
) |
|
|
|
if _is_async: |
|
return self.arerank(prepared_request) |
|
|
|
client = _get_httpx_client() |
|
try: |
|
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as err: |
|
error_code = err.response.status_code |
|
raise BedrockError(status_code=error_code, message=err.response.text) |
|
except httpx.TimeoutException: |
|
raise BedrockError(status_code=408, message="Timeout error occurred.") |
|
|
|
return BedrockRerankConfig()._transform_response(response.json()) |
|
|
|
def _prepare_request( |
|
self, |
|
api_base: Optional[str], |
|
extra_headers: Optional[dict], |
|
data: dict, |
|
optional_params: dict, |
|
) -> BedrockPreparedRequest: |
|
try: |
|
from botocore.auth import SigV4Auth |
|
from botocore.awsrequest import AWSRequest |
|
except ImportError: |
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") |
|
boto3_credentials_info = self._get_boto_credentials_from_optional_params( |
|
optional_params |
|
) |
|
|
|
|
|
_, proxy_endpoint_url = self.get_runtime_endpoint( |
|
api_base=api_base, |
|
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint, |
|
aws_region_name=boto3_credentials_info.aws_region_name, |
|
) |
|
proxy_endpoint_url = proxy_endpoint_url.replace( |
|
"bedrock-runtime", "bedrock-agent-runtime" |
|
) |
|
proxy_endpoint_url = f"{proxy_endpoint_url}/rerank" |
|
sigv4 = SigV4Auth( |
|
boto3_credentials_info.credentials, |
|
"bedrock", |
|
boto3_credentials_info.aws_region_name, |
|
) |
|
|
|
body = json.dumps(data).encode("utf-8") |
|
|
|
headers = {"Content-Type": "application/json"} |
|
if extra_headers is not None: |
|
headers = {"Content-Type": "application/json", **extra_headers} |
|
request = AWSRequest( |
|
method="POST", url=proxy_endpoint_url, data=body, headers=headers |
|
) |
|
sigv4.add_auth(request) |
|
if ( |
|
extra_headers is not None and "Authorization" in extra_headers |
|
): |
|
request.headers["Authorization"] = extra_headers["Authorization"] |
|
prepped = request.prepare() |
|
|
|
return BedrockPreparedRequest( |
|
endpoint_url=proxy_endpoint_url, |
|
prepped=prepped, |
|
body=body, |
|
data=data, |
|
) |
|
|