Spaces:
Build error
Build error
from __future__ import annotations | |
from types import MappingProxyType | |
from typing import Annotated, Any, Coroutine, Literal, overload | |
from pydantic import ( | |
BaseModel, | |
Field, | |
SecretStr, | |
WithJsonSchema, | |
) | |
from openhands.core.logger import openhands_logger as logger | |
from openhands.events.action.action import Action | |
from openhands.events.action.commands import CmdRunAction | |
from openhands.events.stream import EventStream | |
from openhands.integrations.github.github_service import GithubServiceImpl | |
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl | |
from openhands.integrations.service_types import ( | |
AuthenticationError, | |
Branch, | |
GitService, | |
ProviderType, | |
Repository, | |
SuggestedTask, | |
User, | |
) | |
from openhands.server.types import AppMode | |
class ProviderToken(BaseModel): | |
token: SecretStr | None = Field(default=None) | |
user_id: str | None = Field(default=None) | |
host: str | None = Field(default=None) | |
model_config = { | |
'frozen': True, # Makes the entire model immutable | |
'validate_assignment': True, | |
} | |
def from_value(cls, token_value: ProviderToken | dict[str, str]) -> ProviderToken: | |
"""Factory method to create a ProviderToken from various input types""" | |
if isinstance(token_value, cls): | |
return token_value | |
elif isinstance(token_value, dict): | |
token_str = token_value.get('token', '') | |
# Override with emtpy string if it was set to None | |
# Cannot pass None to SecretStr | |
if token_str is None: | |
token_str = '' | |
user_id = token_value.get('user_id') | |
host = token_value.get('host') | |
return cls(token=SecretStr(token_str), user_id=user_id, host=host) | |
else: | |
raise ValueError('Unsupported Provider token type') | |
class CustomSecret(BaseModel): | |
secret: SecretStr = Field(default_factory=lambda: SecretStr('')) | |
description: str = Field(default='') | |
model_config = { | |
'frozen': True, # Makes the entire model immutable | |
'validate_assignment': True, | |
} | |
def from_value(cls, secret_value: CustomSecret | dict[str, str]) -> CustomSecret: | |
"""Factory method to create a ProviderToken from various input types""" | |
if isinstance(secret_value, CustomSecret): | |
return secret_value | |
elif isinstance(secret_value, dict): | |
secret = secret_value.get('secret') | |
description = secret_value.get('description') | |
return cls(secret=SecretStr(secret), description=description) | |
else: | |
raise ValueError('Unsupport Provider token type') | |
PROVIDER_TOKEN_TYPE = MappingProxyType[ProviderType, ProviderToken] | |
CUSTOM_SECRETS_TYPE = MappingProxyType[str, CustomSecret] | |
PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA = Annotated[ | |
PROVIDER_TOKEN_TYPE, | |
WithJsonSchema({'type': 'object', 'additionalProperties': {'type': 'string'}}), | |
] | |
CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA = Annotated[ | |
CUSTOM_SECRETS_TYPE, | |
WithJsonSchema({'type': 'object', 'additionalProperties': {'type': 'string'}}), | |
] | |
class ProviderHandler: | |
def __init__( | |
self, | |
provider_tokens: PROVIDER_TOKEN_TYPE, | |
external_auth_id: str | None = None, | |
external_auth_token: SecretStr | None = None, | |
external_token_manager: bool = False, | |
): | |
if not isinstance(provider_tokens, MappingProxyType): | |
raise TypeError( | |
f'provider_tokens must be a MappingProxyType, got {type(provider_tokens).__name__}' | |
) | |
self.service_class_map: dict[ProviderType, type[GitService]] = { | |
ProviderType.GITHUB: GithubServiceImpl, | |
ProviderType.GITLAB: GitLabServiceImpl, | |
} | |
self.external_auth_id = external_auth_id | |
self.external_auth_token = external_auth_token | |
self.external_token_manager = external_token_manager | |
self._provider_tokens = provider_tokens | |
def provider_tokens(self) -> PROVIDER_TOKEN_TYPE: | |
"""Read-only access to provider tokens.""" | |
return self._provider_tokens | |
def _get_service(self, provider: ProviderType) -> GitService: | |
"""Helper method to instantiate a service for a given provider""" | |
token = self.provider_tokens[provider] | |
service_class = self.service_class_map[provider] | |
return service_class( | |
user_id=token.user_id, | |
external_auth_id=self.external_auth_id, | |
external_auth_token=self.external_auth_token, | |
token=token.token, | |
external_token_manager=self.external_token_manager, | |
base_domain=token.host, | |
) | |
async def get_user(self) -> User: | |
"""Get user information from the first available provider""" | |
for provider in self.provider_tokens: | |
try: | |
service = self._get_service(provider) | |
return await service.get_user() | |
except Exception: | |
continue | |
raise AuthenticationError('Need valid provider token') | |
async def _get_latest_provider_token( | |
self, provider: ProviderType | |
) -> SecretStr | None: | |
"""Get latest token from service""" | |
service = self._get_service(provider) | |
return await service.get_latest_token() | |
async def get_repositories(self, sort: str, app_mode: AppMode) -> list[Repository]: | |
""" | |
Get repositories from providers | |
""" | |
all_repos: list[Repository] = [] | |
for provider in self.provider_tokens: | |
try: | |
service = self._get_service(provider) | |
service_repos = await service.get_repositories(sort, app_mode) | |
all_repos.extend(service_repos) | |
except Exception as e: | |
logger.warning(f'Error fetching repos from {provider}: {e}') | |
return all_repos | |
async def get_suggested_tasks(self) -> list[SuggestedTask]: | |
""" | |
Get suggested tasks from providers | |
""" | |
tasks: list[SuggestedTask] = [] | |
for provider in self.provider_tokens: | |
try: | |
service = self._get_service(provider) | |
service_repos = await service.get_suggested_tasks() | |
tasks.extend(service_repos) | |
except Exception as e: | |
logger.warning(f'Error fetching repos from {provider}: {e}') | |
return tasks | |
async def search_repositories( | |
self, | |
query: str, | |
per_page: int, | |
sort: str, | |
order: str, | |
) -> list[Repository]: | |
all_repos: list[Repository] = [] | |
for provider in self.provider_tokens: | |
try: | |
service = self._get_service(provider) | |
service_repos = await service.search_repositories( | |
query, per_page, sort, order | |
) | |
all_repos.extend(service_repos) | |
except Exception as e: | |
logger.warning(f'Error searching repos from {provider}: {e}') | |
continue | |
return all_repos | |
async def set_event_stream_secrets( | |
self, | |
event_stream: EventStream, | |
env_vars: dict[ProviderType, SecretStr] | None = None, | |
) -> None: | |
""" | |
This ensures that the latest provider tokens are masked from the event stream | |
It is called when the provider tokens are first initialized in the runtime or when tokens are re-exported with the latest working ones | |
Args: | |
event_stream: Agent session's event stream | |
env_vars: Dict of providers and their tokens that require updating | |
""" | |
if env_vars: | |
exposed_env_vars = self.expose_env_vars(env_vars) | |
else: | |
exposed_env_vars = await self.get_env_vars(expose_secrets=True) | |
event_stream.set_secrets(exposed_env_vars) | |
def expose_env_vars( | |
self, env_secrets: dict[ProviderType, SecretStr] | |
) -> dict[str, str]: | |
""" | |
Return string values instead of typed values for environment secrets | |
Called just before exporting secrets to runtime, or setting secrets in the event stream | |
""" | |
exposed_envs = {} | |
for provider, token in env_secrets.items(): | |
env_key = ProviderHandler.get_provider_env_key(provider) | |
exposed_envs[env_key] = token.get_secret_value() | |
return exposed_envs | |
def get_env_vars( | |
self, | |
expose_secrets: Literal[True], | |
providers: list[ProviderType] | None = ..., | |
get_latest: bool = False, | |
) -> Coroutine[Any, Any, dict[str, str]]: ... | |
def get_env_vars( | |
self, | |
expose_secrets: Literal[False], | |
providers: list[ProviderType] | None = ..., | |
get_latest: bool = False, | |
) -> Coroutine[Any, Any, dict[ProviderType, SecretStr]]: ... | |
async def get_env_vars( | |
self, | |
expose_secrets: bool = False, | |
providers: list[ProviderType] | None = None, | |
get_latest: bool = False, | |
) -> dict[ProviderType, SecretStr] | dict[str, str]: | |
""" | |
Retrieves the provider tokens from ProviderHandler object | |
This is used when initializing/exporting new provider tokens in the runtime | |
Args: | |
expose_secrets: Flag which returns strings instead of secrets | |
providers: Return provider tokens for the list passed in, otherwise return all available providers | |
get_latest: Get the latest working token for the providers if True, otherwise get the existing ones | |
""" | |
if not self.provider_tokens: | |
return {} | |
env_vars: dict[ProviderType, SecretStr] = {} | |
all_providers = [provider for provider in ProviderType] | |
provider_list = providers if providers else all_providers | |
for provider in provider_list: | |
if provider in self.provider_tokens: | |
token = ( | |
self.provider_tokens[provider].token | |
if self.provider_tokens | |
else SecretStr('') | |
) | |
if get_latest: | |
token = await self._get_latest_provider_token(provider) | |
if token: | |
env_vars[provider] = token | |
if not expose_secrets: | |
return env_vars | |
return self.expose_env_vars(env_vars) | |
def check_cmd_action_for_provider_token_ref( | |
cls, event: Action | |
) -> list[ProviderType]: | |
""" | |
Detect if agent run action is using a provider token (e.g $GITHUB_TOKEN) | |
Returns a list of providers which are called by the agent | |
""" | |
if not isinstance(event, CmdRunAction): | |
return [] | |
called_providers = [] | |
for provider in ProviderType: | |
if ProviderHandler.get_provider_env_key(provider) in event.command.lower(): | |
called_providers.append(provider) | |
return called_providers | |
def get_provider_env_key(cls, provider: ProviderType) -> str: | |
""" | |
Map ProviderType value to the environment variable name in the runtime | |
""" | |
return f'{provider.value}_token'.lower() | |
async def verify_repo_provider( | |
self, repository: str, specified_provider: ProviderType | None = None | |
): | |
if specified_provider: | |
try: | |
service = self._get_service(specified_provider) | |
return await service.get_repository_details_from_repo_name(repository) | |
except Exception: | |
pass | |
for provider in self.provider_tokens: | |
try: | |
service = self._get_service(provider) | |
return await service.get_repository_details_from_repo_name(repository) | |
except Exception: | |
pass | |
raise AuthenticationError(f'Unable to access repo {repository}') | |
async def get_branches( | |
self, repository: str, specified_provider: ProviderType | None = None | |
) -> list[Branch]: | |
""" | |
Get branches for a repository | |
Args: | |
repository: The repository name | |
specified_provider: Optional provider type to use | |
Returns: | |
A list of branches for the repository | |
""" | |
all_branches: list[Branch] = [] | |
if specified_provider: | |
try: | |
service = self._get_service(specified_provider) | |
branches = await service.get_branches(repository) | |
return branches | |
except Exception as e: | |
logger.warning( | |
f'Error fetching branches from {specified_provider}: {e}' | |
) | |
for provider in self.provider_tokens: | |
try: | |
service = self._get_service(provider) | |
branches = await service.get_branches(repository) | |
all_branches.extend(branches) | |
# If we found branches, no need to check other providers | |
if all_branches: | |
break | |
except Exception as e: | |
logger.warning(f'Error fetching branches from {provider}: {e}') | |
# Sort branches by last push date (newest first) | |
all_branches.sort( | |
key=lambda b: b.last_push_date if b.last_push_date else '', reverse=True | |
) | |
# Move main/master branch to the top if it exists | |
main_branches = [] | |
other_branches = [] | |
for branch in all_branches: | |
if branch.name.lower() in ['main', 'master']: | |
main_branches.append(branch) | |
else: | |
other_branches.append(branch) | |
return main_branches + other_branches | |