import subprocess import os import re import sys import filecmp import logging import shutil import sysconfig import datetime import platform import pkg_resources errors = 0 # Define the 'errors' variable before using it log = logging.getLogger('sd') # setup console and file logging def setup_logging(clean=False): # # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master # from rich.theme import Theme from rich.logging import RichHandler from rich.console import Console from rich.pretty import install as pretty_install from rich.traceback import install as traceback_install console = Console( log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme( { 'traceback.border': 'black', 'traceback.border.syntax_error': 'black', 'inspect.value.border': 'black', } ), ) # logging.getLogger("urllib3").setLevel(logging.ERROR) # logging.getLogger("httpx").setLevel(logging.ERROR) current_datetime = datetime.datetime.now() current_datetime_str = current_datetime.strftime('%Y%m%d-%H%M%S') log_file = os.path.join( os.path.dirname(__file__), f'../logs/setup/kohya_ss_gui_{current_datetime_str}.log', ) # Create directories if they don't exist log_directory = os.path.dirname(log_file) os.makedirs(log_directory, exist_ok=True) level = logging.INFO logging.basicConfig( level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', filename=log_file, filemode='a', encoding='utf-8', force=True, ) log.setLevel( logging.DEBUG ) # log to file is always at level debug for facility `sd` pretty_install(console=console) traceback_install( console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=[], ) rh = RichHandler( show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console, ) rh.set_name(level) while log.hasHandlers() and len(log.handlers) > 0: log.removeHandler(log.handlers[0]) log.addHandler(rh) def configure_accelerate(run_accelerate=False): # # This function was taken and adapted from code written by jstayco # from pathlib import Path def env_var_exists(var_name): return var_name in os.environ and os.environ[var_name] != '' log.info('Configuring accelerate...') source_accelerate_config_file = os.path.join( os.path.dirname(os.path.abspath(__file__)), '..', 'config_files', 'accelerate', 'default_config.yaml', ) if not os.path.exists(source_accelerate_config_file): if run_accelerate: run_cmd('accelerate config') else: log.warning( f'Could not find the accelerate configuration file in {source_accelerate_config_file}. Please configure accelerate manually by runningthe option in the menu.' ) log.debug( f'Source accelerate config location: {source_accelerate_config_file}' ) target_config_location = None log.debug( f"Environment variables: HF_HOME: {os.environ.get('HF_HOME')}, " f"LOCALAPPDATA: {os.environ.get('LOCALAPPDATA')}, " f"USERPROFILE: {os.environ.get('USERPROFILE')}" ) if env_var_exists('HF_HOME'): target_config_location = Path( os.environ['HF_HOME'], 'accelerate', 'default_config.yaml' ) elif env_var_exists('LOCALAPPDATA'): target_config_location = Path( os.environ['LOCALAPPDATA'], 'huggingface', 'accelerate', 'default_config.yaml', ) elif env_var_exists('USERPROFILE'): target_config_location = Path( os.environ['USERPROFILE'], '.cache', 'huggingface', 'accelerate', 'default_config.yaml', ) log.debug(f'Target config location: {target_config_location}') if target_config_location: if not target_config_location.is_file(): target_config_location.parent.mkdir(parents=True, exist_ok=True) log.debug( f'Target accelerate config location: {target_config_location}' ) shutil.copyfile( source_accelerate_config_file, target_config_location ) log.info( f'Copied accelerate config file to: {target_config_location}' ) else: if run_accelerate: run_cmd('accelerate config') else: log.warning( 'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.' ) else: if run_accelerate: run_cmd('accelerate config') else: log.warning( 'Could not automatically configure accelerate. Please manually configure accelerate with the option in the menu or with: accelerate config.' ) def check_torch(): # # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master # # Check for nVidia toolkit or AMD toolkit if shutil.which('nvidia-smi') is not None or os.path.exists( os.path.join( os.environ.get('SystemRoot') or r'C:\Windows', 'System32', 'nvidia-smi.exe', ) ): log.info('nVidia toolkit detected') elif shutil.which('rocminfo') is not None or os.path.exists( '/opt/rocm/bin/rocminfo' ): log.info('AMD toolkit detected') else: log.info('Using CPU-only Torch') try: import torch log.info(f'Torch {torch.__version__}') # Check if CUDA is available if not torch.cuda.is_available(): log.warning('Torch reports CUDA not available') else: if torch.version.cuda: # Log nVidia CUDA and cuDNN versions log.info( f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' ) elif torch.version.hip: # Log AMD ROCm HIP version log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') else: log.warning('Unknown Torch backend') # Log information about detected GPUs for device in [ torch.cuda.device(i) for i in range(torch.cuda.device_count()) ]: log.info( f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' ) return int(torch.__version__[0]) except Exception as e: # log.warning(f'Could not load torch: {e}') return 0 # report current version of code def check_repo_version(): # pylint: disable=unused-argument if os.path.exists('.release'): with open(os.path.join('./.release'), 'r', encoding='utf8') as file: release= file.read() log.info(f'Version: {release}') else: log.debug('Could not read release...') # execute git command def git(arg: str, folder: str = None, ignore: bool = False): # # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master # git_cmd = os.environ.get('GIT', "git") result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.') txt = result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0: txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") txt = txt.strip() if result.returncode != 0 and not ignore: global errors # pylint: disable=global-statement errors += 1 log.error(f'Error running git: {folder} / {arg}') if 'or stash them' in txt: log.error(f'Local changes detected: check log for details...') log.debug(f'Git output: {txt}') def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False): # arg = arg.replace('>=', '==') if not quiet: log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}') log.debug(f"Running pip: {arg}") if show_stdout: subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ) else: result = subprocess.run(f'"{sys.executable}" -m pip {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) txt = result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0: txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") txt = txt.strip() if result.returncode != 0 and not ignore: global errors # pylint: disable=global-statement errors += 1 log.error(f'Error running pip: {arg}') log.debug(f'Pip output: {txt}') return txt def installed(package, friendly: str = None): # # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master # # Remove brackets and their contents from the line using regular expressions # e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2 package = re.sub(r'\[.*?\]', '', package) try: if friendly: pkgs = friendly.split() else: pkgs = [ p for p in package.split() if not p.startswith('-') and not p.startswith('=') ] pkgs = [ p.split('/')[-1] for p in pkgs ] # get only package name if installing from URL for pkg in pkgs: if '>=' in pkg: pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')] elif '==' in pkg: pkg_name, pkg_version = [x.strip() for x in pkg.split('==')] else: pkg_name, pkg_version = pkg.strip(), None spec = pkg_resources.working_set.by_key.get(pkg_name, None) if spec is None: spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None) if spec is None: spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None) if spec is not None: version = pkg_resources.get_distribution(pkg_name).version log.debug(f'Package version found: {pkg_name} {version}') if pkg_version is not None: if '>=' in pkg: ok = version >= pkg_version else: ok = version == pkg_version if not ok: log.warning(f'Package wrong version: {pkg_name} {version} required {pkg_version}') return False else: log.debug(f'Package version not found: {pkg_name}') return False return True except ModuleNotFoundError: log.debug(f'Package not installed: {pkgs}') return False # install package using pip if not already installed def install( # # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master # package, friendly: str = None, ignore: bool = False, reinstall: bool = False, show_stdout: bool = False, ): # Remove anything after '#' in the package variable package = package.split('#')[0].strip() if reinstall: global quick_allowed # pylint: disable=global-statement quick_allowed = False if reinstall or not installed(package, friendly): pip(f'install --upgrade {package}', ignore=ignore, show_stdout=show_stdout) def process_requirements_line(line, show_stdout: bool = False): # Remove brackets and their contents from the line using regular expressions # e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2 package_name = re.sub(r'\[.*?\]', '', line) install(line, package_name, show_stdout=show_stdout) def install_requirements(requirements_file, check_no_verify_flag=False, show_stdout: bool = False): if check_no_verify_flag: log.info(f'Verifying modules instalation status from {requirements_file}...') else: log.info(f'Installing modules from {requirements_file}...') with open(requirements_file, 'r', encoding='utf8') as f: # Read lines from the requirements file, strip whitespace, and filter out empty lines, comments, and lines starting with '.' if check_no_verify_flag: lines = [ line.strip() for line in f.readlines() if line.strip() != '' and not line.startswith('#') and line is not None and 'no_verify' not in line ] else: lines = [ line.strip() for line in f.readlines() if line.strip() != '' and not line.startswith('#') and line is not None ] # Iterate over each line and install the requirements for line in lines: # Check if the line starts with '-r' to include another requirements file if line.startswith('-r'): # Get the path to the included requirements file included_file = line[2:].strip() # Expand the included requirements file recursively install_requirements(included_file, check_no_verify_flag=check_no_verify_flag, show_stdout=show_stdout) else: process_requirements_line(line, show_stdout=show_stdout) def ensure_base_requirements(): try: import rich # pylint: disable=unused-import except ImportError: install('--upgrade rich', 'rich') def run_cmd(run_cmd): try: subprocess.run(run_cmd, shell=True, check=False, env=os.environ) except subprocess.CalledProcessError as e: print(f'Error occurred while running command: {run_cmd}') print(f'Error: {e}') # check python version def check_python(ignore=True, skip_git=False): # # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master # supported_minors = [9, 10] log.info(f'Python {platform.python_version()} on {platform.system()}') if not ( int(sys.version_info.major) == 3 and int(sys.version_info.minor) in supported_minors ): log.error( f'Incompatible Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} required 3.{supported_minors}' ) if not ignore: sys.exit(1) if not skip_git: git_cmd = os.environ.get('GIT', 'git') if shutil.which(git_cmd) is None: log.error('Git not found') if not ignore: sys.exit(1) else: git_version = git('--version', folder=None, ignore=False) log.debug(f'Git {git_version.replace("git version", "").strip()}') def delete_file(file_path): if os.path.exists(file_path): os.remove(file_path) def write_to_file(file_path, content): try: with open(file_path, 'w') as file: file.write(content) except IOError as e: print(f'Error occurred while writing to file: {file_path}') print(f'Error: {e}') def clear_screen(): # Check the current operating system to execute the correct clear screen command if os.name == 'nt': # If the operating system is Windows os.system('cls') else: # If the operating system is Linux or Mac os.system('clear')