Spaces:
Build error
Build error
import json | |
import os | |
from datetime import datetime | |
from typing import Any | |
import httpx | |
from pydantic import SecretStr | |
from openhands.core.logger import openhands_logger as logger | |
from openhands.integrations.github.queries import ( | |
suggested_task_issue_graphql_query, | |
suggested_task_pr_graphql_query, | |
) | |
from openhands.integrations.service_types import ( | |
BaseGitService, | |
Branch, | |
GitService, | |
ProviderType, | |
Repository, | |
RequestMethod, | |
SuggestedTask, | |
TaskType, | |
UnknownException, | |
User, | |
) | |
from openhands.server.types import AppMode | |
from openhands.utils.import_utils import get_impl | |
class GitHubService(BaseGitService, GitService): | |
"""Default implementation of GitService for GitHub integration. | |
TODO: This doesn't seem a good candidate for the get_impl() pattern. What are the abstract methods we should actually separate and implement here? | |
This is an extension point in OpenHands that allows applications to customize GitHub | |
integration behavior. Applications can substitute their own implementation by: | |
1. Creating a class that inherits from GitService | |
2. Implementing all required methods | |
3. Setting server_config.github_service_class to the fully qualified name of the class | |
The class is instantiated via get_impl() in openhands.server.shared.py. | |
""" | |
BASE_URL = 'https://api.github.com' | |
token: SecretStr = SecretStr('') | |
refresh = False | |
def __init__( | |
self, | |
user_id: str | None = None, | |
external_auth_id: str | None = None, | |
external_auth_token: SecretStr | None = None, | |
token: SecretStr | None = None, | |
external_token_manager: bool = False, | |
base_domain: str | None = None, | |
): | |
self.user_id = user_id | |
self.external_token_manager = external_token_manager | |
if token: | |
self.token = token | |
if base_domain and base_domain != 'github.com': | |
self.BASE_URL = f'https://{base_domain}/api/v3' | |
self.external_auth_id = external_auth_id | |
self.external_auth_token = external_auth_token | |
def provider(self) -> str: | |
return ProviderType.GITHUB.value | |
async def _get_github_headers(self) -> dict: | |
"""Retrieve the GH Token from settings store to construct the headers.""" | |
if not self.token: | |
self.token = await self.get_latest_token() | |
return { | |
'Authorization': f'Bearer {self.token.get_secret_value() if self.token else ""}', | |
'Accept': 'application/vnd.github.v3+json', | |
} | |
def _has_token_expired(self, status_code: int) -> bool: | |
return status_code == 401 | |
async def get_latest_token(self) -> SecretStr | None: | |
return self.token | |
async def _make_request( | |
self, | |
url: str, | |
params: dict | None = None, | |
method: RequestMethod = RequestMethod.GET, | |
) -> tuple[Any, dict]: | |
try: | |
async with httpx.AsyncClient() as client: | |
github_headers = await self._get_github_headers() | |
# Make initial request | |
response = await self.execute_request( | |
client=client, | |
url=url, | |
headers=github_headers, | |
params=params, | |
method=method, | |
) | |
# Handle token refresh if needed | |
if self.refresh and self._has_token_expired(response.status_code): | |
await self.get_latest_token() | |
github_headers = await self._get_github_headers() | |
response = await self.execute_request( | |
client=client, | |
url=url, | |
headers=github_headers, | |
params=params, | |
method=method, | |
) | |
response.raise_for_status() | |
headers = {} | |
if 'Link' in response.headers: | |
headers['Link'] = response.headers['Link'] | |
return response.json(), headers | |
except httpx.HTTPStatusError as e: | |
raise self.handle_http_status_error(e) | |
except httpx.HTTPError as e: | |
raise self.handle_http_error(e) | |
async def get_user(self) -> User: | |
url = f'{self.BASE_URL}/user' | |
response, _ = await self._make_request(url) | |
return User( | |
id=response.get('id'), | |
login=response.get('login'), | |
avatar_url=response.get('avatar_url'), | |
company=response.get('company'), | |
name=response.get('name'), | |
email=response.get('email'), | |
) | |
async def verify_access(self) -> bool: | |
"""Verify if the token is valid by making a simple request.""" | |
url = f'{self.BASE_URL}' | |
await self._make_request(url) | |
return True | |
async def _fetch_paginated_repos( | |
self, url: str, params: dict, max_repos: int, extract_key: str | None = None | |
) -> list[dict]: | |
""" | |
Fetch repositories with pagination support. | |
Args: | |
url: The API endpoint URL | |
params: Query parameters for the request | |
max_repos: Maximum number of repositories to fetch | |
extract_key: If provided, extract repositories from this key in the response | |
Returns: | |
List of repository dictionaries | |
""" | |
repos: list[dict] = [] | |
page = 1 | |
while len(repos) < max_repos: | |
page_params = {**params, 'page': str(page)} | |
response, headers = await self._make_request(url, page_params) | |
# Extract repositories from response | |
page_repos = response.get(extract_key, []) if extract_key else response | |
if not page_repos: # No more repositories | |
break | |
repos.extend(page_repos) | |
page += 1 | |
# Check if we've reached the last page | |
link_header = headers.get('Link', '') | |
if 'rel="next"' not in link_header: | |
break | |
return repos[:max_repos] # Trim to max_repos if needed | |
def parse_pushed_at_date(self, repo): | |
ts = repo.get('pushed_at') | |
return datetime.strptime(ts, '%Y-%m-%dT%H:%M:%SZ') if ts else datetime.min | |
async def get_repositories(self, sort: str, app_mode: AppMode) -> list[Repository]: | |
MAX_REPOS = 1000 | |
PER_PAGE = 100 # Maximum allowed by GitHub API | |
all_repos: list[dict] = [] | |
if app_mode == AppMode.SAAS: | |
# Get all installation IDs and fetch repos for each one | |
installation_ids = await self.get_installation_ids() | |
# Iterate through each installation ID | |
for installation_id in installation_ids: | |
params = {'per_page': str(PER_PAGE)} | |
url = ( | |
f'{self.BASE_URL}/user/installations/{installation_id}/repositories' | |
) | |
# Fetch repositories for this installation | |
installation_repos = await self._fetch_paginated_repos( | |
url, params, MAX_REPOS - len(all_repos), extract_key='repositories' | |
) | |
all_repos.extend(installation_repos) | |
# If we've already reached MAX_REPOS, no need to check other installations | |
if len(all_repos) >= MAX_REPOS: | |
break | |
if sort == 'pushed': | |
all_repos.sort(key=self.parse_pushed_at_date, reverse=True) | |
else: | |
# Original behavior for non-SaaS mode | |
params = {'per_page': str(PER_PAGE), 'sort': sort} | |
url = f'{self.BASE_URL}/user/repos' | |
# Fetch user repositories | |
all_repos = await self._fetch_paginated_repos(url, params, MAX_REPOS) | |
# Convert to Repository objects | |
return [ | |
Repository( | |
id=repo.get('id'), | |
full_name=repo.get('full_name'), | |
stargazers_count=repo.get('stargazers_count'), | |
git_provider=ProviderType.GITHUB, | |
is_public=not repo.get('private', True), | |
) | |
for repo in all_repos | |
] | |
async def get_installation_ids(self) -> list[int]: | |
url = f'{self.BASE_URL}/user/installations' | |
response, _ = await self._make_request(url) | |
installations = response.get('installations', []) | |
return [i['id'] for i in installations] | |
async def search_repositories( | |
self, query: str, per_page: int, sort: str, order: str | |
) -> list[Repository]: | |
url = f'{self.BASE_URL}/search/repositories' | |
# Add is:public to the query to ensure we only search for public repositories | |
query_with_visibility = f'{query} is:public' | |
params = { | |
'q': query_with_visibility, | |
'per_page': per_page, | |
'sort': sort, | |
'order': order, | |
} | |
response, _ = await self._make_request(url, params) | |
repo_items = response.get('items', []) | |
repos = [ | |
Repository( | |
id=repo.get('id'), | |
full_name=repo.get('full_name'), | |
stargazers_count=repo.get('stargazers_count'), | |
git_provider=ProviderType.GITHUB, | |
is_public=True, | |
) | |
for repo in repo_items | |
] | |
return repos | |
async def execute_graphql_query( | |
self, query: str, variables: dict[str, Any] | |
) -> dict[str, Any]: | |
"""Execute a GraphQL query against the GitHub API.""" | |
try: | |
async with httpx.AsyncClient() as client: | |
github_headers = await self._get_github_headers() | |
response = await client.post( | |
f'{self.BASE_URL}/graphql', | |
headers=github_headers, | |
json={'query': query, 'variables': variables}, | |
) | |
response.raise_for_status() | |
result = response.json() | |
if 'errors' in result: | |
raise UnknownException( | |
f'GraphQL query error: {json.dumps(result["errors"])}' | |
) | |
return dict(result) | |
except httpx.HTTPStatusError as e: | |
raise self.handle_http_status_error(e) | |
except httpx.HTTPError as e: | |
raise self.handle_http_error(e) | |
async def get_suggested_tasks(self) -> list[SuggestedTask]: | |
"""Get suggested tasks for the authenticated user across all repositories. | |
Returns: | |
- PRs authored by the user. | |
- Issues assigned to the user. | |
Note: Queries are split to avoid timeout issues. | |
""" | |
# Get user info to use in queries | |
user = await self.get_user() | |
login = user.login | |
tasks: list[SuggestedTask] = [] | |
variables = {'login': login} | |
try: | |
pr_response = await self.execute_graphql_query( | |
suggested_task_pr_graphql_query, variables | |
) | |
pr_data = pr_response['data']['user'] | |
# Process pull requests | |
for pr in pr_data['pullRequests']['nodes']: | |
repo_name = pr['repository']['nameWithOwner'] | |
# Start with default task type | |
task_type = TaskType.OPEN_PR | |
# Check for specific states | |
if pr['mergeable'] == 'CONFLICTING': | |
task_type = TaskType.MERGE_CONFLICTS | |
elif ( | |
pr['commits']['nodes'] | |
and pr['commits']['nodes'][0]['commit']['statusCheckRollup'] | |
and pr['commits']['nodes'][0]['commit']['statusCheckRollup'][ | |
'state' | |
] | |
== 'FAILURE' | |
): | |
task_type = TaskType.FAILING_CHECKS | |
elif any( | |
review['state'] in ['CHANGES_REQUESTED', 'COMMENTED'] | |
for review in pr['reviews']['nodes'] | |
): | |
task_type = TaskType.UNRESOLVED_COMMENTS | |
# Only add the task if it's not OPEN_PR | |
if task_type != TaskType.OPEN_PR: | |
tasks.append( | |
SuggestedTask( | |
git_provider=ProviderType.GITHUB, | |
task_type=task_type, | |
repo=repo_name, | |
issue_number=pr['number'], | |
title=pr['title'], | |
) | |
) | |
except Exception as e: | |
logger.info( | |
f'Error fetching suggested task for PRs: {e}', | |
extra={ | |
'signal': 'github_suggested_tasks', | |
'user_id': self.external_auth_id, | |
}, | |
) | |
try: | |
# Execute issue query | |
issue_response = await self.execute_graphql_query( | |
suggested_task_issue_graphql_query, variables | |
) | |
issue_data = issue_response['data']['user'] | |
# Process issues | |
for issue in issue_data['issues']['nodes']: | |
repo_name = issue['repository']['nameWithOwner'] | |
tasks.append( | |
SuggestedTask( | |
git_provider=ProviderType.GITHUB, | |
task_type=TaskType.OPEN_ISSUE, | |
repo=repo_name, | |
issue_number=issue['number'], | |
title=issue['title'], | |
) | |
) | |
return tasks | |
except Exception as e: | |
logger.info( | |
f'Error fetching suggested task for issues: {e}', | |
extra={ | |
'signal': 'github_suggested_tasks', | |
'user_id': self.external_auth_id, | |
}, | |
) | |
return tasks | |
async def get_repository_details_from_repo_name( | |
self, repository: str | |
) -> Repository: | |
url = f'{self.BASE_URL}/repos/{repository}' | |
repo, _ = await self._make_request(url) | |
return Repository( | |
id=repo.get('id'), | |
full_name=repo.get('full_name'), | |
stargazers_count=repo.get('stargazers_count'), | |
git_provider=ProviderType.GITHUB, | |
is_public=not repo.get('private', True), | |
) | |
async def get_branches(self, repository: str) -> list[Branch]: | |
"""Get branches for a repository""" | |
url = f'{self.BASE_URL}/repos/{repository}/branches' | |
# Set maximum branches to fetch (10 pages with 100 per page) | |
MAX_BRANCHES = 1000 | |
PER_PAGE = 100 | |
all_branches: list[Branch] = [] | |
page = 1 | |
# Fetch up to 10 pages of branches | |
while page <= 10 and len(all_branches) < MAX_BRANCHES: | |
params = {'per_page': str(PER_PAGE), 'page': str(page)} | |
response, headers = await self._make_request(url, params) | |
if not response: # No more branches | |
break | |
for branch_data in response: | |
# Extract the last commit date if available | |
last_push_date = None | |
if branch_data.get('commit') and branch_data['commit'].get('commit'): | |
commit_info = branch_data['commit']['commit'] | |
if commit_info.get('committer') and commit_info['committer'].get( | |
'date' | |
): | |
last_push_date = commit_info['committer']['date'] | |
branch = Branch( | |
name=branch_data.get('name'), | |
commit_sha=branch_data.get('commit', {}).get('sha', ''), | |
protected=branch_data.get('protected', False), | |
last_push_date=last_push_date, | |
) | |
all_branches.append(branch) | |
page += 1 | |
# Check if we've reached the last page | |
link_header = headers.get('Link', '') | |
if 'rel="next"' not in link_header: | |
break | |
return all_branches | |
async def create_pr( | |
self, | |
repo_name: str, | |
source_branch: str, | |
target_branch: str, | |
title: str, | |
body: str | None = None, | |
draft: bool = True, | |
) -> str: | |
""" | |
Creates a PR using user credentials | |
Args: | |
repo_name: The full name of the repository (owner/repo) | |
source_branch: The name of the branch where your changes are implemented | |
target_branch: The name of the branch you want the changes pulled into | |
title: The title of the pull request (optional, defaults to a generic title) | |
body: The body/description of the pull request (optional) | |
draft: Whether to create the PR as a draft (optional, defaults to False) | |
Returns: | |
- PR URL when successful | |
- Error message when unsuccessful | |
""" | |
url = f'{self.BASE_URL}/repos/{repo_name}/pulls' | |
# Set default body if none provided | |
if not body: | |
body = f'Merging changes from {source_branch} into {target_branch}' | |
# Prepare the request payload | |
payload = { | |
'title': title, | |
'head': source_branch, | |
'base': target_branch, | |
'body': body, | |
'draft': draft, | |
} | |
# Make the POST request to create the PR | |
response, _ = await self._make_request( | |
url=url, params=payload, method=RequestMethod.POST | |
) | |
# Return the HTML URL of the created PR | |
return response['html_url'] | |
github_service_cls = os.environ.get( | |
'OPENHANDS_GITHUB_SERVICE_CLS', | |
'openhands.integrations.github.github_service.GitHubService', | |
) | |
GithubServiceImpl = get_impl(GitHubService, github_service_cls) | |