jbilcke-hf HF Staff commited on
Commit
e8c26e7
·
1 Parent(s): 66c6879
finetrainers/finetrainers__lib__trainer.py DELETED
@@ -1,1235 +0,0 @@
1
- import json
2
- import logging
3
- import math
4
- import os
5
- import gc
6
- import random
7
- from datetime import datetime, timedelta
8
- from pathlib import Path
9
- from typing import Any, Dict, List
10
-
11
- import diffusers
12
- import torch
13
- import torch.backends
14
- import transformers
15
- import wandb
16
- from accelerate import Accelerator, DistributedType
17
- from accelerate.logging import get_logger
18
- from accelerate.utils import (
19
- DistributedDataParallelKwargs,
20
- InitProcessGroupKwargs,
21
- ProjectConfiguration,
22
- gather_object,
23
- set_seed,
24
- )
25
- from diffusers import DiffusionPipeline
26
- from diffusers.configuration_utils import FrozenDict
27
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
28
- from diffusers.optimization import get_scheduler
29
- from diffusers.training_utils import cast_training_params
30
- from diffusers.utils import export_to_video, load_image, load_video
31
- from huggingface_hub import create_repo, upload_folder
32
- from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
33
- from tqdm import tqdm
34
-
35
- from .args import Args, validate_args
36
- from .constants import (
37
- FINETRAINERS_LOG_LEVEL,
38
- PRECOMPUTED_CONDITIONS_DIR_NAME,
39
- PRECOMPUTED_DIR_NAME,
40
- PRECOMPUTED_LATENTS_DIR_NAME,
41
- )
42
- from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset
43
- from .hooks import apply_layerwise_upcasting
44
- from .models import get_config_from_model_name
45
- from .patches import perform_peft_patches
46
- from .state import State
47
- from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from
48
- from .utils.data_utils import should_perform_precomputation
49
- from .utils.diffusion_utils import (
50
- get_scheduler_alphas,
51
- get_scheduler_sigmas,
52
- prepare_loss_weights,
53
- prepare_sigmas,
54
- prepare_target,
55
- )
56
- from .utils.file_utils import string_to_filename
57
- from .utils.hub_utils import save_model_card
58
- from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous
59
- from .utils.model_utils import resolve_vae_cls_from_ckpt_path
60
- from .utils.optimizer_utils import get_optimizer
61
- from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model
62
-
63
-
64
- logger = get_logger("finetrainers")
65
- logger.setLevel(FINETRAINERS_LOG_LEVEL)
66
-
67
-
68
- class Trainer:
69
- def __init__(self, args: Args) -> None:
70
- validate_args(args)
71
-
72
- self.args = args
73
- self.args.seed = self.args.seed or datetime.now().year
74
- self.state = State()
75
-
76
- # Tokenizers
77
- self.tokenizer = None
78
- self.tokenizer_2 = None
79
- self.tokenizer_3 = None
80
-
81
- # Text encoders
82
- self.text_encoder = None
83
- self.text_encoder_2 = None
84
- self.text_encoder_3 = None
85
-
86
- # Denoisers
87
- self.transformer = None
88
- self.unet = None
89
-
90
- # Autoencoders
91
- self.vae = None
92
-
93
- # Scheduler
94
- self.scheduler = None
95
-
96
- self.transformer_config = None
97
- self.vae_config = None
98
-
99
- self._init_distributed()
100
- self._init_logging()
101
- self._init_directories_and_repositories()
102
- self._init_config_options()
103
-
104
- # Peform any patches needed for training
105
- if len(self.args.layerwise_upcasting_modules) > 0:
106
- perform_peft_patches()
107
- # TODO(aryan): handle text encoders
108
- # if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]):
109
- # perform_text_encoder_patches()
110
-
111
- self.state.model_name = self.args.model_name
112
- self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type)
113
-
114
- def prepare_dataset(self) -> None:
115
- # TODO(aryan): Make a background process for fetching
116
- logger.info("Initializing dataset and dataloader")
117
-
118
- self.dataset = ImageOrVideoDatasetWithResizing(
119
- data_root=self.args.data_root,
120
- caption_column=self.args.caption_column,
121
- video_column=self.args.video_column,
122
- resolution_buckets=self.args.video_resolution_buckets,
123
- dataset_file=self.args.dataset_file,
124
- id_token=self.args.id_token,
125
- remove_llm_prefixes=self.args.remove_common_llm_caption_prefixes,
126
- )
127
- self.dataloader = torch.utils.data.DataLoader(
128
- self.dataset,
129
- batch_size=1,
130
- sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True),
131
- collate_fn=self.model_config.get("collate_fn"),
132
- num_workers=self.args.dataloader_num_workers,
133
- pin_memory=self.args.pin_memory,
134
- )
135
-
136
- def prepare_models(self) -> None:
137
- logger.info("Initializing models")
138
-
139
- load_components_kwargs = self._get_load_components_kwargs()
140
- condition_components, latent_components, diffusion_components = {}, {}, {}
141
- if not self.args.precompute_conditions:
142
- # To download the model files first on the main process (if not already present)
143
- # and then load the cached files afterward from the other processes.
144
- with self.state.accelerator.main_process_first():
145
- condition_components = self.model_config["load_condition_models"](**load_components_kwargs)
146
- latent_components = self.model_config["load_latent_models"](**load_components_kwargs)
147
- diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs)
148
-
149
- components = {}
150
- components.update(condition_components)
151
- components.update(latent_components)
152
- components.update(diffusion_components)
153
- self._set_components(components)
154
-
155
- if self.vae is not None:
156
- if self.args.enable_slicing:
157
- self.vae.enable_slicing()
158
- if self.args.enable_tiling:
159
- self.vae.enable_tiling()
160
-
161
- def prepare_precomputations(self) -> None:
162
- if not self.args.precompute_conditions:
163
- return
164
-
165
- logger.info("Initializing precomputations")
166
-
167
- if self.args.batch_size != 1:
168
- raise ValueError("Precomputation is only supported with batch size 1. This will be supported in future.")
169
-
170
- def collate_fn(batch):
171
- latent_conditions = [x["latent_conditions"] for x in batch]
172
- text_conditions = [x["text_conditions"] for x in batch]
173
- batched_latent_conditions = {}
174
- batched_text_conditions = {}
175
- for key in list(latent_conditions[0].keys()):
176
- if torch.is_tensor(latent_conditions[0][key]):
177
- batched_latent_conditions[key] = torch.cat([x[key] for x in latent_conditions], dim=0)
178
- else:
179
- # TODO(aryan): implement batch sampler for precomputed latents
180
- batched_latent_conditions[key] = [x[key] for x in latent_conditions][0]
181
- for key in list(text_conditions[0].keys()):
182
- if torch.is_tensor(text_conditions[0][key]):
183
- batched_text_conditions[key] = torch.cat([x[key] for x in text_conditions], dim=0)
184
- else:
185
- # TODO(aryan): implement batch sampler for precomputed latents
186
- batched_text_conditions[key] = [x[key] for x in text_conditions][0]
187
- return {"latent_conditions": batched_latent_conditions, "text_conditions": batched_text_conditions}
188
-
189
- cleaned_model_id = string_to_filename(self.args.pretrained_model_name_or_path)
190
- precomputation_dir = (
191
- Path(self.args.data_root) / f"{self.args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
192
- )
193
- should_precompute = should_perform_precomputation(precomputation_dir)
194
- if not should_precompute:
195
- logger.info("Precomputed conditions and latents found. Loading precomputed data.")
196
- self.dataloader = torch.utils.data.DataLoader(
197
- PrecomputedDataset(
198
- data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
199
- ),
200
- batch_size=self.args.batch_size,
201
- shuffle=True,
202
- collate_fn=collate_fn,
203
- num_workers=self.args.dataloader_num_workers,
204
- pin_memory=self.args.pin_memory,
205
- )
206
- return
207
-
208
- logger.info("Precomputed conditions and latents not found. Running precomputation.")
209
-
210
- # At this point, no models are loaded, so we need to load and precompute conditions and latents
211
- with self.state.accelerator.main_process_first():
212
- condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs())
213
- self._set_components(condition_components)
214
- self._move_components_to_device()
215
- self._disable_grad_for_components([self.text_encoder, self.text_encoder_2, self.text_encoder_3])
216
-
217
- if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty":
218
- logger.warning(
219
- "Caption dropout is not supported with precomputation yet. This will be supported in the future."
220
- )
221
-
222
- conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
223
- latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
224
- conditions_dir.mkdir(parents=True, exist_ok=True)
225
- latents_dir.mkdir(parents=True, exist_ok=True)
226
-
227
- accelerator = self.state.accelerator
228
-
229
- # Precompute conditions
230
- progress_bar = tqdm(
231
- range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
232
- desc="Precomputing conditions",
233
- disable=not accelerator.is_local_main_process,
234
- )
235
- index = 0
236
- for i, data in enumerate(self.dataset):
237
- if i % accelerator.num_processes != accelerator.process_index:
238
- continue
239
-
240
- logger.debug(
241
- f"Precomputing conditions for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
242
- )
243
-
244
- text_conditions = self.model_config["prepare_conditions"](
245
- tokenizer=self.tokenizer,
246
- tokenizer_2=self.tokenizer_2,
247
- tokenizer_3=self.tokenizer_3,
248
- text_encoder=self.text_encoder,
249
- text_encoder_2=self.text_encoder_2,
250
- text_encoder_3=self.text_encoder_3,
251
- prompt=data["prompt"],
252
- device=accelerator.device,
253
- dtype=self.args.transformer_dtype,
254
- )
255
- filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt"
256
- torch.save(text_conditions, filename.as_posix())
257
- index += 1
258
- progress_bar.update(1)
259
- self._delete_components()
260
-
261
- memory_statistics = get_memory_statistics()
262
- logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}")
263
- torch.cuda.reset_peak_memory_stats(accelerator.device)
264
-
265
- # Precompute latents
266
- with self.state.accelerator.main_process_first():
267
- latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
268
- self._set_components(latent_components)
269
- self._move_components_to_device()
270
- self._disable_grad_for_components([self.vae])
271
-
272
- if self.vae is not None:
273
- if self.args.enable_slicing:
274
- self.vae.enable_slicing()
275
- if self.args.enable_tiling:
276
- self.vae.enable_tiling()
277
-
278
- progress_bar = tqdm(
279
- range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
280
- desc="Precomputing latents",
281
- disable=not accelerator.is_local_main_process,
282
- )
283
- index = 0
284
- for i, data in enumerate(self.dataset):
285
- if i % accelerator.num_processes != accelerator.process_index:
286
- continue
287
-
288
- logger.debug(
289
- f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
290
- )
291
-
292
- latent_conditions = self.model_config["prepare_latents"](
293
- vae=self.vae,
294
- image_or_video=data["video"].unsqueeze(0),
295
- device=accelerator.device,
296
- dtype=self.args.transformer_dtype,
297
- generator=self.state.generator,
298
- precompute=True,
299
- )
300
- filename = latents_dir / f"latents-{accelerator.process_index}-{index}.pt"
301
- torch.save(latent_conditions, filename.as_posix())
302
- index += 1
303
- progress_bar.update(1)
304
- self._delete_components()
305
-
306
- accelerator.wait_for_everyone()
307
- logger.info("Precomputation complete")
308
-
309
- memory_statistics = get_memory_statistics()
310
- logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}")
311
- torch.cuda.reset_peak_memory_stats(accelerator.device)
312
-
313
- # Update dataloader to use precomputed conditions and latents
314
- self.dataloader = torch.utils.data.DataLoader(
315
- PrecomputedDataset(
316
- data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
317
- ),
318
- batch_size=self.args.batch_size,
319
- shuffle=True,
320
- collate_fn=collate_fn,
321
- num_workers=self.args.dataloader_num_workers,
322
- pin_memory=self.args.pin_memory,
323
- )
324
-
325
- def prepare_trainable_parameters(self) -> None:
326
- logger.info("Initializing trainable parameters")
327
-
328
- with self.state.accelerator.main_process_first():
329
- diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs())
330
- self._set_components(diffusion_components)
331
-
332
- components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.vae]
333
- self._disable_grad_for_components(components)
334
-
335
- if self.args.training_type == "full-finetune":
336
- logger.info("Finetuning transformer with no additional parameters")
337
- self._enable_grad_for_components([self.transformer])
338
- else:
339
- logger.info("Finetuning transformer with PEFT parameters")
340
- self._disable_grad_for_components([self.transformer])
341
-
342
- # Layerwise upcasting must be applied before adding the LoRA adapter.
343
- # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
344
- # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
345
- if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules:
346
- apply_layerwise_upcasting(
347
- self.transformer,
348
- storage_dtype=self.args.layerwise_upcasting_storage_dtype,
349
- compute_dtype=self.args.transformer_dtype,
350
- skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
351
- non_blocking=True,
352
- )
353
-
354
- self._move_components_to_device()
355
-
356
- if self.args.gradient_checkpointing:
357
- self.transformer.enable_gradient_checkpointing()
358
-
359
- if self.args.training_type == "lora":
360
- transformer_lora_config = LoraConfig(
361
- r=self.args.rank,
362
- lora_alpha=self.args.lora_alpha,
363
- init_lora_weights=True,
364
- target_modules=self.args.target_modules,
365
- )
366
- self.transformer.add_adapter(transformer_lora_config)
367
- else:
368
- transformer_lora_config = None
369
-
370
- # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32
371
- # even if layerwise upcasting. Would be nice to have a test as well
372
-
373
- self.register_saving_loading_hooks(transformer_lora_config)
374
-
375
- def register_saving_loading_hooks(self, transformer_lora_config):
376
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
377
- def save_model_hook(models, weights, output_dir):
378
- if self.state.accelerator.is_main_process:
379
- transformer_lora_layers_to_save = None
380
-
381
- for model in models:
382
- if isinstance(
383
- unwrap_model(self.state.accelerator, model),
384
- type(unwrap_model(self.state.accelerator, self.transformer)),
385
- ):
386
- model = unwrap_model(self.state.accelerator, model)
387
- if self.args.training_type == "lora":
388
- transformer_lora_layers_to_save = get_peft_model_state_dict(model)
389
- else:
390
- raise ValueError(f"Unexpected save model: {model.__class__}")
391
-
392
- # make sure to pop weight so that corresponding model is not saved again
393
- if weights:
394
- weights.pop()
395
-
396
- if self.args.training_type == "lora":
397
- self.model_config["pipeline_cls"].save_lora_weights(
398
- output_dir,
399
- transformer_lora_layers=transformer_lora_layers_to_save,
400
- )
401
- else:
402
- model.save_pretrained(os.path.join(output_dir, "transformer"))
403
-
404
- # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
405
- # to able to load all diffusion components from a specific checkpoint folder during validation, we need to
406
- # ensure the scheduler config is serialized as well.
407
- self.scheduler.save_pretrained(os.path.join(output_dir, "scheduler"))
408
-
409
- def load_model_hook(models, input_dir):
410
- if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
411
- while len(models) > 0:
412
- model = models.pop()
413
- if isinstance(
414
- unwrap_model(self.state.accelerator, model),
415
- type(unwrap_model(self.state.accelerator, self.transformer)),
416
- ):
417
- transformer_ = unwrap_model(self.state.accelerator, model)
418
- else:
419
- raise ValueError(
420
- f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}"
421
- )
422
- else:
423
- transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__
424
-
425
- if self.args.training_type == "lora":
426
- transformer_ = transformer_cls_.from_pretrained(
427
- self.args.pretrained_model_name_or_path, subfolder="transformer"
428
- )
429
- transformer_.add_adapter(transformer_lora_config)
430
- lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)
431
- transformer_state_dict = {
432
- f'{k.replace("transformer.", "")}': v
433
- for k, v in lora_state_dict.items()
434
- if k.startswith("transformer.")
435
- }
436
- incompatible_keys = set_peft_model_state_dict(
437
- transformer_, transformer_state_dict, adapter_name="default"
438
- )
439
- if incompatible_keys is not None:
440
- # check only for unexpected keys
441
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
442
- if unexpected_keys:
443
- logger.warning(
444
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
445
- f" {unexpected_keys}. "
446
- )
447
- else:
448
- transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer"))
449
-
450
- self.state.accelerator.register_save_state_pre_hook(save_model_hook)
451
- self.state.accelerator.register_load_state_pre_hook(load_model_hook)
452
-
453
- def prepare_optimizer(self) -> None:
454
- logger.info("Initializing optimizer and lr scheduler")
455
-
456
- self.state.train_epochs = self.args.train_epochs
457
- self.state.train_steps = self.args.train_steps
458
-
459
- # Make sure the trainable params are in float32
460
- if self.args.training_type == "lora":
461
- cast_training_params([self.transformer], dtype=torch.float32)
462
-
463
- self.state.learning_rate = self.args.lr
464
- if self.args.scale_lr:
465
- self.state.learning_rate = (
466
- self.state.learning_rate
467
- * self.args.gradient_accumulation_steps
468
- * self.args.batch_size
469
- * self.state.accelerator.num_processes
470
- )
471
-
472
- transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters()))
473
- transformer_parameters_with_lr = {
474
- "params": transformer_trainable_parameters,
475
- "lr": self.state.learning_rate,
476
- }
477
- params_to_optimize = [transformer_parameters_with_lr]
478
- self.state.num_trainable_parameters = sum(p.numel() for p in transformer_trainable_parameters)
479
-
480
- use_deepspeed_opt = (
481
- self.state.accelerator.state.deepspeed_plugin is not None
482
- and "optimizer" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
483
- )
484
- optimizer = get_optimizer(
485
- params_to_optimize=params_to_optimize,
486
- optimizer_name=self.args.optimizer,
487
- learning_rate=self.state.learning_rate,
488
- beta1=self.args.beta1,
489
- beta2=self.args.beta2,
490
- beta3=self.args.beta3,
491
- epsilon=self.args.epsilon,
492
- weight_decay=self.args.weight_decay,
493
- use_8bit=self.args.use_8bit_bnb,
494
- use_deepspeed=use_deepspeed_opt,
495
- )
496
-
497
- num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
498
- if self.state.train_steps is None:
499
- self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
500
- self.state.overwrote_max_train_steps = True
501
-
502
- use_deepspeed_lr_scheduler = (
503
- self.state.accelerator.state.deepspeed_plugin is not None
504
- and "scheduler" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
505
- )
506
- total_training_steps = self.state.train_steps * self.state.accelerator.num_processes
507
- num_warmup_steps = self.args.lr_warmup_steps * self.state.accelerator.num_processes
508
-
509
- if use_deepspeed_lr_scheduler:
510
- from accelerate.utils import DummyScheduler
511
-
512
- lr_scheduler = DummyScheduler(
513
- name=self.args.lr_scheduler,
514
- optimizer=optimizer,
515
- total_num_steps=total_training_steps,
516
- num_warmup_steps=num_warmup_steps,
517
- )
518
- else:
519
- lr_scheduler = get_scheduler(
520
- name=self.args.lr_scheduler,
521
- optimizer=optimizer,
522
- num_warmup_steps=num_warmup_steps,
523
- num_training_steps=total_training_steps,
524
- num_cycles=self.args.lr_num_cycles,
525
- power=self.args.lr_power,
526
- )
527
-
528
- self.optimizer = optimizer
529
- self.lr_scheduler = lr_scheduler
530
-
531
- def prepare_for_training(self) -> None:
532
- self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare(
533
- self.transformer, self.optimizer, self.dataloader, self.lr_scheduler
534
- )
535
-
536
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
537
- num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
538
- if self.state.overwrote_max_train_steps:
539
- self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
540
- # Afterwards we recalculate our number of training epochs
541
- self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch)
542
- self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
543
-
544
- def prepare_trackers(self) -> None:
545
- logger.info("Initializing trackers")
546
-
547
- tracker_name = self.args.tracker_name or "finetrainers-experiment"
548
- self.state.accelerator.init_trackers(tracker_name, config=self._get_training_info())
549
-
550
- def train(self) -> None:
551
- logger.info("Starting training")
552
-
553
-
554
- # Add these lines at the beginning
555
- if hasattr(resource, 'RLIMIT_NOFILE'):
556
- try:
557
- soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
558
- logger.info(f"Current file descriptor limits in trainer: soft={soft}, hard={hard}")
559
- # Try to increase to hard limit if possible
560
- if soft < hard:
561
- resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
562
- new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
563
- logger.info(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}")
564
- except Exception as e:
565
- logger.warning(f"Could not check or update file descriptor limits: {e}")
566
-
567
- memory_statistics = get_memory_statistics()
568
- logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
569
-
570
- if self.vae_config is None:
571
- # If we've precomputed conditions and latents already, and are now re-using it, we will never load
572
- # the VAE so self.vae_config will not be set. So, we need to load it here.
573
- vae_cls = resolve_vae_cls_from_ckpt_path(
574
- self.args.pretrained_model_name_or_path, revision=self.args.revision, cache_dir=self.args.cache_dir
575
- )
576
- vae_config = vae_cls.load_config(
577
- self.args.pretrained_model_name_or_path,
578
- subfolder="vae",
579
- revision=self.args.revision,
580
- cache_dir=self.args.cache_dir,
581
- )
582
- self.vae_config = FrozenDict(**vae_config)
583
-
584
- # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
585
- # to able to load all diffusion components from a specific checkpoint folder during validation, we need to
586
- # ensure the scheduler config is serialized as well.
587
- if self.args.training_type == "full-finetune":
588
- self.scheduler.save_pretrained(os.path.join(self.args.output_dir, "scheduler"))
589
-
590
- self.state.train_batch_size = (
591
- self.args.batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps
592
- )
593
- info = {
594
- "trainable parameters": self.state.num_trainable_parameters,
595
- "total samples": len(self.dataset),
596
- "train epochs": self.state.train_epochs,
597
- "train steps": self.state.train_steps,
598
- "batches per device": self.args.batch_size,
599
- "total batches observed per epoch": len(self.dataloader),
600
- "train batch size": self.state.train_batch_size,
601
- "gradient accumulation steps": self.args.gradient_accumulation_steps,
602
- }
603
- logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
604
-
605
- global_step = 0
606
- first_epoch = 0
607
- initial_global_step = 0
608
-
609
- # Potentially load in the weights and states from a previous save
610
- (
611
- resume_from_checkpoint_path,
612
- initial_global_step,
613
- global_step,
614
- first_epoch,
615
- ) = get_latest_ckpt_path_to_resume_from(
616
- resume_from_checkpoint=self.args.resume_from_checkpoint,
617
- num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
618
- output_dir=self.args.output_dir,
619
- )
620
- if resume_from_checkpoint_path:
621
- self.state.accelerator.load_state(resume_from_checkpoint_path)
622
-
623
- progress_bar = tqdm(
624
- range(0, self.state.train_steps),
625
- initial=initial_global_step,
626
- desc="Training steps",
627
- disable=not self.state.accelerator.is_local_main_process,
628
- )
629
-
630
- accelerator = self.state.accelerator
631
- generator = torch.Generator(device=accelerator.device)
632
- if self.args.seed is not None:
633
- generator = generator.manual_seed(self.args.seed)
634
- self.state.generator = generator
635
-
636
- scheduler_sigmas = get_scheduler_sigmas(self.scheduler)
637
- scheduler_sigmas = (
638
- scheduler_sigmas.to(device=accelerator.device, dtype=torch.float32)
639
- if scheduler_sigmas is not None
640
- else None
641
- )
642
- scheduler_alphas = get_scheduler_alphas(self.scheduler)
643
- scheduler_alphas = (
644
- scheduler_alphas.to(device=accelerator.device, dtype=torch.float32)
645
- if scheduler_alphas is not None
646
- else None
647
- )
648
-
649
- for epoch in range(first_epoch, self.state.train_epochs):
650
- logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})")
651
-
652
- self.transformer.train()
653
- models_to_accumulate = [self.transformer]
654
- epoch_loss = 0.0
655
- num_loss_updates = 0
656
-
657
- for step, batch in enumerate(self.dataloader):
658
- logger.debug(f"Starting step {step + 1}")
659
- logs = {}
660
-
661
- with accelerator.accumulate(models_to_accumulate):
662
- if not self.args.precompute_conditions:
663
- videos = batch["videos"]
664
- prompts = batch["prompts"]
665
- batch_size = len(prompts)
666
-
667
- if self.args.caption_dropout_technique == "empty":
668
- if random.random() < self.args.caption_dropout_p:
669
- prompts = [""] * batch_size
670
-
671
- latent_conditions = self.model_config["prepare_latents"](
672
- vae=self.vae,
673
- image_or_video=videos,
674
- patch_size=self.transformer_config.patch_size,
675
- patch_size_t=self.transformer_config.patch_size_t,
676
- device=accelerator.device,
677
- dtype=self.args.transformer_dtype,
678
- generator=self.state.generator,
679
- )
680
- text_conditions = self.model_config["prepare_conditions"](
681
- tokenizer=self.tokenizer,
682
- text_encoder=self.text_encoder,
683
- tokenizer_2=self.tokenizer_2,
684
- text_encoder_2=self.text_encoder_2,
685
- prompt=prompts,
686
- device=accelerator.device,
687
- dtype=self.args.transformer_dtype,
688
- )
689
- else:
690
- latent_conditions = batch["latent_conditions"]
691
- text_conditions = batch["text_conditions"]
692
- latent_conditions["latents"] = DiagonalGaussianDistribution(
693
- latent_conditions["latents"]
694
- ).sample(self.state.generator)
695
-
696
- # This method should only be called for precomputed latents.
697
- # TODO(aryan): rename this in separate PR
698
- latent_conditions = self.model_config["post_latent_preparation"](
699
- vae_config=self.vae_config,
700
- patch_size=self.transformer_config.patch_size,
701
- patch_size_t=self.transformer_config.patch_size_t,
702
- **latent_conditions,
703
- )
704
- align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype)
705
- align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype)
706
- batch_size = latent_conditions["latents"].shape[0]
707
-
708
- latent_conditions = make_contiguous(latent_conditions)
709
- text_conditions = make_contiguous(text_conditions)
710
-
711
- if self.args.caption_dropout_technique == "zero":
712
- if random.random() < self.args.caption_dropout_p:
713
- text_conditions["prompt_embeds"].fill_(0)
714
- text_conditions["prompt_attention_mask"].fill_(False)
715
-
716
- # TODO(aryan): refactor later
717
- if "pooled_prompt_embeds" in text_conditions:
718
- text_conditions["pooled_prompt_embeds"].fill_(0)
719
-
720
- sigmas = prepare_sigmas(
721
- scheduler=self.scheduler,
722
- sigmas=scheduler_sigmas,
723
- batch_size=batch_size,
724
- num_train_timesteps=self.scheduler.config.num_train_timesteps,
725
- flow_weighting_scheme=self.args.flow_weighting_scheme,
726
- flow_logit_mean=self.args.flow_logit_mean,
727
- flow_logit_std=self.args.flow_logit_std,
728
- flow_mode_scale=self.args.flow_mode_scale,
729
- device=accelerator.device,
730
- generator=self.state.generator,
731
- )
732
- timesteps = (sigmas * 1000.0).long()
733
-
734
- noise = torch.randn(
735
- latent_conditions["latents"].shape,
736
- generator=self.state.generator,
737
- device=accelerator.device,
738
- dtype=self.args.transformer_dtype,
739
- )
740
- sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim)
741
-
742
- # TODO(aryan): We probably don't need calculate_noisy_latents because we can determine the type of
743
- # scheduler and calculate the noisy latents accordingly. Look into this later.
744
- if "calculate_noisy_latents" in self.model_config.keys():
745
- noisy_latents = self.model_config["calculate_noisy_latents"](
746
- scheduler=self.scheduler,
747
- noise=noise,
748
- latents=latent_conditions["latents"],
749
- timesteps=timesteps,
750
- )
751
- else:
752
- # Default to flow-matching noise addition
753
- noisy_latents = (1.0 - sigmas) * latent_conditions["latents"] + sigmas * noise
754
- noisy_latents = noisy_latents.to(latent_conditions["latents"].dtype)
755
-
756
- latent_conditions.update({"noisy_latents": noisy_latents})
757
-
758
- weights = prepare_loss_weights(
759
- scheduler=self.scheduler,
760
- alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None,
761
- sigmas=sigmas,
762
- flow_weighting_scheme=self.args.flow_weighting_scheme,
763
- )
764
- weights = expand_tensor_dims(weights, noise.ndim)
765
-
766
- pred = self.model_config["forward_pass"](
767
- transformer=self.transformer,
768
- scheduler=self.scheduler,
769
- timesteps=timesteps,
770
- **latent_conditions,
771
- **text_conditions,
772
- )
773
- target = prepare_target(
774
- scheduler=self.scheduler, noise=noise, latents=latent_conditions["latents"]
775
- )
776
-
777
- loss = weights.float() * (pred["latents"].float() - target.float()).pow(2)
778
- # Average loss across all but batch dimension
779
- loss = loss.mean(list(range(1, loss.ndim)))
780
- # Average loss across batch dimension
781
- loss = loss.mean()
782
- accelerator.backward(loss)
783
-
784
- if accelerator.sync_gradients:
785
- if accelerator.distributed_type == DistributedType.DEEPSPEED:
786
- grad_norm = self.transformer.get_global_grad_norm()
787
- # In some cases the grad norm may not return a float
788
- if torch.is_tensor(grad_norm):
789
- grad_norm = grad_norm.item()
790
- else:
791
- grad_norm = accelerator.clip_grad_norm_(
792
- self.transformer.parameters(), self.args.max_grad_norm
793
- )
794
- if torch.is_tensor(grad_norm):
795
- grad_norm = grad_norm.item()
796
-
797
- logs["grad_norm"] = grad_norm
798
-
799
- self.optimizer.step()
800
- self.lr_scheduler.step()
801
- self.optimizer.zero_grad()
802
-
803
- # Checks if the accelerator has performed an optimization step behind the scenes
804
- if accelerator.sync_gradients:
805
- progress_bar.update(1)
806
- global_step += 1
807
-
808
- # Checkpointing
809
- if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
810
- if global_step % self.args.checkpointing_steps == 0:
811
- save_path = get_intermediate_ckpt_path(
812
- checkpointing_limit=self.args.checkpointing_limit,
813
- step=global_step,
814
- output_dir=self.args.output_dir,
815
- )
816
- accelerator.save_state(save_path)
817
-
818
- # Maybe run validation
819
- should_run_validation = (
820
- self.args.validation_every_n_steps is not None
821
- and global_step % self.args.validation_every_n_steps == 0
822
- )
823
- if should_run_validation:
824
- self.validate(global_step)
825
-
826
- loss_item = loss.detach().item()
827
- epoch_loss += loss_item
828
- num_loss_updates += 1
829
- logs["step_loss"] = loss_item
830
- logs["lr"] = self.lr_scheduler.get_last_lr()[0]
831
- progress_bar.set_postfix(logs)
832
- accelerator.log(logs, step=global_step)
833
-
834
- if global_step % 100 == 0: # Every 100 steps
835
- # Force garbage collection to clean up any lingering resources
836
- gc.collect()
837
-
838
- if global_step >= self.state.train_steps:
839
- break
840
-
841
-
842
-
843
- if num_loss_updates > 0:
844
- epoch_loss /= num_loss_updates
845
- accelerator.log({"epoch_loss": epoch_loss}, step=global_step)
846
- memory_statistics = get_memory_statistics()
847
- logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
848
-
849
- # Maybe run validation
850
- should_run_validation = (
851
- self.args.validation_every_n_epochs is not None
852
- and (epoch + 1) % self.args.validation_every_n_epochs == 0
853
- )
854
- if should_run_validation:
855
- self.validate(global_step)
856
-
857
- if epoch % 3 == 0: # Every 3 epochs
858
- logger.info("Performing periodic resource cleanup")
859
- free_memory()
860
- gc.collect()
861
- torch.cuda.empty_cache()
862
- torch.cuda.synchronize(accelerator.device)
863
-
864
- accelerator.wait_for_everyone()
865
- if accelerator.is_main_process:
866
- transformer = unwrap_model(accelerator, self.transformer)
867
-
868
- if self.args.training_type == "lora":
869
- transformer_lora_layers = get_peft_model_state_dict(transformer)
870
-
871
- self.model_config["pipeline_cls"].save_lora_weights(
872
- save_directory=self.args.output_dir,
873
- transformer_lora_layers=transformer_lora_layers,
874
- )
875
- else:
876
- transformer.save_pretrained(os.path.join(self.args.output_dir, "transformer"))
877
- accelerator.wait_for_everyone()
878
- self.validate(step=global_step, final_validation=True)
879
-
880
- if accelerator.is_main_process:
881
- if self.args.push_to_hub:
882
- upload_folder(
883
- repo_id=self.state.repo_id, folder_path=self.args.output_dir, ignore_patterns=["checkpoint-*"]
884
- )
885
-
886
- self._delete_components()
887
- memory_statistics = get_memory_statistics()
888
- logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
889
-
890
- accelerator.end_training()
891
-
892
- def validate(self, step: int, final_validation: bool = False) -> None:
893
- logger.info("Starting validation")
894
-
895
- accelerator = self.state.accelerator
896
- num_validation_samples = len(self.args.validation_prompts)
897
-
898
- if num_validation_samples == 0:
899
- logger.warning("No validation samples found. Skipping validation.")
900
- if accelerator.is_main_process:
901
- if self.args.push_to_hub:
902
- save_model_card(
903
- args=self.args,
904
- repo_id=self.state.repo_id,
905
- videos=None,
906
- validation_prompts=None,
907
- )
908
- return
909
-
910
- self.transformer.eval()
911
-
912
- memory_statistics = get_memory_statistics()
913
- logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
914
-
915
- pipeline = self._get_and_prepare_pipeline_for_validation(final_validation=final_validation)
916
-
917
- all_processes_artifacts = []
918
- prompts_to_filenames = {}
919
- for i in range(num_validation_samples):
920
- # Skip current validation on all processes but one
921
- if i % accelerator.num_processes != accelerator.process_index:
922
- continue
923
-
924
- prompt = self.args.validation_prompts[i]
925
- image = self.args.validation_images[i]
926
- video = self.args.validation_videos[i]
927
- height = self.args.validation_heights[i]
928
- width = self.args.validation_widths[i]
929
- num_frames = self.args.validation_num_frames[i]
930
- frame_rate = self.args.validation_frame_rate
931
- if image is not None:
932
- image = load_image(image)
933
- if video is not None:
934
- video = load_video(video)
935
-
936
- logger.debug(
937
- f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
938
- main_process_only=False,
939
- )
940
- validation_artifacts = self.model_config["validation"](
941
- pipeline=pipeline,
942
- prompt=prompt,
943
- image=image,
944
- video=video,
945
- height=height,
946
- width=width,
947
- num_frames=num_frames,
948
- frame_rate=frame_rate,
949
- num_videos_per_prompt=self.args.num_validation_videos_per_prompt,
950
- generator=torch.Generator(device=accelerator.device).manual_seed(
951
- self.args.seed if self.args.seed is not None else 0
952
- ),
953
- # todo support passing `fps` for supported pipelines.
954
- )
955
-
956
- prompt_filename = string_to_filename(prompt)[:25]
957
- artifacts = {
958
- "image": {"type": "image", "value": image},
959
- "video": {"type": "video", "value": video},
960
- }
961
- for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
962
- if artifact_value:
963
- artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
964
- logger.debug(
965
- f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
966
- main_process_only=False,
967
- )
968
-
969
- for index, (key, value) in enumerate(list(artifacts.items())):
970
- artifact_type = value["type"]
971
- artifact_value = value["value"]
972
- if artifact_type not in ["image", "video"] or artifact_value is None:
973
- continue
974
-
975
- extension = "png" if artifact_type == "image" else "mp4"
976
- filename = "validation-" if not final_validation else "final-"
977
- filename += f"{step}-{accelerator.process_index}-{index}-{prompt_filename}.{extension}"
978
- if accelerator.is_main_process and extension == "mp4":
979
- prompts_to_filenames[prompt] = filename
980
- filename = os.path.join(self.args.output_dir, filename)
981
-
982
- if artifact_type == "image" and artifact_value:
983
- logger.debug(f"Saving image to {filename}")
984
- artifact_value.save(filename)
985
- artifact_value = wandb.Image(filename)
986
- elif artifact_type == "video" and artifact_value:
987
- logger.debug(f"Saving video to {filename}")
988
- # TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`.
989
- export_to_video(artifact_value, filename, fps=frame_rate)
990
- artifact_value = wandb.Video(filename, caption=prompt)
991
-
992
- all_processes_artifacts.append(artifact_value)
993
-
994
- all_artifacts = gather_object(all_processes_artifacts)
995
-
996
- if accelerator.is_main_process:
997
- tracker_key = "final" if final_validation else "validation"
998
- for tracker in accelerator.trackers:
999
- if tracker.name == "wandb":
1000
- artifact_log_dict = {}
1001
-
1002
- image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
1003
- if len(image_artifacts) > 0:
1004
- artifact_log_dict["images"] = image_artifacts
1005
- video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
1006
- if len(video_artifacts) > 0:
1007
- artifact_log_dict["videos"] = video_artifacts
1008
- tracker.log({tracker_key: artifact_log_dict}, step=step)
1009
-
1010
- if self.args.push_to_hub and final_validation:
1011
- video_filenames = list(prompts_to_filenames.values())
1012
- prompts = list(prompts_to_filenames.keys())
1013
- save_model_card(
1014
- args=self.args,
1015
- repo_id=self.state.repo_id,
1016
- videos=video_filenames,
1017
- validation_prompts=prompts,
1018
- )
1019
-
1020
- # Remove all hooks that might have been added during pipeline initialization to the models
1021
- pipeline.remove_all_hooks()
1022
- del pipeline
1023
-
1024
- accelerator.wait_for_everyone()
1025
-
1026
- free_memory()
1027
- memory_statistics = get_memory_statistics()
1028
- logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
1029
- torch.cuda.reset_peak_memory_stats(accelerator.device)
1030
-
1031
- if not final_validation:
1032
- self.transformer.train()
1033
-
1034
- def evaluate(self) -> None:
1035
- raise NotImplementedError("Evaluation has not been implemented yet.")
1036
-
1037
- def _init_distributed(self) -> None:
1038
- logging_dir = Path(self.args.output_dir, self.args.logging_dir)
1039
- project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
1040
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
1041
- init_process_group_kwargs = InitProcessGroupKwargs(
1042
- backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
1043
- )
1044
- report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
1045
-
1046
- accelerator = Accelerator(
1047
- project_config=project_config,
1048
- gradient_accumulation_steps=self.args.gradient_accumulation_steps,
1049
- log_with=report_to,
1050
- kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
1051
- )
1052
-
1053
- # Disable AMP for MPS.
1054
- if torch.backends.mps.is_available():
1055
- accelerator.native_amp = False
1056
-
1057
- self.state.accelerator = accelerator
1058
-
1059
- if self.args.seed is not None:
1060
- self.state.seed = self.args.seed
1061
- set_seed(self.args.seed)
1062
-
1063
- def _init_logging(self) -> None:
1064
- logging.basicConfig(
1065
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
1066
- datefmt="%m/%d/%Y %H:%M:%S",
1067
- level=FINETRAINERS_LOG_LEVEL,
1068
- )
1069
- if self.state.accelerator.is_local_main_process:
1070
- transformers.utils.logging.set_verbosity_warning()
1071
- diffusers.utils.logging.set_verbosity_info()
1072
- else:
1073
- transformers.utils.logging.set_verbosity_error()
1074
- diffusers.utils.logging.set_verbosity_error()
1075
-
1076
- logger.info("Initialized FineTrainers")
1077
- logger.info(self.state.accelerator.state, main_process_only=False)
1078
-
1079
- def _init_directories_and_repositories(self) -> None:
1080
- if self.state.accelerator.is_main_process:
1081
- self.args.output_dir = Path(self.args.output_dir)
1082
- self.args.output_dir.mkdir(parents=True, exist_ok=True)
1083
- self.state.output_dir = Path(self.args.output_dir)
1084
-
1085
- if self.args.push_to_hub:
1086
- repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
1087
- self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
1088
-
1089
- def _init_config_options(self) -> None:
1090
- # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1091
- if self.args.allow_tf32 and torch.cuda.is_available():
1092
- torch.backends.cuda.matmul.allow_tf32 = True
1093
-
1094
- def _move_components_to_device(self):
1095
- if self.text_encoder is not None:
1096
- self.text_encoder = self.text_encoder.to(self.state.accelerator.device)
1097
- if self.text_encoder_2 is not None:
1098
- self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device)
1099
- if self.text_encoder_3 is not None:
1100
- self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device)
1101
- if self.transformer is not None:
1102
- self.transformer = self.transformer.to(self.state.accelerator.device)
1103
- if self.unet is not None:
1104
- self.unet = self.unet.to(self.state.accelerator.device)
1105
- if self.vae is not None:
1106
- self.vae = self.vae.to(self.state.accelerator.device)
1107
-
1108
- def _get_load_components_kwargs(self) -> Dict[str, Any]:
1109
- load_component_kwargs = {
1110
- "text_encoder_dtype": self.args.text_encoder_dtype,
1111
- "text_encoder_2_dtype": self.args.text_encoder_2_dtype,
1112
- "text_encoder_3_dtype": self.args.text_encoder_3_dtype,
1113
- "transformer_dtype": self.args.transformer_dtype,
1114
- "vae_dtype": self.args.vae_dtype,
1115
- "shift": self.args.flow_shift,
1116
- "revision": self.args.revision,
1117
- "cache_dir": self.args.cache_dir,
1118
- }
1119
- if self.args.pretrained_model_name_or_path is not None:
1120
- load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path
1121
- return load_component_kwargs
1122
-
1123
- def _set_components(self, components: Dict[str, Any]) -> None:
1124
- # Set models
1125
- self.tokenizer = components.get("tokenizer", self.tokenizer)
1126
- self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2)
1127
- self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3)
1128
- self.text_encoder = components.get("text_encoder", self.text_encoder)
1129
- self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2)
1130
- self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3)
1131
- self.transformer = components.get("transformer", self.transformer)
1132
- self.unet = components.get("unet", self.unet)
1133
- self.vae = components.get("vae", self.vae)
1134
- self.scheduler = components.get("scheduler", self.scheduler)
1135
-
1136
- # Set configs
1137
- self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config
1138
- self.vae_config = self.vae.config if self.vae is not None else self.vae_config
1139
-
1140
- def _delete_components(self) -> None:
1141
- self.tokenizer = None
1142
- self.tokenizer_2 = None
1143
- self.tokenizer_3 = None
1144
- self.text_encoder = None
1145
- self.text_encoder_2 = None
1146
- self.text_encoder_3 = None
1147
- self.transformer = None
1148
- self.unet = None
1149
- self.vae = None
1150
- self.scheduler = None
1151
- free_memory()
1152
- torch.cuda.synchronize(self.state.accelerator.device)
1153
-
1154
- def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline:
1155
- accelerator = self.state.accelerator
1156
- if not final_validation:
1157
- pipeline = self.model_config["initialize_pipeline"](
1158
- model_id=self.args.pretrained_model_name_or_path,
1159
- tokenizer=self.tokenizer,
1160
- text_encoder=self.text_encoder,
1161
- tokenizer_2=self.tokenizer_2,
1162
- text_encoder_2=self.text_encoder_2,
1163
- transformer=unwrap_model(accelerator, self.transformer),
1164
- vae=self.vae,
1165
- device=accelerator.device,
1166
- revision=self.args.revision,
1167
- cache_dir=self.args.cache_dir,
1168
- enable_slicing=self.args.enable_slicing,
1169
- enable_tiling=self.args.enable_tiling,
1170
- enable_model_cpu_offload=self.args.enable_model_cpu_offload,
1171
- is_training=True,
1172
- )
1173
- else:
1174
- self._delete_components()
1175
-
1176
- # Load the transformer weights from the final checkpoint if performing full-finetune
1177
- transformer = None
1178
- if self.args.training_type == "full-finetune":
1179
- transformer = self.model_config["load_diffusion_models"](model_id=self.args.output_dir)["transformer"]
1180
-
1181
- pipeline = self.model_config["initialize_pipeline"](
1182
- model_id=self.args.pretrained_model_name_or_path,
1183
- transformer=transformer,
1184
- device=accelerator.device,
1185
- revision=self.args.revision,
1186
- cache_dir=self.args.cache_dir,
1187
- enable_slicing=self.args.enable_slicing,
1188
- enable_tiling=self.args.enable_tiling,
1189
- enable_model_cpu_offload=self.args.enable_model_cpu_offload,
1190
- is_training=False,
1191
- )
1192
-
1193
- # Load the LoRA weights if performing LoRA finetuning
1194
- if self.args.training_type == "lora":
1195
- pipeline.load_lora_weights(self.args.output_dir)
1196
-
1197
- return pipeline
1198
-
1199
- def _disable_grad_for_components(self, components: List[torch.nn.Module]):
1200
- for component in components:
1201
- if component is not None:
1202
- component.requires_grad_(False)
1203
-
1204
- def _enable_grad_for_components(self, components: List[torch.nn.Module]):
1205
- for component in components:
1206
- if component is not None:
1207
- component.requires_grad_(True)
1208
-
1209
- def _get_training_info(self) -> dict:
1210
- args = self.args.to_dict()
1211
-
1212
- training_args = args.get("training_arguments", {})
1213
- training_type = training_args.get("training_type", "")
1214
-
1215
- # LoRA/non-LoRA stuff.
1216
- if training_type == "full-finetune":
1217
- filtered_training_args = {
1218
- k: v for k, v in training_args.items() if k not in {"rank", "lora_alpha", "target_modules"}
1219
- }
1220
- else:
1221
- filtered_training_args = training_args
1222
-
1223
- # Diffusion/flow stuff.
1224
- diffusion_args = args.get("diffusion_arguments", {})
1225
- scheduler_name = self.scheduler.__class__.__name__
1226
- if scheduler_name != "FlowMatchEulerDiscreteScheduler":
1227
- filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k}
1228
- else:
1229
- filtered_diffusion_args = diffusion_args
1230
-
1231
- # Rest of the stuff.
1232
- updated_training_info = args.copy()
1233
- updated_training_info["training_arguments"] = filtered_training_args
1234
- updated_training_info["diffusion_arguments"] = filtered_diffusion_args
1235
- return updated_training_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vms/services/captioner.py CHANGED
@@ -508,15 +508,15 @@ class CaptioningService:
508
  break
509
 
510
  try:
511
- print(f"we are in file_path {str(file_path)}")
512
  # Choose appropriate processing method based on file type
513
  if is_video_file(file_path):
514
  process_gen = self.process_video(file_path, prompt, prompt_prefix)
515
  else:
516
  process_gen = self.process_image(file_path, prompt, prompt_prefix)
517
- print("got process_gen = ", process_gen)
518
  async for progress, caption in process_gen:
519
- print(f"process_gen contains this caption = {caption}")
520
  if caption and prompt_prefix and not caption.startswith(prompt_prefix):
521
  caption = f"{prompt_prefix}{caption}"
522
 
@@ -525,7 +525,7 @@ class CaptioningService:
525
  txt_path = file_path.with_suffix('.txt')
526
  txt_path.write_text(caption)
527
 
528
- logger.debug(f"Progress update: {progress.status}")
529
 
530
  # Store progress info
531
  status_update[file_path.name] = {
 
508
  break
509
 
510
  try:
511
+ #print(f"we are in file_path {str(file_path)}")
512
  # Choose appropriate processing method based on file type
513
  if is_video_file(file_path):
514
  process_gen = self.process_video(file_path, prompt, prompt_prefix)
515
  else:
516
  process_gen = self.process_image(file_path, prompt, prompt_prefix)
517
+ #print("got process_gen = ", process_gen)
518
  async for progress, caption in process_gen:
519
+ #print(f"process_gen contains this caption = {caption}")
520
  if caption and prompt_prefix and not caption.startswith(prompt_prefix):
521
  caption = f"{prompt_prefix}{caption}"
522
 
 
525
  txt_path = file_path.with_suffix('.txt')
526
  txt_path.write_text(caption)
527
 
528
+ #logger.debug(f"Progress update: {progress.status}")
529
 
530
  # Store progress info
531
  status_update[file_path.name] = {
vms/tabs/manage_tab.py CHANGED
@@ -56,12 +56,12 @@ class ManageTab(BaseTab):
56
  gr.Markdown("## Storage management")
57
  with gr.Row():
58
  self.components["download_dataset_btn"] = gr.DownloadButton(
59
- "Download dataset",
60
  variant="secondary",
61
  size="lg"
62
  )
63
  self.components["download_model_btn"] = gr.DownloadButton(
64
- "Download model",
65
  variant="secondary",
66
  size="lg"
67
  )
 
56
  gr.Markdown("## Storage management")
57
  with gr.Row():
58
  self.components["download_dataset_btn"] = gr.DownloadButton(
59
+ "Download dataset (click again if DL doesn't start)",
60
  variant="secondary",
61
  size="lg"
62
  )
63
  self.components["download_model_btn"] = gr.DownloadButton(
64
+ "Download model (click again if DL doesn't start)",
65
  variant="secondary",
66
  size="lg"
67
  )