File size: 7,302 Bytes
11c2c17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import subprocess
import os
import filecmp
import logging
import shutil
import sysconfig
import setup_common

errors = 0  # Define the 'errors' variable before using it
log = logging.getLogger('sd')

# ANSI escape code for yellow color
YELLOW = '\033[93m'
RESET_COLOR = '\033[0m'


def cudann_install():
    cudnn_src = os.path.join(
        os.path.dirname(os.path.realpath(__file__)), '..\cudnn_windows'
    )
    cudnn_dest = os.path.join(sysconfig.get_paths()['purelib'], 'torch', 'lib')

    log.info(f'Checking for CUDNN files in {cudnn_dest}...')
    if os.path.exists(cudnn_src):
        if os.path.exists(cudnn_dest):
            # check for different files
            filecmp.clear_cache()
            for file in os.listdir(cudnn_src):
                src_file = os.path.join(cudnn_src, file)
                dest_file = os.path.join(cudnn_dest, file)
                # if dest file exists, check if it's different
                if os.path.exists(dest_file):
                    if not filecmp.cmp(src_file, dest_file, shallow=False):
                        shutil.copy2(src_file, cudnn_dest)
                else:
                    shutil.copy2(src_file, cudnn_dest)
            log.info('Copied CUDNN 8.6 files to destination')
        else:
            log.warning(f'Destination directory {cudnn_dest} does not exist')
    else:
        log.error(f'Installation Failed: "{cudnn_src}" could not be found.')


def sync_bits_and_bytes_files():
    import filecmp

    """
    Check for "different" bitsandbytes Files and copy only if necessary.
    This function is specific for Windows OS.
    """

    # Only execute on Windows
    if os.name != 'nt':
        print('This function is only applicable to Windows OS.')
        return

    try:
        log.info(f'Copying bitsandbytes files...')
        # Define source and destination directories
        source_dir = os.path.join(os.getcwd(), 'bitsandbytes_windows')

        dest_dir_base = os.path.join(
            sysconfig.get_paths()['purelib'], 'bitsandbytes'
        )

        # Clear file comparison cache
        filecmp.clear_cache()

        # Iterate over each file in source directory
        for file in os.listdir(source_dir):
            source_file_path = os.path.join(source_dir, file)

            # Decide the destination directory based on file name
            if file in ('main.py', 'paths.py'):
                dest_dir = os.path.join(dest_dir_base, 'cuda_setup')
            else:
                dest_dir = dest_dir_base

            dest_file_path = os.path.join(dest_dir, file)

            # Compare the source file with the destination file
            if os.path.exists(dest_file_path) and filecmp.cmp(
                source_file_path, dest_file_path
            ):
                log.debug(
                    f'Skipping {source_file_path} as it already exists in {dest_dir}'
                )
            else:
                # Copy file from source to destination, maintaining original file's metadata
                log.debug(f'Copy {source_file_path} to {dest_dir}')
                shutil.copy2(source_file_path, dest_dir)

    except FileNotFoundError as fnf_error:
        log.error(f'File not found error: {fnf_error}')
    except PermissionError as perm_error:
        log.error(f'Permission error: {perm_error}')
    except Exception as e:
        log.error(f'An unexpected error occurred: {e}')


def install_kohya_ss_torch1():
    setup_common.check_repo_version()
    setup_common.check_python()

    # Upgrade pip if needed
    setup_common.install('--upgrade pip')

    if setup_common.check_torch() == 2:
        input(
            f'{YELLOW}\nTorch 2 is already installed in the venv. To install Torch 1 delete the venv and re-run setup.bat\n\nHit enter to continue...{RESET_COLOR}'
        )
        return

    # setup_common.install(
    #     'torch==1.12.1+cu116 torchvision==0.13.1+cu116 --index-url https://download.pytorch.org/whl/cu116',
    #     'torch torchvision'
    # )
    # setup_common.install(
    #     'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl -U -I --no-deps',
    #     'xformers-0.0.14'
    # )
    setup_common.install_requirements('requirements_windows_torch1.txt', check_no_verify_flag=False)
    sync_bits_and_bytes_files()
    setup_common.configure_accelerate(run_accelerate=True)
    # run_cmd(f'accelerate config')


def install_kohya_ss_torch2():
    setup_common.check_repo_version()
    setup_common.check_python()

    # Upgrade pip if needed
    setup_common.install('--upgrade pip')

    if setup_common.check_torch() == 1:
        input(
            f'{YELLOW}\nTorch 1 is already installed in the venv. To install Torch 2 delete the venv and re-run setup.bat\n\nHit any key to acknowledge.{RESET_COLOR}'
        )
        return

    # setup_common.install(
    #     'torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118',
    #     'torch torchvision'
    # )
    setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=False)
    # install('https://huggingface.co/r4ziel/xformers_pre_built/resolve/main/triton-2.0.0-cp310-cp310-win_amd64.whl', 'triton', reinstall=reinstall)
    sync_bits_and_bytes_files()
    setup_common.configure_accelerate(run_accelerate=True)
    # run_cmd(f'accelerate config')


def main_menu():
    setup_common.clear_screen()
    while True:
        print('\nKohya_ss GUI setup menu:\n')
        print('1. Install kohya_ss gui')
        print('2. (Optional) Install cudann files')
        print('3. (Optional) Install bitsandbytes-windows')
        print('4. (Optional) Manually configure accelerate')
        print('5. (Optional) Start Kohya_ss GUI in browser')
        print('6. Quit')

        choice = input('\nEnter your choice: ')
        print('')

        if choice == '1':
            while True:
                print('1. Torch 1 (legacy)')
                print('2. Torch 2 (recommended)')
                print('3. Cancel')
                choice_torch = input('\nEnter your choice: ')
                print('')

                if choice_torch == '1':
                    install_kohya_ss_torch1()
                    break
                elif choice_torch == '2':
                    install_kohya_ss_torch2()
                    break
                elif choice_torch == '3':
                    break
                else:
                    print('Invalid choice. Please enter a number between 1-3.')
        elif choice == '2':
            cudann_install()
        elif choice == '3':
            setup_common.install('--upgrade bitsandbytes-windows', reinstall=True)
        elif choice == '4':
            setup_common.run_cmd('accelerate config')
        elif choice == '5':
            subprocess.Popen('start cmd /k .\gui.bat --inbrowser', shell=True) # /k keep the terminal open on quit. /c would close the terminal instead
        elif choice == '6':
            print('Quitting the program.')
            break
        else:
            print('Invalid choice. Please enter a number between 1-5.')


if __name__ == '__main__':
    setup_common.ensure_base_requirements()
    setup_common.setup_logging()
    main_menu()