TestLLM / litellm /llms /bedrock /base_aws_llm.py
Raju2024's picture
Upload 1072 files
e3278e4 verified
import hashlib
import json
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import httpx
from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache
from litellm.secret_managers.main import get_secret, get_secret_str
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
from botocore.credentials import Credentials
else:
Credentials = Any
AWSPreparedRequest = Any
class Boto3CredentialsInfo(BaseModel):
credentials: Credentials
aws_region_name: str
aws_bedrock_runtime_endpoint: Optional[str]
class AwsAuthError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class BaseAWSLLM:
def __init__(self) -> None:
self.iam_cache = DualCache()
super().__init__()
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:
"""
Generate a unique cache key based on the credential arguments.
"""
# Convert credential arguments to a JSON string and hash it to create a unique key
credential_str = json.dumps(credential_args, sort_keys=True)
return hashlib.sha256(credential_str.encode()).hexdigest()
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
aws_sts_endpoint: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
## CHECK IS 'os.environ/' passed in
param_names = [
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_session_name",
"aws_profile_name",
"aws_role_name",
"aws_web_identity_token",
"aws_sts_endpoint",
]
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
elif param is None: # check if uppercase value in env
key = param_names[i]
if key.upper() in os.environ:
params_to_check[i] = os.getenv(key)
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
) = params_to_check
verbose_logger.debug(
"in get credentials\n"
"aws_access_key_id=%s\n"
"aws_secret_access_key=%s\n"
"aws_session_token=%s\n"
"aws_region_name=%s\n"
"aws_session_name=%s\n"
"aws_profile_name=%s\n"
"aws_role_name=%s\n"
"aws_web_identity_token=%s\n"
"aws_sts_endpoint=%s",
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
aws_sts_endpoint,
)
# create cache key for non-expiring auth flows
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
cache_key = self.get_cache_key(args)
_cached_credentials = self.iam_cache.get_cache(cache_key)
if _cached_credentials:
return _cached_credentials
#########################################################
# Handle diff boto3 auth flows
# for each helper
# Return:
# Credentials - boto3.Credentials
# cache ttl - Optional[int]. If None, the credentials are not cached. Some auth flows have no expiry time.
#########################################################
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
credentials, _cache_ttl = self._auth_with_web_identity_token(
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
aws_region_name=aws_region_name,
aws_sts_endpoint=aws_sts_endpoint,
)
elif aws_role_name is not None and aws_session_name is not None:
credentials, _cache_ttl = self._auth_with_aws_role(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
)
elif aws_profile_name is not None: ### CHECK SESSION ###
credentials, _cache_ttl = self._auth_with_aws_profile(aws_profile_name)
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_session_token is not None
):
credentials, _cache_ttl = self._auth_with_aws_session_token(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_region_name is not None
):
credentials, _cache_ttl = self._auth_with_access_key_and_secret_key(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
else:
credentials, _cache_ttl = self._auth_with_env_vars()
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
return credentials
def _auth_with_web_identity_token(
self,
aws_web_identity_token: str,
aws_role_name: str,
aws_session_name: str,
aws_region_name: Optional[str],
aws_sts_endpoint: Optional[str],
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Web Identity Token
"""
import boto3
verbose_logger.debug(
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
)
if aws_sts_endpoint is None:
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
else:
sts_endpoint = aws_sts_endpoint
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise AwsAuthError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=sts_endpoint,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
Policy='{"Version":"2012-10-17","Statement":[{"Sid":"BedrockLiteLLM","Effect":"Allow","Action":["bedrock:InvokeModel","bedrock:InvokeModelWithResponseStream"],"Resource":"*","Condition":{"Bool":{"aws:SecureTransport":"true"},"StringLike":{"aws:UserAgent":"litellm/*"}}}]}',
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
if sts_response["PackedPolicySize"] > 75:
verbose_logger.warning(
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
)
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds, self._get_default_ttl_for_boto3_credentials()
def _auth_with_aws_role(
self,
aws_access_key_id: Optional[str],
aws_secret_access_key: Optional[str],
aws_role_name: str,
aws_session_name: str,
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Role
"""
import boto3
from botocore.credentials import Credentials
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
sts_expiry = sts_credentials["Expiration"]
# Convert to timezone-aware datetime for comparison
current_time = datetime.now(sts_expiry.tzinfo)
sts_ttl = (sts_expiry - current_time).total_seconds() - 60
return credentials, sts_ttl
def _auth_with_aws_profile(
self, aws_profile_name: str
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS profile
"""
import boto3
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials(), None
def _auth_with_aws_session_token(
self,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_session_token: str,
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Session Token
"""
### CHECK FOR AWS SESSION TOKEN ###
from botocore.credentials import Credentials
credentials = Credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
return credentials, None
def _auth_with_access_key_and_secret_key(
self,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_region_name: Optional[str],
) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Access Key and Secret Key
"""
import boto3
# Check if credentials are already in cache. These credentials have no expiry time.
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
credentials = session.get_credentials()
return credentials, self._get_default_ttl_for_boto3_credentials()
def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]:
"""
Authenticate with AWS Environment Variables
"""
import boto3
session = boto3.Session()
credentials = session.get_credentials()
return credentials, None
def _get_default_ttl_for_boto3_credentials(self) -> int:
"""
Get the default TTL for boto3 credentials
Returns `3600-60` which is 59 minutes
"""
return 3600 - 60
def get_runtime_endpoint(
self,
api_base: Optional[str],
aws_bedrock_runtime_endpoint: Optional[str],
aws_region_name: str,
) -> Tuple[str, str]:
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
# Determine proxy_endpoint_url
if env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = aws_bedrock_runtime_endpoint
else:
proxy_endpoint_url = endpoint_url
return endpoint_url, proxy_endpoint_url
def _get_boto_credentials_from_optional_params(
self, optional_params: dict
) -> Boto3CredentialsInfo:
"""
Get boto3 credentials from optional params
Args:
optional_params (dict): Optional parameters for the model call
Returns:
Credentials: Boto3 credentials object
"""
try:
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret_str("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
)
return Boto3CredentialsInfo(
credentials=credentials,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
)
def get_request_headers(
self,
credentials: Credentials,
aws_region_name: str,
extra_headers: Optional[dict],
endpoint_url: str,
data: str,
headers: dict,
) -> AWSPreparedRequest:
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
return prepped