RohitGandikota commited on
Commit
66982a6
·
1 Parent(s): 6916a7c

pushing training code

Browse files
trainscripts/textsliders/demo_train.py DELETED
@@ -1,434 +0,0 @@
1
- # ref:
2
- # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
3
- # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
4
-
5
- from typing import List, Optional
6
- import argparse
7
- import ast
8
- from pathlib import Path
9
- import gc
10
-
11
- import torch
12
- from tqdm import tqdm
13
-
14
-
15
- from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
16
- import train_util
17
- import model_util
18
- import prompt_util
19
- from prompt_util import (
20
- PromptEmbedsCache,
21
- PromptEmbedsPair,
22
- PromptSettings,
23
- PromptEmbedsXL,
24
- )
25
- import debug_util
26
- import config_util
27
- from config_util import RootConfig
28
-
29
- import wandb
30
-
31
- NUM_IMAGES_PER_PROMPT = 1
32
-
33
-
34
- def flush():
35
- torch.cuda.empty_cache()
36
- gc.collect()
37
-
38
-
39
- def train(
40
- config: RootConfig,
41
- prompts: list[PromptSettings],
42
- device,
43
- ):
44
- metadata = {
45
- "prompts": ",".join([prompt.json() for prompt in prompts]),
46
- "config": config.json(),
47
- }
48
- save_path = Path(config.save.path)
49
-
50
- modules = DEFAULT_TARGET_REPLACE
51
- if config.network.type == "c3lier":
52
- modules += UNET_TARGET_REPLACE_MODULE_CONV
53
-
54
- if config.logging.verbose:
55
- print(metadata)
56
-
57
- if config.logging.use_wandb:
58
- wandb.init(project=f"LECO_{config.save.name}", config=metadata)
59
-
60
- weight_dtype = config_util.parse_precision(config.train.precision)
61
- save_weight_dtype = config_util.parse_precision(config.train.precision)
62
-
63
- (
64
- tokenizers,
65
- text_encoders,
66
- unet,
67
- noise_scheduler,
68
- ) = model_util.load_models_xl(
69
- config.pretrained_model.name_or_path,
70
- scheduler_name=config.train.noise_scheduler,
71
- )
72
-
73
- for text_encoder in text_encoders:
74
- text_encoder.to(device, dtype=weight_dtype)
75
- text_encoder.requires_grad_(False)
76
- text_encoder.eval()
77
-
78
- unet.to(device, dtype=weight_dtype)
79
- if config.other.use_xformers:
80
- unet.enable_xformers_memory_efficient_attention()
81
- unet.requires_grad_(False)
82
- unet.eval()
83
-
84
- network = LoRANetwork(
85
- unet,
86
- rank=config.network.rank,
87
- multiplier=1.0,
88
- alpha=config.network.alpha,
89
- train_method=config.network.training_method,
90
- ).to(device, dtype=weight_dtype)
91
-
92
- optimizer_module = train_util.get_optimizer(config.train.optimizer)
93
- #optimizer_args
94
- optimizer_kwargs = {}
95
- if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
96
- for arg in config.train.optimizer_args.split(" "):
97
- key, value = arg.split("=")
98
- value = ast.literal_eval(value)
99
- optimizer_kwargs[key] = value
100
-
101
- optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
102
- lr_scheduler = train_util.get_lr_scheduler(
103
- config.train.lr_scheduler,
104
- optimizer,
105
- max_iterations=config.train.iterations,
106
- lr_min=config.train.lr / 100,
107
- )
108
- criteria = torch.nn.MSELoss()
109
-
110
- print("Prompts")
111
- for settings in prompts:
112
- print(settings)
113
-
114
- # debug
115
- debug_util.check_requires_grad(network)
116
- debug_util.check_training_mode(network)
117
-
118
- cache = PromptEmbedsCache()
119
- prompt_pairs: list[PromptEmbedsPair] = []
120
-
121
- with torch.no_grad():
122
- for settings in prompts:
123
- print(settings)
124
- for prompt in [
125
- settings.target,
126
- settings.positive,
127
- settings.neutral,
128
- settings.unconditional,
129
- ]:
130
- if cache[prompt] == None:
131
- tex_embs, pool_embs = train_util.encode_prompts_xl(
132
- tokenizers,
133
- text_encoders,
134
- [prompt],
135
- num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
136
- )
137
- cache[prompt] = PromptEmbedsXL(
138
- tex_embs,
139
- pool_embs
140
- )
141
-
142
- prompt_pairs.append(
143
- PromptEmbedsPair(
144
- criteria,
145
- cache[settings.target],
146
- cache[settings.positive],
147
- cache[settings.unconditional],
148
- cache[settings.neutral],
149
- settings,
150
- )
151
- )
152
-
153
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
154
- del tokenizer, text_encoder
155
-
156
- flush()
157
-
158
- pbar = tqdm(range(config.train.iterations))
159
-
160
- loss = None
161
-
162
- for i in pbar:
163
- with torch.no_grad():
164
- noise_scheduler.set_timesteps(
165
- config.train.max_denoising_steps, device=device
166
- )
167
-
168
- optimizer.zero_grad()
169
-
170
- prompt_pair: PromptEmbedsPair = prompt_pairs[
171
- torch.randint(0, len(prompt_pairs), (1,)).item()
172
- ]
173
-
174
- # 1 ~ 49 からランダム
175
- timesteps_to = torch.randint(
176
- 1, config.train.max_denoising_steps, (1,)
177
- ).item()
178
-
179
- height, width = prompt_pair.resolution, prompt_pair.resolution
180
- if prompt_pair.dynamic_resolution:
181
- height, width = train_util.get_random_resolution_in_bucket(
182
- prompt_pair.resolution
183
- )
184
-
185
- if config.logging.verbose:
186
- print("gudance_scale:", prompt_pair.guidance_scale)
187
- print("resolution:", prompt_pair.resolution)
188
- print("dynamic_resolution:", prompt_pair.dynamic_resolution)
189
- if prompt_pair.dynamic_resolution:
190
- print("bucketed resolution:", (height, width))
191
- print("batch_size:", prompt_pair.batch_size)
192
- print("dynamic_crops:", prompt_pair.dynamic_crops)
193
-
194
- latents = train_util.get_initial_latents(
195
- noise_scheduler, prompt_pair.batch_size, height, width, 1
196
- ).to(device, dtype=weight_dtype)
197
-
198
- add_time_ids = train_util.get_add_time_ids(
199
- height,
200
- width,
201
- dynamic_crops=prompt_pair.dynamic_crops,
202
- dtype=weight_dtype,
203
- ).to(device, dtype=weight_dtype)
204
-
205
- with network:
206
- # ちょっとデノイズされれたものが返る
207
- denoised_latents = train_util.diffusion_xl(
208
- unet,
209
- noise_scheduler,
210
- latents, # 単純なノイズのlatentsを渡す
211
- text_embeddings=train_util.concat_embeddings(
212
- prompt_pair.unconditional.text_embeds,
213
- prompt_pair.target.text_embeds,
214
- prompt_pair.batch_size,
215
- ),
216
- add_text_embeddings=train_util.concat_embeddings(
217
- prompt_pair.unconditional.pooled_embeds,
218
- prompt_pair.target.pooled_embeds,
219
- prompt_pair.batch_size,
220
- ),
221
- add_time_ids=train_util.concat_embeddings(
222
- add_time_ids, add_time_ids, prompt_pair.batch_size
223
- ),
224
- start_timesteps=0,
225
- total_timesteps=timesteps_to,
226
- guidance_scale=3,
227
- )
228
-
229
- noise_scheduler.set_timesteps(1000)
230
-
231
- current_timestep = noise_scheduler.timesteps[
232
- int(timesteps_to * 1000 / config.train.max_denoising_steps)
233
- ]
234
-
235
- # with network: の外では空のLoRAのみが有効になる
236
- positive_latents = train_util.predict_noise_xl(
237
- unet,
238
- noise_scheduler,
239
- current_timestep,
240
- denoised_latents,
241
- text_embeddings=train_util.concat_embeddings(
242
- prompt_pair.unconditional.text_embeds,
243
- prompt_pair.positive.text_embeds,
244
- prompt_pair.batch_size,
245
- ),
246
- add_text_embeddings=train_util.concat_embeddings(
247
- prompt_pair.unconditional.pooled_embeds,
248
- prompt_pair.positive.pooled_embeds,
249
- prompt_pair.batch_size,
250
- ),
251
- add_time_ids=train_util.concat_embeddings(
252
- add_time_ids, add_time_ids, prompt_pair.batch_size
253
- ),
254
- guidance_scale=1,
255
- ).to(device, dtype=weight_dtype)
256
- neutral_latents = train_util.predict_noise_xl(
257
- unet,
258
- noise_scheduler,
259
- current_timestep,
260
- denoised_latents,
261
- text_embeddings=train_util.concat_embeddings(
262
- prompt_pair.unconditional.text_embeds,
263
- prompt_pair.neutral.text_embeds,
264
- prompt_pair.batch_size,
265
- ),
266
- add_text_embeddings=train_util.concat_embeddings(
267
- prompt_pair.unconditional.pooled_embeds,
268
- prompt_pair.neutral.pooled_embeds,
269
- prompt_pair.batch_size,
270
- ),
271
- add_time_ids=train_util.concat_embeddings(
272
- add_time_ids, add_time_ids, prompt_pair.batch_size
273
- ),
274
- guidance_scale=1,
275
- ).to(device, dtype=weight_dtype)
276
- unconditional_latents = train_util.predict_noise_xl(
277
- unet,
278
- noise_scheduler,
279
- current_timestep,
280
- denoised_latents,
281
- text_embeddings=train_util.concat_embeddings(
282
- prompt_pair.unconditional.text_embeds,
283
- prompt_pair.unconditional.text_embeds,
284
- prompt_pair.batch_size,
285
- ),
286
- add_text_embeddings=train_util.concat_embeddings(
287
- prompt_pair.unconditional.pooled_embeds,
288
- prompt_pair.unconditional.pooled_embeds,
289
- prompt_pair.batch_size,
290
- ),
291
- add_time_ids=train_util.concat_embeddings(
292
- add_time_ids, add_time_ids, prompt_pair.batch_size
293
- ),
294
- guidance_scale=1,
295
- ).to(device, dtype=weight_dtype)
296
-
297
- if config.logging.verbose:
298
- print("positive_latents:", positive_latents[0, 0, :5, :5])
299
- print("neutral_latents:", neutral_latents[0, 0, :5, :5])
300
- print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
301
-
302
- with network:
303
- target_latents = train_util.predict_noise_xl(
304
- unet,
305
- noise_scheduler,
306
- current_timestep,
307
- denoised_latents,
308
- text_embeddings=train_util.concat_embeddings(
309
- prompt_pair.unconditional.text_embeds,
310
- prompt_pair.target.text_embeds,
311
- prompt_pair.batch_size,
312
- ),
313
- add_text_embeddings=train_util.concat_embeddings(
314
- prompt_pair.unconditional.pooled_embeds,
315
- prompt_pair.target.pooled_embeds,
316
- prompt_pair.batch_size,
317
- ),
318
- add_time_ids=train_util.concat_embeddings(
319
- add_time_ids, add_time_ids, prompt_pair.batch_size
320
- ),
321
- guidance_scale=1,
322
- ).to(device, dtype=weight_dtype)
323
-
324
- if config.logging.verbose:
325
- print("target_latents:", target_latents[0, 0, :5, :5])
326
-
327
- positive_latents.requires_grad = False
328
- neutral_latents.requires_grad = False
329
- unconditional_latents.requires_grad = False
330
-
331
- loss = prompt_pair.loss(
332
- target_latents=target_latents,
333
- positive_latents=positive_latents,
334
- neutral_latents=neutral_latents,
335
- unconditional_latents=unconditional_latents,
336
- )
337
-
338
- # 1000倍しないとずっと0.000...になってしまって見た目的に面白くない
339
- pbar.set_description(f"Loss*1k: {loss.item()*1000:.4f}")
340
- if config.logging.use_wandb:
341
- wandb.log(
342
- {"loss": loss, "iteration": i, "lr": lr_scheduler.get_last_lr()[0]}
343
- )
344
-
345
- loss.backward()
346
- optimizer.step()
347
- lr_scheduler.step()
348
-
349
- del (
350
- positive_latents,
351
- neutral_latents,
352
- unconditional_latents,
353
- target_latents,
354
- latents,
355
- )
356
- flush()
357
-
358
- # if (
359
- # i % config.save.per_steps == 0
360
- # and i != 0
361
- # and i != config.train.iterations - 1
362
- # ):
363
- # print("Saving...")
364
- # save_path.mkdir(parents=True, exist_ok=True)
365
- # network.save_weights(
366
- # save_path / f"{config.save.name}_{i}steps.pt",
367
- # dtype=save_weight_dtype,
368
- # )
369
-
370
- print("Saving...")
371
- save_path.mkdir(parents=True, exist_ok=True)
372
- network.save_weights(
373
- save_path / f"{config.save.name}",
374
- dtype=save_weight_dtype,
375
- )
376
-
377
- del (
378
- unet,
379
- noise_scheduler,
380
- loss,
381
- optimizer,
382
- network,
383
- )
384
-
385
- flush()
386
-
387
- print("Done.")
388
-
389
-
390
- # def main(args):
391
- # config_file = args.config_file
392
-
393
- # config = config_util.load_config_from_yaml(config_file)
394
- # if args.name is not None:
395
- # config.save.name = args.name
396
- # attributes = []
397
- # if args.attributes is not None:
398
- # attributes = args.attributes.split(',')
399
- # attributes = [a.strip() for a in attributes]
400
-
401
- # config.network.alpha = args.alpha
402
- # config.network.rank = args.rank
403
- # config.save.name += f'_alpha{args.alpha}'
404
- # config.save.name += f'_rank{config.network.rank }'
405
- # config.save.name += f'_{config.network.training_method}'
406
- # config.save.path += f'/{config.save.name}'
407
-
408
- # prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
409
-
410
- # device = torch.device(f"cuda:{args.device}")
411
- # train(config, prompts, device)
412
-
413
-
414
- def train_xl(target, postive, negative, lr, iterations, config_file, rank, device, attributes,save_name):
415
-
416
- config = config_util.load_config_from_yaml(config_file)
417
- randn = torch.randint(1, 10000000, (1,)).item()
418
- config.save.name = save_name
419
-
420
- config.train.lr = float(lr)
421
- config.train.iterations=int(iterations)
422
-
423
- if attributes is not None:
424
- attributes = attributes.split(',')
425
- attributes = [a.strip() for a in attributes]
426
- config.network.alpha = 1.0
427
- config.network.rank = rank
428
-
429
- config.save.path += f'/{config.save.name}'
430
-
431
- prompts = prompt_util.load_prompts_from_yaml(path=config.prompts_file, target=target, positive=positive, negative=negative, attributes=attributes)
432
-
433
- device = torch.device(f"cuda:{device}")
434
- train(config, prompts, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trainscripts/textsliders/demotrain.py CHANGED
@@ -12,7 +12,7 @@ import torch
12
  from tqdm import tqdm
13
 
14
 
15
- from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
16
  import train_util
17
  import model_util
18
  import prompt_util
 
12
  from tqdm import tqdm
13
 
14
 
15
+ from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
16
  import train_util
17
  import model_util
18
  import prompt_util