Spaces:
Build error
Build error
File size: 5,717 Bytes
51ff9e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from types import MappingProxyType
from typing import Any
from pydantic import (
BaseModel,
Field,
SerializationInfo,
field_serializer,
model_validator,
)
from pydantic.json import pydantic_encoder
from openhands.events.stream import EventStream
from openhands.integrations.provider import (
CUSTOM_SECRETS_TYPE,
CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA,
PROVIDER_TOKEN_TYPE,
PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA,
CustomSecret,
ProviderToken,
)
from openhands.integrations.service_types import ProviderType
class UserSecrets(BaseModel):
provider_tokens: PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA = Field(
default_factory=lambda: MappingProxyType({})
)
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA = Field(
default_factory=lambda: MappingProxyType({})
)
model_config = {
'frozen': True,
'validate_assignment': True,
'arbitrary_types_allowed': True,
}
@field_serializer('provider_tokens')
def provider_tokens_serializer(
self, provider_tokens: PROVIDER_TOKEN_TYPE, info: SerializationInfo
) -> dict[str, dict[str, str | Any]]:
tokens = {}
expose_secrets = info.context and info.context.get('expose_secrets', False)
for token_type, provider_token in provider_tokens.items():
if not provider_token or not provider_token.token:
continue
token_type_str = (
token_type.value
if isinstance(token_type, ProviderType)
else str(token_type)
)
token = None
if provider_token.token:
token = (
provider_token.token.get_secret_value()
if expose_secrets
else pydantic_encoder(provider_token.token)
)
tokens[token_type_str] = {
'token': token,
'host': provider_token.host,
'user_id': provider_token.user_id,
}
return tokens
@field_serializer('custom_secrets')
def custom_secrets_serializer(
self, custom_secrets: CUSTOM_SECRETS_TYPE, info: SerializationInfo
):
secrets = {}
expose_secrets = info.context and info.context.get('expose_secrets', False)
if custom_secrets:
for secret_name, secret_value in custom_secrets.items():
secrets[secret_name] = {
'secret': secret_value.secret.get_secret_value()
if expose_secrets
else pydantic_encoder(secret_value.secret),
'description': secret_value.description,
}
return secrets
@model_validator(mode='before')
@classmethod
def convert_dict_to_mappingproxy(
cls, data: dict[str, dict[str, Any] | MappingProxyType] | PROVIDER_TOKEN_TYPE
) -> dict[str, MappingProxyType | None]:
"""Custom deserializer to convert dictionary into MappingProxyType"""
if not isinstance(data, dict):
raise ValueError('UserSecrets must be initialized with a dictionary')
new_data: dict[str, MappingProxyType | None] = {}
if 'provider_tokens' in data:
tokens = data['provider_tokens']
if isinstance(
tokens, dict
): # Ensure conversion happens only for dict inputs
converted_tokens = {}
for key, value in tokens.items():
try:
provider_type = (
ProviderType(key) if isinstance(key, str) else key
)
converted_tokens[provider_type] = ProviderToken.from_value(
value
)
except ValueError:
# Skip invalid provider types or tokens
continue
# Convert to MappingProxyType
new_data['provider_tokens'] = MappingProxyType(converted_tokens)
elif isinstance(tokens, MappingProxyType):
new_data['provider_tokens'] = tokens
if 'custom_secrets' in data:
secrets = data['custom_secrets']
if isinstance(secrets, dict):
converted_secrets = {}
for key, value in secrets.items():
try:
converted_secrets[key] = CustomSecret.from_value(value)
except ValueError:
continue
new_data['custom_secrets'] = MappingProxyType(converted_secrets)
elif isinstance(secrets, MappingProxyType):
new_data['custom_secrets'] = secrets
return new_data
def set_event_stream_secrets(self, event_stream: EventStream) -> None:
"""
This ensures that provider tokens and custom secrets masked from the event stream
Args:
event_stream: Agent session's event stream
"""
secrets = self.get_env_vars()
event_stream.set_secrets(secrets)
def get_env_vars(self) -> dict[str, str]:
secret_store = self.model_dump(context={'expose_secrets': True})
custom_secrets = secret_store.get('custom_secrets', {})
secrets = {}
for secret_name, value in custom_secrets.items():
secrets[secret_name] = value['secret']
return secrets
def get_custom_secrets_descriptions(self) -> dict[str, str]:
secrets = {}
for secret_name, secret in self.custom_secrets.items():
secrets[secret_name] = secret.description
return secrets
|