File size: 4,313 Bytes
0163a2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import sys
import launch
import platform
import os
import shutil
import site
import glob
import re

dirname = os.path.dirname(__file__)
repo_dir = os.path.join(dirname, "kohya_ss")


def prepare_environment():
    torch_command = os.environ.get(
        "TORCH_COMMAND",
        "pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118",
    )
    sd_scripts_repo = os.environ.get("SD_SCRIPTS_REPO", "https://github.com/kohya-ss/sd-scripts.git")
    sd_scripts_branch = os.environ.get("SD_SCRIPTS_BRANCH", "main")
    requirements_file = os.environ.get("REQS_FILE", "requirements.txt")

    sys.argv, skip_install = launch.extract_arg(sys.argv, "--skip-install")
    sys.argv, disable_strict_version = launch.extract_arg(
        sys.argv, "--disable-strict-version"
    )
    sys.argv, skip_torch_cuda_test = launch.extract_arg(
        sys.argv, "--skip-torch-cuda-test"
    )
    sys.argv, skip_checkout_repo = launch.extract_arg(sys.argv, "--skip-checkout-repo")
    sys.argv, update = launch.extract_arg(sys.argv, "--update")
    sys.argv, reinstall_xformers = launch.extract_arg(sys.argv, "--reinstall-xformers")
    sys.argv, reinstall_torch = launch.extract_arg(sys.argv, "--reinstall-torch")
    xformers = "--xformers" in sys.argv
    ngrok = "--ngrok" in sys.argv

    if skip_install:
        return


    if (
        reinstall_torch
        or not launch.is_installed("torch")
        or not launch.is_installed("torchvision")
    ) and not disable_strict_version:
        launch.run(
            f'"{launch.python}" -m {torch_command}',
            "Installing torch and torchvision",
            "Couldn't install torch",
        )

    if not skip_torch_cuda_test:
        launch.run_python(
            "import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'"
        )

    if (not launch.is_installed("xformers") or reinstall_xformers) and xformers:
        launch.run_pip("install xformers --pre", "xformers")

    if update and os.path.exists(repo_dir):
        launch.run(f'cd "{repo_dir}" && {launch.git} fetch --prune')
        launch.run(f'cd "{repo_dir}" && {launch.git} reset --hard origin/main')
    elif not os.path.exists(repo_dir):
        launch.run(
            f'{launch.git} clone {sd_scripts_repo} "{repo_dir}"'
        )

    if not skip_checkout_repo:
        launch.run(f'cd "{repo_dir}" && {launch.git} checkout {sd_scripts_branch}')

    if not launch.is_installed("gradio"):
        launch.run_pip("install gradio==3.16.2", "gradio")

    if not launch.is_installed("pyngrok") and ngrok:
        launch.run_pip("install pyngrok", "ngrok")

    if platform.system() == "Linux":
        if not launch.is_installed("triton"):
            launch.run_pip("install triton", "triton")

    if disable_strict_version:
        with open(os.path.join(repo_dir, requirements_file), "r") as f:
            txt = f.read()
            requirements = [
                re.split("==|<|>", a)[0]
                for a in txt.split("\n")
                if (not a.startswith("#") and a != ".")
            ]
            requirements = " ".join(requirements)
            launch.run_pip(
                f'install "{requirements}" "{repo_dir}"',
                "requirements for kohya sd-scripts",
            )
    else:
        launch.run(
            f'cd "{repo_dir}" && "{launch.python}" -m pip install -r requirements.txt',
            desc=f"Installing requirements for kohya sd-scripts",
            errdesc=f"Couldn't install requirements for kohya sd-scripts",
        )

    if platform.system() == "Windows":
        for file in glob.glob(os.path.join(repo_dir, "bitsandbytes_windows", "*")):
            filename = os.path.basename(file)
            for dir in site.getsitepackages():
                outfile = (
                    os.path.join(dir, "bitsandbytes", "cuda_setup", filename)
                    if filename == "main.py"
                    else os.path.join(dir, "bitsandbytes", filename)
                )
                if not os.path.exists(os.path.dirname(outfile)):
                    continue
                shutil.copy(file, outfile)


if __name__ == "__main__":
    prepare_environment()