Spaces:
Build error
Build error
from dataclasses import dataclass | |
from typing import Callable | |
class CommandResult: | |
""" | |
Represents the result of a shell command execution. | |
Attributes: | |
content (str): The output content of the command. | |
exit_code (int): The exit code of the command execution. | |
""" | |
content: str | |
exit_code: int | |
class GitHandler: | |
""" | |
A handler for executing Git-related operations via shell commands. | |
""" | |
def __init__( | |
self, | |
execute_shell_fn: Callable[[str, str | None], CommandResult], | |
): | |
self.execute = execute_shell_fn | |
self.cwd: str | None = None | |
def set_cwd(self, cwd: str) -> None: | |
""" | |
Sets the current working directory for Git operations. | |
Args: | |
cwd (str): The directory path. | |
""" | |
self.cwd = cwd | |
def _is_git_repo(self) -> bool: | |
""" | |
Checks if the current directory is a Git repository. | |
Returns: | |
bool: True if inside a Git repository, otherwise False. | |
""" | |
cmd = 'git --no-pager rev-parse --is-inside-work-tree' | |
output = self.execute(cmd, self.cwd) | |
return output.content.strip() == 'true' | |
def _get_current_file_content(self, file_path: str) -> str: | |
""" | |
Retrieves the current content of a given file. | |
Args: | |
file_path (str): Path to the file. | |
Returns: | |
str: The file content. | |
""" | |
output = self.execute(f'cat {file_path}', self.cwd) | |
return output.content | |
def _verify_ref_exists(self, ref: str) -> bool: | |
""" | |
Verifies whether a specific Git reference exists. | |
Args: | |
ref (str): The Git reference to check. | |
Returns: | |
bool: True if the reference exists, otherwise False. | |
""" | |
cmd = f'git --no-pager rev-parse --verify {ref}' | |
output = self.execute(cmd, self.cwd) | |
return output.exit_code == 0 | |
def _get_valid_ref(self) -> str | None: | |
""" | |
Determines a valid Git reference for comparison. | |
Returns: | |
str | None: A valid Git reference or None if no valid reference is found. | |
""" | |
current_branch = self._get_current_branch() | |
default_branch = self._get_default_branch() | |
ref_current_branch = f'origin/{current_branch}' | |
ref_non_default_branch = f'$(git --no-pager merge-base HEAD "$(git --no-pager rev-parse --abbrev-ref origin/{default_branch})")' | |
ref_default_branch = 'origin/' + default_branch | |
ref_new_repo = '$(git --no-pager rev-parse --verify 4b825dc642cb6eb9a060e54bf8d69288fbee4904)' # compares with empty tree | |
refs = [ | |
ref_current_branch, | |
ref_non_default_branch, | |
ref_default_branch, | |
ref_new_repo, | |
] | |
for ref in refs: | |
if self._verify_ref_exists(ref): | |
return ref | |
return None | |
def _get_ref_content(self, file_path: str) -> str: | |
""" | |
Retrieves the content of a file from a valid Git reference. | |
Args: | |
file_path (str): The file path in the repository. | |
Returns: | |
str: The content of the file from the reference, or an empty string if unavailable. | |
""" | |
ref = self._get_valid_ref() | |
if not ref: | |
return '' | |
cmd = f'git --no-pager show {ref}:{file_path}' | |
output = self.execute(cmd, self.cwd) | |
return output.content if output.exit_code == 0 else '' | |
def _get_default_branch(self) -> str: | |
""" | |
Retrieves the primary Git branch name of the repository. | |
Returns: | |
str: The name of the primary branch. | |
""" | |
cmd = 'git --no-pager remote show origin | grep "HEAD branch"' | |
output = self.execute(cmd, self.cwd) | |
return output.content.split()[-1].strip() | |
def _get_current_branch(self) -> str: | |
""" | |
Retrieves the currently selected Git branch. | |
Returns: | |
str: The name of the current branch. | |
""" | |
cmd = 'git --no-pager rev-parse --abbrev-ref HEAD' | |
output = self.execute(cmd, self.cwd) | |
return output.content.strip() | |
def _get_changed_files(self) -> list[str]: | |
""" | |
Retrieves a list of changed files compared to a valid Git reference. | |
Returns: | |
list[str]: A list of changed file paths. | |
""" | |
ref = self._get_valid_ref() | |
if not ref: | |
return [] | |
diff_cmd = f'git --no-pager diff --name-status {ref}' | |
output = self.execute(diff_cmd, self.cwd) | |
if output.exit_code != 0: | |
raise RuntimeError( | |
f'Failed to get diff for ref {ref} in {self.cwd}. Command output: {output.content}' | |
) | |
return output.content.splitlines() | |
def _get_untracked_files(self) -> list[dict[str, str]]: | |
""" | |
Retrieves a list of untracked files in the repository. This is useful for detecting new files. | |
Returns: | |
list[dict[str, str]]: A list of dictionaries containing file paths and statuses. | |
""" | |
cmd = 'git --no-pager ls-files --others --exclude-standard' | |
output = self.execute(cmd, self.cwd) | |
obs_list = output.content.splitlines() | |
return ( | |
[{'status': 'A', 'path': path} for path in obs_list] | |
if output.exit_code == 0 | |
else [] | |
) | |
def get_git_changes(self) -> list[dict[str, str]] | None: | |
""" | |
Retrieves the list of changed files in the Git repository. | |
Returns: | |
list[dict[str, str]] | None: A list of dictionaries containing file paths and statuses. None if not a git repository. | |
""" | |
if not self._is_git_repo(): | |
return None | |
changes_list = self._get_changed_files() | |
result = parse_git_changes(changes_list) | |
# join with any untracked files | |
result += self._get_untracked_files() | |
return result | |
def get_git_diff(self, file_path: str) -> dict[str, str]: | |
""" | |
Retrieves the original and modified content of a file in the repository. | |
Args: | |
file_path (str): Path to the file. | |
Returns: | |
dict[str, str]: A dictionary containing the original and modified content. | |
""" | |
modified = self._get_current_file_content(file_path) | |
original = self._get_ref_content(file_path) | |
return { | |
'modified': modified, | |
'original': original, | |
} | |
def parse_git_changes(changes_list: list[str]) -> list[dict[str, str]]: | |
""" | |
Parses the list of changed files and extracts their statuses and paths. | |
Args: | |
changes_list (list[str]): List of changed file entries. | |
Returns: | |
list[dict[str, str]]: Parsed list of file changes with statuses. | |
""" | |
result = [] | |
for line in changes_list: | |
status = line[:2].strip() | |
path = line[2:].strip() | |
# Get the first non-space character as the primary status | |
primary_status = status.replace(' ', '')[0] | |
result.append( | |
{ | |
'status': primary_status, | |
'path': path, | |
} | |
) | |
return result | |