File size: 2,935 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
Deprecated. Only PostgresSQL is supported.
"""

from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import DynamoDBArgs
from litellm.proxy.db.base_client import CustomDB


class DynamoDBWrapper(CustomDB):
    from aiodynamo.credentials import Credentials, StaticCredentials

    credentials: Credentials

    def __init__(self, database_arguments: DynamoDBArgs):
        from aiodynamo.models import PayPerRequest, Throughput

        self.throughput_type = None
        if database_arguments.billing_mode == "PAY_PER_REQUEST":
            self.throughput_type = PayPerRequest()
        elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
            if (
                database_arguments.read_capacity_units is not None
                and isinstance(database_arguments.read_capacity_units, int)
                and database_arguments.write_capacity_units is not None
                and isinstance(database_arguments.write_capacity_units, int)
            ):
                self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units)  # type: ignore
            else:
                raise Exception(
                    f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
                )
        self.database_arguments = database_arguments
        self.region_name = database_arguments.region_name

    def set_env_vars_based_on_arn(self):
        if self.database_arguments.aws_role_name is None:
            return
        verbose_proxy_logger.debug(
            f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
        )
        import os

        import boto3

        sts_client = boto3.client("sts")

        # call 1
        sts_client.assume_role_with_web_identity(
            RoleArn=self.database_arguments.aws_role_name,
            RoleSessionName=self.database_arguments.aws_session_name,
            WebIdentityToken=self.database_arguments.aws_web_identity_token,
        )

        # call 2
        assumed_role = sts_client.assume_role(
            RoleArn=self.database_arguments.assume_role_aws_role_name,
            RoleSessionName=self.database_arguments.assume_role_aws_session_name,
        )

        aws_access_key_id = assumed_role["Credentials"]["AccessKeyId"]
        aws_secret_access_key = assumed_role["Credentials"]["SecretAccessKey"]
        aws_session_token = assumed_role["Credentials"]["SessionToken"]

        verbose_proxy_logger.debug(
            f"Got STS assumed Role, aws_access_key_id={aws_access_key_id}"
        )
        # set these in the env so aiodynamo can use them
        os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
        os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
        os.environ["AWS_SESSION_TOKEN"] = aws_session_token