File size: 5,717 Bytes
51ff9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from types import MappingProxyType
from typing import Any

from pydantic import (
    BaseModel,
    Field,
    SerializationInfo,
    field_serializer,
    model_validator,
)
from pydantic.json import pydantic_encoder

from openhands.events.stream import EventStream
from openhands.integrations.provider import (
    CUSTOM_SECRETS_TYPE,
    CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA,
    PROVIDER_TOKEN_TYPE,
    PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA,
    CustomSecret,
    ProviderToken,
)
from openhands.integrations.service_types import ProviderType


class UserSecrets(BaseModel):
    provider_tokens: PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA = Field(
        default_factory=lambda: MappingProxyType({})
    )

    custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA = Field(
        default_factory=lambda: MappingProxyType({})
    )

    model_config = {
        'frozen': True,
        'validate_assignment': True,
        'arbitrary_types_allowed': True,
    }

    @field_serializer('provider_tokens')
    def provider_tokens_serializer(
        self, provider_tokens: PROVIDER_TOKEN_TYPE, info: SerializationInfo
    ) -> dict[str, dict[str, str | Any]]:
        tokens = {}
        expose_secrets = info.context and info.context.get('expose_secrets', False)

        for token_type, provider_token in provider_tokens.items():
            if not provider_token or not provider_token.token:
                continue

            token_type_str = (
                token_type.value
                if isinstance(token_type, ProviderType)
                else str(token_type)
            )

            token = None
            if provider_token.token:
                token = (
                    provider_token.token.get_secret_value()
                    if expose_secrets
                    else pydantic_encoder(provider_token.token)
                )

            tokens[token_type_str] = {
                'token': token,
                'host': provider_token.host,
                'user_id': provider_token.user_id,
            }

        return tokens

    @field_serializer('custom_secrets')
    def custom_secrets_serializer(
        self, custom_secrets: CUSTOM_SECRETS_TYPE, info: SerializationInfo
    ):
        secrets = {}
        expose_secrets = info.context and info.context.get('expose_secrets', False)

        if custom_secrets:
            for secret_name, secret_value in custom_secrets.items():
                secrets[secret_name] = {
                    'secret': secret_value.secret.get_secret_value()
                    if expose_secrets
                    else pydantic_encoder(secret_value.secret),
                    'description': secret_value.description,
                }

        return secrets

    @model_validator(mode='before')
    @classmethod
    def convert_dict_to_mappingproxy(
        cls, data: dict[str, dict[str, Any] | MappingProxyType] | PROVIDER_TOKEN_TYPE
    ) -> dict[str, MappingProxyType | None]:
        """Custom deserializer to convert dictionary into MappingProxyType"""
        if not isinstance(data, dict):
            raise ValueError('UserSecrets must be initialized with a dictionary')

        new_data: dict[str, MappingProxyType | None] = {}

        if 'provider_tokens' in data:
            tokens = data['provider_tokens']
            if isinstance(
                tokens, dict
            ):  # Ensure conversion happens only for dict inputs
                converted_tokens = {}
                for key, value in tokens.items():
                    try:
                        provider_type = (
                            ProviderType(key) if isinstance(key, str) else key
                        )
                        converted_tokens[provider_type] = ProviderToken.from_value(
                            value
                        )
                    except ValueError:
                        # Skip invalid provider types or tokens
                        continue

                # Convert to MappingProxyType
                new_data['provider_tokens'] = MappingProxyType(converted_tokens)
            elif isinstance(tokens, MappingProxyType):
                new_data['provider_tokens'] = tokens

        if 'custom_secrets' in data:
            secrets = data['custom_secrets']
            if isinstance(secrets, dict):
                converted_secrets = {}
                for key, value in secrets.items():
                    try:
                        converted_secrets[key] = CustomSecret.from_value(value)
                    except ValueError:
                        continue

                new_data['custom_secrets'] = MappingProxyType(converted_secrets)
            elif isinstance(secrets, MappingProxyType):
                new_data['custom_secrets'] = secrets

        return new_data

    def set_event_stream_secrets(self, event_stream: EventStream) -> None:
        """
        This ensures that provider tokens and custom secrets masked from the event stream
        Args:
            event_stream: Agent session's event stream
        """

        secrets = self.get_env_vars()
        event_stream.set_secrets(secrets)

    def get_env_vars(self) -> dict[str, str]:
        secret_store = self.model_dump(context={'expose_secrets': True})
        custom_secrets = secret_store.get('custom_secrets', {})
        secrets = {}
        for secret_name, value in custom_secrets.items():
            secrets[secret_name] = value['secret']

        return secrets

    def get_custom_secrets_descriptions(self) -> dict[str, str]:
        secrets = {}
        for secret_name, secret in self.custom_secrets.items():
            secrets[secret_name] = secret.description

        return secrets