Spaces:
No application file
No application file
import re | |
import shlex | |
import subprocess | |
import sys | |
import time | |
from enum import Enum | |
import webui.args | |
from autodebug.autodebug import InstallFailException | |
from setup_tools.os import is_windows | |
from threading import Thread | |
valid_last: list[tuple[str, str]] = None | |
class CompareAction(Enum): | |
LT = -2 | |
LEQ = -1 | |
EQ = 0 | |
GEQ = 1 | |
GT = 2 | |
class Requirement: | |
def __init__(self): | |
self.running = False | |
def install_or_upgrade_if_needed(self): | |
if not self.is_installed() or not self.is_right_version(): | |
self.post_install(self.install()) | |
def post_install(self, install_output: tuple[int, str, str]): | |
exit_code, stdout, stderr = install_output | |
if exit_code != 0: | |
raise InstallFailException(exit_code, stdout, stderr) | |
def is_right_version(self): | |
raise NotImplementedError('Not implemented') | |
def is_installed(self): | |
raise NotImplementedError('Not implemented') | |
def install_check(self, package_name: str) -> bool: | |
return self.get_package_version(package_name) is not False | |
def loading_thread(status_dict, name): | |
idx = 0 | |
load_symbols = ['|', '/', '-', '\\'] | |
while status_dict['running']: | |
curr_symbol = load_symbols[idx % len(load_symbols)] | |
idx += 1 | |
print(f'\rInstalling {name} {curr_symbol}', end='') | |
time.sleep(0.25) | |
print(f'\rInstalled {name}! ' if status_dict[ | |
'success'] else f'\rFailed to install {name}. Check AutoDebug output.') | |
def install_pip(self, command, name=None) -> tuple[int, str, str]: | |
global valid_last | |
valid_last = None | |
if not name: | |
name = command | |
status_dict = { | |
'running': True | |
} | |
verbose = webui.args.args.verbose | |
if not verbose: | |
thread = Thread(target=self.loading_thread, args=[status_dict, name], daemon=True) | |
thread.start() | |
args = f'"{sys.executable}" -m pip install --upgrade {command}' | |
args = args if self.is_windows() else shlex.split(args) | |
result = subprocess.run(args, capture_output=not verbose, text=True) | |
status_dict['success'] = result.returncode == 0 | |
status_dict['running'] = False | |
if not verbose: | |
while thread.is_alive(): | |
time.sleep(0.1) | |
return result.returncode, result.stdout, result.stderr | |
def is_windows(self) -> bool: | |
return is_windows() | |
def install(self) -> tuple[int, str, str]: | |
raise NotImplementedError('Not implemented') | |
def pip_freeze(self) -> list[tuple[str, str]]: | |
global valid_last | |
if valid_last: | |
return valid_last | |
args = f'"{sys.executable}" -m pip freeze' | |
args = args if self.is_windows() else shlex.split(args) | |
result = subprocess.run(args, capture_output=True, text=True) | |
test_str = result.stdout | |
out_list = [] | |
matches = re.finditer('^(.*)(?:==| @ )(.+)$', test_str, re.MULTILINE) | |
for match in matches: | |
out_list.append((match.group(1), match.group(2))) | |
valid_last = out_list | |
return out_list | |
def get_package_version(self, name: str, freeze: dict[tuple[str, str]] | None = None) -> bool | str: | |
if freeze is None: | |
freeze = self.pip_freeze() | |
for p_name, version in freeze: | |
if name.casefold() == p_name.casefold(): | |
return version | |
return False | |
class SimpleRequirement(Requirement): | |
package_name: str | |
def is_right_version(self): | |
return True | |
def is_installed(self): | |
return self.install_check(self.package_name) | |
def install(self) -> tuple[int, str, str]: | |
return self.install_pip(self.package_name) | |
class SimpleRequirementInit(SimpleRequirement): | |
def __init__(self, package_name, compare: CompareAction = None, version: str = None): | |
super().__init__() | |
self.package_name = package_name | |
self.compare = compare | |
self.version = version | |
def is_right_version(self): | |
if self.compare is None or self.version is None: | |
return True | |
from packaging import version | |
version_obj = version.parse(self.get_package_version(self.package_name)) | |
version_target_obj = version.parse(self.version) | |
match self.compare: | |
case CompareAction.LT: | |
return version_obj < version_target_obj | |
case CompareAction.LEQ: | |
return version_obj <= version_target_obj | |
case CompareAction.EQ: | |
return version_obj == version_target_obj | |
case CompareAction.GEQ: | |
return version_obj >= version_target_obj | |
case CompareAction.GT: | |
return version_obj > version_target_obj | |
case _: | |
return True | |
def install(self) -> tuple[int, str, str]: | |
if self.version is None: | |
return self.install_pip(self.package_name) | |
match self.compare: | |
case CompareAction.LT: | |
symbol = '<' | |
case CompareAction.LEQ: | |
symbol = '<=' | |
case CompareAction.EQ: | |
symbol = '==' | |
case CompareAction.GEQ: | |
symbol = '>=' | |
case CompareAction.GT: | |
symbol = '>' | |
case _: | |
symbol = '==' | |
return self.install_pip(f'{self.package_name}{symbol}{self.version}', self.package_name) | |
class SimpleGitRequirement(SimpleRequirement): | |
def __init__(self, package_name, repo, check_version=False): | |
super().__init__() | |
self.package_name = package_name | |
self.repo = repo | |
self.check_version = check_version | |
def is_right_version(self): | |
if not self.check_version: | |
return True | |
return self.get_package_version(self.package_name) == self.repo | |
def install(self) -> tuple[int, str, str]: | |
return self.install_pip(self.repo, self.package_name) | |