Spaces:
Build error
Build error
import subprocess | |
import os | |
import re | |
import sys | |
import filecmp | |
import logging | |
import shutil | |
import sysconfig | |
import datetime | |
import platform | |
import pkg_resources | |
errors = 0 # Define the 'errors' variable before using it | |
log = logging.getLogger('sd') | |
# setup console and file logging | |
def setup_logging(clean=False): | |
# | |
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master | |
# | |
from rich.theme import Theme | |
from rich.logging import RichHandler | |
from rich.console import Console | |
from rich.pretty import install as pretty_install | |
from rich.traceback import install as traceback_install | |
console = Console( | |
log_time=True, | |
log_time_format='%H:%M:%S-%f', | |
theme=Theme( | |
{ | |
'traceback.border': 'black', | |
'traceback.border.syntax_error': 'black', | |
'inspect.value.border': 'black', | |
} | |
), | |
) | |
# logging.getLogger("urllib3").setLevel(logging.ERROR) | |
# logging.getLogger("httpx").setLevel(logging.ERROR) | |
current_datetime = datetime.datetime.now() | |
current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S') | |
log_file = os.path.join( | |
os.path.dirname(__file__), | |
f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log', | |
) | |
# Create directories if they don't exist | |
log_directory = os.path.dirname(log_file) | |
os.makedirs(log_directory, exist_ok=True) | |
level = logging.INFO | |
logging.basicConfig( | |
level=logging.ERROR, | |
format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', | |
filename=log_file, | |
filemode='a', | |
encoding='utf-8', | |
force=True, | |
) | |
log.setLevel( | |
logging.DEBUG | |
) # log to file is always at level debug for facility `sd` | |
pretty_install(console=console) | |
traceback_install( | |
console=console, | |
extra_lines=1, | |
width=console.width, | |
word_wrap=False, | |
indent_guides=False, | |
suppress=[], | |
) | |
rh = RichHandler( | |
show_time=True, | |
omit_repeated_times=False, | |
show_level=True, | |
show_path=False, | |
markup=False, | |
rich_tracebacks=True, | |
log_time_format='%H:%M:%S-%f', | |
level=level, | |
console=console, | |
) | |
rh.set_name(level) | |
while log.hasHandlers() and len(log.handlers) > 0: | |
log.removeHandler(log.handlers[0]) | |
log.addHandler(rh) | |
def configure_accelerate(run_accelerate=False): | |
# | |
# This function was taken and adapted from code written by jstayco | |
# | |
from pathlib import Path | |
def env_var_exists(var_name): | |
return var_name in os.environ and os.environ[var_name] != '' | |
log.info('Configuring accelerate...') | |
source_accelerate_config_file = os.path.join( | |
os.path.dirname(os.path.abspath(__file__)), | |
'..', | |
'config_files', | |
'accelerate', | |
'default_config.yaml', | |
) | |
if not os.path.exists(source_accelerate_config_file): | |
if run_accelerate: | |
run_cmd('accelerate config') | |
else: | |
log.warning( | |
f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.' | |
) | |
log.debug( | |
f'Source accelerate config location: {source_accelerate_config_file}' | |
) | |
target_config_location = None | |
log.debug( | |
f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, " | |
f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, " | |
f"USERPROFILE: {os.environ.get('USERPROFILE')}" | |
) | |
if env_var_exists('HF_HOME'): | |
target_config_location = Path( | |
os.environ['HF_HOME'], 'accelerate', 'default_config.yaml' | |
) | |
elif env_var_exists('LOCALAPPDATA'): | |
target_config_location = Path( | |
os.environ['LOCALAPPDATA'], | |
'huggingface', | |
'accelerate', | |
'default_config.yaml', | |
) | |
elif env_var_exists('USERPROFILE'): | |
target_config_location = Path( | |
os.environ['USERPROFILE'], | |
'.cache', | |
'huggingface', | |
'accelerate', | |
'default_config.yaml', | |
) | |
log.debug(f'Target config location: {target_config_location}') | |
if target_config_location: | |
if not target_config_location.is_file(): | |
target_config_location.parent.mkdir(parents=True, exist_ok=True) | |
log.debug( | |
f'Target accelerate config location: {target_config_location}' | |
) | |
shutil.copyfile( | |
source_accelerate_config_file, target_config_location | |
) | |
log.info( | |
f'Copied accelerate config file to: {target_config_location}' | |
) | |
else: | |
if run_accelerate: | |
run_cmd('accelerate config') | |
else: | |
log.warning( | |
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.' | |
) | |
else: | |
if run_accelerate: | |
run_cmd('accelerate config') | |
else: | |
log.warning( | |
'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.' | |
) | |
def check_torch(): | |
# | |
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master | |
# | |
# Check for nVidia toolkit or AMD toolkit | |
if shutil.which('nvidia-smi') is not None or os.path.exists( | |
os.path.join( | |
os.environ.get('SystemRoot') or r'C:\Windows', | |
'System32', | |
'nvidia-smi.exe', | |
) | |
): | |
log.info('nVidia toolkit detected') | |
elif shutil.which('rocminfo') is not None or os.path.exists( | |
'/opt/rocm/bin/rocminfo' | |
): | |
log.info('AMD toolkit detected') | |
else: | |
log.info('Using CPU-only Torch') | |
try: | |
import torch | |
log.info(f'Torch {torch.__version__}') | |
# Check if CUDA is available | |
if not torch.cuda.is_available(): | |
log.warning('Torch reports CUDA not available') | |
else: | |
if torch.version.cuda: | |
# Log nVidia CUDA and cuDNN versions | |
log.info( | |
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' | |
) | |
elif torch.version.hip: | |
# Log AMD ROCm HIP version | |
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') | |
else: | |
log.warning('Unknown Torch backend') | |
# Log information about detected GPUs | |
for device in [ | |
torch.cuda.device(i) for i in range(torch.cuda.device_count()) | |
]: | |
log.info( | |
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' | |
) | |
return int(torch.__version__[0]) | |
except Exception as e: | |
# log.warning(f'Could not load torch: {e}') | |
return 0 | |
# report current version of code | |
def check_repo_version(): # pylint: disable=unused-argument | |
if os.path.exists('.release'): | |
with open(os.path.join('./.release'), 'r', encoding='utf8') as file: | |
release= file.read() | |
log.info(f'Version: {release}') | |
else: | |
log.debug('Could not read release...') | |
# execute git command | |
def git(arg: str, folder: str = None, ignore: bool = False): | |
# | |
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master | |
# | |
git_cmd = os.environ.get('GIT', "git") | |
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.') | |
txt = result.stdout.decode(encoding="utf8", errors="ignore") | |
if len(result.stderr) > 0: | |
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") | |
txt = txt.strip() | |
if result.returncode != 0 and not ignore: | |
global errors # pylint: disable=global-statement | |
errors += 1 | |
log.error(f'Error running git: {folder} / {arg}') | |
if 'or stash them' in txt: | |
log.error(f'Local changes detected: check log for details...') | |
log.debug(f'Git output: {txt}') | |
def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False): | |
# arg = arg.replace('>=', '==') | |
if not quiet: | |
log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}') | |
log.debug(f"Running pip: {arg}") | |
if show_stdout: | |
subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ) | |
else: | |
result = subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
txt = result.stdout.decode(encoding="utf8", errors="ignore") | |
if len(result.stderr) > 0: | |
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") | |
txt = txt.strip() | |
if result.returncode != 0 and not ignore: | |
global errors # pylint: disable=global-statement | |
errors += 1 | |
log.error(f'Error running pip: {arg}') | |
log.debug(f'Pip output: {txt}') | |
return txt | |
def installed(package, friendly: str = None): | |
# | |
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master | |
# | |
# Remove brackets and their contents from the line using regular expressions | |
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2 | |
package = re.sub(r'\[.*?\]', '', package) | |
try: | |
if friendly: | |
pkgs = friendly.split() | |
else: | |
pkgs = [ | |
p | |
for p in package.split() | |
if not p.startswith('-') and not p.startswith('=') | |
] | |
pkgs = [ | |
p.split('/')[-1] for p in pkgs | |
] # get only package name if installing from URL | |
for pkg in pkgs: | |
if '>=' in pkg: | |
pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')] | |
elif '==' in pkg: | |
pkg_name, pkg_version = [x.strip() for x in pkg.split('==')] | |
else: | |
pkg_name, pkg_version = pkg.strip(), None | |
spec = pkg_resources.working_set.by_key.get(pkg_name, None) | |
if spec is None: | |
spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None) | |
if spec is None: | |
spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None) | |
if spec is not None: | |
version = pkg_resources.get_distribution(pkg_name).version | |
log.debug(f'Package version found: {pkg_name} {version}') | |
if pkg_version is not None: | |
if '>=' in pkg: | |
ok = version >= pkg_version | |
else: | |
ok = version == pkg_version | |
if not ok: | |
log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}') | |
return False | |
else: | |
log.debug(f'Package version not found: {pkg_name}') | |
return False | |
return True | |
except ModuleNotFoundError: | |
log.debug(f'Package not installed: {pkgs}') | |
return False | |
# install package using pip if not already installed | |
def install( | |
# | |
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master | |
# | |
package, | |
friendly: str = None, | |
ignore: bool = False, | |
reinstall: bool = False, | |
show_stdout: bool = False, | |
): | |
# Remove anything after '#' in the package variable | |
package = package.split('#')[0].strip() | |
if reinstall: | |
global quick_allowed # pylint: disable=global-statement | |
quick_allowed = False | |
if reinstall or not installed(package, friendly): | |
pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout) | |
def process_requirements_line(line, show_stdout: bool = False): | |
# Remove brackets and their contents from the line using regular expressions | |
# e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2 | |
package_name = re.sub(r'\[.*?\]', '', line) | |
install(line, package_name, show_stdout=show_stdout) | |
def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False): | |
if check_no_verify_flag: | |
log.info(f'Verifying modules instalation status from {requirements_file}...') | |
else: | |
log.info(f'Installing modules from {requirements_file}...') | |
with open(requirements_file, 'r', encoding='utf8') as f: | |
# Read lines from the requirements file, strip whitespace, and filter out empty lines, comments, and lines starting with '.' | |
if check_no_verify_flag: | |
lines = [ | |
line.strip() | |
for line in f.readlines() | |
if line.strip() != '' | |
and not line.startswith('#') | |
and line is not None | |
and 'no_verify' not in line | |
] | |
else: | |
lines = [ | |
line.strip() | |
for line in f.readlines() | |
if line.strip() != '' | |
and not line.startswith('#') | |
and line is not None | |
] | |
# Iterate over each line and install the requirements | |
for line in lines: | |
# Check if the line starts with '-r' to include another requirements file | |
if line.startswith('-r'): | |
# Get the path to the included requirements file | |
included_file = line[2:].strip() | |
# Expand the included requirements file recursively | |
install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout) | |
else: | |
process_requirements_line(line, show_stdout=show_stdout) | |
def ensure_base_requirements(): | |
try: | |
import rich # pylint: disable=unused-import | |
except ImportError: | |
install('--upgrade rich', 'rich') | |
def run_cmd(run_cmd): | |
try: | |
subprocess.run(run_cmd, shell=True, check=False, env=os.environ) | |
except subprocess.CalledProcessError as e: | |
print(f'Error occurred while running command: {run_cmd}') | |
print(f'Error: {e}') | |
# check python version | |
def check_python(ignore=True, skip_git=False): | |
# | |
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master | |
# | |
supported_minors = [9, 10] | |
log.info(f'Python {platform.python_version()} on {platform.system()}') | |
if not ( | |
int(sys.version_info.major) == 3 | |
and int(sys.version_info.minor) in supported_minors | |
): | |
log.error( | |
f'Incompatible Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} required 3.{supported_minors}' | |
) | |
if not ignore: | |
sys.exit(1) | |
if not skip_git: | |
git_cmd = os.environ.get('GIT', 'git') | |
if shutil.which(git_cmd) is None: | |
log.error('Git not found') | |
if not ignore: | |
sys.exit(1) | |
else: | |
git_version = git('--version', folder=None, ignore=False) | |
log.debug(f'Git {git_version.replace("git version", "").strip()}') | |
def delete_file(file_path): | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
def write_to_file(file_path, content): | |
try: | |
with open(file_path, 'w') as file: | |
file.write(content) | |
except IOError as e: | |
print(f'Error occurred while writing to file: {file_path}') | |
print(f'Error: {e}') | |
def clear_screen(): | |
# Check the current operating system to execute the correct clear screen command | |
if os.name == 'nt': # If the operating system is Windows | |
os.system('cls') | |
else: # If the operating system is Linux or Mac | |
os.system('clear') | |