File size: 18,123 Bytes
11c2c17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
import argparse
import gc
import math
import os
from types import SimpleNamespace
from typing import Any
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
import open_clip
from library import model_util, sdxl_model_util, train_util
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline

TOKENIZER_PATH = "openai/clip-vit-large-patch14"

DEFAULT_NOISE_OFFSET = 0.0357


# TODO: separate checkpoint for each U-Net/Text Encoder/VAE
def load_target_model(args, accelerator, model_version: str, weight_dtype):
    # load models for each process
    for pi in range(accelerator.state.num_processes):
        if pi == accelerator.state.local_process_index:
            print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")

            (
                load_stable_diffusion_format,
                text_encoder1,
                text_encoder2,
                vae,
                unet,
                logit_scale,
                ckpt_info,
            ) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")

            # work on low-ram device
            if args.lowram:
                text_encoder1.to(accelerator.device)
                text_encoder2.to(accelerator.device)
                unet.to(accelerator.device)
                vae.to(accelerator.device)

            gc.collect()
            torch.cuda.empty_cache()
        accelerator.wait_for_everyone()

    text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])

    return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info


def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
    # only supports StableDiffusion
    name_or_path = args.pretrained_model_name_or_path
    name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
    load_stable_diffusion_format = os.path.isfile(name_or_path)  # determine SD or Diffusers
    assert (
        load_stable_diffusion_format
    ), f"only supports StableDiffusion format for SDXL / SDXLではStableDiffusion形式のみサポートしています: {name_or_path}"

    print(f"load StableDiffusion checkpoint: {name_or_path}")
    (
        text_encoder1,
        text_encoder2,
        vae,
        unet,
        logit_scale,
        ckpt_info,
    ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)

    # VAEを読み込む
    if args.vae is not None:
        vae = model_util.load_vae(args.vae, weight_dtype)
        print("additional VAE loaded")

    return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info


class WrapperTokenizer:
    # open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
    # make open clip tokenizer compatible with HuggingFace tokenizer
    def __init__(self):
        open_clip_tokenizer = open_clip.tokenizer._tokenizer
        self.model_max_length = 77
        self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
        self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
        self.pad_token_id = 0  # 結果から推定している assumption from result

    def __call__(self, *args: Any, **kwds: Any) -> Any:
        return self.tokenize(*args, **kwds)

    def tokenize(self, text, padding=False, truncation=None, max_length=None, return_tensors=None):
        if padding == "max_length":
            # for training
            assert max_length is not None
            assert truncation == True
            assert return_tensors == "pt"
            input_ids = open_clip.tokenize(text, context_length=max_length)
            return SimpleNamespace(**{"input_ids": input_ids})

        # for weighted prompt
        assert isinstance(text, str), f"input must be str: {text}"

        input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0]  # tokenizer returns list

        # find eos
        eos_index = (input_ids == self.eos_token_id).nonzero().max()
        input_ids = input_ids[: eos_index + 1]  # include eos
        return SimpleNamespace(**{"input_ids": input_ids})

    # for Textual Inversion
    # わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?

    def encode(self, text, add_special_tokens=False):
        assert not add_special_tokens
        input_ids = open_clip.tokenizer._tokenizer.encode(text)
        return input_ids

    def add_tokens(self, new_tokens):
        tokens_to_add = []
        for token in new_tokens:
            token = token.lower()
            if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
                tokens_to_add.append(token)

        # open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
        for token in tokens_to_add:
            open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
            open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
            open_clip.tokenizer._tokenizer.vocab_size += 1

            # open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
            # めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
            # set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
            # this is very rough, so it will not work if the specification of open clip tokenizer changes
            open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"

        return len(tokens_to_add)

    def convert_tokens_to_ids(self, tokens):
        input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
        return input_ids

    def __len__(self):
        return open_clip.tokenizer._tokenizer.vocab_size


def load_tokenizers(args: argparse.Namespace):
    print("prepare tokenizers")
    original_path = TOKENIZER_PATH

    tokenizer1: CLIPTokenizer = None
    if args.tokenizer_cache_dir:
        local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
        if os.path.exists(local_tokenizer_path):
            print(f"load tokenizer from cache: {local_tokenizer_path}")
            tokenizer1 = CLIPTokenizer.from_pretrained(local_tokenizer_path)

    if tokenizer1 is None:
        tokenizer1 = CLIPTokenizer.from_pretrained(original_path)

    if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
        print(f"save Tokenizer to cache: {local_tokenizer_path}")
        tokenizer1.save_pretrained(local_tokenizer_path)

    if hasattr(args, "max_token_length") and args.max_token_length is not None:
        print(f"update token length: {args.max_token_length}")

    # tokenizer2 is from open_clip
    # TODO caching
    tokenizer2 = WrapperTokenizer()

    return [tokenizer1, tokenizer2]


def get_hidden_states(
    args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
):
    # input_ids: b,n,77 -> b*n, 77
    b_size = input_ids1.size()[0]
    input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length))  # batch_size*n, 77
    input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length))  # batch_size*n, 77

    # text_encoder1
    enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
    hidden_states1 = enc_out["hidden_states"][11]

    # text_encoder2
    enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
    hidden_states2 = enc_out["hidden_states"][-2]  # penuultimate layer
    pool2 = enc_out["text_embeds"]

    # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
    n_size = 1 if args.max_token_length is None else args.max_token_length // 75
    hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
    hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))

    if args.max_token_length is not None:
        # bs*3, 77, 768 or 1024
        # encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
        states_list = [hidden_states1[:, 0].unsqueeze(1)]  # <BOS>
        for i in range(1, args.max_token_length, tokenizer1.model_max_length):
            states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2])  # <BOS> の後から <EOS> の前まで
        states_list.append(hidden_states1[:, -1].unsqueeze(1))  # <EOS>
        hidden_states1 = torch.cat(states_list, dim=1)

        # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
        states_list = [hidden_states2[:, 0].unsqueeze(1)]  # <BOS>
        for i in range(1, args.max_token_length, tokenizer2.model_max_length):
            chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2]  # <BOS> の後から 最後の前まで
            # this causes an error:
            # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
            # if i > 1:
            #     for j in range(len(chunk)):  # batch_size
            #         if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id:  # 空、つまり <BOS> <EOS> <PAD> ...のパターン
            #             chunk[j, 0] = chunk[j, 1]  # 次の <PAD> の値をコピーする
            states_list.append(chunk)  # <BOS> の後から <EOS> の前まで
        states_list.append(hidden_states2[:, -1].unsqueeze(1))  # <EOS> か <PAD> のどちらか
        hidden_states2 = torch.cat(states_list, dim=1)

        # pool はnの最初のものを使う
        pool2 = pool2[::n_size]

    if weight_dtype is not None:
        # this is required for additional network training
        hidden_states1 = hidden_states1.to(weight_dtype)
        hidden_states2 = hidden_states2.to(weight_dtype)

    return hidden_states1, hidden_states2, pool2


def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
        device=timesteps.device
    )
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


def get_timestep_embedding(x, outdim):
    assert len(x.shape) == 2
    b, dims = x.shape[0], x.shape[1]
    x = torch.flatten(x)
    emb = timestep_embedding(x, outdim)
    emb = torch.reshape(emb, (b, dims * outdim))
    return emb


def get_size_embeddings(orig_size, crop_size, target_size, device):
    emb1 = get_timestep_embedding(orig_size, 256)
    emb2 = get_timestep_embedding(crop_size, 256)
    emb3 = get_timestep_embedding(target_size, 256)
    vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
    return vector


def save_sd_model_on_train_end(
    args: argparse.Namespace,
    src_path: str,
    save_stable_diffusion_format: bool,
    use_safetensors: bool,
    save_dtype: torch.dtype,
    epoch: int,
    global_step: int,
    text_encoder1,
    text_encoder2,
    unet,
    vae,
    logit_scale,
    ckpt_info,
):
    def sd_saver(ckpt_file, epoch_no, global_step):
        sdxl_model_util.save_stable_diffusion_checkpoint(
            ckpt_file,
            text_encoder1,
            text_encoder2,
            unet,
            epoch_no,
            global_step,
            ckpt_info,
            vae,
            logit_scale,
            save_dtype,
        )

    def diffusers_saver(out_dir):
        raise NotImplementedError("diffusers_saver is not implemented")

    train_util.save_sd_model_on_train_end_common(
        args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
    )


# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_sd_model_on_epoch_end_or_stepwise(
    args: argparse.Namespace,
    on_epoch_end: bool,
    accelerator,
    src_path,
    save_stable_diffusion_format: bool,
    use_safetensors: bool,
    save_dtype: torch.dtype,
    epoch: int,
    num_train_epochs: int,
    global_step: int,
    text_encoder1,
    text_encoder2,
    unet,
    vae,
    logit_scale,
    ckpt_info,
):
    def sd_saver(ckpt_file, epoch_no, global_step):
        sdxl_model_util.save_stable_diffusion_checkpoint(
            ckpt_file,
            text_encoder1,
            text_encoder2,
            unet,
            epoch_no,
            global_step,
            ckpt_info,
            vae,
            logit_scale,
            save_dtype,
        )

    def diffusers_saver(out_dir):
        raise NotImplementedError("diffusers_saver is not implemented")

    train_util.save_sd_model_on_epoch_end_or_stepwise_common(
        args,
        on_epoch_end,
        accelerator,
        save_stable_diffusion_format,
        use_safetensors,
        epoch,
        num_train_epochs,
        global_step,
        sd_saver,
        diffusers_saver,
    )


# TextEncoderの出力をキャッシュする
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
    print("caching text encoder outputs")

    tokenizer1, tokenizer2 = tokenizers
    text_encoder1, text_encoder2 = text_encoders
    text_encoder1.to(accelerator.device)
    text_encoder2.to(accelerator.device)
    if weight_dtype is not None:
        text_encoder1.to(dtype=weight_dtype)
        text_encoder2.to(dtype=weight_dtype)

    text_encoder1_cache = {}
    text_encoder2_cache = {}
    for batch in tqdm(dataset):
        input_ids1_batch = batch["input_ids"].to(accelerator.device)
        input_ids2_batch = batch["input_ids2"].to(accelerator.device)

        # split batch to avoid OOM
        # TODO specify batch size by args
        for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
            # remove input_ids already in cache
            input_id1_cache_key = tuple(input_id1.flatten().tolist())
            input_id2_cache_key = tuple(input_id2.flatten().tolist())
            if input_id1_cache_key in text_encoder1_cache:
                assert input_id2_cache_key in text_encoder2_cache
                continue

            with torch.no_grad():
                encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
                    args,
                    input_id1,
                    input_id2,
                    tokenizer1,
                    tokenizer2,
                    text_encoder1,
                    text_encoder2,
                    None if not args.full_fp16 else weight_dtype,
                )
            encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0)  # n*75+2,768
            encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0)  # n*75+2,1280
            pool2 = pool2.detach().to("cpu").squeeze(0)  # 1280
            text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1
            text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2)
    return text_encoder1_cache, text_encoder2_cache


def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
    parser.add_argument(
        "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
    )


def verify_sdxl_training_args(args: argparse.Namespace):
    assert (
        not args.v2 and not args.v_parameterization
    ), "v2 or v_parameterization cannot be enabled in SDXL training / SDXL学習ではv2とv_parameterizationを有効にすることはできません"
    if args.clip_skip is not None:
        print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")

    if args.multires_noise_iterations:
        print(
            f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
        )
    else:
        if args.noise_offset is None:
            args.noise_offset = DEFAULT_NOISE_OFFSET
        elif args.noise_offset != DEFAULT_NOISE_OFFSET:
            print(
                f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
            )
        print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")

    assert (
        not hasattr(args, "weighted_captions") or not args.weighted_captions
    ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"


def sample_images(*args, **kwargs):
    return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)