|
import hashlib |
|
import json |
|
import os |
|
from datetime import datetime |
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple |
|
|
|
import httpx |
|
from pydantic import BaseModel |
|
|
|
from litellm._logging import verbose_logger |
|
from litellm.caching.caching import DualCache |
|
from litellm.secret_managers.main import get_secret, get_secret_str |
|
|
|
if TYPE_CHECKING: |
|
from botocore.awsrequest import AWSPreparedRequest |
|
from botocore.credentials import Credentials |
|
else: |
|
Credentials = Any |
|
AWSPreparedRequest = Any |
|
|
|
|
|
class Boto3CredentialsInfo(BaseModel): |
|
credentials: Credentials |
|
aws_region_name: str |
|
aws_bedrock_runtime_endpoint: Optional[str] |
|
|
|
|
|
class AwsAuthError(Exception): |
|
def __init__(self, status_code, message): |
|
self.status_code = status_code |
|
self.message = message |
|
self.request = httpx.Request( |
|
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock" |
|
) |
|
self.response = httpx.Response(status_code=status_code, request=self.request) |
|
super().__init__( |
|
self.message |
|
) |
|
|
|
|
|
class BaseAWSLLM: |
|
def __init__(self) -> None: |
|
self.iam_cache = DualCache() |
|
super().__init__() |
|
|
|
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str: |
|
""" |
|
Generate a unique cache key based on the credential arguments. |
|
""" |
|
|
|
credential_str = json.dumps(credential_args, sort_keys=True) |
|
return hashlib.sha256(credential_str.encode()).hexdigest() |
|
|
|
def get_credentials( |
|
self, |
|
aws_access_key_id: Optional[str] = None, |
|
aws_secret_access_key: Optional[str] = None, |
|
aws_session_token: Optional[str] = None, |
|
aws_region_name: Optional[str] = None, |
|
aws_session_name: Optional[str] = None, |
|
aws_profile_name: Optional[str] = None, |
|
aws_role_name: Optional[str] = None, |
|
aws_web_identity_token: Optional[str] = None, |
|
aws_sts_endpoint: Optional[str] = None, |
|
): |
|
""" |
|
Return a boto3.Credentials object |
|
""" |
|
|
|
param_names = [ |
|
"aws_access_key_id", |
|
"aws_secret_access_key", |
|
"aws_session_token", |
|
"aws_region_name", |
|
"aws_session_name", |
|
"aws_profile_name", |
|
"aws_role_name", |
|
"aws_web_identity_token", |
|
"aws_sts_endpoint", |
|
] |
|
params_to_check: List[Optional[str]] = [ |
|
aws_access_key_id, |
|
aws_secret_access_key, |
|
aws_session_token, |
|
aws_region_name, |
|
aws_session_name, |
|
aws_profile_name, |
|
aws_role_name, |
|
aws_web_identity_token, |
|
aws_sts_endpoint, |
|
] |
|
|
|
|
|
for i, param in enumerate(params_to_check): |
|
if param and param.startswith("os.environ/"): |
|
_v = get_secret(param) |
|
if _v is not None and isinstance(_v, str): |
|
params_to_check[i] = _v |
|
elif param is None: |
|
key = param_names[i] |
|
if key.upper() in os.environ: |
|
params_to_check[i] = os.getenv(key) |
|
|
|
|
|
( |
|
aws_access_key_id, |
|
aws_secret_access_key, |
|
aws_session_token, |
|
aws_region_name, |
|
aws_session_name, |
|
aws_profile_name, |
|
aws_role_name, |
|
aws_web_identity_token, |
|
aws_sts_endpoint, |
|
) = params_to_check |
|
|
|
verbose_logger.debug( |
|
"in get credentials\n" |
|
"aws_access_key_id=%s\n" |
|
"aws_secret_access_key=%s\n" |
|
"aws_session_token=%s\n" |
|
"aws_region_name=%s\n" |
|
"aws_session_name=%s\n" |
|
"aws_profile_name=%s\n" |
|
"aws_role_name=%s\n" |
|
"aws_web_identity_token=%s\n" |
|
"aws_sts_endpoint=%s", |
|
aws_access_key_id, |
|
aws_secret_access_key, |
|
aws_session_token, |
|
aws_region_name, |
|
aws_session_name, |
|
aws_profile_name, |
|
aws_role_name, |
|
aws_web_identity_token, |
|
aws_sts_endpoint, |
|
) |
|
|
|
|
|
args = {k: v for k, v in locals().items() if k.startswith("aws_")} |
|
|
|
cache_key = self.get_cache_key(args) |
|
_cached_credentials = self.iam_cache.get_cache(cache_key) |
|
if _cached_credentials: |
|
return _cached_credentials |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
aws_web_identity_token is not None |
|
and aws_role_name is not None |
|
and aws_session_name is not None |
|
): |
|
credentials, _cache_ttl = self._auth_with_web_identity_token( |
|
aws_web_identity_token=aws_web_identity_token, |
|
aws_role_name=aws_role_name, |
|
aws_session_name=aws_session_name, |
|
aws_region_name=aws_region_name, |
|
aws_sts_endpoint=aws_sts_endpoint, |
|
) |
|
elif aws_role_name is not None and aws_session_name is not None: |
|
credentials, _cache_ttl = self._auth_with_aws_role( |
|
aws_access_key_id=aws_access_key_id, |
|
aws_secret_access_key=aws_secret_access_key, |
|
aws_role_name=aws_role_name, |
|
aws_session_name=aws_session_name, |
|
) |
|
|
|
elif aws_profile_name is not None: |
|
credentials, _cache_ttl = self._auth_with_aws_profile(aws_profile_name) |
|
elif ( |
|
aws_access_key_id is not None |
|
and aws_secret_access_key is not None |
|
and aws_session_token is not None |
|
): |
|
credentials, _cache_ttl = self._auth_with_aws_session_token( |
|
aws_access_key_id=aws_access_key_id, |
|
aws_secret_access_key=aws_secret_access_key, |
|
aws_session_token=aws_session_token, |
|
) |
|
elif ( |
|
aws_access_key_id is not None |
|
and aws_secret_access_key is not None |
|
and aws_region_name is not None |
|
): |
|
credentials, _cache_ttl = self._auth_with_access_key_and_secret_key( |
|
aws_access_key_id=aws_access_key_id, |
|
aws_secret_access_key=aws_secret_access_key, |
|
aws_region_name=aws_region_name, |
|
) |
|
else: |
|
credentials, _cache_ttl = self._auth_with_env_vars() |
|
|
|
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl) |
|
return credentials |
|
|
|
def _auth_with_web_identity_token( |
|
self, |
|
aws_web_identity_token: str, |
|
aws_role_name: str, |
|
aws_session_name: str, |
|
aws_region_name: Optional[str], |
|
aws_sts_endpoint: Optional[str], |
|
) -> Tuple[Credentials, Optional[int]]: |
|
""" |
|
Authenticate with AWS Web Identity Token |
|
""" |
|
import boto3 |
|
|
|
verbose_logger.debug( |
|
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}" |
|
) |
|
|
|
if aws_sts_endpoint is None: |
|
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com" |
|
else: |
|
sts_endpoint = aws_sts_endpoint |
|
|
|
oidc_token = get_secret(aws_web_identity_token) |
|
|
|
if oidc_token is None: |
|
raise AwsAuthError( |
|
message="OIDC token could not be retrieved from secret manager.", |
|
status_code=401, |
|
) |
|
|
|
sts_client = boto3.client( |
|
"sts", |
|
region_name=aws_region_name, |
|
endpoint_url=sts_endpoint, |
|
) |
|
|
|
|
|
|
|
sts_response = sts_client.assume_role_with_web_identity( |
|
RoleArn=aws_role_name, |
|
RoleSessionName=aws_session_name, |
|
WebIdentityToken=oidc_token, |
|
DurationSeconds=3600, |
|
Policy='{"Version":"2012-10-17","Statement":[{"Sid":"BedrockLiteLLM","Effect":"Allow","Action":["bedrock:InvokeModel","bedrock:InvokeModelWithResponseStream"],"Resource":"*","Condition":{"Bool":{"aws:SecureTransport":"true"},"StringLike":{"aws:UserAgent":"litellm/*"}}}]}', |
|
) |
|
|
|
iam_creds_dict = { |
|
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"], |
|
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"], |
|
"aws_session_token": sts_response["Credentials"]["SessionToken"], |
|
"region_name": aws_region_name, |
|
} |
|
|
|
if sts_response["PackedPolicySize"] > 75: |
|
verbose_logger.warning( |
|
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}" |
|
) |
|
|
|
session = boto3.Session(**iam_creds_dict) |
|
|
|
iam_creds = session.get_credentials() |
|
return iam_creds, self._get_default_ttl_for_boto3_credentials() |
|
|
|
def _auth_with_aws_role( |
|
self, |
|
aws_access_key_id: Optional[str], |
|
aws_secret_access_key: Optional[str], |
|
aws_role_name: str, |
|
aws_session_name: str, |
|
) -> Tuple[Credentials, Optional[int]]: |
|
""" |
|
Authenticate with AWS Role |
|
""" |
|
import boto3 |
|
from botocore.credentials import Credentials |
|
|
|
sts_client = boto3.client( |
|
"sts", |
|
aws_access_key_id=aws_access_key_id, |
|
aws_secret_access_key=aws_secret_access_key, |
|
) |
|
|
|
sts_response = sts_client.assume_role( |
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name |
|
) |
|
|
|
|
|
sts_credentials = sts_response["Credentials"] |
|
|
|
credentials = Credentials( |
|
access_key=sts_credentials["AccessKeyId"], |
|
secret_key=sts_credentials["SecretAccessKey"], |
|
token=sts_credentials["SessionToken"], |
|
) |
|
|
|
sts_expiry = sts_credentials["Expiration"] |
|
|
|
current_time = datetime.now(sts_expiry.tzinfo) |
|
sts_ttl = (sts_expiry - current_time).total_seconds() - 60 |
|
return credentials, sts_ttl |
|
|
|
def _auth_with_aws_profile( |
|
self, aws_profile_name: str |
|
) -> Tuple[Credentials, Optional[int]]: |
|
""" |
|
Authenticate with AWS profile |
|
""" |
|
import boto3 |
|
|
|
|
|
client = boto3.Session(profile_name=aws_profile_name) |
|
return client.get_credentials(), None |
|
|
|
def _auth_with_aws_session_token( |
|
self, |
|
aws_access_key_id: str, |
|
aws_secret_access_key: str, |
|
aws_session_token: str, |
|
) -> Tuple[Credentials, Optional[int]]: |
|
""" |
|
Authenticate with AWS Session Token |
|
""" |
|
|
|
from botocore.credentials import Credentials |
|
|
|
credentials = Credentials( |
|
access_key=aws_access_key_id, |
|
secret_key=aws_secret_access_key, |
|
token=aws_session_token, |
|
) |
|
|
|
return credentials, None |
|
|
|
def _auth_with_access_key_and_secret_key( |
|
self, |
|
aws_access_key_id: str, |
|
aws_secret_access_key: str, |
|
aws_region_name: Optional[str], |
|
) -> Tuple[Credentials, Optional[int]]: |
|
""" |
|
Authenticate with AWS Access Key and Secret Key |
|
""" |
|
import boto3 |
|
|
|
|
|
|
|
session = boto3.Session( |
|
aws_access_key_id=aws_access_key_id, |
|
aws_secret_access_key=aws_secret_access_key, |
|
region_name=aws_region_name, |
|
) |
|
|
|
credentials = session.get_credentials() |
|
return credentials, self._get_default_ttl_for_boto3_credentials() |
|
|
|
def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]: |
|
""" |
|
Authenticate with AWS Environment Variables |
|
""" |
|
import boto3 |
|
|
|
session = boto3.Session() |
|
credentials = session.get_credentials() |
|
return credentials, None |
|
|
|
def _get_default_ttl_for_boto3_credentials(self) -> int: |
|
""" |
|
Get the default TTL for boto3 credentials |
|
|
|
Returns `3600-60` which is 59 minutes |
|
""" |
|
return 3600 - 60 |
|
|
|
def get_runtime_endpoint( |
|
self, |
|
api_base: Optional[str], |
|
aws_bedrock_runtime_endpoint: Optional[str], |
|
aws_region_name: str, |
|
) -> Tuple[str, str]: |
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") |
|
if api_base is not None: |
|
endpoint_url = api_base |
|
elif aws_bedrock_runtime_endpoint is not None and isinstance( |
|
aws_bedrock_runtime_endpoint, str |
|
): |
|
endpoint_url = aws_bedrock_runtime_endpoint |
|
elif env_aws_bedrock_runtime_endpoint and isinstance( |
|
env_aws_bedrock_runtime_endpoint, str |
|
): |
|
endpoint_url = env_aws_bedrock_runtime_endpoint |
|
else: |
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" |
|
|
|
|
|
if env_aws_bedrock_runtime_endpoint and isinstance( |
|
env_aws_bedrock_runtime_endpoint, str |
|
): |
|
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint |
|
elif aws_bedrock_runtime_endpoint is not None and isinstance( |
|
aws_bedrock_runtime_endpoint, str |
|
): |
|
proxy_endpoint_url = aws_bedrock_runtime_endpoint |
|
else: |
|
proxy_endpoint_url = endpoint_url |
|
|
|
return endpoint_url, proxy_endpoint_url |
|
|
|
def _get_boto_credentials_from_optional_params( |
|
self, optional_params: dict |
|
) -> Boto3CredentialsInfo: |
|
""" |
|
Get boto3 credentials from optional params |
|
|
|
Args: |
|
optional_params (dict): Optional parameters for the model call |
|
|
|
Returns: |
|
Credentials: Boto3 credentials object |
|
""" |
|
try: |
|
from botocore.credentials import Credentials |
|
except ImportError: |
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") |
|
|
|
|
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) |
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None) |
|
aws_session_token = optional_params.pop("aws_session_token", None) |
|
aws_region_name = optional_params.pop("aws_region_name", None) |
|
aws_role_name = optional_params.pop("aws_role_name", None) |
|
aws_session_name = optional_params.pop("aws_session_name", None) |
|
aws_profile_name = optional_params.pop("aws_profile_name", None) |
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) |
|
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) |
|
aws_bedrock_runtime_endpoint = optional_params.pop( |
|
"aws_bedrock_runtime_endpoint", None |
|
) |
|
|
|
|
|
if aws_region_name is None: |
|
|
|
litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None) |
|
|
|
if litellm_aws_region_name is not None and isinstance( |
|
litellm_aws_region_name, str |
|
): |
|
aws_region_name = litellm_aws_region_name |
|
|
|
standard_aws_region_name = get_secret_str("AWS_REGION", None) |
|
if standard_aws_region_name is not None and isinstance( |
|
standard_aws_region_name, str |
|
): |
|
aws_region_name = standard_aws_region_name |
|
|
|
if aws_region_name is None: |
|
aws_region_name = "us-west-2" |
|
|
|
credentials: Credentials = self.get_credentials( |
|
aws_access_key_id=aws_access_key_id, |
|
aws_secret_access_key=aws_secret_access_key, |
|
aws_session_token=aws_session_token, |
|
aws_region_name=aws_region_name, |
|
aws_session_name=aws_session_name, |
|
aws_profile_name=aws_profile_name, |
|
aws_role_name=aws_role_name, |
|
aws_web_identity_token=aws_web_identity_token, |
|
aws_sts_endpoint=aws_sts_endpoint, |
|
) |
|
|
|
return Boto3CredentialsInfo( |
|
credentials=credentials, |
|
aws_region_name=aws_region_name, |
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, |
|
) |
|
|
|
def get_request_headers( |
|
self, |
|
credentials: Credentials, |
|
aws_region_name: str, |
|
extra_headers: Optional[dict], |
|
endpoint_url: str, |
|
data: str, |
|
headers: dict, |
|
) -> AWSPreparedRequest: |
|
try: |
|
from botocore.auth import SigV4Auth |
|
from botocore.awsrequest import AWSRequest |
|
except ImportError: |
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") |
|
|
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) |
|
|
|
request = AWSRequest( |
|
method="POST", url=endpoint_url, data=data, headers=headers |
|
) |
|
sigv4.add_auth(request) |
|
if ( |
|
extra_headers is not None and "Authorization" in extra_headers |
|
): |
|
request.headers["Authorization"] = extra_headers["Authorization"] |
|
prepped = request.prepare() |
|
|
|
return prepped |
|
|