jbilcke-hf HF Staff commited on
Commit
c5911ab
·
1 Parent(s): 956cf49

delete finetrainers

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. finetrainers/__init__.py +0 -5
  2. finetrainers/args.py +0 -865
  3. finetrainers/config.py +0 -58
  4. finetrainers/constants.py +0 -83
  5. finetrainers/data/__init__.py +0 -27
  6. finetrainers/data/_artifact.py +0 -29
  7. finetrainers/data/dataloader.py +0 -40
  8. finetrainers/data/dataset.py +0 -978
  9. finetrainers/data/precomputation.py +0 -376
  10. finetrainers/data/sampler.py +0 -58
  11. finetrainers/data/utils.py +0 -20
  12. finetrainers/functional/__init__.py +0 -16
  13. finetrainers/functional/diffusion.py +0 -11
  14. finetrainers/functional/image.py +0 -54
  15. finetrainers/functional/text.py +0 -26
  16. finetrainers/functional/video.py +0 -94
  17. finetrainers/logging.py +0 -111
  18. finetrainers/models/__init__.py +0 -1
  19. finetrainers/models/cogvideox/__init__.py +0 -1
  20. finetrainers/models/cogvideox/base_specification.py +0 -423
  21. finetrainers/models/cogvideox/utils.py +0 -51
  22. finetrainers/models/cogview4/__init__.py +0 -1
  23. finetrainers/models/cogview4/base_specification.py +0 -395
  24. finetrainers/models/hunyuan_video/__init__.py +0 -1
  25. finetrainers/models/hunyuan_video/base_specification.py +0 -410
  26. finetrainers/models/ltx_video/__init__.py +0 -1
  27. finetrainers/models/ltx_video/base_specification.py +0 -517
  28. finetrainers/models/modeling_utils.py +0 -289
  29. finetrainers/models/utils.py +0 -62
  30. finetrainers/models/wan/__init__.py +0 -1
  31. finetrainers/models/wan/base_specification.py +0 -393
  32. finetrainers/optimizer.py +0 -449
  33. finetrainers/parallel/__init__.py +0 -22
  34. finetrainers/parallel/accelerate.py +0 -218
  35. finetrainers/parallel/base.py +0 -96
  36. finetrainers/parallel/deepspeed.py +0 -7
  37. finetrainers/parallel/ptd.py +0 -228
  38. finetrainers/parallel/utils.py +0 -99
  39. finetrainers/patches/__init__.py +0 -28
  40. finetrainers/patches/dependencies/peft/patch.py +0 -25
  41. finetrainers/patches/models/ltx_video/patch.py +0 -127
  42. finetrainers/patches/models/wan/patch.py +0 -33
  43. finetrainers/patches/utils.py +0 -18
  44. finetrainers/processors/__init__.py +0 -6
  45. finetrainers/processors/base.py +0 -20
  46. finetrainers/processors/clip.py +0 -65
  47. finetrainers/processors/glm.py +0 -74
  48. finetrainers/processors/llama.py +0 -118
  49. finetrainers/processors/t5.py +0 -73
  50. finetrainers/processors/text.py +0 -22
finetrainers/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from .args import BaseArgs
2
- from .config import ModelType, TrainingType
3
- from .logging import get_logger
4
- from .models import ModelSpecification
5
- from .trainer import SFTTrainer
 
 
 
 
 
 
finetrainers/args.py DELETED
@@ -1,865 +0,0 @@
1
- import argparse
2
- import os
3
- import pathlib
4
- import sys
5
- from typing import Any, Callable, Dict, List, Optional
6
-
7
- import torch
8
-
9
- from .config import SUPPORTED_MODEL_CONFIGS, ModelType, TrainingType
10
- from .logging import get_logger
11
- from .parallel import ParallelBackendEnum
12
- from .utils import get_non_null_items
13
-
14
-
15
- logger = get_logger()
16
-
17
-
18
- class BaseArgs:
19
- r"""
20
- The arguments for the finetrainers training script.
21
-
22
- For helpful information about arguments, run `python train.py --help`.
23
-
24
- TODO(aryan): add `python train.py --recommend_configs --model_name <model_name>` to recommend
25
- good training configs for a model after extensive testing.
26
- TODO(aryan): add `python train.py --memory_requirements --model_name <model_name>` to show
27
- memory requirements per model, per training type with sensible training settings.
28
-
29
- PARALLEL ARGUMENTS
30
- ------------------
31
- parallel_backend (`str`, defaults to `accelerate`):
32
- The parallel backend to use for training. Choose between ['accelerate', 'ptd'].
33
- pp_degree (`int`, defaults to `1`):
34
- The degree of pipeline parallelism.
35
- dp_degree (`int`, defaults to `1`):
36
- The degree of data parallelism (number of model replicas).
37
- dp_shards (`int`, defaults to `-1`):
38
- The number of data parallel shards (number of model partitions).
39
- cp_degree (`int`, defaults to `1`):
40
- The degree of context parallelism.
41
-
42
- MODEL ARGUMENTS
43
- ---------------
44
- model_name (`str`):
45
- Name of model to train. To get a list of models, run `python train.py --list_models`.
46
- pretrained_model_name_or_path (`str`):
47
- Path to pretrained model or model identifier from https://huggingface.co/models. The model should be
48
- loadable based on specified `model_name`.
49
- revision (`str`, defaults to `None`):
50
- If provided, the model will be loaded from a specific branch of the model repository.
51
- variant (`str`, defaults to `None`):
52
- Variant of model weights to use. Some models provide weight variants, such as `fp16`, to reduce disk
53
- storage requirements.
54
- cache_dir (`str`, defaults to `None`):
55
- The directory where the downloaded models and datasets will be stored, or loaded from.
56
- tokenizer_id (`str`, defaults to `None`):
57
- Identifier for the tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
58
- tokenizer_2_id (`str`, defaults to `None`):
59
- Identifier for the second tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
60
- tokenizer_3_id (`str`, defaults to `None`):
61
- Identifier for the third tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
62
- text_encoder_id (`str`, defaults to `None`):
63
- Identifier for the text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
64
- text_encoder_2_id (`str`, defaults to `None`):
65
- Identifier for the second text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
66
- text_encoder_3_id (`str`, defaults to `None`):
67
- Identifier for the third text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
68
- transformer_id (`str`, defaults to `None`):
69
- Identifier for the transformer model. This is useful when using a different transformer model than the default from `pretrained_model_name_or_path`.
70
- vae_id (`str`, defaults to `None`):
71
- Identifier for the VAE model. This is useful when using a different VAE model than the default from `pretrained_model_name_or_path`.
72
- text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
73
- Data type for the text encoder when generating text embeddings.
74
- text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
75
- Data type for the text encoder 2 when generating text embeddings.
76
- text_encoder_3_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
77
- Data type for the text encoder 3 when generating text embeddings.
78
- transformer_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
79
- Data type for the transformer model.
80
- vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
81
- Data type for the VAE model.
82
- layerwise_upcasting_modules (`List[str]`, defaults to `[]`):
83
- Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer'].
84
- layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`):
85
- Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2'].
86
- layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`):
87
- Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision
88
- naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
89
- by default, and recommend adding more layers to the default list based on the model architecture.
90
-
91
- DATASET ARGUMENTS
92
- -----------------
93
- dataset_config (`str`):
94
- File to a dataset file containing information about training data. This file can contain information about one or
95
- more datasets in JSON format. The file must have a key called "datasets", which is a list of dictionaries. Each
96
- dictionary must contain the following keys:
97
- - "data_root": (`str`)
98
- The root directory containing the dataset. This parameter must be provided if `dataset_file` is not provided.
99
- - "dataset_file": (`str`)
100
- Path to a CSV/JSON/JSONL/PARQUET/ARROW/HF_HUB_DATASET file containing metadata for training. This parameter
101
- must be provided if `data_root` is not provided.
102
- - "dataset_type": (`str`)
103
- Type of dataset. Choose between ['image', 'video'].
104
- - "id_token": (`str`)
105
- Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training
106
- for single subject/concept/style training, but is not necessary.
107
- - "image_resolution_buckets": (`List[Tuple[int, int]]`)
108
- Resolution buckets for image. This should be a list of tuples containing 2 values, where each tuple
109
- represents the resolution (height, width). All images will be resized to the nearest bucket resolution.
110
- This parameter must be provided if `dataset_type` is 'image'.
111
- - "video_resolution_buckets": (`List[Tuple[int, int, int]]`)
112
- Resolution buckets for video. This should be a list of tuples containing 3 values, where each tuple
113
- represents the resolution (num_frames, height, width). All videos will be resized to the nearest bucket
114
- resolution. This parameter must be provided if `dataset_type` is 'video'.
115
- - "reshape_mode": (`str`)
116
- All input images/videos are reshaped using this mode. Choose between the following:
117
- ["center_crop", "random_crop", "bicubic"].
118
- - "remove_common_llm_caption_prefixes": (`boolean`)
119
- Whether or not to remove common LLM caption prefixes. See `~constants.py` for the list of common prefixes.
120
- dataset_shuffle_buffer_size (`int`, defaults to `1`):
121
- The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default
122
- value of `1` means that the dataset will not be shuffled.
123
- precomputation_items (`int`, defaults to `512`):
124
- Number of data samples to precompute at once for memory-efficient training. The higher this value,
125
- the more disk memory will be used to save the precomputed samples (conditions and latents).
126
- precomputation_dir (`str`, defaults to `None`):
127
- The directory where the precomputed samples will be stored. If not provided, the precomputed samples
128
- will be stored in a temporary directory of the output directory.
129
- precomputation_once (`bool`, defaults to `False`):
130
- Precompute embeddings from all datasets at once before training. This is useful to save time during training
131
- with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during
132
- training when required. Make sure to set `precomputation_items` to a reasonable value in line with the size
133
- of your dataset(s).
134
-
135
- DATALOADER_ARGUMENTS
136
- --------------------
137
- See https://pytorch.org/docs/stable/data.html for more information.
138
-
139
- dataloader_num_workers (`int`, defaults to `0`):
140
- Number of subprocesses to use for data loading. `0` means that the data will be loaded in a blocking manner
141
- on the main process.
142
- pin_memory (`bool`, defaults to `False`):
143
- Whether or not to use the pinned memory setting in PyTorch dataloader. This is useful for faster data loading.
144
-
145
- DIFFUSION ARGUMENTS
146
- -------------------
147
- flow_resolution_shifting (`bool`, defaults to `False`):
148
- Resolution-dependent shifting of timestep schedules.
149
- [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206).
150
- TODO(aryan): We don't support this yet.
151
- flow_base_seq_len (`int`, defaults to `256`):
152
- Base number of tokens for images/video when applying resolution-dependent shifting.
153
- flow_max_seq_len (`int`, defaults to `4096`):
154
- Maximum number of tokens for images/video when applying resolution-dependent shifting.
155
- flow_base_shift (`float`, defaults to `0.5`):
156
- Base shift for timestep schedules when applying resolution-dependent shifting.
157
- flow_max_shift (`float`, defaults to `1.15`):
158
- Maximum shift for timestep schedules when applying resolution-dependent shifting.
159
- flow_shift (`float`, defaults to `1.0`):
160
- Instead of training with uniform/logit-normal sigmas, shift them as (shift * sigma) / (1 + (shift - 1) * sigma).
161
- Setting it higher is helpful when trying to train models for high-resolution generation or to produce better
162
- samples in lower number of inference steps.
163
- flow_weighting_scheme (`str`, defaults to `none`):
164
- We default to the "none" weighting scheme for uniform sampling and uniform loss.
165
- Choose between ['sigma_sqrt', 'logit_normal', 'mode', 'cosmap', 'none'].
166
- flow_logit_mean (`float`, defaults to `0.0`):
167
- Mean to use when using the `'logit_normal'` weighting scheme.
168
- flow_logit_std (`float`, defaults to `1.0`):
169
- Standard deviation to use when using the `'logit_normal'` weighting scheme.
170
- flow_mode_scale (`float`, defaults to `1.29`):
171
- Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.
172
-
173
- TRAINING ARGUMENTS
174
- ------------------
175
- training_type (`str`, defaults to `None`):
176
- Type of training to perform. Choose between ['lora'].
177
- seed (`int`, defaults to `42`):
178
- A seed for reproducible training.
179
- batch_size (`int`, defaults to `1`):
180
- Per-device batch size.
181
- train_steps (`int`, defaults to `1000`):
182
- Total number of training steps to perform.
183
- max_data_samples (`int`, defaults to `2**64`):
184
- Maximum number of data samples observed during training training. If lesser than that required by `train_steps`,
185
- the training will stop early.
186
- gradient_accumulation_steps (`int`, defaults to `1`):
187
- Number of gradients steps to accumulate before performing an optimizer step.
188
- gradient_checkpointing (`bool`, defaults to `False`):
189
- Whether or not to use gradient/activation checkpointing to save memory at the expense of slower
190
- backward pass.
191
- checkpointing_steps (`int`, defaults to `500`):
192
- Save a checkpoint of the training state every X training steps. These checkpoints can be used both
193
- as final checkpoints in case they are better than the last checkpoint, and are also suitable for
194
- resuming training using `resume_from_checkpoint`.
195
- checkpointing_limit (`int`, defaults to `None`):
196
- Max number of checkpoints to store.
197
- resume_from_checkpoint (`str`, defaults to `None`):
198
- Whether training should be resumed from a previous checkpoint. Use a path saved by `checkpointing_steps`,
199
- or `"latest"` to automatically select the last available checkpoint.
200
-
201
- OPTIMIZER ARGUMENTS
202
- -------------------
203
- optimizer (`str`, defaults to `adamw`):
204
- The optimizer type to use. Choose between the following:
205
- - Torch optimizers: ["adam", "adamw"]
206
- - Bitsandbytes optimizers: ["adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"]
207
- lr (`float`, defaults to `1e-4`):
208
- Initial learning rate (after the potential warmup period) to use.
209
- lr_scheduler (`str`, defaults to `cosine_with_restarts`):
210
- The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',
211
- 'constant', 'constant_with_warmup'].
212
- lr_warmup_steps (`int`, defaults to `500`):
213
- Number of steps for the warmup in the lr scheduler.
214
- lr_num_cycles (`int`, defaults to `1`):
215
- Number of hard resets of the lr in cosine_with_restarts scheduler.
216
- lr_power (`float`, defaults to `1.0`):
217
- Power factor of the polynomial scheduler.
218
- beta1 (`float`, defaults to `0.9`):
219
- beta2 (`float`, defaults to `0.95`):
220
- beta3 (`float`, defaults to `0.999`):
221
- weight_decay (`float`, defaults to `0.0001`):
222
- Penalty for large weights in the model.
223
- epsilon (`float`, defaults to `1e-8`):
224
- Small value to avoid division by zero in the optimizer.
225
- max_grad_norm (`float`, defaults to `1.0`):
226
- Maximum gradient norm to clip the gradients.
227
-
228
- VALIDATION ARGUMENTS
229
- --------------------
230
- validation_dataset_file (`str`, defaults to `None`):
231
- Path to a CSV/JSON/PARQUET/ARROW file containing information for validation. The file must contain atleast the
232
- "caption" column. Other columns such as "image_path" and "video_path" can be provided too. If provided, "image_path"
233
- will be used to load a PIL.Image.Image and set the "image" key in the sample dictionary. Similarly, "video_path"
234
- will be used to load a List[PIL.Image.Image] and set the "video" key in the sample dictionary.
235
- The validation dataset file may contain other attributes specific to inference/validation such as:
236
- - "height" and "width" and "num_frames": Resolution
237
- - "num_inference_steps": Number of inference steps
238
- - "guidance_scale": Classifier-free Guidance Scale
239
- - ... (any number of additional attributes can be provided. The ModelSpecification::validate method will be
240
- invoked with the sample dictionary to validate the sample.)
241
- validation_steps (`int`, defaults to `500`):
242
- Number of training steps after which a validation step is performed.
243
- enable_model_cpu_offload (`bool`, defaults to `False`):
244
- Whether or not to offload different modeling components to CPU during validation.
245
-
246
- MISCELLANEOUS ARGUMENTS
247
- -----------------------
248
- tracker_name (`str`, defaults to `finetrainers`):
249
- Name of the tracker/project to use for logging training metrics.
250
- push_to_hub (`bool`, defaults to `False`):
251
- Whether or not to push the model to the Hugging Face Hub.
252
- hub_token (`str`, defaults to `None`):
253
- The API token to use for pushing the model to the Hugging Face Hub.
254
- hub_model_id (`str`, defaults to `None`):
255
- The model identifier to use for pushing the model to the Hugging Face Hub.
256
- output_dir (`str`, defaults to `None`):
257
- The directory where the model checkpoints and logs will be stored.
258
- logging_dir (`str`, defaults to `logs`):
259
- The directory where the logs will be stored.
260
- logging_steps (`int`, defaults to `1`):
261
- Training logs will be tracked every `logging_steps` steps.
262
- allow_tf32 (`bool`, defaults to `False`):
263
- Whether or not to allow the use of TF32 matmul on compatible hardware.
264
- nccl_timeout (`int`, defaults to `1800`):
265
- Timeout for the NCCL communication.
266
- report_to (`str`, defaults to `wandb`):
267
- The name of the logger to use for logging training metrics. Choose between ['wandb'].
268
- verbose (`int`, defaults to `1`):
269
- Whether or not to print verbose logs.
270
- - 0: Diffusers/Transformers warning logging on local main process only
271
- - 1: Diffusers/Transformers info logging on local main process only
272
- - 2: Diffusers/Transformers debug logging on local main process only
273
- - 3: Diffusers/Transformers debug logging on all processes
274
- """
275
-
276
- # Parallel arguments
277
- parallel_backend = ParallelBackendEnum.ACCELERATE
278
- pp_degree: int = 1
279
- dp_degree: int = 1
280
- dp_shards: int = 1
281
- cp_degree: int = 1
282
- tp_degree: int = 1
283
-
284
- # Model arguments
285
- model_name: str = None
286
- pretrained_model_name_or_path: str = None
287
- revision: Optional[str] = None
288
- variant: Optional[str] = None
289
- cache_dir: Optional[str] = None
290
- tokenizer_id: Optional[str] = None
291
- tokenizer_2_id: Optional[str] = None
292
- tokenizer_3_id: Optional[str] = None
293
- text_encoder_id: Optional[str] = None
294
- text_encoder_2_id: Optional[str] = None
295
- text_encoder_3_id: Optional[str] = None
296
- transformer_id: Optional[str] = None
297
- vae_id: Optional[str] = None
298
- text_encoder_dtype: torch.dtype = torch.bfloat16
299
- text_encoder_2_dtype: torch.dtype = torch.bfloat16
300
- text_encoder_3_dtype: torch.dtype = torch.bfloat16
301
- transformer_dtype: torch.dtype = torch.bfloat16
302
- vae_dtype: torch.dtype = torch.bfloat16
303
- layerwise_upcasting_modules: List[str] = []
304
- layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn
305
- layerwise_upcasting_skip_modules_pattern: List[str] = [
306
- "patch_embed",
307
- "pos_embed",
308
- "x_embedder",
309
- "context_embedder",
310
- "time_embed",
311
- "^proj_in$",
312
- "^proj_out$",
313
- "norm",
314
- ]
315
-
316
- # Dataset arguments
317
- dataset_config: str = None
318
- dataset_shuffle_buffer_size: int = 1
319
- enable_precomputation: bool = False
320
- precomputation_items: int = 512
321
- precomputation_dir: Optional[str] = None
322
- precomputation_once: bool = False
323
-
324
- # Dataloader arguments
325
- dataloader_num_workers: int = 0
326
- pin_memory: bool = False
327
-
328
- # Diffusion arguments
329
- flow_resolution_shifting: bool = False
330
- flow_base_seq_len: int = 256
331
- flow_max_seq_len: int = 4096
332
- flow_base_shift: float = 0.5
333
- flow_max_shift: float = 1.15
334
- flow_shift: float = 1.0
335
- flow_weighting_scheme: str = "none"
336
- flow_logit_mean: float = 0.0
337
- flow_logit_std: float = 1.0
338
- flow_mode_scale: float = 1.29
339
-
340
- # Training arguments
341
- training_type: str = None
342
- seed: int = 42
343
- batch_size: int = 1
344
- train_steps: int = 1000
345
- max_data_samples: int = 2**64
346
- gradient_accumulation_steps: int = 1
347
- gradient_checkpointing: bool = False
348
- checkpointing_steps: int = 500
349
- checkpointing_limit: Optional[int] = None
350
- resume_from_checkpoint: Optional[str] = None
351
- enable_slicing: bool = False
352
- enable_tiling: bool = False
353
-
354
- # Optimizer arguments
355
- optimizer: str = "adamw"
356
- lr: float = 1e-4
357
- lr_scheduler: str = "cosine_with_restarts"
358
- lr_warmup_steps: int = 0
359
- lr_num_cycles: int = 1
360
- lr_power: float = 1.0
361
- beta1: float = 0.9
362
- beta2: float = 0.95
363
- beta3: float = 0.999
364
- weight_decay: float = 0.0001
365
- epsilon: float = 1e-8
366
- max_grad_norm: float = 1.0
367
-
368
- # Validation arguments
369
- validation_dataset_file: Optional[str] = None
370
- validation_steps: int = 500
371
- enable_model_cpu_offload: bool = False
372
-
373
- # Miscellaneous arguments
374
- tracker_name: str = "finetrainers"
375
- push_to_hub: bool = False
376
- hub_token: Optional[str] = None
377
- hub_model_id: Optional[str] = None
378
- output_dir: str = None
379
- logging_dir: Optional[str] = "logs"
380
- logging_steps: int = 1
381
- allow_tf32: bool = False
382
- init_timeout: int = 300 # 5 minutes
383
- nccl_timeout: int = 600 # 10 minutes, considering that validation may be performed
384
- report_to: str = "wandb"
385
- verbose: int = 1
386
-
387
- def to_dict(self) -> Dict[str, Any]:
388
- parallel_arguments = {
389
- "pp_degree": self.pp_degree,
390
- "dp_degree": self.dp_degree,
391
- "dp_shards": self.dp_shards,
392
- "cp_degree": self.cp_degree,
393
- "tp_degree": self.tp_degree,
394
- }
395
-
396
- model_arguments = {
397
- "model_name": self.model_name,
398
- "pretrained_model_name_or_path": self.pretrained_model_name_or_path,
399
- "revision": self.revision,
400
- "variant": self.variant,
401
- "cache_dir": self.cache_dir,
402
- "tokenizer_id": self.tokenizer_id,
403
- "tokenizer_2_id": self.tokenizer_2_id,
404
- "tokenizer_3_id": self.tokenizer_3_id,
405
- "text_encoder_id": self.text_encoder_id,
406
- "text_encoder_2_id": self.text_encoder_2_id,
407
- "text_encoder_3_id": self.text_encoder_3_id,
408
- "transformer_id": self.transformer_id,
409
- "vae_id": self.vae_id,
410
- "text_encoder_dtype": self.text_encoder_dtype,
411
- "text_encoder_2_dtype": self.text_encoder_2_dtype,
412
- "text_encoder_3_dtype": self.text_encoder_3_dtype,
413
- "transformer_dtype": self.transformer_dtype,
414
- "vae_dtype": self.vae_dtype,
415
- "layerwise_upcasting_modules": self.layerwise_upcasting_modules,
416
- "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
417
- "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
418
- }
419
- model_arguments = get_non_null_items(model_arguments)
420
-
421
- dataset_arguments = {
422
- "dataset_config": self.dataset_config,
423
- "dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size,
424
- "enable_precomputation": self.enable_precomputation,
425
- "precomputation_items": self.precomputation_items,
426
- "precomputation_dir": self.precomputation_dir,
427
- "precomputation_once": self.precomputation_once,
428
- }
429
- dataset_arguments = get_non_null_items(dataset_arguments)
430
-
431
- dataloader_arguments = {
432
- "dataloader_num_workers": self.dataloader_num_workers,
433
- "pin_memory": self.pin_memory,
434
- }
435
-
436
- diffusion_arguments = {
437
- "flow_resolution_shifting": self.flow_resolution_shifting,
438
- "flow_base_seq_len": self.flow_base_seq_len,
439
- "flow_max_seq_len": self.flow_max_seq_len,
440
- "flow_base_shift": self.flow_base_shift,
441
- "flow_max_shift": self.flow_max_shift,
442
- "flow_shift": self.flow_shift,
443
- "flow_weighting_scheme": self.flow_weighting_scheme,
444
- "flow_logit_mean": self.flow_logit_mean,
445
- "flow_logit_std": self.flow_logit_std,
446
- "flow_mode_scale": self.flow_mode_scale,
447
- }
448
-
449
- training_arguments = {
450
- "training_type": self.training_type,
451
- "seed": self.seed,
452
- "batch_size": self.batch_size,
453
- "train_steps": self.train_steps,
454
- "max_data_samples": self.max_data_samples,
455
- "gradient_accumulation_steps": self.gradient_accumulation_steps,
456
- "gradient_checkpointing": self.gradient_checkpointing,
457
- "checkpointing_steps": self.checkpointing_steps,
458
- "checkpointing_limit": self.checkpointing_limit,
459
- "resume_from_checkpoint": self.resume_from_checkpoint,
460
- "enable_slicing": self.enable_slicing,
461
- "enable_tiling": self.enable_tiling,
462
- }
463
- training_arguments = get_non_null_items(training_arguments)
464
-
465
- optimizer_arguments = {
466
- "optimizer": self.optimizer,
467
- "lr": self.lr,
468
- "lr_scheduler": self.lr_scheduler,
469
- "lr_warmup_steps": self.lr_warmup_steps,
470
- "lr_num_cycles": self.lr_num_cycles,
471
- "lr_power": self.lr_power,
472
- "beta1": self.beta1,
473
- "beta2": self.beta2,
474
- "beta3": self.beta3,
475
- "weight_decay": self.weight_decay,
476
- "epsilon": self.epsilon,
477
- "max_grad_norm": self.max_grad_norm,
478
- }
479
- optimizer_arguments = get_non_null_items(optimizer_arguments)
480
-
481
- validation_arguments = {
482
- "validation_dataset_file": self.validation_dataset_file,
483
- "validation_steps": self.validation_steps,
484
- "enable_model_cpu_offload": self.enable_model_cpu_offload,
485
- }
486
- validation_arguments = get_non_null_items(validation_arguments)
487
-
488
- miscellaneous_arguments = {
489
- "tracker_name": self.tracker_name,
490
- "push_to_hub": self.push_to_hub,
491
- "hub_token": self.hub_token,
492
- "hub_model_id": self.hub_model_id,
493
- "output_dir": self.output_dir,
494
- "logging_dir": self.logging_dir,
495
- "logging_steps": self.logging_steps,
496
- "allow_tf32": self.allow_tf32,
497
- "init_timeout": self.init_timeout,
498
- "nccl_timeout": self.nccl_timeout,
499
- "report_to": self.report_to,
500
- "verbose": self.verbose,
501
- }
502
- miscellaneous_arguments = get_non_null_items(miscellaneous_arguments)
503
-
504
- return {
505
- "parallel_arguments": parallel_arguments,
506
- "model_arguments": model_arguments,
507
- "dataset_arguments": dataset_arguments,
508
- "dataloader_arguments": dataloader_arguments,
509
- "diffusion_arguments": diffusion_arguments,
510
- "training_arguments": training_arguments,
511
- "optimizer_arguments": optimizer_arguments,
512
- "validation_arguments": validation_arguments,
513
- "miscellaneous_arguments": miscellaneous_arguments,
514
- }
515
-
516
- def extend_args(
517
- self,
518
- add_fn: Callable[[argparse.ArgumentParser], None],
519
- map_fn: Callable[["BaseArgs"], None],
520
- validate_fn: Callable[["BaseArgs"], None],
521
- ) -> None:
522
- if not hasattr(self, "_extended_add_arguments"):
523
- self._extended_add_arguments = []
524
- self._extended_add_arguments.append((add_fn, validate_fn, map_fn))
525
-
526
- def parse_args(self):
527
- _LIST_MODELS = "--list_models"
528
-
529
- parser = argparse.ArgumentParser()
530
-
531
- special_args = [_LIST_MODELS]
532
- if any(arg in sys.argv for arg in special_args):
533
- _add_helper_arguments(parser)
534
- args = parser.parse_args()
535
- _display_helper_messages(args)
536
- sys.exit(0)
537
- else:
538
- _add_args(parser)
539
- for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
540
- add_fn, _, _ = extended_add_arg_fns
541
- add_fn(parser)
542
-
543
- args, remaining_args = parser.parse_known_args()
544
- logger.debug(f"Remaining unparsed arguments: {remaining_args}")
545
-
546
- mapped_args = _map_to_args_type(args)
547
- for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
548
- _, _, map_fn = extended_add_arg_fns
549
- map_fn(args, mapped_args)
550
-
551
- _validate_args(mapped_args)
552
- for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
553
- _, validate_fn, _ = extended_add_arg_fns
554
- validate_fn(mapped_args)
555
-
556
- return mapped_args
557
-
558
-
559
- def _add_args(parser: argparse.ArgumentParser) -> None:
560
- _add_parallel_arguments(parser)
561
- _add_model_arguments(parser)
562
- _add_dataset_arguments(parser)
563
- _add_dataloader_arguments(parser)
564
- _add_diffusion_arguments(parser)
565
- _add_training_arguments(parser)
566
- _add_optimizer_arguments(parser)
567
- _add_validation_arguments(parser)
568
- _add_miscellaneous_arguments(parser)
569
-
570
-
571
- def _validate_args(args: BaseArgs):
572
- _validate_model_args(args)
573
- _validate_dataset_args(args)
574
- _validate_validation_args(args)
575
-
576
-
577
- def _add_parallel_arguments(parser: argparse.ArgumentParser) -> None:
578
- parser.add_argument(
579
- "--parallel_backend",
580
- type=str,
581
- default=ParallelBackendEnum.ACCELERATE,
582
- choices=[ParallelBackendEnum.ACCELERATE, ParallelBackendEnum.PTD],
583
- )
584
- parser.add_argument("--pp_degree", type=int, default=1)
585
- parser.add_argument("--dp_degree", type=int, default=1)
586
- parser.add_argument("--dp_shards", type=int, default=1)
587
- parser.add_argument("--cp_degree", type=int, default=1)
588
- parser.add_argument("--tp_degree", type=int, default=1)
589
-
590
-
591
- def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
592
- parser.add_argument(
593
- "--model_name", type=str, required=True, choices=[x.value for x in ModelType.__members__.values()]
594
- )
595
- parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
596
- parser.add_argument("--revision", type=str, default=None, required=False)
597
- parser.add_argument("--variant", type=str, default=None)
598
- parser.add_argument("--cache_dir", type=str, default=None)
599
- parser.add_argument("--tokenizer_id", type=str, default=None)
600
- parser.add_argument("--tokenizer_2_id", type=str, default=None)
601
- parser.add_argument("--tokenizer_3_id", type=str, default=None)
602
- parser.add_argument("--text_encoder_id", type=str, default=None)
603
- parser.add_argument("--text_encoder_2_id", type=str, default=None)
604
- parser.add_argument("--text_encoder_3_id", type=str, default=None)
605
- parser.add_argument("--transformer_id", type=str, default=None)
606
- parser.add_argument("--vae_id", type=str, default=None)
607
- parser.add_argument("--text_encoder_dtype", type=str, default="bf16")
608
- parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16")
609
- parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16")
610
- parser.add_argument("--transformer_dtype", type=str, default="bf16")
611
- parser.add_argument("--vae_dtype", type=str, default="bf16")
612
- parser.add_argument("--layerwise_upcasting_modules", type=str, default=[], nargs="+", choices=["transformer"])
613
- parser.add_argument(
614
- "--layerwise_upcasting_storage_dtype",
615
- type=str,
616
- default="float8_e4m3fn",
617
- choices=["float8_e4m3fn", "float8_e5m2"],
618
- )
619
- parser.add_argument(
620
- "--layerwise_upcasting_skip_modules_pattern",
621
- type=str,
622
- default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
623
- nargs="+",
624
- )
625
-
626
-
627
- def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
628
- parser.add_argument("--dataset_config", type=str, required=True)
629
- parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1)
630
- parser.add_argument("--enable_precomputation", action="store_true")
631
- parser.add_argument("--precomputation_items", type=int, default=512)
632
- parser.add_argument("--precomputation_dir", type=str, default=None)
633
- parser.add_argument("--precomputation_once", action="store_true")
634
-
635
-
636
- def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None:
637
- parser.add_argument("--dataloader_num_workers", type=int, default=0)
638
- parser.add_argument("--pin_memory", action="store_true")
639
-
640
-
641
- def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
642
- parser.add_argument("--flow_resolution_shifting", action="store_true")
643
- parser.add_argument("--flow_base_seq_len", type=int, default=256)
644
- parser.add_argument("--flow_max_seq_len", type=int, default=4096)
645
- parser.add_argument("--flow_base_shift", type=float, default=0.5)
646
- parser.add_argument("--flow_max_shift", type=float, default=1.15)
647
- parser.add_argument("--flow_shift", type=float, default=1.0)
648
- parser.add_argument(
649
- "--flow_weighting_scheme",
650
- type=str,
651
- default="none",
652
- choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
653
- )
654
- parser.add_argument("--flow_logit_mean", type=float, default=0.0)
655
- parser.add_argument("--flow_logit_std", type=float, default=1.0)
656
- parser.add_argument("--flow_mode_scale", type=float, default=1.29)
657
-
658
-
659
- def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
660
- parser.add_argument(
661
- "--training_type", type=str, choices=[x.value for x in TrainingType.__members__.values()], required=True
662
- )
663
- parser.add_argument("--seed", type=int, default=None)
664
- parser.add_argument("--batch_size", type=int, default=1)
665
- parser.add_argument("--train_steps", type=int, default=1000)
666
- parser.add_argument("--max_data_samples", type=int, default=2**64)
667
- parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
668
- parser.add_argument("--gradient_checkpointing", action="store_true")
669
- parser.add_argument("--checkpointing_steps", type=int, default=500)
670
- parser.add_argument("--checkpointing_limit", type=int, default=None)
671
- parser.add_argument("--resume_from_checkpoint", type=str, default=None)
672
- parser.add_argument("--enable_slicing", action="store_true")
673
- parser.add_argument("--enable_tiling", action="store_true")
674
-
675
-
676
- def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None:
677
- parser.add_argument("--lr", type=float, default=1e-4)
678
- parser.add_argument("--lr_scheduler", type=str, default="constant")
679
- parser.add_argument("--lr_warmup_steps", type=int, default=500)
680
- parser.add_argument("--lr_num_cycles", type=int, default=1)
681
- parser.add_argument("--lr_power", type=float, default=1.0)
682
- parser.add_argument(
683
- "--optimizer",
684
- type=lambda s: s.lower(),
685
- default="adam",
686
- choices=["adam", "adamw", "adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"],
687
- )
688
- parser.add_argument("--beta1", type=float, default=0.9)
689
- parser.add_argument("--beta2", type=float, default=0.95)
690
- parser.add_argument("--beta3", type=float, default=None)
691
- parser.add_argument("--weight_decay", type=float, default=1e-04)
692
- parser.add_argument("--epsilon", type=float, default=1e-8)
693
- parser.add_argument("--max_grad_norm", default=1.0, type=float)
694
-
695
-
696
- def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
697
- parser.add_argument("--validation_dataset_file", type=str, default=None)
698
- parser.add_argument("--validation_steps", type=int, default=500)
699
- parser.add_argument("--enable_model_cpu_offload", action="store_true")
700
-
701
-
702
- def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
703
- parser.add_argument("--tracker_name", type=str, default="finetrainers")
704
- parser.add_argument("--push_to_hub", action="store_true")
705
- parser.add_argument("--hub_token", type=str, default=None)
706
- parser.add_argument("--hub_model_id", type=str, default=None)
707
- parser.add_argument("--output_dir", type=str, default="finetrainers-training")
708
- parser.add_argument("--logging_dir", type=str, default="logs")
709
- parser.add_argument("--logging_steps", type=int, default=1)
710
- parser.add_argument("--allow_tf32", action="store_true")
711
- parser.add_argument("--init_timeout", type=int, default=300)
712
- parser.add_argument("--nccl_timeout", type=int, default=600)
713
- parser.add_argument("--report_to", type=str, default="none", choices=["none", "wandb"])
714
- parser.add_argument("--verbose", type=int, default=0, choices=[0, 1, 2, 3])
715
-
716
-
717
- def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
718
- parser.add_argument("--list_models", action="store_true")
719
-
720
-
721
- _DTYPE_MAP = {
722
- "bf16": torch.bfloat16,
723
- "fp16": torch.float16,
724
- "fp32": torch.float32,
725
- "float8_e4m3fn": torch.float8_e4m3fn,
726
- "float8_e5m2": torch.float8_e5m2,
727
- }
728
-
729
-
730
- def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs:
731
- result_args = BaseArgs()
732
-
733
- # Parallel arguments
734
- result_args.parallel_backend = args.parallel_backend
735
- result_args.pp_degree = args.pp_degree
736
- result_args.dp_degree = args.dp_degree
737
- result_args.dp_shards = args.dp_shards
738
- result_args.cp_degree = args.cp_degree
739
- result_args.tp_degree = args.tp_degree
740
-
741
- # Model arguments
742
- result_args.model_name = args.model_name
743
- result_args.pretrained_model_name_or_path = args.pretrained_model_name_or_path
744
- result_args.revision = args.revision
745
- result_args.variant = args.variant
746
- result_args.cache_dir = args.cache_dir
747
- result_args.tokenizer_id = args.tokenizer_id
748
- result_args.tokenizer_2_id = args.tokenizer_2_id
749
- result_args.tokenizer_3_id = args.tokenizer_3_id
750
- result_args.text_encoder_id = args.text_encoder_id
751
- result_args.text_encoder_2_id = args.text_encoder_2_id
752
- result_args.text_encoder_3_id = args.text_encoder_3_id
753
- result_args.transformer_id = args.transformer_id
754
- result_args.vae_id = args.vae_id
755
- result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
756
- result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
757
- result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
758
- result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
759
- result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]
760
- result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules
761
- result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype]
762
- result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern
763
-
764
- # Dataset arguments
765
- result_args.dataset_config = args.dataset_config
766
- result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size
767
- result_args.enable_precomputation = args.enable_precomputation
768
- result_args.precomputation_items = args.precomputation_items
769
- result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed")
770
- result_args.precomputation_once = args.precomputation_once
771
-
772
- # Dataloader arguments
773
- result_args.dataloader_num_workers = args.dataloader_num_workers
774
- result_args.pin_memory = args.pin_memory
775
-
776
- # Diffusion arguments
777
- result_args.flow_resolution_shifting = args.flow_resolution_shifting
778
- result_args.flow_base_seq_len = args.flow_base_seq_len
779
- result_args.flow_max_seq_len = args.flow_max_seq_len
780
- result_args.flow_base_shift = args.flow_base_shift
781
- result_args.flow_max_shift = args.flow_max_shift
782
- result_args.flow_shift = args.flow_shift
783
- result_args.flow_weighting_scheme = args.flow_weighting_scheme
784
- result_args.flow_logit_mean = args.flow_logit_mean
785
- result_args.flow_logit_std = args.flow_logit_std
786
- result_args.flow_mode_scale = args.flow_mode_scale
787
-
788
- # Training arguments
789
- result_args.training_type = args.training_type
790
- result_args.seed = args.seed
791
- result_args.batch_size = args.batch_size
792
- result_args.train_steps = args.train_steps
793
- result_args.max_data_samples = args.max_data_samples
794
- result_args.gradient_accumulation_steps = args.gradient_accumulation_steps
795
- result_args.gradient_checkpointing = args.gradient_checkpointing
796
- result_args.checkpointing_steps = args.checkpointing_steps
797
- result_args.checkpointing_limit = args.checkpointing_limit
798
- result_args.resume_from_checkpoint = args.resume_from_checkpoint
799
- result_args.enable_slicing = args.enable_slicing
800
- result_args.enable_tiling = args.enable_tiling
801
-
802
- # Optimizer arguments
803
- result_args.optimizer = args.optimizer or "adamw"
804
- result_args.lr = args.lr or 1e-4
805
- result_args.lr_scheduler = args.lr_scheduler
806
- result_args.lr_warmup_steps = args.lr_warmup_steps
807
- result_args.lr_num_cycles = args.lr_num_cycles
808
- result_args.lr_power = args.lr_power
809
- result_args.beta1 = args.beta1
810
- result_args.beta2 = args.beta2
811
- result_args.beta3 = args.beta3
812
- result_args.weight_decay = args.weight_decay
813
- result_args.epsilon = args.epsilon
814
- result_args.max_grad_norm = args.max_grad_norm
815
-
816
- # Validation arguments
817
- result_args.validation_dataset_file = args.validation_dataset_file
818
- result_args.validation_steps = args.validation_steps
819
- result_args.enable_model_cpu_offload = args.enable_model_cpu_offload
820
-
821
- # Miscellaneous arguments
822
- result_args.tracker_name = args.tracker_name
823
- result_args.push_to_hub = args.push_to_hub
824
- result_args.hub_token = args.hub_token
825
- result_args.hub_model_id = args.hub_model_id
826
- result_args.output_dir = args.output_dir
827
- result_args.logging_dir = args.logging_dir
828
- result_args.logging_steps = args.logging_steps
829
- result_args.allow_tf32 = args.allow_tf32
830
- result_args.init_timeout = args.init_timeout
831
- result_args.nccl_timeout = args.nccl_timeout
832
- result_args.report_to = args.report_to
833
- result_args.verbose = args.verbose
834
-
835
- return result_args
836
-
837
-
838
- def _validate_model_args(args: BaseArgs):
839
- if args.training_type == "full-finetune":
840
- assert (
841
- "transformer" not in args.layerwise_upcasting_modules
842
- ), "Layerwise upcasting is not supported for full-finetune training"
843
-
844
-
845
- def _validate_dataset_args(args: BaseArgs):
846
- dataset_config = pathlib.Path(args.dataset_config)
847
- if not dataset_config.exists():
848
- raise ValueError(f"Dataset config file {args.dataset_config} does not exist.")
849
- if args.dataset_shuffle_buffer_size < 1:
850
- raise ValueError("Dataset shuffle buffer size must be greater than 0.")
851
- if args.precomputation_items < 1:
852
- raise ValueError("Precomputation items must be greater than 0.")
853
-
854
-
855
- def _validate_validation_args(args: BaseArgs):
856
- if args.enable_model_cpu_offload:
857
- if any(x > 1 for x in [args.pp_degree, args.dp_degree, args.dp_shards, args.cp_degree, args.tp_degree]):
858
- raise ValueError("Model CPU offload is not supported on multi-GPU at the moment.")
859
-
860
-
861
- def _display_helper_messages(args: argparse.Namespace):
862
- if args.list_models:
863
- print("Supported models:")
864
- for index, model_name in enumerate(SUPPORTED_MODEL_CONFIGS.keys()):
865
- print(f" {index + 1}. {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/config.py DELETED
@@ -1,58 +0,0 @@
1
- from enum import Enum
2
- from typing import Type
3
-
4
- from .models import ModelSpecification
5
- from .models.cogvideox import CogVideoXModelSpecification
6
- from .models.cogview4 import CogView4ModelSpecification
7
- from .models.hunyuan_video import HunyuanVideoModelSpecification
8
- from .models.ltx_video import LTXVideoModelSpecification
9
- from .models.wan import WanModelSpecification
10
-
11
-
12
- class ModelType(str, Enum):
13
- COGVIDEOX = "cogvideox"
14
- COGVIEW4 = "cogview4"
15
- HUNYUAN_VIDEO = "hunyuan_video"
16
- LTX_VIDEO = "ltx_video"
17
- WAN = "wan"
18
-
19
-
20
- class TrainingType(str, Enum):
21
- LORA = "lora"
22
- FULL_FINETUNE = "full-finetune"
23
-
24
-
25
- SUPPORTED_MODEL_CONFIGS = {
26
- ModelType.COGVIDEOX: {
27
- TrainingType.LORA: CogVideoXModelSpecification,
28
- TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
29
- },
30
- ModelType.COGVIEW4: {
31
- TrainingType.LORA: CogView4ModelSpecification,
32
- TrainingType.FULL_FINETUNE: CogView4ModelSpecification,
33
- },
34
- ModelType.HUNYUAN_VIDEO: {
35
- TrainingType.LORA: HunyuanVideoModelSpecification,
36
- TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
37
- },
38
- ModelType.LTX_VIDEO: {
39
- TrainingType.LORA: LTXVideoModelSpecification,
40
- TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
41
- },
42
- ModelType.WAN: {
43
- TrainingType.LORA: WanModelSpecification,
44
- TrainingType.FULL_FINETUNE: WanModelSpecification,
45
- },
46
- }
47
-
48
-
49
- def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]:
50
- if model_name not in SUPPORTED_MODEL_CONFIGS:
51
- raise ValueError(
52
- f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
53
- )
54
- if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
55
- raise ValueError(
56
- f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
57
- )
58
- return SUPPORTED_MODEL_CONFIGS[model_name][training_type]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/constants.py DELETED
@@ -1,83 +0,0 @@
1
- import os
2
-
3
-
4
- DEFAULT_HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
5
- DEFAULT_WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
6
- DEFAULT_FRAME_BUCKETS = [49]
7
-
8
- DEFAULT_IMAGE_RESOLUTION_BUCKETS = []
9
- for height in DEFAULT_HEIGHT_BUCKETS:
10
- for width in DEFAULT_WIDTH_BUCKETS:
11
- DEFAULT_IMAGE_RESOLUTION_BUCKETS.append((height, width))
12
-
13
- DEFAULT_VIDEO_RESOLUTION_BUCKETS = []
14
- for frames in DEFAULT_FRAME_BUCKETS:
15
- for height in DEFAULT_HEIGHT_BUCKETS:
16
- for width in DEFAULT_WIDTH_BUCKETS:
17
- DEFAULT_VIDEO_RESOLUTION_BUCKETS.append((frames, height, width))
18
-
19
-
20
- FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")
21
-
22
- PRECOMPUTED_DIR_NAME = "precomputed"
23
- PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions"
24
- PRECOMPUTED_LATENTS_DIR_NAME = "latents"
25
-
26
- MODEL_DESCRIPTION = r"""
27
- \# {model_id} {training_type} finetune
28
-
29
- <Gallery />
30
-
31
- \#\# Model Description
32
-
33
- This model is a {training_type} of the `{model_id}` model.
34
-
35
- This model was trained using the `fine-video-trainers` library - a repository containing memory-optimized scripts for training video models with [Diffusers](https://github.com/huggingface/diffusers).
36
-
37
- \#\# Download model
38
-
39
- [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
40
-
41
- \#\# Usage
42
-
43
- Requires [🧨 Diffusers](https://github.com/huggingface/diffusers) installed.
44
-
45
- ```python
46
- {model_example}
47
- ```
48
-
49
- For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
50
-
51
- \#\# License
52
-
53
- Please adhere to the license of the base model.
54
- """.strip()
55
-
56
- _COMMON_BEGINNING_PHRASES = (
57
- "This video",
58
- "The video",
59
- "This clip",
60
- "The clip",
61
- "The animation",
62
- "This image",
63
- "The image",
64
- "This picture",
65
- "The picture",
66
- )
67
- _COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents")
68
-
69
- COMMON_LLM_START_PHRASES = (
70
- "In the video,",
71
- "In this video,",
72
- "In this video clip,",
73
- "In the clip,",
74
- "Caption:",
75
- *(
76
- f"{beginning} {continuation}"
77
- for beginning in _COMMON_BEGINNING_PHRASES
78
- for continuation in _COMMON_CONTINUATION_WORDS
79
- ),
80
- )
81
-
82
- SUPPORTED_IMAGE_FILE_EXTENSIONS = ("jpg", "jpeg", "png")
83
- SUPPORTED_VIDEO_FILE_EXTENSIONS = ("mp4", "mov")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/data/__init__.py DELETED
@@ -1,27 +0,0 @@
1
- from ._artifact import ImageArtifact, VideoArtifact
2
- from .dataloader import DPDataLoader
3
- from .dataset import (
4
- ImageCaptionFilePairDataset,
5
- ImageFileCaptionFileListDataset,
6
- ImageFolderDataset,
7
- ImageWebDataset,
8
- ValidationDataset,
9
- VideoCaptionFilePairDataset,
10
- VideoFileCaptionFileListDataset,
11
- VideoFolderDataset,
12
- VideoWebDataset,
13
- combine_datasets,
14
- initialize_dataset,
15
- wrap_iterable_dataset_for_preprocessing,
16
- )
17
- from .precomputation import (
18
- InMemoryDataIterable,
19
- InMemoryDistributedDataPreprocessor,
20
- InMemoryOnceDataIterable,
21
- PrecomputedDataIterable,
22
- PrecomputedDistributedDataPreprocessor,
23
- PrecomputedOnceDataIterable,
24
- initialize_preprocessor,
25
- )
26
- from .sampler import ResolutionSampler
27
- from .utils import find_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/data/_artifact.py DELETED
@@ -1,29 +0,0 @@
1
- # ===== THIS FILE ONLY EXISTS FOR THE TIME BEING SINCE I DID NOT KNOW WHERE TO PUT IT =====
2
-
3
- from dataclasses import dataclass
4
- from typing import Any, List
5
-
6
- from PIL.Image import Image
7
-
8
-
9
- @dataclass
10
- class Artifact:
11
- type: str
12
- value: Any
13
- file_extension: str
14
-
15
-
16
- @dataclass
17
- class ImageArtifact(Artifact):
18
- value: Image
19
-
20
- def __init__(self, value: Image):
21
- super().__init__(type="image", value=value, file_extension="png")
22
-
23
-
24
- @dataclass
25
- class VideoArtifact(Artifact):
26
- value: List[Image]
27
-
28
- def __init__(self, value: List[Image]):
29
- super().__init__(type="video", value=value, file_extension="mp4")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/data/dataloader.py DELETED
@@ -1,40 +0,0 @@
1
- import pickle
2
- from typing import Any, Dict
3
-
4
- import torch.distributed.checkpoint.stateful
5
- import torchdata.stateful_dataloader
6
-
7
- from ..logging import get_logger
8
-
9
-
10
- logger = get_logger()
11
-
12
-
13
- class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful):
14
- def __init__(
15
- self,
16
- rank: int,
17
- dataset: torch.utils.data.IterableDataset,
18
- batch_size: int = 1,
19
- num_workers: int = 0,
20
- collate_fn=None,
21
- ) -> None:
22
- super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn)
23
-
24
- self._dp_rank = rank
25
- self._rank_id = f"dp_rank_{rank}"
26
-
27
- def state_dict(self) -> Dict[str, Any]:
28
- # Store state only for dp rank to avoid replicating the same state across other dimensions
29
- return {self._rank_id: pickle.dumps(super().state_dict())}
30
-
31
- def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
32
- # State being empty is valid
33
- if not state_dict:
34
- return
35
-
36
- if self._rank_id not in state_dict:
37
- logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}")
38
- return
39
-
40
- super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/data/dataset.py DELETED
@@ -1,978 +0,0 @@
1
- import pathlib
2
- import random
3
- from typing import Any, Dict, List, Optional, Tuple, Union
4
-
5
- import datasets
6
- import datasets.data_files
7
- import datasets.distributed
8
- import datasets.exceptions
9
- import huggingface_hub
10
- import huggingface_hub.errors
11
- import numpy as np
12
- import PIL.Image
13
- import torch
14
- import torch.distributed.checkpoint.stateful
15
- from diffusers.utils import load_image, load_video
16
- from huggingface_hub import list_repo_files, repo_exists, snapshot_download
17
- from tqdm.auto import tqdm
18
-
19
- from .. import constants
20
- from .. import functional as FF
21
- from ..logging import get_logger
22
- from . import utils
23
-
24
-
25
- import decord # isort:skip
26
-
27
- decord.bridge.set_bridge("torch")
28
-
29
- logger = get_logger()
30
-
31
-
32
- # fmt: off
33
- MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
34
- COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
35
- COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
36
- COMMON_IMAGE_FILES = ["image.txt", "images.txt"]
37
- COMMON_WDS_CAPTION_COLUMN_NAMES = ["txt", "text", "caption", "captions", "short_caption", "long_caption", "prompt", "prompts", "short_prompt", "long_prompt", "description", "descriptions", "alt_text", "alt_texts", "alt_caption", "alt_captions", "alt_prompt", "alt_prompts", "alt_description", "alt_descriptions", "image_description", "image_descriptions", "image_caption", "image_captions", "image_prompt", "image_prompts", "image_alt_text", "image_alt_texts", "image_alt_caption", "image_alt_captions", "image_alt_prompt", "image_alt_prompts", "image_alt_description", "image_alt_descriptions", "video_description", "video_descriptions", "video_caption", "video_captions", "video_prompt", "video_prompts", "video_alt_text", "video_alt_texts", "video_alt_caption", "video_alt_captions", "video_alt_prompt", "video_alt_prompts", "video_alt_description"]
38
- # fmt: on
39
-
40
-
41
- class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
42
- def __init__(self, root: str, infinite: bool = False) -> None:
43
- super().__init__()
44
-
45
- self.root = pathlib.Path(root)
46
- self.infinite = infinite
47
-
48
- data = []
49
- caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0))
50
- for caption_file in caption_files:
51
- data_file = self._find_data_file(caption_file)
52
- if data_file:
53
- data.append(
54
- {
55
- "caption": (self.root / caption_file).as_posix(),
56
- "image": (self.root / data_file).as_posix(),
57
- }
58
- )
59
-
60
- data = datasets.Dataset.from_list(data)
61
- data = data.cast_column("image", datasets.Image(mode="RGB"))
62
-
63
- self._data = data.to_iterable_dataset()
64
- self._sample_index = 0
65
- self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
66
-
67
- def _get_data_iter(self):
68
- if self._sample_index == 0:
69
- return iter(self._data)
70
- return iter(self._data.skip(self._sample_index))
71
-
72
- def __iter__(self):
73
- while True:
74
- for sample in self._get_data_iter():
75
- self._sample_index += 1
76
- sample["caption"] = _read_caption_from_file(sample["caption"])
77
- sample["image"] = _preprocess_image(sample["image"])
78
- yield sample
79
-
80
- if not self.infinite:
81
- logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
82
- break
83
- else:
84
- self._sample_index = 0
85
-
86
- def load_state_dict(self, state_dict):
87
- self._sample_index = state_dict["sample_index"]
88
-
89
- def state_dict(self):
90
- return {"sample_index": self._sample_index}
91
-
92
- def _find_data_file(self, caption_file: str) -> str:
93
- caption_file = pathlib.Path(caption_file)
94
- data_file = None
95
- found_data = 0
96
-
97
- for extension in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
98
- image_filename = caption_file.with_suffix(f".{extension}")
99
- if image_filename.exists():
100
- found_data += 1
101
- data_file = image_filename
102
-
103
- if found_data == 0:
104
- return False
105
- elif found_data > 1:
106
- raise ValueError(
107
- f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data "
108
- f"file per caption file. The following extensions are supported:\n"
109
- f" - Images: {constants.SUPPORTED_IMAGE_FILE_EXTENSIONS}\n"
110
- )
111
-
112
- return data_file.as_posix()
113
-
114
-
115
- class VideoCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
116
- def __init__(self, root: str, infinite: bool = False) -> None:
117
- super().__init__()
118
-
119
- self.root = pathlib.Path(root)
120
- self.infinite = infinite
121
-
122
- data = []
123
- caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0))
124
- for caption_file in caption_files:
125
- data_file = self._find_data_file(caption_file)
126
- if data_file:
127
- data.append(
128
- {
129
- "caption": (self.root / caption_file).as_posix(),
130
- "video": (self.root / data_file).as_posix(),
131
- }
132
- )
133
-
134
- data = datasets.Dataset.from_list(data)
135
- data = data.cast_column("video", datasets.Video())
136
-
137
- self._data = data.to_iterable_dataset()
138
- self._sample_index = 0
139
- self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
140
-
141
- def _get_data_iter(self):
142
- if self._sample_index == 0:
143
- return iter(self._data)
144
- return iter(self._data.skip(self._sample_index))
145
-
146
- def __iter__(self):
147
- while True:
148
- for sample in self._get_data_iter():
149
- self._sample_index += 1
150
- sample["caption"] = _read_caption_from_file(sample["caption"])
151
- sample["video"] = _preprocess_video(sample["video"])
152
- yield sample
153
-
154
- if not self.infinite:
155
- logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
156
- break
157
- else:
158
- self._sample_index = 0
159
-
160
- def load_state_dict(self, state_dict):
161
- self._sample_index = state_dict["sample_index"]
162
-
163
- def state_dict(self):
164
- return {"sample_index": self._sample_index}
165
-
166
- def _find_data_file(self, caption_file: str) -> str:
167
- caption_file = pathlib.Path(caption_file)
168
- data_file = None
169
- found_data = 0
170
-
171
- for extension in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
172
- video_filename = caption_file.with_suffix(f".{extension}")
173
- if video_filename.exists():
174
- found_data += 1
175
- data_file = video_filename
176
-
177
- if found_data == 0:
178
- return False
179
- elif found_data > 1:
180
- raise ValueError(
181
- f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data "
182
- f"file per caption file. The following extensions are supported:\n"
183
- f" - Videos: {constants.SUPPORTED_VIDEO_FILE_EXTENSIONS}\n"
184
- )
185
-
186
- return data_file.as_posix()
187
-
188
-
189
- class ImageFileCaptionFileListDataset(
190
- torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
191
- ):
192
- def __init__(self, root: str, infinite: bool = False) -> None:
193
- super().__init__()
194
-
195
- VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"]
196
- VALID_IMAGE_FILES = ["image.txt", "images.txt"]
197
-
198
- self.root = pathlib.Path(root)
199
- self.infinite = infinite
200
-
201
- data = []
202
- existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()]
203
- existing_image_files = [file for file in VALID_IMAGE_FILES if (self.root / file).exists()]
204
-
205
- if len(existing_caption_files) == 0:
206
- raise FileNotFoundError(
207
- f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
208
- )
209
- if len(existing_image_files) == 0:
210
- raise FileNotFoundError(
211
- f"No image file found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}"
212
- )
213
- if len(existing_caption_files) > 1:
214
- raise ValueError(
215
- f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
216
- )
217
- if len(existing_image_files) > 1:
218
- raise ValueError(
219
- f"Multiple image files found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}"
220
- )
221
-
222
- caption_file = existing_caption_files[0]
223
- image_file = existing_image_files[0]
224
-
225
- with open((self.root / caption_file).as_posix(), "r") as f:
226
- captions = f.read().splitlines()
227
- with open((self.root / image_file).as_posix(), "r") as f:
228
- images = f.read().splitlines()
229
- images = [(self.root / image).as_posix() for image in images]
230
-
231
- if len(captions) != len(images):
232
- raise ValueError(f"Number of captions ({len(captions)}) must match number of images ({len(images)})")
233
-
234
- for caption, image in zip(captions, images):
235
- data.append({"caption": caption, "image": image})
236
-
237
- data = datasets.Dataset.from_list(data)
238
- data = data.cast_column("image", datasets.Image(mode="RGB"))
239
-
240
- self._data = data.to_iterable_dataset()
241
- self._sample_index = 0
242
- self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
243
-
244
- def _get_data_iter(self):
245
- if self._sample_index == 0:
246
- return iter(self._data)
247
- return iter(self._data.skip(self._sample_index))
248
-
249
- def __iter__(self):
250
- while True:
251
- for sample in self._get_data_iter():
252
- self._sample_index += 1
253
- sample["image"] = _preprocess_image(sample["image"])
254
- yield sample
255
-
256
- if not self.infinite:
257
- logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
258
- break
259
- else:
260
- self._sample_index = 0
261
-
262
- def load_state_dict(self, state_dict):
263
- self._sample_index = state_dict["sample_index"]
264
-
265
- def state_dict(self):
266
- return {"sample_index": self._sample_index}
267
-
268
-
269
- class VideoFileCaptionFileListDataset(
270
- torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
271
- ):
272
- def __init__(self, root: str, infinite: bool = False) -> None:
273
- super().__init__()
274
-
275
- VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"]
276
- VALID_VIDEO_FILES = ["video.txt", "videos.txt"]
277
-
278
- self.root = pathlib.Path(root)
279
- self.infinite = infinite
280
-
281
- data = []
282
- existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()]
283
- existing_video_files = [file for file in VALID_VIDEO_FILES if (self.root / file).exists()]
284
-
285
- if len(existing_caption_files) == 0:
286
- raise FileNotFoundError(
287
- f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
288
- )
289
- if len(existing_video_files) == 0:
290
- raise FileNotFoundError(
291
- f"No video file found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}"
292
- )
293
- if len(existing_caption_files) > 1:
294
- raise ValueError(
295
- f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
296
- )
297
- if len(existing_video_files) > 1:
298
- raise ValueError(
299
- f"Multiple video files found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}"
300
- )
301
-
302
- caption_file = existing_caption_files[0]
303
- video_file = existing_video_files[0]
304
-
305
- with open((self.root / caption_file).as_posix(), "r") as f:
306
- captions = f.read().splitlines()
307
- with open((self.root / video_file).as_posix(), "r") as f:
308
- videos = f.read().splitlines()
309
- videos = [(self.root / video).as_posix() for video in videos]
310
-
311
- if len(captions) != len(videos):
312
- raise ValueError(f"Number of captions ({len(captions)}) must match number of videos ({len(videos)})")
313
-
314
- for caption, video in zip(captions, videos):
315
- data.append({"caption": caption, "video": video})
316
-
317
- data = datasets.Dataset.from_list(data)
318
- data = data.cast_column("video", datasets.Video())
319
-
320
- self._data = data.to_iterable_dataset()
321
- self._sample_index = 0
322
- self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
323
-
324
- def _get_data_iter(self):
325
- if self._sample_index == 0:
326
- return iter(self._data)
327
- return iter(self._data.skip(self._sample_index))
328
-
329
- def __iter__(self):
330
- while True:
331
- for sample in self._get_data_iter():
332
- self._sample_index += 1
333
- sample["video"] = _preprocess_video(sample["video"])
334
- yield sample
335
-
336
- if not self.infinite:
337
- logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
338
- break
339
- else:
340
- self._sample_index = 0
341
-
342
- def load_state_dict(self, state_dict):
343
- self._sample_index = state_dict["sample_index"]
344
-
345
- def state_dict(self):
346
- return {"sample_index": self._sample_index}
347
-
348
-
349
- class ImageFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
350
- def __init__(self, root: str, infinite: bool = False) -> None:
351
- super().__init__()
352
-
353
- self.root = pathlib.Path(root)
354
- self.infinite = infinite
355
-
356
- data = datasets.load_dataset("imagefolder", data_dir=self.root.as_posix(), split="train")
357
-
358
- self._data = data.to_iterable_dataset()
359
- self._sample_index = 0
360
- self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
361
-
362
- def _get_data_iter(self):
363
- if self._sample_index == 0:
364
- return iter(self._data)
365
- return iter(self._data.skip(self._sample_index))
366
-
367
- def __iter__(self):
368
- while True:
369
- for sample in self._get_data_iter():
370
- self._sample_index += 1
371
- sample["image"] = _preprocess_image(sample["image"])
372
- yield sample
373
-
374
- if not self.infinite:
375
- logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
376
- break
377
- else:
378
- self._sample_index = 0
379
-
380
- def load_state_dict(self, state_dict):
381
- self._sample_index = state_dict["sample_index"]
382
-
383
- def state_dict(self):
384
- return {"sample_index": self._sample_index}
385
-
386
-
387
- class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
388
- def __init__(self, root: str, infinite: bool = False) -> None:
389
- super().__init__()
390
-
391
- self.root = pathlib.Path(root)
392
- self.infinite = infinite
393
-
394
- data = datasets.load_dataset("videofolder", data_dir=self.root.as_posix(), split="train")
395
-
396
- self._data = data.to_iterable_dataset()
397
- self._sample_index = 0
398
- self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
399
-
400
- def _get_data_iter(self):
401
- if self._sample_index == 0:
402
- return iter(self._data)
403
- return iter(self._data.skip(self._sample_index))
404
-
405
- def __iter__(self):
406
- while True:
407
- for sample in self._get_data_iter():
408
- self._sample_index += 1
409
- sample["video"] = _preprocess_video(sample["video"])
410
- yield sample
411
-
412
- if not self.infinite:
413
- logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
414
- break
415
- else:
416
- self._sample_index = 0
417
-
418
- def load_state_dict(self, state_dict):
419
- self._sample_index = state_dict["sample_index"]
420
-
421
- def state_dict(self):
422
- return {"sample_index": self._sample_index}
423
-
424
-
425
- class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
426
- def __init__(
427
- self,
428
- dataset_name: str,
429
- infinite: bool = False,
430
- column_names: Union[str, List[str]] = "__auto__",
431
- weights: Dict[str, float] = -1,
432
- **kwargs,
433
- ) -> None:
434
- super().__init__()
435
-
436
- assert weights == -1 or isinstance(
437
- weights, dict
438
- ), "`weights` must be a dictionary of probabilities for each caption column"
439
-
440
- self.dataset_name = dataset_name
441
- self.infinite = infinite
442
-
443
- data = datasets.load_dataset(dataset_name, split="train", streaming=True)
444
-
445
- if column_names == "__auto__":
446
- if weights == -1:
447
- caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES]
448
- if len(caption_columns) == 0:
449
- raise ValueError(
450
- f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}"
451
- )
452
- weights = [1] * len(caption_columns)
453
- else:
454
- caption_columns = list(weights.keys())
455
- weights = list(weights.values())
456
- if not all(column in data.column_names for column in caption_columns):
457
- raise ValueError(
458
- f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}"
459
- )
460
- else:
461
- if isinstance(column_names, str):
462
- if column_names not in data.column_names:
463
- raise ValueError(
464
- f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}"
465
- )
466
- caption_columns = [column_names]
467
- weights = [1] if weights == -1 else [weights.get(column_names)]
468
- elif isinstance(column_names, list):
469
- if not all(column in data.column_names for column in column_names):
470
- raise ValueError(
471
- f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}"
472
- )
473
- caption_columns = column_names
474
- weights = [1] if weights == -1 else [weights.get(column) for column in column_names]
475
- else:
476
- raise ValueError(f"Unsupported type for column_name: {type(column_names)}")
477
-
478
- for column_names in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
479
- if column_names in data.column_names:
480
- data = data.cast_column(column_names, datasets.Image(mode="RGB"))
481
- data = data.rename_column(column_names, "image")
482
- break
483
-
484
- self._data = data
485
- self._sample_index = 0
486
- self._precomputable_once = False
487
- self._caption_columns = caption_columns
488
- self._weights = weights
489
-
490
- def _get_data_iter(self):
491
- if self._sample_index == 0:
492
- return iter(self._data)
493
- return iter(self._data.skip(self._sample_index))
494
-
495
- def __iter__(self):
496
- while True:
497
- for sample in self._get_data_iter():
498
- self._sample_index += 1
499
- caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0]
500
- sample["caption"] = sample[caption_column]
501
- sample["image"] = _preprocess_image(sample["image"])
502
- yield sample
503
-
504
- if not self.infinite:
505
- logger.warning(f"Dataset {self.dataset_name} has run out of data")
506
- break
507
- else:
508
- # Reset offset for the next iteration
509
- self._sample_index = 0
510
- logger.warning(f"Dataset {self.dataset_name} is being re-looped")
511
-
512
- def load_state_dict(self, state_dict):
513
- self._sample_index = state_dict["sample_index"]
514
-
515
- def state_dict(self):
516
- return {"sample_index": self._sample_index}
517
-
518
-
519
- class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
520
- def __init__(
521
- self,
522
- dataset_name: str,
523
- infinite: bool = False,
524
- column_names: Union[str, List[str]] = "__auto__",
525
- weights: Dict[str, float] = -1,
526
- **kwargs,
527
- ) -> None:
528
- super().__init__()
529
-
530
- assert weights == -1 or isinstance(
531
- weights, dict
532
- ), "`weights` must be a dictionary of probabilities for each caption column"
533
-
534
- self.dataset_name = dataset_name
535
- self.infinite = infinite
536
-
537
- data = datasets.load_dataset(dataset_name, split="train", streaming=True)
538
-
539
- if column_names == "__auto__":
540
- if weights == -1:
541
- caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES]
542
- if len(caption_columns) == 0:
543
- raise ValueError(
544
- f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}"
545
- )
546
- weights = [1] * len(caption_columns)
547
- else:
548
- caption_columns = list(weights.keys())
549
- weights = list(weights.values())
550
- if not all(column in data.column_names for column in caption_columns):
551
- raise ValueError(
552
- f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}"
553
- )
554
- else:
555
- if isinstance(column_names, str):
556
- if column_names not in data.column_names:
557
- raise ValueError(
558
- f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}"
559
- )
560
- caption_columns = [column_names]
561
- weights = [1] if weights == -1 else [weights.get(column_names)]
562
- elif isinstance(column_names, list):
563
- if not all(column in data.column_names for column in column_names):
564
- raise ValueError(
565
- f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}"
566
- )
567
- caption_columns = column_names
568
- weights = [1] if weights == -1 else [weights.get(column) for column in column_names]
569
- else:
570
- raise ValueError(f"Unsupported type for column_name: {type(column_names)}")
571
-
572
- for column_names in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
573
- if column_names in data.column_names:
574
- data = data.cast_column(column_names, datasets.Video())
575
- data = data.rename_column(column_names, "video")
576
- break
577
-
578
- self._data = data
579
- self._sample_index = 0
580
- self._precomputable_once = False
581
- self._caption_columns = caption_columns
582
- self._weights = weights
583
-
584
- def _get_data_iter(self):
585
- if self._sample_index == 0:
586
- return iter(self._data)
587
- return iter(self._data.skip(self._sample_index))
588
-
589
- def __iter__(self):
590
- while True:
591
- for sample in self._get_data_iter():
592
- self._sample_index += 1
593
- caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0]
594
- sample["caption"] = sample[caption_column]
595
- sample["video"] = _preprocess_video(sample["video"])
596
- yield sample
597
-
598
- if not self.infinite:
599
- logger.warning(f"Dataset {self.dataset_name} has run out of data")
600
- break
601
- else:
602
- # Reset offset for the next iteration
603
- self._sample_index = 0
604
- logger.warning(f"Dataset {self.dataset_name} is being re-looped")
605
-
606
- def load_state_dict(self, state_dict):
607
- self._sample_index = state_dict["sample_index"]
608
-
609
- def state_dict(self):
610
- return {"sample_index": self._sample_index}
611
-
612
-
613
- class ValidationDataset(torch.utils.data.IterableDataset):
614
- def __init__(self, filename: str):
615
- super().__init__()
616
-
617
- self.filename = pathlib.Path(filename)
618
-
619
- if not self.filename.exists():
620
- raise FileNotFoundError(f"File {self.filename.as_posix()} does not exist")
621
-
622
- if self.filename.suffix == ".csv":
623
- data = datasets.load_dataset("csv", data_files=self.filename.as_posix(), split="train")
624
- elif self.filename.suffix == ".json":
625
- data = datasets.load_dataset("json", data_files=self.filename.as_posix(), split="train", field="data")
626
- elif self.filename.suffix == ".parquet":
627
- data = datasets.load_dataset("parquet", data_files=self.filename.as_posix(), split="train")
628
- elif self.filename.suffix == ".arrow":
629
- data = datasets.load_dataset("arrow", data_files=self.filename.as_posix(), split="train")
630
- else:
631
- _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"]
632
- raise ValueError(
633
- f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}"
634
- )
635
-
636
- self._data = data.to_iterable_dataset()
637
-
638
- def __iter__(self):
639
- for sample in self._data:
640
- # For consistency reasons, we mandate that "caption" is always present in the validation dataset.
641
- # However, since the model specifications use "prompt", we create an alias here.
642
- sample["prompt"] = sample["caption"]
643
-
644
- # Load image or video if the path is provided
645
- # TODO(aryan): need to handle custom columns here for control conditions
646
- sample["image"] = None
647
- sample["video"] = None
648
-
649
- if sample.get("image_path", None) is not None:
650
- image_path = pathlib.Path(sample["image_path"])
651
- if not image_path.is_file():
652
- logger.warning(f"Image file {image_path.as_posix()} does not exist.")
653
- else:
654
- sample["image"] = load_image(sample["image_path"])
655
-
656
- if sample.get("video_path", None) is not None:
657
- video_path = pathlib.Path(sample["video_path"])
658
- if not video_path.is_file():
659
- logger.warning(f"Video file {video_path.as_posix()} does not exist.")
660
- else:
661
- sample["video"] = load_video(sample["video_path"])
662
-
663
- sample = {k: v for k, v in sample.items() if v is not None}
664
- yield sample
665
-
666
-
667
- class IterableDatasetPreprocessingWrapper(
668
- torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
669
- ):
670
- def __init__(
671
- self,
672
- dataset: torch.utils.data.IterableDataset,
673
- dataset_type: str,
674
- id_token: Optional[str] = None,
675
- image_resolution_buckets: List[Tuple[int, int]] = None,
676
- video_resolution_buckets: List[Tuple[int, int, int]] = None,
677
- reshape_mode: str = "bicubic",
678
- remove_common_llm_caption_prefixes: bool = False,
679
- **kwargs,
680
- ):
681
- super().__init__()
682
-
683
- self.dataset = dataset
684
- self.dataset_type = dataset_type
685
- self.id_token = id_token
686
- self.image_resolution_buckets = image_resolution_buckets
687
- self.video_resolution_buckets = video_resolution_buckets
688
- self.reshape_mode = reshape_mode
689
- self.remove_common_llm_caption_prefixes = remove_common_llm_caption_prefixes
690
-
691
- logger.info(
692
- f"Initializing IterableDatasetPreprocessingWrapper for the dataset with the following configuration:\n"
693
- f" - Dataset Type: {dataset_type}\n"
694
- f" - ID Token: {id_token}\n"
695
- f" - Image Resolution Buckets: {image_resolution_buckets}\n"
696
- f" - Video Resolution Buckets: {video_resolution_buckets}\n"
697
- f" - Reshape Mode: {reshape_mode}\n"
698
- f" - Remove Common LLM Caption Prefixes: {remove_common_llm_caption_prefixes}\n"
699
- )
700
-
701
- def __iter__(self):
702
- logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset")
703
- for sample in iter(self.dataset):
704
- if self.dataset_type == "image":
705
- if self.image_resolution_buckets:
706
- sample["_original_num_frames"] = 1
707
- sample["_original_height"] = sample["image"].size(1)
708
- sample["_original_width"] = sample["image"].size(2)
709
- sample["image"] = FF.resize_to_nearest_bucket_image(
710
- sample["image"], self.image_resolution_buckets, self.reshape_mode
711
- )
712
- elif self.dataset_type == "video":
713
- if self.video_resolution_buckets:
714
- sample["_original_num_frames"] = sample["video"].size(0)
715
- sample["_original_height"] = sample["video"].size(2)
716
- sample["_original_width"] = sample["video"].size(3)
717
- sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
718
- sample["video"], self.video_resolution_buckets, self.reshape_mode
719
- )
720
- if _first_frame_only:
721
- msg = (
722
- "The number of frames in the video is less than the minimum bucket size "
723
- "specified. The first frame is being used as a single frame video. This "
724
- "message is logged at the first occurence and for every 128th occurence "
725
- "after that."
726
- )
727
- logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE", msg, frequency=128)
728
- sample["video"] = sample["video"][0]
729
-
730
- if self.remove_common_llm_caption_prefixes:
731
- sample["caption"] = FF.remove_prefix(sample["caption"], constants.COMMON_LLM_START_PHRASES)
732
-
733
- if self.id_token is not None:
734
- sample["caption"] = f"{self.id_token} {sample['caption']}"
735
-
736
- yield sample
737
-
738
- def load_state_dict(self, state_dict):
739
- self.dataset.load_state_dict(state_dict["dataset"])
740
-
741
- def state_dict(self):
742
- return {"dataset": self.dataset.state_dict()}
743
-
744
-
745
- class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
746
- def __init__(self, datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False):
747
- super().__init__()
748
-
749
- self.datasets = datasets
750
- self.buffer_size = buffer_size
751
- self.shuffle = shuffle
752
-
753
- logger.info(
754
- f"Initializing IterableCombinedDataset with the following configuration:\n"
755
- f" - Number of Datasets: {len(datasets)}\n"
756
- f" - Buffer Size: {buffer_size}\n"
757
- f" - Shuffle: {shuffle}\n"
758
- )
759
-
760
- def __iter__(self):
761
- logger.info(f"Starting IterableCombinedDataset with {len(self.datasets)} datasets")
762
- iterators = [iter(dataset) for dataset in self.datasets]
763
- buffer = []
764
- per_iter = max(1, self.buffer_size // len(iterators))
765
-
766
- for index, it in enumerate(iterators):
767
- for _ in tqdm(range(per_iter), desc=f"Filling buffer from data iterator {index}"):
768
- try:
769
- buffer.append((it, next(it)))
770
- except StopIteration:
771
- continue
772
-
773
- while len(buffer) > 0:
774
- idx = 0
775
- if self.shuffle:
776
- idx = random.randint(0, len(buffer) - 1)
777
- current_it, sample = buffer.pop(idx)
778
- yield sample
779
- try:
780
- buffer.append((current_it, next(current_it)))
781
- except StopIteration:
782
- pass
783
-
784
- def load_state_dict(self, state_dict):
785
- for dataset, dataset_state_dict in zip(self.datasets, state_dict["datasets"]):
786
- dataset.load_state_dict(dataset_state_dict)
787
-
788
- def state_dict(self):
789
- return {"datasets": [dataset.state_dict() for dataset in self.datasets]}
790
-
791
-
792
- # TODO(aryan): maybe write a test for this
793
- def initialize_dataset(
794
- dataset_name_or_root: str,
795
- dataset_type: str = "video",
796
- streaming: bool = True,
797
- infinite: bool = False,
798
- *,
799
- _caption_options: Optional[Dict[str, Any]] = None,
800
- ) -> torch.utils.data.IterableDataset:
801
- assert dataset_type in ["image", "video"]
802
-
803
- try:
804
- does_repo_exist_on_hub = repo_exists(dataset_name_or_root, repo_type="dataset")
805
- except huggingface_hub.errors.HFValidationError:
806
- does_repo_exist_on_hub = False
807
-
808
- if does_repo_exist_on_hub:
809
- return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite, _caption_options=_caption_options)
810
- else:
811
- return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite)
812
-
813
-
814
- def combine_datasets(
815
- datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False
816
- ) -> torch.utils.data.IterableDataset:
817
- return IterableCombinedDataset(datasets=datasets, buffer_size=buffer_size, shuffle=shuffle)
818
-
819
-
820
- def wrap_iterable_dataset_for_preprocessing(
821
- dataset: torch.utils.data.IterableDataset, dataset_type: str, config: Dict[str, Any]
822
- ) -> torch.utils.data.IterableDataset:
823
- return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config)
824
-
825
-
826
- def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infinite: bool = False):
827
- root = pathlib.Path(dataset_name_or_root)
828
- supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"]
829
- metadata_files = [root / metadata_file for metadata_file in supported_metadata_files]
830
- metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()]
831
-
832
- if len(metadata_files) > 1:
833
- raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.")
834
-
835
- if len(metadata_files) == 1:
836
- if dataset_type == "image":
837
- dataset = ImageFolderDataset(root.as_posix(), infinite=infinite)
838
- else:
839
- dataset = VideoFolderDataset(root.as_posix(), infinite=infinite)
840
- return dataset
841
-
842
- if _has_data_caption_file_pairs(root, remote=False):
843
- if dataset_type == "image":
844
- dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite)
845
- else:
846
- dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite)
847
- elif _has_data_file_caption_file_lists(root, remote=False):
848
- if dataset_type == "image":
849
- dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
850
- else:
851
- dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
852
- else:
853
- raise ValueError(
854
- f"Could not find any supported dataset structure in the directory {root}. Please open an issue at "
855
- f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will "
856
- f"help you set it up."
857
- )
858
-
859
- return dataset
860
-
861
-
862
- def _initialize_hub_dataset(
863
- dataset_name: str, dataset_type: str, infinite: bool = False, *, _caption_options: Optional[Dict[str, Any]] = None
864
- ):
865
- repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
866
- if _has_data_caption_file_pairs(repo_file_list, remote=True):
867
- return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
868
- elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
869
- return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
870
-
871
- has_tar_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in repo_file_list)
872
- if has_tar_files:
873
- return _initialize_webdataset(dataset_name, dataset_type, infinite, _caption_options=_caption_options)
874
-
875
- # TODO(aryan): This should be improved
876
- caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")]
877
- if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT:
878
- try:
879
- dataset_root = snapshot_download(dataset_name, repo_type="dataset")
880
- if dataset_type == "image":
881
- dataset = ImageFolderDataset(dataset_root, infinite=infinite)
882
- else:
883
- dataset = VideoFolderDataset(dataset_root, infinite=infinite)
884
- return dataset
885
- except Exception:
886
- pass
887
-
888
- raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub")
889
-
890
-
891
- def _initialize_data_caption_file_dataset_from_hub(
892
- dataset_name: str, dataset_type: str, infinite: bool = False
893
- ) -> torch.utils.data.IterableDataset:
894
- logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
895
- dataset_root = snapshot_download(dataset_name, repo_type="dataset")
896
- if dataset_type == "image":
897
- return ImageCaptionFilePairDataset(dataset_root, infinite=infinite)
898
- else:
899
- return VideoCaptionFilePairDataset(dataset_root, infinite=infinite)
900
-
901
-
902
- def _initialize_data_file_caption_file_dataset_from_hub(
903
- dataset_name: str, dataset_type: str, infinite: bool = False
904
- ) -> torch.utils.data.IterableDataset:
905
- logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
906
- dataset_root = snapshot_download(dataset_name, repo_type="dataset")
907
- if dataset_type == "image":
908
- return ImageFileCaptionFileListDataset(dataset_root, infinite=infinite)
909
- else:
910
- return VideoFileCaptionFileListDataset(dataset_root, infinite=infinite)
911
-
912
-
913
- def _initialize_webdataset(
914
- dataset_name: str, dataset_type: str, infinite: bool = False, _caption_options: Optional[Dict[str, Any]] = None
915
- ) -> torch.utils.data.IterableDataset:
916
- logger.info(f"Streaming webdataset {dataset_name} from the HF Hub")
917
- _caption_options = _caption_options or {}
918
- if dataset_type == "image":
919
- return ImageWebDataset(dataset_name, infinite=infinite, **_caption_options)
920
- else:
921
- return VideoWebDataset(dataset_name, infinite=infinite, **_caption_options)
922
-
923
-
924
- def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
925
- # TODO(aryan): this logic can be improved
926
- if not remote:
927
- caption_files = utils.find_files(root.as_posix(), "*.txt", depth=0)
928
- for caption_file in caption_files:
929
- caption_file = pathlib.Path(caption_file)
930
- for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
931
- data_filename = caption_file.with_suffix(f".{extension}")
932
- if data_filename.exists():
933
- return True
934
- return False
935
- else:
936
- caption_files = [file for file in root if file.endswith(".txt")]
937
- for caption_file in caption_files:
938
- caption_file = pathlib.Path(caption_file)
939
- for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
940
- data_filename = caption_file.with_suffix(f".{extension}").name
941
- if data_filename in root:
942
- return True
943
- return False
944
-
945
-
946
- def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
947
- # TODO(aryan): this logic can be improved
948
- if not remote:
949
- file_list = {x.name for x in root.iterdir()}
950
- has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES)
951
- has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES)
952
- has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES)
953
- return has_caption_files and (has_video_files or has_image_files)
954
- else:
955
- has_caption_files = any(file in root for file in COMMON_CAPTION_FILES)
956
- has_video_files = any(file in root for file in COMMON_VIDEO_FILES)
957
- has_image_files = any(file in root for file in COMMON_IMAGE_FILES)
958
- return has_caption_files and (has_video_files or has_image_files)
959
-
960
-
961
- def _read_caption_from_file(filename: str) -> str:
962
- with open(filename, "r") as f:
963
- return f.read().strip()
964
-
965
-
966
- def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
967
- image = image.convert("RGB")
968
- image = np.array(image).astype(np.float32)
969
- image = torch.from_numpy(image)
970
- image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
971
- return image
972
-
973
-
974
- def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
975
- video = video.get_batch(list(range(len(video))))
976
- video = video.permute(0, 3, 1, 2).contiguous()
977
- video = video.float() / 127.5 - 1.0
978
- return video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/data/precomputation.py DELETED
@@ -1,376 +0,0 @@
1
- import pathlib
2
- from typing import Any, Callable, Dict, Iterable, List, Optional, Union
3
-
4
- import torch
5
- from tqdm.auto import tqdm
6
-
7
- from .. import utils
8
- from ..logging import get_logger
9
-
10
-
11
- logger = get_logger()
12
-
13
-
14
- def initialize_preprocessor(
15
- rank: int,
16
- num_items: int,
17
- processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
18
- save_dir: Optional[str] = None,
19
- enable_precomputation: bool = False,
20
- ) -> Union["InMemoryDistributedDataPreprocessor", "PrecomputedDistributedDataPreprocessor"]:
21
- if enable_precomputation:
22
- return PrecomputedDistributedDataPreprocessor(rank, num_items, processor_fn, save_dir)
23
- return InMemoryDistributedDataPreprocessor(rank, num_items, processor_fn)
24
-
25
-
26
- class DistributedDataProcessorMixin:
27
- def consume(self, *args, **kwargs):
28
- raise NotImplementedError("DistributedDataProcessorMixin::consume must be implemented by the subclass.")
29
-
30
- def consume_once(self, *args, **kwargs):
31
- raise NotImplementedError("DistributedDataProcessorMixin::consume_once must be implemented by the subclass.")
32
-
33
- @property
34
- def requires_data(self):
35
- raise NotImplementedError("DistributedDataProcessorMixin::requires_data must be implemented by the subclass.")
36
-
37
-
38
- class InMemoryDistributedDataPreprocessor(DistributedDataProcessorMixin):
39
- def __init__(
40
- self, rank: int, num_items: int, processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]]
41
- ) -> None:
42
- super().__init__()
43
-
44
- self._rank = rank
45
- self._num_items = num_items
46
- self._processor_fn = processor_fn
47
-
48
- self._cached_samples = []
49
- self._buffer = InMemoryDataBuffer(num_items)
50
- self._preprocessed_iterator: Union["InMemoryDataIterable", "InMemoryOnceDataIterable"] = None
51
-
52
- def consume(
53
- self,
54
- data_type: str,
55
- components: Dict[str, Any],
56
- data_iterator,
57
- generator: Optional[torch.Generator] = None,
58
- cache_samples: bool = False,
59
- use_cached_samples: bool = False,
60
- drop_samples: bool = False,
61
- ) -> Iterable[Dict[str, Any]]:
62
- if data_type not in self._processor_fn.keys():
63
- raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
64
- if cache_samples:
65
- if use_cached_samples:
66
- raise ValueError("Cannot cache and use cached samples at the same time.")
67
- if drop_samples:
68
- raise ValueError("Cannot cache and drop samples at the same time.")
69
-
70
- for i in range(self._num_items):
71
- if use_cached_samples:
72
- item = self._cached_samples[i]
73
- else:
74
- item = next(data_iterator)
75
- if cache_samples:
76
- self._cached_samples.append(item)
77
- item = self._processor_fn[data_type](**item, **components, generator=generator)
78
- self._buffer.add(data_type, item)
79
-
80
- if drop_samples:
81
- del self._cached_samples
82
- self._cached_samples = []
83
-
84
- self._preprocessed_iterator = InMemoryDataIterable(self._rank, data_type, self._buffer)
85
- return iter(self._preprocessed_iterator)
86
-
87
- def consume_once(
88
- self,
89
- data_type: str,
90
- components: Dict[str, Any],
91
- data_iterator,
92
- generator: Optional[torch.Generator] = None,
93
- cache_samples: bool = False,
94
- use_cached_samples: bool = False,
95
- drop_samples: bool = False,
96
- ) -> Iterable[Dict[str, Any]]:
97
- if data_type not in self._processor_fn.keys():
98
- raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
99
- if cache_samples:
100
- if use_cached_samples:
101
- raise ValueError("Cannot cache and use cached samples at the same time.")
102
- if drop_samples:
103
- raise ValueError("Cannot cache and drop samples at the same time.")
104
-
105
- for i in range(self._num_items):
106
- if use_cached_samples:
107
- item = self._cached_samples[i]
108
- else:
109
- item = next(data_iterator)
110
- if cache_samples:
111
- self._cached_samples.append(item)
112
- item = self._processor_fn[data_type](**item, **components, generator=generator)
113
- self._buffer.add(data_type, item)
114
-
115
- if drop_samples:
116
- del self._cached_samples
117
- self._cached_samples = []
118
-
119
- self._preprocessed_iterator = InMemoryOnceDataIterable(self._rank, data_type, self._buffer)
120
- return iter(self._preprocessed_iterator)
121
-
122
- @property
123
- def requires_data(self):
124
- if self._preprocessed_iterator is None:
125
- return True
126
- return self._preprocessed_iterator.requires_data
127
-
128
-
129
- class PrecomputedDistributedDataPreprocessor(DistributedDataProcessorMixin):
130
- def __init__(
131
- self,
132
- rank: int,
133
- num_items: int,
134
- processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
135
- save_dir: str,
136
- ) -> None:
137
- super().__init__()
138
-
139
- self._rank = rank
140
- self._num_items = num_items
141
- self._processor_fn = processor_fn
142
- self._save_dir = pathlib.Path(save_dir)
143
-
144
- self._cached_samples = []
145
- self._preprocessed_iterator: Union["PrecomputedDataIterable", "PrecomputedOnceDataIterable"] = None
146
-
147
- self._save_dir.mkdir(parents=True, exist_ok=True)
148
-
149
- subdirectories = [f for f in self._save_dir.iterdir() if f.is_dir()]
150
- utils.delete_files(subdirectories)
151
-
152
- def consume(
153
- self,
154
- data_type: str,
155
- components: Dict[str, Any],
156
- data_iterator,
157
- generator: Optional[torch.Generator] = None,
158
- cache_samples: bool = False,
159
- use_cached_samples: bool = False,
160
- drop_samples: bool = False,
161
- ) -> Iterable[Dict[str, Any]]:
162
- if data_type not in self._processor_fn.keys():
163
- raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
164
- if cache_samples:
165
- if use_cached_samples:
166
- raise ValueError("Cannot cache and use cached samples at the same time.")
167
- if drop_samples:
168
- raise ValueError("Cannot cache and drop samples at the same time.")
169
-
170
- for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items):
171
- if use_cached_samples:
172
- item = self._cached_samples[i]
173
- else:
174
- item = next(data_iterator)
175
- if cache_samples:
176
- self._cached_samples.append(item)
177
- item = self._processor_fn[data_type](**item, **components, generator=generator)
178
- _save_item(self._rank, i, item, self._save_dir, data_type)
179
-
180
- if drop_samples:
181
- del self._cached_samples
182
- self._cached_samples = []
183
-
184
- self._preprocessed_iterator = PrecomputedDataIterable(self._rank, self._save_dir, data_type)
185
- return iter(self._preprocessed_iterator)
186
-
187
- def consume_once(
188
- self,
189
- data_type: str,
190
- components: Dict[str, Any],
191
- data_iterator,
192
- generator: Optional[torch.Generator] = None,
193
- cache_samples: bool = False,
194
- use_cached_samples: bool = False,
195
- drop_samples: bool = False,
196
- ) -> Iterable[Dict[str, Any]]:
197
- if data_type not in self._processor_fn.keys():
198
- raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
199
- if cache_samples:
200
- if use_cached_samples:
201
- raise ValueError("Cannot cache and use cached samples at the same time.")
202
- if drop_samples:
203
- raise ValueError("Cannot cache and drop samples at the same time.")
204
-
205
- for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items):
206
- if use_cached_samples:
207
- item = self._cached_samples[i]
208
- else:
209
- item = next(data_iterator)
210
- if cache_samples:
211
- self._cached_samples.append(item)
212
- item = self._processor_fn[data_type](**item, **components, generator=generator)
213
- _save_item(self._rank, i, item, self._save_dir, data_type)
214
-
215
- if drop_samples:
216
- del self._cached_samples
217
- self._cached_samples = []
218
-
219
- self._preprocessed_iterator = PrecomputedOnceDataIterable(self._rank, self._save_dir, data_type)
220
- return iter(self._preprocessed_iterator)
221
-
222
- @property
223
- def requires_data(self):
224
- if self._preprocessed_iterator is None:
225
- return True
226
- return self._preprocessed_iterator.requires_data
227
-
228
-
229
- class InMemoryDataIterable:
230
- """
231
- An iterator that loads data items from an in-memory buffer. Once all the data is consumed,
232
- `requires_data` is set to True, indicating that the more data is required and the preprocessor's
233
- consume method should be called again.
234
- """
235
-
236
- def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None:
237
- self._rank = rank
238
- self._data_type = data_type
239
- self._buffer = buffer
240
-
241
- self._requires_data = False
242
-
243
- def __iter__(self) -> Iterable[Dict[str, Any]]:
244
- while (length := self._buffer.get_length(self._data_type)) > 0:
245
- if length <= 1:
246
- self._requires_data = True
247
- yield self._buffer.get(self._data_type)
248
-
249
- def __len__(self) -> int:
250
- return self._buffer.get_length(self._data_type)
251
-
252
- @property
253
- def requires_data(self):
254
- return self._requires_data
255
-
256
-
257
- class InMemoryOnceDataIterable:
258
- """
259
- An iterator that loads data items from an in-memory buffer. This iterator will never set
260
- `requires_data` to True, as it is assumed that all the data was configured to be preprocessed
261
- by the user. The data will indefinitely be cycled from the buffer.
262
- """
263
-
264
- def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None:
265
- self._rank = rank
266
- self._data_type = data_type
267
- self._buffer = buffer
268
-
269
- self._requires_data = False
270
-
271
- def __iter__(self) -> Iterable[Dict[str, Any]]:
272
- assert len(self) > 0, "No data available in the buffer."
273
- while True:
274
- item = self._buffer.get(self._data_type)
275
- yield item
276
- self._buffer.add(self._data_type, item)
277
-
278
- def __len__(self) -> int:
279
- return self._buffer.get_length(self._data_type)
280
-
281
- @property
282
- def requires_data(self):
283
- return self._requires_data
284
-
285
-
286
- class PrecomputedDataIterable:
287
- """
288
- An iterator that loads preconfigured number of data items from disk. Once all the data is
289
- loaded, `requires_data` is set to True, indicating that the more data is required and
290
- the preprocessor's consume method should be called again.
291
- """
292
-
293
- def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
294
- self._rank = rank
295
- self._save_dir = pathlib.Path(save_dir)
296
- self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt")))
297
- self._data_type = data_type
298
-
299
- self._requires_data = False
300
-
301
- def __iter__(self) -> Iterable[Dict[str, Any]]:
302
- for i in range(self._num_items):
303
- if i == self._num_items - 1:
304
- self._requires_data = True
305
- yield _load_item(self._rank, i, self._save_dir, self._data_type)
306
-
307
- def __len__(self) -> int:
308
- return self._num_items
309
-
310
- @property
311
- def requires_data(self):
312
- return self._requires_data
313
-
314
-
315
- class PrecomputedOnceDataIterable:
316
- """
317
- An infinite iterator that loads preprocessed data from disk. Once initialized, this iterator
318
- will never set `requires_data` to True, as it is assumed that all the data was configured to
319
- be preprocessed by the user.
320
- """
321
-
322
- def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
323
- self._rank = rank
324
- self._save_dir = pathlib.Path(save_dir)
325
- self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt")))
326
- self._data_type = data_type
327
-
328
- self._requires_data = False
329
-
330
- def __iter__(self) -> Iterable[Dict[str, Any]]:
331
- index = 0
332
- while True:
333
- yield _load_item(self._rank, index, self._save_dir, self._data_type)
334
- index = (index + 1) % self._num_items
335
-
336
- def __len__(self) -> int:
337
- return self._num_items
338
-
339
- @property
340
- def requires_data(self):
341
- return self._requires_data
342
-
343
-
344
- class InMemoryDataBuffer:
345
- def __init__(self, max_limit: int = -1) -> None:
346
- self.max_limit = max_limit
347
- self.buffer: Dict[str, List[str]] = {}
348
-
349
- def add(self, data_type: str, item: Dict[str, Any]) -> None:
350
- if data_type not in self.buffer:
351
- self.buffer[data_type] = []
352
- if self.max_limit != -1 and len(self.buffer[data_type]) >= self.max_limit:
353
- logger.log_freq(
354
- "WARN",
355
- "IN_MEMORY_DATA_BUFFER_FULL",
356
- "Buffer is full. Dropping the oldest item. This message will be logged every 64th time this happens.",
357
- 64,
358
- )
359
- self.buffer[data_type].pop(0)
360
- self.buffer[data_type].append(item)
361
-
362
- def get(self, data_type: str) -> Dict[str, Any]:
363
- return self.buffer[data_type].pop(0)
364
-
365
- def get_length(self, data_type: str) -> int:
366
- return len(self.buffer[data_type])
367
-
368
-
369
- def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None:
370
- filename = directory / f"{data_type}-{rank}-{index}.pt"
371
- torch.save(item, filename.as_posix())
372
-
373
-
374
- def _load_item(rank: int, index: int, directory: pathlib.Path, data_type: str) -> Dict[str, Any]:
375
- filename = directory / f"{data_type}-{rank}-{index}.pt"
376
- return torch.load(filename.as_posix(), weights_only=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/data/sampler.py DELETED
@@ -1,58 +0,0 @@
1
- from typing import Any, Dict, List, Tuple
2
-
3
- import torch
4
-
5
-
6
- class ResolutionSampler:
7
- def __init__(self, batch_size: int = 1, dim_keys: Dict[str, Tuple[int, ...]] = None) -> None:
8
- self.batch_size = batch_size
9
- self.dim_keys = dim_keys
10
- assert dim_keys is not None, "dim_keys must be provided"
11
-
12
- self._chosen_leader_key = None
13
- self._unsatisfied_buckets: Dict[Tuple[int, ...], List[Dict[Any, Any]]] = {}
14
- self._satisfied_buckets: List[Dict[Any, Any]] = []
15
-
16
- def consume(self, *dict_items: Dict[Any, Any]) -> None:
17
- if self._chosen_leader_key is None:
18
- self._determine_leader_item(*dict_items)
19
- self._update_buckets(*dict_items)
20
-
21
- def get_batch(self) -> List[Dict[str, Any]]:
22
- return list(zip(*self._satisfied_buckets.pop(-1)))
23
-
24
- @property
25
- def is_ready(self) -> bool:
26
- return len(self._satisfied_buckets) > 0
27
-
28
- def _determine_leader_item(self, *dict_items: Dict[Any, Any]) -> None:
29
- num_observed = 0
30
- for dict_item in dict_items:
31
- for key in self.dim_keys.keys():
32
- if key in dict_item.keys():
33
- self._chosen_leader_key = key
34
- if not torch.is_tensor(dict_item[key]):
35
- raise ValueError(f"Leader key {key} must be a tensor")
36
- num_observed += 1
37
- if num_observed > 1:
38
- raise ValueError(
39
- f"Only one leader key is allowed in provided list of data dictionaries. Found {num_observed} leader keys"
40
- )
41
- if self._chosen_leader_key is None:
42
- raise ValueError("No leader key found in provided list of data dictionaries")
43
-
44
- def _update_buckets(self, *dict_items: Dict[Any, Any]) -> None:
45
- chosen_value = [
46
- dict_item[self._chosen_leader_key]
47
- for dict_item in dict_items
48
- if self._chosen_leader_key in dict_item.keys()
49
- ]
50
- if len(chosen_value) == 0:
51
- raise ValueError(f"Leader key {self._chosen_leader_key} not found in provided list of data dictionaries")
52
- chosen_value = chosen_value[0]
53
- dims = tuple(chosen_value.size(x) for x in self.dim_keys[self._chosen_leader_key])
54
- if dims not in self._unsatisfied_buckets:
55
- self._unsatisfied_buckets[dims] = []
56
- self._unsatisfied_buckets[dims].append(dict_items)
57
- if len(self._unsatisfied_buckets[dims]) == self.batch_size:
58
- self._satisfied_buckets.append(self._unsatisfied_buckets.pop(dims))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/data/utils.py DELETED
@@ -1,20 +0,0 @@
1
- import pathlib
2
- from typing import List
3
-
4
-
5
- def find_files(root: str, pattern: str, depth: int = 0) -> List[str]:
6
- root_path = pathlib.Path(root)
7
- result_files = []
8
-
9
- def within_depth(path: pathlib.Path) -> bool:
10
- return len(path.relative_to(root_path).parts) <= depth
11
-
12
- if depth == 0:
13
- result_files.extend([str(file) for file in root_path.glob(pattern)])
14
- else:
15
- # rglob matches all levels, but we filter by depth
16
- for file in root_path.rglob(pattern):
17
- if file.is_file() and within_depth(file.parent):
18
- result_files.append(str(file))
19
-
20
- return result_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/functional/__init__.py DELETED
@@ -1,16 +0,0 @@
1
- from .diffusion import flow_match_target, flow_match_xt
2
- from .image import (
3
- bicubic_resize_image,
4
- center_crop_image,
5
- find_nearest_resolution_image,
6
- resize_crop_image,
7
- resize_to_nearest_bucket_image,
8
- )
9
- from .text import dropout_caption, dropout_embeddings_to_zero, remove_prefix
10
- from .video import (
11
- bicubic_resize_video,
12
- center_crop_video,
13
- find_nearest_video_resolution,
14
- resize_crop_video,
15
- resize_to_nearest_bucket_video,
16
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/functional/diffusion.py DELETED
@@ -1,11 +0,0 @@
1
- import torch
2
-
3
-
4
- def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
5
- r"""Forward process of flow matching."""
6
- return (1.0 - t) * x0 + t * n
7
-
8
-
9
- def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
10
- r"""Loss target for flow matching."""
11
- return n - x0
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/functional/image.py DELETED
@@ -1,54 +0,0 @@
1
- from typing import List, Literal, Tuple
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
-
7
- def center_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
8
- num_channels, height, width = image.shape
9
- crop_h, crop_w = size
10
- top = (height - crop_h) // 2
11
- left = (width - crop_w) // 2
12
- return image[:, top : top + crop_h, left : left + crop_w]
13
-
14
-
15
- def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
16
- num_channels, height, width = image.shape
17
- target_h, target_w = size
18
- scale = max(target_h / height, target_w / width)
19
- new_h, new_w = int(height * scale), int(width * scale)
20
- image = F.interpolate(image, size=(new_h, new_w), mode="bilinear", align_corners=False)
21
- return center_crop_image(image, size)
22
-
23
-
24
- def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
25
- return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0]
26
-
27
-
28
- def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
29
- num_channels, height, width = image.shape
30
- aspect_ratio = width / height
31
-
32
- def aspect_ratio_diff(bucket):
33
- return abs((bucket[1] / bucket[0]) - aspect_ratio)
34
-
35
- return min(resolution_buckets, key=aspect_ratio_diff)
36
-
37
-
38
- def resize_to_nearest_bucket_image(
39
- image: torch.Tensor,
40
- resolution_buckets: List[Tuple[int, int]],
41
- resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic",
42
- ) -> torch.Tensor:
43
- target_size = find_nearest_resolution_image(image, resolution_buckets)
44
-
45
- if resize_mode == "center_crop":
46
- return center_crop_image(image, target_size)
47
- elif resize_mode == "resize_crop":
48
- return resize_crop_image(image, target_size)
49
- elif resize_mode == "bicubic":
50
- return bicubic_resize_image(image, target_size)
51
- else:
52
- raise ValueError(
53
- f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'."
54
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/functional/text.py DELETED
@@ -1,26 +0,0 @@
1
- import random
2
- from typing import List, Union
3
-
4
- import torch
5
-
6
-
7
- def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]:
8
- if random.random() >= dropout_p:
9
- return caption
10
- if isinstance(caption, str):
11
- return ""
12
- return [""] * len(caption)
13
-
14
-
15
- def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor:
16
- if random.random() >= dropout_p:
17
- return embed
18
- embed = torch.zeros_like(embed)
19
- return embed
20
-
21
-
22
- def remove_prefix(text: str, prefixes: List[str]) -> str:
23
- for prefix in prefixes:
24
- if text.startswith(prefix):
25
- return text.removeprefix(prefix).strip()
26
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/functional/video.py DELETED
@@ -1,94 +0,0 @@
1
- from typing import List, Literal, Tuple
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
-
7
- def center_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
8
- num_frames, num_channels, height, width = video.shape
9
- crop_h, crop_w = size
10
- top = (height - crop_h) // 2
11
- left = (width - crop_w) // 2
12
- return video[:, :, top : top + crop_h, left : left + crop_w]
13
-
14
-
15
- def resize_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
16
- num_frames, num_channels, height, width = video.shape
17
- target_h, target_w = size
18
- scale = max(target_h / height, target_w / width)
19
- new_h, new_w = int(height * scale), int(width * scale)
20
- video = F.interpolate(video, size=(new_h, new_w), mode="bilinear", align_corners=False)
21
- return center_crop_video(video, size)
22
-
23
-
24
- def bicubic_resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
25
- num_frames, num_channels, height, width = video.shape
26
- video = F.interpolate(video, size=size, mode="bicubic", align_corners=False)
27
- return video
28
-
29
-
30
- def find_nearest_video_resolution(
31
- video: torch.Tensor, resolution_buckets: List[Tuple[int, int, int]]
32
- ) -> Tuple[int, int, int]:
33
- num_frames, num_channels, height, width = video.shape
34
- aspect_ratio = width / height
35
- possible_buckets = [b for b in resolution_buckets if b[0] <= num_frames]
36
-
37
- if not possible_buckets:
38
- best_frame_match = min(resolution_buckets, key=lambda b: abs(b[0] - num_frames))
39
- else:
40
- best_frame_match = max(possible_buckets, key=lambda b: b[0])
41
-
42
- frame_filtered_buckets = [b for b in resolution_buckets if b[0] == best_frame_match[0]]
43
-
44
- def aspect_ratio_diff(bucket):
45
- return abs((bucket[2] / bucket[1]) - aspect_ratio)
46
-
47
- return min(frame_filtered_buckets, key=aspect_ratio_diff)
48
-
49
-
50
- def resize_to_nearest_bucket_video(
51
- video: torch.Tensor,
52
- resolution_buckets: List[Tuple[int, int, int]],
53
- resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic",
54
- ) -> torch.Tensor:
55
- """
56
- Resizes a video tensor to the nearest resolution bucket using the specified mode.
57
- - It first finds a frame match with <= T frames.
58
- - Then, it selects the closest height/width bucket.
59
-
60
- Args:
61
- video (`torch.Tensor`):
62
- Input video tensor of shape `(B, T, C, H, W)`.
63
- resolution_buckets (`List[Tuple[int, int, int]]`):
64
- Available (num_frames, height, width) resolution buckets.
65
- resize_mode (`str`):
66
- One of ["center_crop", "resize_crop", "bicubic"].
67
-
68
- Returns:
69
- `torch.Tensor`:
70
- Resized video tensor of the nearest bucket resolution.
71
- """
72
- target_frames, target_h, target_w = find_nearest_video_resolution(video, resolution_buckets)
73
-
74
- # Adjust frame count: only interpolate frames if no lesser/equal frame count exists
75
- num_frames, num_channels, height, width = video.shape
76
- _first_frame_only = False
77
- if num_frames > target_frames:
78
- # Downsample: Select frames evenly
79
- indices = torch.linspace(0, num_frames - 1, target_frames).long()
80
- video = video[indices, :, :, :]
81
- elif num_frames < target_frames:
82
- _first_frame_only = False
83
-
84
- # Resize spatial resolution
85
- if resize_mode == "center_crop":
86
- return center_crop_video(video, (target_h, target_w)), _first_frame_only
87
- elif resize_mode == "resize_crop":
88
- return resize_crop_video(video, (target_h, target_w)), _first_frame_only
89
- elif resize_mode == "bicubic":
90
- return bicubic_resize_video(video, (target_h, target_w)), _first_frame_only
91
- else:
92
- raise ValueError(
93
- f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'."
94
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/logging.py DELETED
@@ -1,111 +0,0 @@
1
- import logging
2
- import os
3
- from typing import TYPE_CHECKING, Union
4
-
5
- from .constants import FINETRAINERS_LOG_LEVEL
6
-
7
-
8
- if TYPE_CHECKING:
9
- from .parallel import ParallelBackendType
10
-
11
-
12
- class FinetrainersLoggerAdapter(logging.LoggerAdapter):
13
- def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None:
14
- super().__init__(logger, {})
15
- self.parallel_backend = parallel_backend
16
- self._log_freq = {}
17
- self._log_freq_counter = {}
18
-
19
- def log(
20
- self,
21
- level,
22
- msg,
23
- *args,
24
- main_process_only: bool = False,
25
- local_main_process_only: bool = True,
26
- in_order: bool = False,
27
- **kwargs,
28
- ):
29
- # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
30
- kwargs.setdefault("stacklevel", 2)
31
-
32
- if not self.isEnabledFor(level):
33
- return
34
-
35
- if self.parallel_backend is None:
36
- if int(os.environ.get("RANK", 0)) == 0:
37
- msg, kwargs = self.process(msg, kwargs)
38
- self.logger.log(level, msg, *args, **kwargs)
39
- return
40
-
41
- if (main_process_only or local_main_process_only) and in_order:
42
- raise ValueError(
43
- "Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True."
44
- )
45
-
46
- if (main_process_only and self.parallel_backend.is_main_process) or (
47
- local_main_process_only and self.parallel_backend.is_local_main_process
48
- ):
49
- msg, kwargs = self.process(msg, kwargs)
50
- self.logger.log(level, msg, *args, **kwargs)
51
- return
52
-
53
- if in_order:
54
- for i in range(self.parallel_backend.world_size):
55
- if self.rank == i:
56
- msg, kwargs = self.process(msg, kwargs)
57
- self.logger.log(level, msg, *args, **kwargs)
58
- self.parallel_backend.wait_for_everyone()
59
- return
60
-
61
- if not main_process_only and not local_main_process_only:
62
- msg, kwargs = self.process(msg, kwargs)
63
- self.logger.log(level, msg, *args, **kwargs)
64
- return
65
-
66
- def log_freq(
67
- self,
68
- level: str,
69
- name: str,
70
- msg: str,
71
- frequency: int,
72
- *,
73
- main_process_only: bool = False,
74
- local_main_process_only: bool = True,
75
- in_order: bool = False,
76
- **kwargs,
77
- ) -> None:
78
- if frequency <= 0:
79
- return
80
- if name not in self._log_freq_counter:
81
- self._log_freq[name] = frequency
82
- self._log_freq_counter[name] = 0
83
- if self._log_freq_counter[name] % self._log_freq[name] == 0:
84
- self.log(
85
- level,
86
- msg,
87
- main_process_only=main_process_only,
88
- local_main_process_only=local_main_process_only,
89
- in_order=in_order,
90
- **kwargs,
91
- )
92
- self._log_freq_counter[name] += 1
93
-
94
-
95
- def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]:
96
- global _logger
97
- return _logger
98
-
99
-
100
- def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter:
101
- _logger.parallel_backend = parallel_backend
102
-
103
-
104
- _logger = logging.getLogger("finetrainers")
105
- _logger.setLevel(FINETRAINERS_LOG_LEVEL)
106
- _console_handler = logging.StreamHandler()
107
- _console_handler.setLevel(FINETRAINERS_LOG_LEVEL)
108
- _formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
109
- _console_handler.setFormatter(_formatter)
110
- _logger.addHandler(_console_handler)
111
- _logger = FinetrainersLoggerAdapter(_logger)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .modeling_utils import ModelSpecification
 
 
finetrainers/models/cogvideox/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .base_specification import CogVideoXModelSpecification
 
 
finetrainers/models/cogvideox/base_specification.py DELETED
@@ -1,423 +0,0 @@
1
- import os
2
- from typing import Any, Dict, List, Optional, Tuple
3
-
4
- import torch
5
- from accelerate import init_empty_weights
6
- from diffusers import (
7
- AutoencoderKLCogVideoX,
8
- CogVideoXDDIMScheduler,
9
- CogVideoXImageToVideoPipeline,
10
- CogVideoXPipeline,
11
- CogVideoXTransformer3DModel,
12
- )
13
- from PIL.Image import Image
14
- from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
15
-
16
- from ... import data
17
- from ...logging import get_logger
18
- from ...processors import ProcessorMixin, T5Processor
19
- from ...typing import ArtifactType, SchedulerType
20
- from ...utils import get_non_null_items
21
- from ..modeling_utils import ModelSpecification
22
- from ..utils import DiagonalGaussianDistribution
23
- from .utils import prepare_rotary_positional_embeddings
24
-
25
-
26
- logger = get_logger()
27
-
28
-
29
- class CogVideoXLatentEncodeProcessor(ProcessorMixin):
30
- r"""
31
- Processor to encode image/video into latents using the CogVideoX VAE.
32
-
33
- Args:
34
- output_names (`List[str]`):
35
- The names of the outputs that the processor returns. The outputs are in the following order:
36
- - latents: The latents of the input image/video.
37
- """
38
-
39
- def __init__(self, output_names: List[str]):
40
- super().__init__()
41
- self.output_names = output_names
42
- assert len(self.output_names) == 1
43
-
44
- def forward(
45
- self,
46
- vae: AutoencoderKLCogVideoX,
47
- image: Optional[torch.Tensor] = None,
48
- video: Optional[torch.Tensor] = None,
49
- generator: Optional[torch.Generator] = None,
50
- compute_posterior: bool = True,
51
- ) -> Dict[str, torch.Tensor]:
52
- device = vae.device
53
- dtype = vae.dtype
54
-
55
- if image is not None:
56
- video = image.unsqueeze(1)
57
-
58
- assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
59
- video = video.to(device=device, dtype=vae.dtype)
60
- video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
61
-
62
- if compute_posterior:
63
- latents = vae.encode(video).latent_dist.sample(generator=generator)
64
- latents = latents.to(dtype=dtype)
65
- else:
66
- if vae.use_slicing and video.shape[0] > 1:
67
- encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
68
- moments = torch.cat(encoded_slices)
69
- else:
70
- moments = vae._encode(video)
71
- latents = moments.to(dtype=dtype)
72
-
73
- latents = latents.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] -> [B, F, C, H, W]
74
- return {self.output_names[0]: latents}
75
-
76
-
77
- class CogVideoXModelSpecification(ModelSpecification):
78
- def __init__(
79
- self,
80
- pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b",
81
- tokenizer_id: Optional[str] = None,
82
- text_encoder_id: Optional[str] = None,
83
- transformer_id: Optional[str] = None,
84
- vae_id: Optional[str] = None,
85
- text_encoder_dtype: torch.dtype = torch.bfloat16,
86
- transformer_dtype: torch.dtype = torch.bfloat16,
87
- vae_dtype: torch.dtype = torch.bfloat16,
88
- revision: Optional[str] = None,
89
- cache_dir: Optional[str] = None,
90
- condition_model_processors: List[ProcessorMixin] = None,
91
- latent_model_processors: List[ProcessorMixin] = None,
92
- **kwargs,
93
- ) -> None:
94
- super().__init__(
95
- pretrained_model_name_or_path=pretrained_model_name_or_path,
96
- tokenizer_id=tokenizer_id,
97
- text_encoder_id=text_encoder_id,
98
- transformer_id=transformer_id,
99
- vae_id=vae_id,
100
- text_encoder_dtype=text_encoder_dtype,
101
- transformer_dtype=transformer_dtype,
102
- vae_dtype=vae_dtype,
103
- revision=revision,
104
- cache_dir=cache_dir,
105
- )
106
-
107
- if condition_model_processors is None:
108
- condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
109
- if latent_model_processors is None:
110
- latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])]
111
-
112
- self.condition_model_processors = condition_model_processors
113
- self.latent_model_processors = latent_model_processors
114
-
115
- @property
116
- def _resolution_dim_keys(self):
117
- return {"latents": (1, 3, 4)}
118
-
119
- def load_condition_models(self) -> Dict[str, torch.nn.Module]:
120
- if self.tokenizer_id is not None:
121
- tokenizer = AutoTokenizer.from_pretrained(
122
- self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
123
- )
124
- else:
125
- tokenizer = T5Tokenizer.from_pretrained(
126
- self.pretrained_model_name_or_path,
127
- subfolder="tokenizer",
128
- revision=self.revision,
129
- cache_dir=self.cache_dir,
130
- )
131
-
132
- if self.text_encoder_id is not None:
133
- text_encoder = AutoModel.from_pretrained(
134
- self.text_encoder_id,
135
- torch_dtype=self.text_encoder_dtype,
136
- revision=self.revision,
137
- cache_dir=self.cache_dir,
138
- )
139
- else:
140
- text_encoder = T5EncoderModel.from_pretrained(
141
- self.pretrained_model_name_or_path,
142
- subfolder="text_encoder",
143
- torch_dtype=self.text_encoder_dtype,
144
- revision=self.revision,
145
- cache_dir=self.cache_dir,
146
- )
147
-
148
- return {"tokenizer": tokenizer, "text_encoder": text_encoder}
149
-
150
- def load_latent_models(self) -> Dict[str, torch.nn.Module]:
151
- if self.vae_id is not None:
152
- vae = AutoencoderKLCogVideoX.from_pretrained(
153
- self.vae_id,
154
- torch_dtype=self.vae_dtype,
155
- revision=self.revision,
156
- cache_dir=self.cache_dir,
157
- )
158
- else:
159
- vae = AutoencoderKLCogVideoX.from_pretrained(
160
- self.pretrained_model_name_or_path,
161
- subfolder="vae",
162
- torch_dtype=self.vae_dtype,
163
- revision=self.revision,
164
- cache_dir=self.cache_dir,
165
- )
166
-
167
- return {"vae": vae}
168
-
169
- def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
170
- if self.transformer_id is not None:
171
- transformer = CogVideoXTransformer3DModel.from_pretrained(
172
- self.transformer_id,
173
- torch_dtype=self.transformer_dtype,
174
- revision=self.revision,
175
- cache_dir=self.cache_dir,
176
- )
177
- else:
178
- transformer = CogVideoXTransformer3DModel.from_pretrained(
179
- self.pretrained_model_name_or_path,
180
- subfolder="transformer",
181
- torch_dtype=self.transformer_dtype,
182
- revision=self.revision,
183
- cache_dir=self.cache_dir,
184
- )
185
-
186
- scheduler = CogVideoXDDIMScheduler.from_pretrained(
187
- self.pretrained_model_name_or_path, subfolder="scheduler", revision=self.revision, cache_dir=self.cache_dir
188
- )
189
-
190
- return {"transformer": transformer, "scheduler": scheduler}
191
-
192
- def load_pipeline(
193
- self,
194
- tokenizer: Optional[T5Tokenizer] = None,
195
- text_encoder: Optional[T5EncoderModel] = None,
196
- transformer: Optional[CogVideoXTransformer3DModel] = None,
197
- vae: Optional[AutoencoderKLCogVideoX] = None,
198
- scheduler: Optional[CogVideoXDDIMScheduler] = None,
199
- enable_slicing: bool = False,
200
- enable_tiling: bool = False,
201
- enable_model_cpu_offload: bool = False,
202
- training: bool = False,
203
- **kwargs,
204
- ) -> CogVideoXPipeline:
205
- components = {
206
- "tokenizer": tokenizer,
207
- "text_encoder": text_encoder,
208
- "transformer": transformer,
209
- "vae": vae,
210
- "scheduler": scheduler,
211
- }
212
- components = get_non_null_items(components)
213
-
214
- pipe = CogVideoXPipeline.from_pretrained(
215
- self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
216
- )
217
- pipe.text_encoder.to(self.text_encoder_dtype)
218
- pipe.vae.to(self.vae_dtype)
219
-
220
- if not training:
221
- pipe.transformer.to(self.transformer_dtype)
222
-
223
- if enable_slicing:
224
- pipe.vae.enable_slicing()
225
- if enable_tiling:
226
- pipe.vae.enable_tiling()
227
- if enable_model_cpu_offload:
228
- pipe.enable_model_cpu_offload()
229
-
230
- return pipe
231
-
232
- @torch.no_grad()
233
- def prepare_conditions(
234
- self,
235
- tokenizer: T5Tokenizer,
236
- text_encoder: T5EncoderModel,
237
- caption: str,
238
- max_sequence_length: int = 226,
239
- **kwargs,
240
- ) -> Dict[str, Any]:
241
- conditions = {
242
- "tokenizer": tokenizer,
243
- "text_encoder": text_encoder,
244
- "caption": caption,
245
- "max_sequence_length": max_sequence_length,
246
- **kwargs,
247
- }
248
- input_keys = set(conditions.keys())
249
- conditions = super().prepare_conditions(**conditions)
250
- conditions = {k: v for k, v in conditions.items() if k not in input_keys}
251
- conditions.pop("prompt_attention_mask", None)
252
- return conditions
253
-
254
- @torch.no_grad()
255
- def prepare_latents(
256
- self,
257
- vae: AutoencoderKLCogVideoX,
258
- image: Optional[torch.Tensor] = None,
259
- video: Optional[torch.Tensor] = None,
260
- generator: Optional[torch.Generator] = None,
261
- compute_posterior: bool = True,
262
- **kwargs,
263
- ) -> Dict[str, torch.Tensor]:
264
- conditions = {
265
- "vae": vae,
266
- "image": image,
267
- "video": video,
268
- "generator": generator,
269
- "compute_posterior": compute_posterior,
270
- **kwargs,
271
- }
272
- input_keys = set(conditions.keys())
273
- conditions = super().prepare_latents(**conditions)
274
- conditions = {k: v for k, v in conditions.items() if k not in input_keys}
275
- return conditions
276
-
277
- def forward(
278
- self,
279
- transformer: CogVideoXTransformer3DModel,
280
- scheduler: CogVideoXDDIMScheduler,
281
- condition_model_conditions: Dict[str, torch.Tensor],
282
- latent_model_conditions: Dict[str, torch.Tensor],
283
- sigmas: torch.Tensor,
284
- generator: Optional[torch.Generator] = None,
285
- compute_posterior: bool = True,
286
- **kwargs,
287
- ) -> Tuple[torch.Tensor, ...]:
288
- # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself.
289
- VAE_SPATIAL_SCALE_FACTOR = 8
290
- rope_base_height = self.transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR
291
- rope_base_width = self.transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR
292
- patch_size = self.transformer_config.patch_size
293
- patch_size_t = getattr(self.transformer_config, "patch_size_t", None)
294
-
295
- if compute_posterior:
296
- latents = latent_model_conditions.pop("latents")
297
- else:
298
- posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"), _dim=2)
299
- latents = posterior.sample(generator=generator)
300
- del posterior
301
-
302
- if not getattr(self.vae_config, "invert_scale_latents", False):
303
- latents = latents * self.vae_config.scaling_factor
304
-
305
- if patch_size_t is not None:
306
- latents = self._pad_frames(latents, patch_size_t)
307
-
308
- timesteps = (sigmas.flatten() * 1000.0).long()
309
-
310
- noise = torch.zeros_like(latents).normal_(generator=generator)
311
- noisy_latents = scheduler.add_noise(latents, noise, timesteps)
312
-
313
- batch_size, num_frames, num_channels, height, width = latents.shape
314
- ofs_emb = (
315
- None
316
- if getattr(self.transformer_config, "ofs_embed_dim", None) is None
317
- else latents.new_full((batch_size,), fill_value=2.0)
318
- )
319
-
320
- image_rotary_emb = (
321
- prepare_rotary_positional_embeddings(
322
- height=height * VAE_SPATIAL_SCALE_FACTOR,
323
- width=width * VAE_SPATIAL_SCALE_FACTOR,
324
- num_frames=num_frames,
325
- vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR,
326
- patch_size=patch_size,
327
- patch_size_t=patch_size_t,
328
- attention_head_dim=self.transformer_config.attention_head_dim,
329
- device=transformer.device,
330
- base_height=rope_base_height,
331
- base_width=rope_base_width,
332
- )
333
- if self.transformer_config.use_rotary_positional_embeddings
334
- else None
335
- )
336
-
337
- latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
338
- latent_model_conditions["image_rotary_emb"] = image_rotary_emb
339
- latent_model_conditions["ofs"] = ofs_emb
340
-
341
- velocity = transformer(
342
- **latent_model_conditions,
343
- **condition_model_conditions,
344
- timestep=timesteps,
345
- return_dict=False,
346
- )[0]
347
- # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same
348
- # code paths as scheduler.get_velocity(), which can be confusing to understand.
349
- pred = scheduler.get_velocity(velocity, noisy_latents, timesteps)
350
- target = latents
351
-
352
- return pred, target, sigmas
353
-
354
- def validation(
355
- self,
356
- pipeline: CogVideoXPipeline,
357
- prompt: str,
358
- image: Optional[Image] = None,
359
- height: Optional[int] = None,
360
- width: Optional[int] = None,
361
- num_frames: Optional[int] = None,
362
- num_inference_steps: int = 50,
363
- generator: Optional[torch.Generator] = None,
364
- **kwargs,
365
- ) -> List[ArtifactType]:
366
- # TODO(aryan): add support for more parameters
367
- if image is not None:
368
- pipeline = CogVideoXImageToVideoPipeline.from_pipe(pipeline)
369
-
370
- generation_kwargs = {
371
- "prompt": prompt,
372
- "image": image,
373
- "height": height,
374
- "width": width,
375
- "num_frames": num_frames,
376
- "num_inference_steps": num_inference_steps,
377
- "generator": generator,
378
- "return_dict": True,
379
- "output_type": "pil",
380
- }
381
- generation_kwargs = get_non_null_items(generation_kwargs)
382
- video = pipeline(**generation_kwargs).frames[0]
383
- return [data.VideoArtifact(value=video)]
384
-
385
- def _save_lora_weights(
386
- self,
387
- directory: str,
388
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
389
- scheduler: Optional[SchedulerType] = None,
390
- *args,
391
- **kwargs,
392
- ) -> None:
393
- # TODO(aryan): this needs refactoring
394
- if transformer_state_dict is not None:
395
- CogVideoXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
396
- if scheduler is not None:
397
- scheduler.save_pretrained(os.path.join(directory, "scheduler"))
398
-
399
- def _save_model(
400
- self,
401
- directory: str,
402
- transformer: CogVideoXTransformer3DModel,
403
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
404
- scheduler: Optional[SchedulerType] = None,
405
- ) -> None:
406
- # TODO(aryan): this needs refactoring
407
- if transformer_state_dict is not None:
408
- with init_empty_weights():
409
- transformer_copy = CogVideoXTransformer3DModel.from_config(transformer.config)
410
- transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
411
- transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
412
- if scheduler is not None:
413
- scheduler.save_pretrained(os.path.join(directory, "scheduler"))
414
-
415
- @staticmethod
416
- def _pad_frames(latents: torch.Tensor, patch_size_t: int) -> torch.Tensor:
417
- num_frames = latents.size(1)
418
- additional_frames = patch_size_t - (num_frames % patch_size_t)
419
- if additional_frames > 0:
420
- last_frame = latents[:, -1:]
421
- padding_frames = last_frame.expand(-1, additional_frames, -1, -1, -1)
422
- latents = torch.cat([latents, padding_frames], dim=1)
423
- return latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/cogvideox/utils.py DELETED
@@ -1,51 +0,0 @@
1
- from typing import Optional, Tuple
2
-
3
- import torch
4
- from diffusers.models.embeddings import get_3d_rotary_pos_embed
5
- from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
6
-
7
-
8
- def prepare_rotary_positional_embeddings(
9
- height: int,
10
- width: int,
11
- num_frames: int,
12
- vae_scale_factor_spatial: int = 8,
13
- patch_size: int = 2,
14
- patch_size_t: int = None,
15
- attention_head_dim: int = 64,
16
- device: Optional[torch.device] = None,
17
- base_height: int = 480,
18
- base_width: int = 720,
19
- ) -> Tuple[torch.Tensor, torch.Tensor]:
20
- grid_height = height // (vae_scale_factor_spatial * patch_size)
21
- grid_width = width // (vae_scale_factor_spatial * patch_size)
22
- base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
23
- base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
24
-
25
- if patch_size_t is None:
26
- # CogVideoX 1.0
27
- grid_crops_coords = get_resize_crop_region_for_grid(
28
- (grid_height, grid_width), base_size_width, base_size_height
29
- )
30
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
31
- embed_dim=attention_head_dim,
32
- crops_coords=grid_crops_coords,
33
- grid_size=(grid_height, grid_width),
34
- temporal_size=num_frames,
35
- )
36
- else:
37
- # CogVideoX 1.5
38
- base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
39
-
40
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
41
- embed_dim=attention_head_dim,
42
- crops_coords=None,
43
- grid_size=(grid_height, grid_width),
44
- temporal_size=base_num_frames,
45
- grid_type="slice",
46
- max_size=(base_size_height, base_size_width),
47
- )
48
-
49
- freqs_cos = freqs_cos.to(device=device)
50
- freqs_sin = freqs_sin.to(device=device)
51
- return freqs_cos, freqs_sin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/cogview4/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .base_specification import CogView4ModelSpecification
 
 
finetrainers/models/cogview4/base_specification.py DELETED
@@ -1,395 +0,0 @@
1
- import os
2
- from typing import Any, Dict, List, Optional, Tuple
3
-
4
- import torch
5
- from accelerate import init_empty_weights
6
- from diffusers import (
7
- AutoencoderKL,
8
- CogView4Pipeline,
9
- CogView4Transformer2DModel,
10
- FlowMatchEulerDiscreteScheduler,
11
- )
12
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
13
- from transformers import AutoTokenizer, GlmModel
14
-
15
- from ... import data
16
- from ... import functional as FF
17
- from ...logging import get_logger
18
- from ...processors import CogView4GLMProcessor, ProcessorMixin
19
- from ...typing import ArtifactType, SchedulerType
20
- from ...utils import get_non_null_items
21
- from ..modeling_utils import ModelSpecification
22
-
23
-
24
- logger = get_logger()
25
-
26
-
27
- class CogView4LatentEncodeProcessor(ProcessorMixin):
28
- r"""
29
- Processor to encode image/video into latents using the LTX VAE.
30
-
31
- Args:
32
- output_names (`List[str]`):
33
- The names of the outputs that the processor returns. The outputs are in the following order:
34
- - latents: The latents of the input image/video.
35
- - original_size: The original size of the input image/video.
36
- - target_size: The target size of the input image/video.
37
- - crop_coords: The top-left crop coordinates of the input image/video.
38
- """
39
-
40
- def __init__(self, output_names: List[str]):
41
- super().__init__()
42
-
43
- self.output_names = output_names
44
- assert len(self.output_names) == 4
45
-
46
- def forward(
47
- self,
48
- vae: AutoencoderKL,
49
- image: Optional[torch.Tensor] = None,
50
- video: Optional[torch.Tensor] = None,
51
- generator: Optional[torch.Generator] = None,
52
- compute_posterior: bool = True,
53
- _original_height: Optional[int] = None,
54
- _original_width: Optional[int] = None,
55
- ) -> Dict[str, torch.Tensor]:
56
- device = vae.device
57
- dtype = vae.dtype
58
-
59
- if video is not None:
60
- # TODO(aryan): perhaps better would be to flatten(0, 1), but need to account for reshaping sigmas accordingly
61
- image = video[:, 0] # [B, F, C, H, W] -> [B, 1, C, H, W]
62
-
63
- assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor"
64
- image = image.to(device=device, dtype=vae.dtype)
65
-
66
- if compute_posterior:
67
- latents = vae.encode(image).latent_dist.sample(generator=generator)
68
- latents = latents.to(dtype=dtype)
69
- else:
70
- if vae.use_slicing and image.shape[0] > 1:
71
- encoded_slices = [vae._encode(x_slice) for x_slice in image.split(1)]
72
- moments = torch.cat(encoded_slices)
73
- else:
74
- moments = vae._encode(image)
75
- latents = moments.to(dtype=dtype)
76
-
77
- batch_size = latents.size(0)
78
- target_height = image.size(2)
79
- target_width = image.size(3)
80
- original_size = torch.tensor([(_original_height, _original_width)], device=device, dtype=dtype).repeat(
81
- batch_size, 1
82
- )
83
- target_size = torch.tensor([(target_height, target_width)], device=device, dtype=dtype).repeat(batch_size, 1)
84
- crop_coords = torch.tensor([(0, 0)], device=device, dtype=dtype).repeat(batch_size, 1)
85
-
86
- return {
87
- self.output_names[0]: latents,
88
- self.output_names[1]: original_size,
89
- self.output_names[2]: target_size,
90
- self.output_names[3]: crop_coords,
91
- }
92
-
93
-
94
- class CogView4ModelSpecification(ModelSpecification):
95
- def __init__(
96
- self,
97
- pretrained_model_name_or_path: str = "THUDM/CogView4-6B",
98
- tokenizer_id: Optional[str] = None,
99
- text_encoder_id: Optional[str] = None,
100
- transformer_id: Optional[str] = None,
101
- vae_id: Optional[str] = None,
102
- text_encoder_dtype: torch.dtype = torch.bfloat16,
103
- transformer_dtype: torch.dtype = torch.bfloat16,
104
- vae_dtype: torch.dtype = torch.bfloat16,
105
- revision: Optional[str] = None,
106
- cache_dir: Optional[str] = None,
107
- condition_model_processors: List[ProcessorMixin] = None,
108
- latent_model_processors: List[ProcessorMixin] = None,
109
- **kwargs,
110
- ) -> None:
111
- super().__init__(
112
- pretrained_model_name_or_path=pretrained_model_name_or_path,
113
- tokenizer_id=tokenizer_id,
114
- text_encoder_id=text_encoder_id,
115
- transformer_id=transformer_id,
116
- vae_id=vae_id,
117
- text_encoder_dtype=text_encoder_dtype,
118
- transformer_dtype=transformer_dtype,
119
- vae_dtype=vae_dtype,
120
- revision=revision,
121
- cache_dir=cache_dir,
122
- )
123
-
124
- if condition_model_processors is None:
125
- condition_model_processors = [CogView4GLMProcessor(["encoder_hidden_states"])]
126
- if latent_model_processors is None:
127
- latent_model_processors = [
128
- CogView4LatentEncodeProcessor(["latents", "original_size", "target_size", "crop_coords"])
129
- ]
130
-
131
- self.condition_model_processors = condition_model_processors
132
- self.latent_model_processors = latent_model_processors
133
-
134
- @property
135
- def _resolution_dim_keys(self):
136
- return {"latents": (2, 3)}
137
-
138
- def load_condition_models(self) -> Dict[str, torch.nn.Module]:
139
- if self.tokenizer_id is not None:
140
- tokenizer = AutoTokenizer.from_pretrained(
141
- self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
142
- )
143
- else:
144
- tokenizer = AutoTokenizer.from_pretrained(
145
- self.pretrained_model_name_or_path,
146
- subfolder="tokenizer",
147
- revision=self.revision,
148
- cache_dir=self.cache_dir,
149
- )
150
-
151
- if self.text_encoder_id is not None:
152
- text_encoder = GlmModel.from_pretrained(
153
- self.text_encoder_id,
154
- torch_dtype=self.text_encoder_dtype,
155
- revision=self.revision,
156
- cache_dir=self.cache_dir,
157
- )
158
- else:
159
- text_encoder = GlmModel.from_pretrained(
160
- self.pretrained_model_name_or_path,
161
- subfolder="text_encoder",
162
- torch_dtype=self.text_encoder_dtype,
163
- revision=self.revision,
164
- cache_dir=self.cache_dir,
165
- )
166
-
167
- return {"tokenizer": tokenizer, "text_encoder": text_encoder}
168
-
169
- def load_latent_models(self) -> Dict[str, torch.nn.Module]:
170
- if self.vae_id is not None:
171
- vae = AutoencoderKL.from_pretrained(
172
- self.vae_id,
173
- torch_dtype=self.vae_dtype,
174
- revision=self.revision,
175
- cache_dir=self.cache_dir,
176
- )
177
- else:
178
- vae = AutoencoderKL.from_pretrained(
179
- self.pretrained_model_name_or_path,
180
- subfolder="vae",
181
- torch_dtype=self.vae_dtype,
182
- revision=self.revision,
183
- cache_dir=self.cache_dir,
184
- )
185
-
186
- return {"vae": vae}
187
-
188
- def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
189
- if self.transformer_id is not None:
190
- transformer = CogView4Transformer2DModel.from_pretrained(
191
- self.transformer_id,
192
- torch_dtype=self.transformer_dtype,
193
- revision=self.revision,
194
- cache_dir=self.cache_dir,
195
- )
196
- else:
197
- transformer = CogView4Transformer2DModel.from_pretrained(
198
- self.pretrained_model_name_or_path,
199
- subfolder="transformer",
200
- torch_dtype=self.transformer_dtype,
201
- revision=self.revision,
202
- cache_dir=self.cache_dir,
203
- )
204
-
205
- scheduler = FlowMatchEulerDiscreteScheduler()
206
-
207
- return {"transformer": transformer, "scheduler": scheduler}
208
-
209
- def load_pipeline(
210
- self,
211
- tokenizer: Optional[AutoTokenizer] = None,
212
- text_encoder: Optional[GlmModel] = None,
213
- transformer: Optional[CogView4Transformer2DModel] = None,
214
- vae: Optional[AutoencoderKL] = None,
215
- scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
216
- enable_slicing: bool = False,
217
- enable_tiling: bool = False,
218
- enable_model_cpu_offload: bool = False,
219
- training: bool = False,
220
- **kwargs,
221
- ) -> CogView4Pipeline:
222
- components = {
223
- "tokenizer": tokenizer,
224
- "text_encoder": text_encoder,
225
- "transformer": transformer,
226
- "vae": vae,
227
- # Load the scheduler based on CogView4's config instead of using the default initialization being used for training
228
- # "scheduler": scheduler,
229
- }
230
- components = get_non_null_items(components)
231
-
232
- pipe = CogView4Pipeline.from_pretrained(
233
- self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
234
- )
235
- pipe.text_encoder.to(self.text_encoder_dtype)
236
- pipe.vae.to(self.vae_dtype)
237
-
238
- if not training:
239
- pipe.transformer.to(self.transformer_dtype)
240
-
241
- if enable_slicing:
242
- pipe.vae.enable_slicing()
243
- if enable_tiling:
244
- pipe.vae.enable_tiling()
245
- if enable_model_cpu_offload:
246
- pipe.enable_model_cpu_offload()
247
-
248
- return pipe
249
-
250
- @torch.no_grad()
251
- def prepare_conditions(
252
- self,
253
- tokenizer: AutoTokenizer,
254
- text_encoder: GlmModel,
255
- caption: str,
256
- max_sequence_length: int = 1024,
257
- **kwargs,
258
- ) -> Dict[str, Any]:
259
- conditions = {
260
- "tokenizer": tokenizer,
261
- "text_encoder": text_encoder,
262
- "caption": caption,
263
- "max_sequence_length": max_sequence_length,
264
- **kwargs,
265
- }
266
- input_keys = set(conditions.keys())
267
- conditions = super().prepare_conditions(**conditions)
268
- conditions = {k: v for k, v in conditions.items() if k not in input_keys}
269
- return conditions
270
-
271
- @torch.no_grad()
272
- def prepare_latents(
273
- self,
274
- vae: AutoencoderKL,
275
- image: Optional[torch.Tensor] = None,
276
- video: Optional[torch.Tensor] = None,
277
- generator: Optional[torch.Generator] = None,
278
- compute_posterior: bool = True,
279
- _original_height: Optional[int] = None,
280
- _original_width: Optional[int] = None,
281
- **kwargs,
282
- ) -> Dict[str, torch.Tensor]:
283
- conditions = {
284
- "vae": vae,
285
- "image": image,
286
- "video": video,
287
- "generator": generator,
288
- "compute_posterior": compute_posterior,
289
- "_original_height": _original_height,
290
- "_original_width": _original_width,
291
- **kwargs,
292
- }
293
- input_keys = set(conditions.keys())
294
- conditions = super().prepare_latents(**conditions)
295
- conditions = {k: v for k, v in conditions.items() if k not in input_keys}
296
- return conditions
297
-
298
- def forward(
299
- self,
300
- transformer: CogView4Transformer2DModel,
301
- condition_model_conditions: Dict[str, torch.Tensor],
302
- latent_model_conditions: Dict[str, torch.Tensor],
303
- sigmas: torch.Tensor,
304
- generator: Optional[torch.Generator] = None,
305
- compute_posterior: bool = True,
306
- **kwargs,
307
- ) -> Tuple[torch.Tensor, ...]:
308
- if compute_posterior:
309
- latents = latent_model_conditions.pop("latents")
310
- else:
311
- posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
312
- latents = posterior.sample(generator=generator)
313
- del posterior
314
-
315
- latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
316
- noise = torch.zeros_like(latents).normal_(generator=generator)
317
- timesteps = (sigmas.flatten() * 1000.0).long()
318
-
319
- base_image_sequence_length = 256
320
- base_shift = 0.25
321
- max_shift = 0.75
322
-
323
- image_sequence_length = latents.size(2) * latents.size(3) // self.transformer_config.patch_size**2
324
- mu = (image_sequence_length / base_image_sequence_length) ** 0.5
325
- mu = mu * max_shift + base_shift
326
- shifted_sigmas = mu / (mu + (1 / sigmas - 1) ** 1.0)
327
- noisy_latents = FF.flow_match_xt(latents, noise, shifted_sigmas)
328
-
329
- latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
330
-
331
- pred = transformer(
332
- **latent_model_conditions,
333
- **condition_model_conditions,
334
- timestep=timesteps,
335
- return_dict=False,
336
- )[0]
337
- target = FF.flow_match_target(noise, latents)
338
-
339
- # NOTE: shifted_sigmas loss weighting seems to work better than sigmas. Needs more investigation
340
- # but let's keep it this way for now. Longer training runs should reveal more insights.
341
- # return pred, target, sigmas
342
- return pred, target, shifted_sigmas
343
-
344
- def validation(
345
- self,
346
- pipeline: CogView4Pipeline,
347
- prompt: str,
348
- height: Optional[int] = None,
349
- width: Optional[int] = None,
350
- num_inference_steps: int = 50,
351
- generator: Optional[torch.Generator] = None,
352
- **kwargs,
353
- ) -> List[ArtifactType]:
354
- generation_kwargs = {
355
- "prompt": prompt,
356
- "height": height,
357
- "width": width,
358
- "num_inference_steps": num_inference_steps,
359
- "generator": generator,
360
- "return_dict": True,
361
- "output_type": "pil",
362
- }
363
- generation_kwargs = get_non_null_items(generation_kwargs)
364
- image = pipeline(**generation_kwargs).images[0]
365
- return [data.ImageArtifact(value=image)]
366
-
367
- def _save_lora_weights(
368
- self,
369
- directory: str,
370
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
371
- scheduler: Optional[SchedulerType] = None,
372
- *args,
373
- **kwargs,
374
- ) -> None:
375
- # TODO(aryan): this needs refactoring
376
- if transformer_state_dict is not None:
377
- CogView4Pipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
378
- if scheduler is not None:
379
- scheduler.save_pretrained(os.path.join(directory, "scheduler"))
380
-
381
- def _save_model(
382
- self,
383
- directory: str,
384
- transformer: CogView4Transformer2DModel,
385
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
386
- scheduler: Optional[SchedulerType] = None,
387
- ) -> None:
388
- # TODO(aryan): this needs refactoring
389
- if transformer_state_dict is not None:
390
- with init_empty_weights():
391
- transformer_copy = CogView4Transformer2DModel.from_config(transformer.config)
392
- transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
393
- transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
394
- if scheduler is not None:
395
- scheduler.save_pretrained(os.path.join(directory, "scheduler"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/hunyuan_video/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .base_specification import HunyuanVideoModelSpecification
 
 
finetrainers/models/hunyuan_video/base_specification.py DELETED
@@ -1,410 +0,0 @@
1
- import os
2
- from typing import Any, Dict, List, Optional, Tuple
3
-
4
- import torch
5
- from accelerate import init_empty_weights
6
- from diffusers import (
7
- AutoencoderKLHunyuanVideo,
8
- FlowMatchEulerDiscreteScheduler,
9
- HunyuanVideoPipeline,
10
- HunyuanVideoTransformer3DModel,
11
- )
12
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
13
- from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel
14
-
15
- from ... import data
16
- from ... import functional as FF
17
- from ...logging import get_logger
18
- from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin
19
- from ...typing import ArtifactType, SchedulerType
20
- from ...utils import get_non_null_items
21
- from ..modeling_utils import ModelSpecification
22
-
23
-
24
- logger = get_logger()
25
-
26
-
27
- class HunyuanLatentEncodeProcessor(ProcessorMixin):
28
- r"""
29
- Processor to encode image/video into latents using the HunyuanVideo VAE.
30
-
31
- Args:
32
- output_names (`List[str]`):
33
- The names of the outputs that the processor returns. The outputs are in the following order:
34
- - latents: The latents of the input image/video.
35
- """
36
-
37
- def __init__(self, output_names: List[str]):
38
- super().__init__()
39
- self.output_names = output_names
40
- assert len(self.output_names) == 1
41
-
42
- def forward(
43
- self,
44
- vae: AutoencoderKLHunyuanVideo,
45
- image: Optional[torch.Tensor] = None,
46
- video: Optional[torch.Tensor] = None,
47
- generator: Optional[torch.Generator] = None,
48
- compute_posterior: bool = True,
49
- ) -> Dict[str, torch.Tensor]:
50
- device = vae.device
51
- dtype = vae.dtype
52
-
53
- if image is not None:
54
- video = image.unsqueeze(1)
55
-
56
- assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
57
- video = video.to(device=device, dtype=vae.dtype)
58
- video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
59
-
60
- if compute_posterior:
61
- latents = vae.encode(video).latent_dist.sample(generator=generator)
62
- latents = latents.to(dtype=dtype)
63
- else:
64
- if vae.use_slicing and video.shape[0] > 1:
65
- encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
66
- moments = torch.cat(encoded_slices)
67
- else:
68
- moments = vae._encode(video)
69
- latents = moments.to(dtype=dtype)
70
-
71
- return {self.output_names[0]: latents}
72
-
73
-
74
- class HunyuanVideoModelSpecification(ModelSpecification):
75
- def __init__(
76
- self,
77
- pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo",
78
- tokenizer_id: Optional[str] = None,
79
- text_encoder_id: Optional[str] = None,
80
- transformer_id: Optional[str] = None,
81
- vae_id: Optional[str] = None,
82
- text_encoder_dtype: torch.dtype = torch.bfloat16,
83
- transformer_dtype: torch.dtype = torch.bfloat16,
84
- vae_dtype: torch.dtype = torch.bfloat16,
85
- revision: Optional[str] = None,
86
- cache_dir: Optional[str] = None,
87
- condition_model_processors: List[ProcessorMixin] = None,
88
- latent_model_processors: List[ProcessorMixin] = None,
89
- **kwargs,
90
- ) -> None:
91
- super().__init__(
92
- pretrained_model_name_or_path=pretrained_model_name_or_path,
93
- tokenizer_id=tokenizer_id,
94
- text_encoder_id=text_encoder_id,
95
- transformer_id=transformer_id,
96
- vae_id=vae_id,
97
- text_encoder_dtype=text_encoder_dtype,
98
- transformer_dtype=transformer_dtype,
99
- vae_dtype=vae_dtype,
100
- revision=revision,
101
- cache_dir=cache_dir,
102
- )
103
-
104
- if condition_model_processors is None:
105
- condition_model_processors = [
106
- LlamaProcessor(["encoder_hidden_states", "encoder_attention_mask"]),
107
- CLIPPooledProcessor(
108
- ["pooled_projections"],
109
- input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"},
110
- ),
111
- ]
112
- if latent_model_processors is None:
113
- latent_model_processors = [HunyuanLatentEncodeProcessor(["latents"])]
114
-
115
- self.condition_model_processors = condition_model_processors
116
- self.latent_model_processors = latent_model_processors
117
-
118
- @property
119
- def _resolution_dim_keys(self):
120
- return {"latents": (2, 3, 4)}
121
-
122
- def load_condition_models(self) -> Dict[str, torch.nn.Module]:
123
- if self.tokenizer_id is not None:
124
- tokenizer = AutoTokenizer.from_pretrained(
125
- self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
126
- )
127
- else:
128
- tokenizer = AutoTokenizer.from_pretrained(
129
- self.pretrained_model_name_or_path,
130
- subfolder="tokenizer",
131
- revision=self.revision,
132
- cache_dir=self.cache_dir,
133
- )
134
-
135
- if self.tokenizer_2_id is not None:
136
- tokenizer_2 = CLIPTokenizer.from_pretrained(
137
- self.tokenizer_2_id, revision=self.revision, cache_dir=self.cache_dir
138
- )
139
- else:
140
- tokenizer_2 = CLIPTokenizer.from_pretrained(
141
- self.pretrained_model_name_or_path,
142
- subfolder="tokenizer_2",
143
- revision=self.revision,
144
- cache_dir=self.cache_dir,
145
- )
146
-
147
- if self.text_encoder_id is not None:
148
- text_encoder = LlamaModel.from_pretrained(
149
- self.text_encoder_id,
150
- torch_dtype=self.text_encoder_dtype,
151
- revision=self.revision,
152
- cache_dir=self.cache_dir,
153
- )
154
- else:
155
- text_encoder = LlamaModel.from_pretrained(
156
- self.pretrained_model_name_or_path,
157
- subfolder="text_encoder",
158
- torch_dtype=self.text_encoder_dtype,
159
- revision=self.revision,
160
- cache_dir=self.cache_dir,
161
- )
162
-
163
- if self.text_encoder_2_id is not None:
164
- text_encoder_2 = CLIPTextModel.from_pretrained(
165
- self.text_encoder_2_id,
166
- torch_dtype=self.text_encoder_2_dtype,
167
- revision=self.revision,
168
- cache_dir=self.cache_dir,
169
- )
170
- else:
171
- text_encoder_2 = CLIPTextModel.from_pretrained(
172
- self.pretrained_model_name_or_path,
173
- subfolder="text_encoder_2",
174
- torch_dtype=self.text_encoder_2_dtype,
175
- revision=self.revision,
176
- cache_dir=self.cache_dir,
177
- )
178
-
179
- return {
180
- "tokenizer": tokenizer,
181
- "tokenizer_2": tokenizer_2,
182
- "text_encoder": text_encoder,
183
- "text_encoder_2": text_encoder_2,
184
- }
185
-
186
- def load_latent_models(self) -> Dict[str, torch.nn.Module]:
187
- if self.vae_id is not None:
188
- vae = AutoencoderKLHunyuanVideo.from_pretrained(
189
- self.vae_id,
190
- torch_dtype=self.vae_dtype,
191
- revision=self.revision,
192
- cache_dir=self.cache_dir,
193
- )
194
- else:
195
- vae = AutoencoderKLHunyuanVideo.from_pretrained(
196
- self.pretrained_model_name_or_path,
197
- subfolder="vae",
198
- torch_dtype=self.vae_dtype,
199
- revision=self.revision,
200
- cache_dir=self.cache_dir,
201
- )
202
-
203
- return {"vae": vae}
204
-
205
- def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
206
- if self.transformer_id is not None:
207
- transformer = HunyuanVideoTransformer3DModel.from_pretrained(
208
- self.transformer_id,
209
- torch_dtype=self.transformer_dtype,
210
- revision=self.revision,
211
- cache_dir=self.cache_dir,
212
- )
213
- else:
214
- transformer = HunyuanVideoTransformer3DModel.from_pretrained(
215
- self.pretrained_model_name_or_path,
216
- subfolder="transformer",
217
- torch_dtype=self.transformer_dtype,
218
- revision=self.revision,
219
- cache_dir=self.cache_dir,
220
- )
221
-
222
- scheduler = FlowMatchEulerDiscreteScheduler()
223
-
224
- return {"transformer": transformer, "scheduler": scheduler}
225
-
226
- def load_pipeline(
227
- self,
228
- tokenizer: Optional[AutoTokenizer] = None,
229
- tokenizer_2: Optional[CLIPTokenizer] = None,
230
- text_encoder: Optional[LlamaModel] = None,
231
- text_encoder_2: Optional[CLIPTextModel] = None,
232
- transformer: Optional[HunyuanVideoTransformer3DModel] = None,
233
- vae: Optional[AutoencoderKLHunyuanVideo] = None,
234
- scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
235
- enable_slicing: bool = False,
236
- enable_tiling: bool = False,
237
- enable_model_cpu_offload: bool = False,
238
- training: bool = False,
239
- **kwargs,
240
- ) -> HunyuanVideoPipeline:
241
- components = {
242
- "tokenizer": tokenizer,
243
- "tokenizer_2": tokenizer_2,
244
- "text_encoder": text_encoder,
245
- "text_encoder_2": text_encoder_2,
246
- "transformer": transformer,
247
- "vae": vae,
248
- "scheduler": scheduler,
249
- }
250
- components = get_non_null_items(components)
251
-
252
- pipe = HunyuanVideoPipeline.from_pretrained(
253
- self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
254
- )
255
- pipe.text_encoder.to(self.text_encoder_dtype)
256
- pipe.text_encoder_2.to(self.text_encoder_2_dtype)
257
- pipe.vae.to(self.vae_dtype)
258
-
259
- if not training:
260
- pipe.transformer.to(self.transformer_dtype)
261
-
262
- if enable_slicing:
263
- pipe.vae.enable_slicing()
264
- if enable_tiling:
265
- pipe.vae.enable_tiling()
266
- if enable_model_cpu_offload:
267
- pipe.enable_model_cpu_offload()
268
-
269
- return pipe
270
-
271
- @torch.no_grad()
272
- def prepare_conditions(
273
- self,
274
- tokenizer: AutoTokenizer,
275
- tokenizer_2: CLIPTokenizer,
276
- text_encoder: LlamaModel,
277
- text_encoder_2: CLIPTextModel,
278
- caption: str,
279
- max_sequence_length: int = 256,
280
- **kwargs,
281
- ) -> Dict[str, Any]:
282
- conditions = {
283
- "tokenizer": tokenizer,
284
- "tokenizer_2": tokenizer_2,
285
- "text_encoder": text_encoder,
286
- "text_encoder_2": text_encoder_2,
287
- "caption": caption,
288
- "max_sequence_length": max_sequence_length,
289
- **kwargs,
290
- }
291
- input_keys = set(conditions.keys())
292
- conditions = super().prepare_conditions(**conditions)
293
- conditions = {k: v for k, v in conditions.items() if k not in input_keys}
294
- return conditions
295
-
296
- @torch.no_grad()
297
- def prepare_latents(
298
- self,
299
- vae: AutoencoderKLHunyuanVideo,
300
- image: Optional[torch.Tensor] = None,
301
- video: Optional[torch.Tensor] = None,
302
- generator: Optional[torch.Generator] = None,
303
- compute_posterior: bool = True,
304
- **kwargs,
305
- ) -> Dict[str, torch.Tensor]:
306
- conditions = {
307
- "vae": vae,
308
- "image": image,
309
- "video": video,
310
- "generator": generator,
311
- "compute_posterior": compute_posterior,
312
- **kwargs,
313
- }
314
- input_keys = set(conditions.keys())
315
- conditions = super().prepare_latents(**conditions)
316
- conditions = {k: v for k, v in conditions.items() if k not in input_keys}
317
- return conditions
318
-
319
- def forward(
320
- self,
321
- transformer: HunyuanVideoTransformer3DModel,
322
- condition_model_conditions: Dict[str, torch.Tensor],
323
- latent_model_conditions: Dict[str, torch.Tensor],
324
- sigmas: torch.Tensor,
325
- guidance: float = 1.0,
326
- generator: Optional[torch.Generator] = None,
327
- compute_posterior: bool = True,
328
- **kwargs,
329
- ) -> Tuple[torch.Tensor, ...]:
330
- if compute_posterior:
331
- latents = latent_model_conditions.pop("latents")
332
- else:
333
- posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
334
- latents = posterior.sample(generator=generator)
335
- del posterior
336
-
337
- latents = latents * self.vae_config.scaling_factor
338
- noise = torch.zeros_like(latents).normal_(generator=generator)
339
- noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
340
-
341
- timesteps = (sigmas.flatten() * 1000.0).long()
342
- guidance = latents.new_full((latents.size(0),), fill_value=guidance) * 1000.0
343
-
344
- latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
345
- latent_model_conditions["guidance"] = guidance
346
-
347
- pred = transformer(
348
- **latent_model_conditions,
349
- **condition_model_conditions,
350
- timestep=timesteps,
351
- return_dict=False,
352
- )[0]
353
- target = FF.flow_match_target(noise, latents)
354
-
355
- return pred, target, sigmas
356
-
357
- def validation(
358
- self,
359
- pipeline: HunyuanVideoPipeline,
360
- prompt: str,
361
- height: Optional[int] = None,
362
- width: Optional[int] = None,
363
- num_frames: Optional[int] = None,
364
- num_inference_steps: int = 50,
365
- generator: Optional[torch.Generator] = None,
366
- **kwargs,
367
- ) -> List[ArtifactType]:
368
- generation_kwargs = {
369
- "prompt": prompt,
370
- "height": height,
371
- "width": width,
372
- "num_frames": num_frames,
373
- "num_inference_steps": num_inference_steps,
374
- "generator": generator,
375
- "return_dict": True,
376
- "output_type": "pil",
377
- }
378
- generation_kwargs = get_non_null_items(generation_kwargs)
379
- video = pipeline(**generation_kwargs).frames[0]
380
- return [data.VideoArtifact(value=video)]
381
-
382
- def _save_lora_weights(
383
- self,
384
- directory: str,
385
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
386
- scheduler: Optional[SchedulerType] = None,
387
- *args,
388
- **kwargs,
389
- ) -> None:
390
- # TODO(aryan): this needs refactoring
391
- if transformer_state_dict is not None:
392
- HunyuanVideoPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
393
- if scheduler is not None:
394
- scheduler.save_pretrained(os.path.join(directory, "scheduler"))
395
-
396
- def _save_model(
397
- self,
398
- directory: str,
399
- transformer: HunyuanVideoTransformer3DModel,
400
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
401
- scheduler: Optional[SchedulerType] = None,
402
- ) -> None:
403
- # TODO(aryan): this needs refactoring
404
- if transformer_state_dict is not None:
405
- with init_empty_weights():
406
- transformer_copy = HunyuanVideoTransformer3DModel.from_config(transformer.config)
407
- transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
408
- transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
409
- if scheduler is not None:
410
- scheduler.save_pretrained(os.path.join(directory, "scheduler"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/ltx_video/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .base_specification import LTXVideoModelSpecification
 
 
finetrainers/models/ltx_video/base_specification.py DELETED
@@ -1,517 +0,0 @@
1
- import os
2
- import random
3
- from typing import Any, Dict, List, Optional, Tuple
4
-
5
- import torch
6
- from accelerate import init_empty_weights
7
- from diffusers import (
8
- AutoencoderKLLTXVideo,
9
- FlowMatchEulerDiscreteScheduler,
10
- LTXImageToVideoPipeline,
11
- LTXPipeline,
12
- LTXVideoTransformer3DModel,
13
- )
14
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
15
- from PIL.Image import Image
16
- from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
17
-
18
- from ... import data
19
- from ... import functional as FF
20
- from ...logging import get_logger
21
- from ...parallel import ParallelBackendEnum
22
- from ...processors import ProcessorMixin, T5Processor
23
- from ...typing import ArtifactType, SchedulerType
24
- from ...utils import get_non_null_items
25
- from ..modeling_utils import ModelSpecification
26
-
27
-
28
- logger = get_logger()
29
-
30
-
31
- class LTXLatentEncodeProcessor(ProcessorMixin):
32
- r"""
33
- Processor to encode image/video into latents using the LTX VAE.
34
-
35
- Args:
36
- output_names (`List[str]`):
37
- The names of the outputs that the processor returns. The outputs are in the following order:
38
- - latents: The latents of the input image/video.
39
- - num_frames: The number of frames in the input video.
40
- - height: The height of the input image/video.
41
- - width: The width of the input image/video.
42
- - latents_mean: The latent channel means from the VAE state dict.
43
- - latents_std: The latent channel standard deviations from the VAE state dict.
44
- """
45
-
46
- def __init__(self, output_names: List[str]):
47
- super().__init__()
48
- self.output_names = output_names
49
- assert len(self.output_names) == 6
50
-
51
- def forward(
52
- self,
53
- vae: AutoencoderKLLTXVideo,
54
- image: Optional[torch.Tensor] = None,
55
- video: Optional[torch.Tensor] = None,
56
- generator: Optional[torch.Generator] = None,
57
- compute_posterior: bool = True,
58
- ) -> Dict[str, torch.Tensor]:
59
- device = vae.device
60
- dtype = vae.dtype
61
-
62
- if image is not None:
63
- video = image.unsqueeze(1)
64
-
65
- assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
66
- video = video.to(device=device, dtype=vae.dtype)
67
- video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
68
-
69
- if compute_posterior:
70
- latents = vae.encode(video).latent_dist.sample(generator=generator)
71
- latents = latents.to(dtype=dtype)
72
- else:
73
- if vae.use_slicing and video.shape[0] > 1:
74
- encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
75
- moments = torch.cat(encoded_slices)
76
- else:
77
- moments = vae._encode(video)
78
- latents = moments.to(dtype=dtype)
79
-
80
- _, _, num_frames, height, width = latents.shape
81
-
82
- return {
83
- self.output_names[0]: latents,
84
- self.output_names[1]: num_frames,
85
- self.output_names[2]: height,
86
- self.output_names[3]: width,
87
- self.output_names[4]: vae.latents_mean,
88
- self.output_names[5]: vae.latents_std,
89
- }
90
-
91
-
92
- class LTXVideoModelSpecification(ModelSpecification):
93
- def __init__(
94
- self,
95
- pretrained_model_name_or_path: str = "Lightricks/LTX-Video",
96
- tokenizer_id: Optional[str] = None,
97
- text_encoder_id: Optional[str] = None,
98
- transformer_id: Optional[str] = None,
99
- vae_id: Optional[str] = None,
100
- text_encoder_dtype: torch.dtype = torch.bfloat16,
101
- transformer_dtype: torch.dtype = torch.bfloat16,
102
- vae_dtype: torch.dtype = torch.bfloat16,
103
- revision: Optional[str] = None,
104
- cache_dir: Optional[str] = None,
105
- condition_model_processors: List[ProcessorMixin] = None,
106
- latent_model_processors: List[ProcessorMixin] = None,
107
- **kwargs,
108
- ) -> None:
109
- super().__init__(
110
- pretrained_model_name_or_path=pretrained_model_name_or_path,
111
- tokenizer_id=tokenizer_id,
112
- text_encoder_id=text_encoder_id,
113
- transformer_id=transformer_id,
114
- vae_id=vae_id,
115
- text_encoder_dtype=text_encoder_dtype,
116
- transformer_dtype=transformer_dtype,
117
- vae_dtype=vae_dtype,
118
- revision=revision,
119
- cache_dir=cache_dir,
120
- )
121
-
122
- if condition_model_processors is None:
123
- condition_model_processors = [T5Processor(["encoder_hidden_states", "encoder_attention_mask"])]
124
- if latent_model_processors is None:
125
- latent_model_processors = [
126
- LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"])
127
- ]
128
-
129
- self.condition_model_processors = condition_model_processors
130
- self.latent_model_processors = latent_model_processors
131
-
132
- @property
133
- def _resolution_dim_keys(self):
134
- return {"latents": (2, 3, 4)}
135
-
136
- def load_condition_models(self) -> Dict[str, torch.nn.Module]:
137
- if self.tokenizer_id is not None:
138
- tokenizer = AutoTokenizer.from_pretrained(
139
- self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
140
- )
141
- else:
142
- tokenizer = T5Tokenizer.from_pretrained(
143
- self.pretrained_model_name_or_path,
144
- subfolder="tokenizer",
145
- revision=self.revision,
146
- cache_dir=self.cache_dir,
147
- )
148
-
149
- if self.text_encoder_id is not None:
150
- text_encoder = AutoModel.from_pretrained(
151
- self.text_encoder_id,
152
- torch_dtype=self.text_encoder_dtype,
153
- revision=self.revision,
154
- cache_dir=self.cache_dir,
155
- )
156
- else:
157
- text_encoder = T5EncoderModel.from_pretrained(
158
- self.pretrained_model_name_or_path,
159
- subfolder="text_encoder",
160
- torch_dtype=self.text_encoder_dtype,
161
- revision=self.revision,
162
- cache_dir=self.cache_dir,
163
- )
164
-
165
- return {"tokenizer": tokenizer, "text_encoder": text_encoder}
166
-
167
- def load_latent_models(self) -> Dict[str, torch.nn.Module]:
168
- if self.vae_id is not None:
169
- vae = AutoencoderKLLTXVideo.from_pretrained(
170
- self.vae_id,
171
- torch_dtype=self.vae_dtype,
172
- revision=self.revision,
173
- cache_dir=self.cache_dir,
174
- )
175
- else:
176
- vae = AutoencoderKLLTXVideo.from_pretrained(
177
- self.pretrained_model_name_or_path,
178
- subfolder="vae",
179
- torch_dtype=self.vae_dtype,
180
- revision=self.revision,
181
- cache_dir=self.cache_dir,
182
- )
183
-
184
- return {"vae": vae}
185
-
186
- def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
187
- if self.transformer_id is not None:
188
- transformer = LTXVideoTransformer3DModel.from_pretrained(
189
- self.transformer_id,
190
- torch_dtype=self.transformer_dtype,
191
- revision=self.revision,
192
- cache_dir=self.cache_dir,
193
- )
194
- else:
195
- transformer = LTXVideoTransformer3DModel.from_pretrained(
196
- self.pretrained_model_name_or_path,
197
- subfolder="transformer",
198
- torch_dtype=self.transformer_dtype,
199
- revision=self.revision,
200
- cache_dir=self.cache_dir,
201
- )
202
-
203
- scheduler = FlowMatchEulerDiscreteScheduler()
204
-
205
- return {"transformer": transformer, "scheduler": scheduler}
206
-
207
- def load_pipeline(
208
- self,
209
- tokenizer: Optional[T5Tokenizer] = None,
210
- text_encoder: Optional[T5EncoderModel] = None,
211
- transformer: Optional[LTXVideoTransformer3DModel] = None,
212
- vae: Optional[AutoencoderKLLTXVideo] = None,
213
- scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
214
- enable_slicing: bool = False,
215
- enable_tiling: bool = False,
216
- enable_model_cpu_offload: bool = False,
217
- training: bool = False,
218
- **kwargs,
219
- ) -> LTXPipeline:
220
- components = {
221
- "tokenizer": tokenizer,
222
- "text_encoder": text_encoder,
223
- "transformer": transformer,
224
- "vae": vae,
225
- "scheduler": scheduler,
226
- }
227
- components = get_non_null_items(components)
228
-
229
- pipe = LTXPipeline.from_pretrained(
230
- self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
231
- )
232
- pipe.text_encoder.to(self.text_encoder_dtype)
233
- pipe.vae.to(self.vae_dtype)
234
-
235
- if not training:
236
- pipe.transformer.to(self.transformer_dtype)
237
-
238
- if enable_slicing:
239
- pipe.vae.enable_slicing()
240
- if enable_tiling:
241
- pipe.vae.enable_tiling()
242
- if enable_model_cpu_offload:
243
- pipe.enable_model_cpu_offload()
244
-
245
- return pipe
246
-
247
- @torch.no_grad()
248
- def prepare_conditions(
249
- self,
250
- tokenizer: T5Tokenizer,
251
- text_encoder: T5EncoderModel,
252
- caption: str,
253
- max_sequence_length: int = 128,
254
- **kwargs,
255
- ) -> Dict[str, Any]:
256
- conditions = {
257
- "tokenizer": tokenizer,
258
- "text_encoder": text_encoder,
259
- "caption": caption,
260
- "max_sequence_length": max_sequence_length,
261
- **kwargs,
262
- }
263
- input_keys = set(conditions.keys())
264
- conditions = super().prepare_conditions(**conditions)
265
- conditions = {k: v for k, v in conditions.items() if k not in input_keys}
266
- return conditions
267
-
268
- @torch.no_grad()
269
- def prepare_latents(
270
- self,
271
- vae: AutoencoderKLLTXVideo,
272
- image: Optional[torch.Tensor] = None,
273
- video: Optional[torch.Tensor] = None,
274
- generator: Optional[torch.Generator] = None,
275
- compute_posterior: bool = True,
276
- **kwargs,
277
- ) -> Dict[str, torch.Tensor]:
278
- conditions = {
279
- "vae": vae,
280
- "image": image,
281
- "video": video,
282
- "generator": generator,
283
- "compute_posterior": compute_posterior,
284
- **kwargs,
285
- }
286
- input_keys = set(conditions.keys())
287
- conditions = super().prepare_latents(**conditions)
288
- conditions = {k: v for k, v in conditions.items() if k not in input_keys}
289
- return conditions
290
-
291
- def forward(
292
- self,
293
- transformer: LTXVideoTransformer3DModel,
294
- condition_model_conditions: Dict[str, torch.Tensor],
295
- latent_model_conditions: Dict[str, torch.Tensor],
296
- sigmas: torch.Tensor,
297
- generator: Optional[torch.Generator] = None,
298
- compute_posterior: bool = True,
299
- **kwargs,
300
- ) -> Tuple[torch.Tensor, ...]:
301
- # TODO(aryan): make this configurable? Should it be?
302
- first_frame_conditioning_p = 0.1
303
- min_first_frame_sigma = 0.25
304
-
305
- if compute_posterior:
306
- latents = latent_model_conditions.pop("latents")
307
- else:
308
- posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
309
- latents = posterior.sample(generator=generator)
310
- del posterior
311
-
312
- latents_mean = latent_model_conditions.pop("latents_mean")
313
- latents_std = latent_model_conditions.pop("latents_std")
314
-
315
- latents = self._normalize_latents(latents, latents_mean, latents_std)
316
- noise = torch.zeros_like(latents).normal_(generator=generator)
317
-
318
- if random.random() < first_frame_conditioning_p:
319
- # Based on Section 2.4 of the paper, it mentions that the first frame timesteps should be a small random value.
320
- # Making as estimated guess, we limit the sigmas to be at least 0.2.
321
- # torch.rand_like returns values in [0, 1). We want to make sure that the first frame sigma is <= actual sigmas
322
- # for image conditioning. In order to do this, we rescale by multiplying with sigmas so the range is [0, sigmas).
323
- first_frame_sigma = torch.rand_like(sigmas) * sigmas
324
- first_frame_sigma = torch.min(first_frame_sigma, sigmas.new_full(sigmas.shape, min_first_frame_sigma))
325
-
326
- latents_first_frame, latents_rest = latents[:, :, :1], latents[:, :, 1:]
327
- noisy_latents_first_frame = FF.flow_match_xt(latents_first_frame, noise[:, :, :1], first_frame_sigma)
328
- noisy_latents_remaining = FF.flow_match_xt(latents_rest, noise[:, :, 1:], sigmas)
329
- noisy_latents = torch.cat([noisy_latents_first_frame, noisy_latents_remaining], dim=2)
330
- else:
331
- noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
332
-
333
- patch_size = self.transformer_config.patch_size
334
- patch_size_t = self.transformer_config.patch_size_t
335
-
336
- latents = self._pack_latents(latents, patch_size, patch_size_t)
337
- noise = self._pack_latents(noise, patch_size, patch_size_t)
338
- noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
339
- sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
340
- timesteps = (sigmas * 1000.0).long()
341
-
342
- latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
343
-
344
- # TODO(aryan): make this configurable
345
- frame_rate = 25
346
- temporal_compression_ratio = 8
347
- vae_spatial_compression_ratio = 32
348
- latent_frame_rate = frame_rate / temporal_compression_ratio
349
-
350
- rope_interpolation_scale = [
351
- 1 / latent_frame_rate,
352
- vae_spatial_compression_ratio,
353
- vae_spatial_compression_ratio,
354
- ]
355
-
356
- pred = transformer(
357
- **latent_model_conditions,
358
- **condition_model_conditions,
359
- timestep=timesteps,
360
- rope_interpolation_scale=rope_interpolation_scale,
361
- return_dict=False,
362
- )[0]
363
- target = FF.flow_match_target(noise, latents)
364
-
365
- return pred, target, sigmas
366
-
367
- def validation(
368
- self,
369
- pipeline: LTXPipeline,
370
- prompt: str,
371
- image: Optional[Image] = None,
372
- height: Optional[int] = None,
373
- width: Optional[int] = None,
374
- num_frames: Optional[int] = None,
375
- frame_rate: int = 25,
376
- num_inference_steps: int = 50,
377
- generator: Optional[torch.Generator] = None,
378
- **kwargs,
379
- ) -> List[ArtifactType]:
380
- if image is not None:
381
- pipeline = LTXImageToVideoPipeline.from_pipe(pipeline)
382
-
383
- generation_kwargs = {
384
- "prompt": prompt,
385
- "image": image,
386
- "height": height,
387
- "width": width,
388
- "num_frames": num_frames,
389
- "frame_rate": frame_rate,
390
- "num_inference_steps": num_inference_steps,
391
- "generator": generator,
392
- "return_dict": True,
393
- "output_type": "pil",
394
- }
395
- generation_kwargs = get_non_null_items(generation_kwargs)
396
- video = pipeline(**generation_kwargs).frames[0]
397
- return [data.VideoArtifact(value=video)]
398
-
399
- def _save_lora_weights(
400
- self,
401
- directory: str,
402
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
403
- scheduler: Optional[SchedulerType] = None,
404
- *args,
405
- **kwargs,
406
- ) -> None:
407
- # TODO(aryan): this needs refactoring
408
- if transformer_state_dict is not None:
409
- LTXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
410
- if scheduler is not None:
411
- scheduler.save_pretrained(os.path.join(directory, "scheduler"))
412
-
413
- def _save_model(
414
- self,
415
- directory: str,
416
- transformer: LTXVideoTransformer3DModel,
417
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
418
- scheduler: Optional[SchedulerType] = None,
419
- ) -> None:
420
- # TODO(aryan): this needs refactoring
421
- if transformer_state_dict is not None:
422
- with init_empty_weights():
423
- transformer_copy = LTXVideoTransformer3DModel.from_config(transformer.config)
424
- transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
425
- transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
426
- if scheduler is not None:
427
- scheduler.save_pretrained(os.path.join(directory, "scheduler"))
428
-
429
- def apply_tensor_parallel(
430
- self,
431
- backend: ParallelBackendEnum,
432
- device_mesh: torch.distributed.DeviceMesh,
433
- transformer: LTXVideoTransformer3DModel,
434
- **kwargs,
435
- ) -> None:
436
- if backend == ParallelBackendEnum.PTD:
437
- _apply_tensor_parallel_ptd(device_mesh, transformer)
438
- else:
439
- raise NotImplementedError(f"Parallel backend {backend} is not supported for LTXVideoModelSpecification")
440
-
441
- @staticmethod
442
- def _normalize_latents(
443
- latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
444
- ) -> torch.Tensor:
445
- # Normalize latents across the channel dimension [B, C, F, H, W]
446
- latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
447
- latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
448
- latents = ((latents.float() - latents_mean) * scaling_factor / latents_std).to(latents)
449
- return latents
450
-
451
- @staticmethod
452
- def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
453
- # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
454
- # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
455
- # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
456
- # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
457
- batch_size, num_channels, num_frames, height, width = latents.shape
458
- post_patch_num_frames = num_frames // patch_size_t
459
- post_patch_height = height // patch_size
460
- post_patch_width = width // patch_size
461
- latents = latents.reshape(
462
- batch_size,
463
- -1,
464
- post_patch_num_frames,
465
- patch_size_t,
466
- post_patch_height,
467
- patch_size,
468
- post_patch_width,
469
- patch_size,
470
- )
471
- latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
472
- return latents
473
-
474
-
475
- def _apply_tensor_parallel_ptd(
476
- device_mesh: torch.distributed.device_mesh.DeviceMesh, transformer: LTXVideoTransformer3DModel
477
- ) -> None:
478
- from torch.distributed.tensor.parallel import parallelize_module
479
- from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel
480
-
481
- transformer_plan = {
482
- # ===== Condition embeddings =====
483
- # "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(),
484
- # "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)),
485
- # "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()),
486
- # "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())),
487
- # "caption_projection.linear_1": ColwiseParallel(),
488
- # "caption_projection.linear_2": RowwiseParallel(),
489
- # "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False),
490
- # ===== =====
491
- }
492
-
493
- for block in transformer.transformer_blocks:
494
- block_plan = {}
495
-
496
- # ===== Attention =====
497
- # 8 all-to-all, 3 all-reduce
498
- # block_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False)
499
- # block_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False)
500
- # block_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False)
501
- # block_plan["attn1.norm_q"] = SequenceParallel()
502
- # block_plan["attn1.norm_k"] = SequenceParallel()
503
- # block_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
504
- # block_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False)
505
- # block_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False)
506
- # block_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False)
507
- # block_plan["attn2.norm_q"] = SequenceParallel()
508
- # block_plan["attn2.norm_k"] = SequenceParallel()
509
- # block_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
510
- # ===== =====
511
-
512
- block_plan["ff.net.0.proj"] = ColwiseParallel()
513
- block_plan["ff.net.2"] = RowwiseParallel()
514
-
515
- parallelize_module(block, device_mesh, block_plan)
516
-
517
- parallelize_module(transformer, device_mesh, transformer_plan)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/modeling_utils.py DELETED
@@ -1,289 +0,0 @@
1
- from typing import Any, Dict, List, Optional, Tuple, Union
2
-
3
- import torch
4
- from diffusers import DiffusionPipeline
5
- from diffusers.configuration_utils import FrozenDict
6
- from PIL.Image import Image
7
-
8
- from ..logging import get_logger
9
- from ..parallel import ParallelBackendEnum
10
- from ..processors import ProcessorMixin
11
- from ..typing import ArtifactType, SchedulerType, TokenizerType
12
- from ..utils import resolve_component_cls
13
-
14
-
15
- logger = get_logger()
16
-
17
- # TODO(aryan): we most likely don't need this. take a look after refactoring more
18
- # fmt: off
19
- IGNORE_KEYS_FOR_COLLATION = {"height", "width", "num_frames", "frame_rate", "rope_interpolation_scale", "return_dict", "attention_kwargs", "cross_attention_kwargs", "joint_attention_kwargs", "latents_mean", "latents_std"}
20
- # fmt: on
21
-
22
-
23
- class ModelSpecification:
24
- r"""
25
- The ModelSpecification class is an interface to be used for Diffusion training recipes. It provides
26
- loose structure about how to organize the code for training. The trainer implementations will
27
- make use of this interface to load models, prepare conditions, prepare latents, forward pass, etc.
28
- """
29
-
30
- def __init__(
31
- self,
32
- pretrained_model_name_or_path: Optional[str] = None,
33
- tokenizer_id: Optional[str] = None,
34
- tokenizer_2_id: Optional[str] = None,
35
- tokenizer_3_id: Optional[str] = None,
36
- text_encoder_id: Optional[str] = None,
37
- text_encoder_2_id: Optional[str] = None,
38
- text_encoder_3_id: Optional[str] = None,
39
- transformer_id: Optional[str] = None,
40
- vae_id: Optional[str] = None,
41
- text_encoder_dtype: torch.dtype = torch.bfloat16,
42
- text_encoder_2_dtype: torch.dtype = torch.bfloat16,
43
- text_encoder_3_dtype: torch.dtype = torch.bfloat16,
44
- transformer_dtype: torch.dtype = torch.bfloat16,
45
- vae_dtype: str = torch.bfloat16,
46
- revision: Optional[str] = None,
47
- cache_dir: Optional[str] = None,
48
- condition_model_processors: List[ProcessorMixin] = None,
49
- latent_model_processors: List[ProcessorMixin] = None,
50
- ) -> None:
51
- self.pretrained_model_name_or_path = pretrained_model_name_or_path
52
- self.tokenizer_id = tokenizer_id
53
- self.tokenizer_2_id = tokenizer_2_id
54
- self.tokenizer_3_id = tokenizer_3_id
55
- self.text_encoder_id = text_encoder_id
56
- self.text_encoder_2_id = text_encoder_2_id
57
- self.text_encoder_3_id = text_encoder_3_id
58
- self.transformer_id = transformer_id
59
- self.vae_id = vae_id
60
- self.text_encoder_dtype = text_encoder_dtype
61
- self.text_encoder_2_dtype = text_encoder_2_dtype
62
- self.text_encoder_3_dtype = text_encoder_3_dtype
63
- self.transformer_dtype = transformer_dtype
64
- self.vae_dtype = vae_dtype
65
- self.revision = revision
66
- self.cache_dir = cache_dir
67
- self.condition_model_processors = condition_model_processors or []
68
- self.latent_model_processors = latent_model_processors or []
69
-
70
- self.transformer_config: Dict[str, Any] = None
71
- self.vae_config: Dict[str, Any] = None
72
-
73
- self._load_configs()
74
-
75
- # TODO(aryan): revisit how to do this better without user having to worry about it
76
- @property
77
- def _resolution_dim_keys(self) -> Dict[str, Tuple[int, ...]]:
78
- raise NotImplementedError(
79
- f"ModelSpecification::_resolution_dim_keys is not implemented for {self.__class__.__name__}"
80
- )
81
-
82
- def load_condition_models(self) -> Dict[str, torch.nn.Module]:
83
- raise NotImplementedError(
84
- f"ModelSpecification::load_condition_models is not implemented for {self.__class__.__name__}"
85
- )
86
-
87
- def load_latent_models(self) -> Dict[str, torch.nn.Module]:
88
- raise NotImplementedError(
89
- f"ModelSpecification::load_latent_models is not implemented for {self.__class__.__name__}"
90
- )
91
-
92
- def load_diffusion_models(self) -> Dict[str, Union[torch.nn.Module]]:
93
- raise NotImplementedError(
94
- f"ModelSpecification::load_diffusion_models is not implemented for {self.__class__.__name__}"
95
- )
96
-
97
- def load_pipeline(
98
- self,
99
- tokenizer: Optional[TokenizerType] = None,
100
- tokenizer_2: Optional[TokenizerType] = None,
101
- tokenizer_3: Optional[TokenizerType] = None,
102
- text_encoder: Optional[torch.nn.Module] = None,
103
- text_encoder_2: Optional[torch.nn.Module] = None,
104
- text_encoder_3: Optional[torch.nn.Module] = None,
105
- transformer: Optional[torch.nn.Module] = None,
106
- vae: Optional[torch.nn.Module] = None,
107
- scheduler: Optional[SchedulerType] = None,
108
- enable_slicing: bool = False,
109
- enable_tiling: bool = False,
110
- enable_model_cpu_offload: bool = False,
111
- training: bool = False,
112
- **kwargs,
113
- ) -> DiffusionPipeline:
114
- raise NotImplementedError(
115
- f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}"
116
- )
117
-
118
- def prepare_conditions(self, **kwargs) -> Dict[str, Any]:
119
- for processor in self.condition_model_processors:
120
- result = processor(**kwargs)
121
- result_keys = set(result.keys())
122
- repeat_keys = result_keys.intersection(kwargs.keys())
123
- if repeat_keys:
124
- logger.warning(
125
- f"Processor {processor.__class__.__name__} returned keys that already exist in "
126
- f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
127
- f"be intended. Please rename the keys in the processor to avoid conflicts."
128
- )
129
- kwargs.update(result)
130
- return kwargs
131
-
132
- def prepare_latents(self, **kwargs) -> Dict[str, Any]:
133
- for processor in self.latent_model_processors:
134
- result = processor(**kwargs)
135
- result_keys = set(result.keys())
136
- repeat_keys = result_keys.intersection(kwargs.keys())
137
- if repeat_keys:
138
- logger.warning(
139
- f"Processor {processor.__class__.__name__} returned keys that already exist in "
140
- f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
141
- f"be intended. Please rename the keys in the processor to avoid conflicts."
142
- )
143
- kwargs.update(result)
144
- return kwargs
145
-
146
- def collate_conditions(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
147
- keys = list(data[0].keys())
148
- collated_data = {}
149
- for key in keys:
150
- if key in IGNORE_KEYS_FOR_COLLATION:
151
- collated_data[key] = data[0][key]
152
- continue
153
- collated_d = [d[key] for d in data]
154
- if isinstance(collated_d[0], torch.Tensor):
155
- collated_d = torch.cat(collated_d)
156
- collated_data[key] = collated_d
157
- return collated_data
158
-
159
- def collate_latents(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
160
- keys = list(data[0].keys())
161
- collated_data = {}
162
- for key in keys:
163
- if key in IGNORE_KEYS_FOR_COLLATION:
164
- collated_data[key] = data[0][key]
165
- continue
166
- collated_d = [d[key] for d in data]
167
- # TODO(aryan): Support multi-resolution collation
168
- if isinstance(collated_d[0], torch.Tensor):
169
- collated_d = torch.cat(collated_d)
170
- collated_data[key] = collated_d
171
- return collated_data
172
-
173
- def forward(
174
- self, transformer: torch.nn.Module, generator: Optional[torch.Generator] = None, **kwargs
175
- ) -> Dict[str, torch.Tensor]:
176
- raise NotImplementedError(f"ModelSpecification::forward is not implemented for {self.__class__.__name__}")
177
-
178
- def validation(
179
- self,
180
- pipeline: DiffusionPipeline,
181
- prompt: Optional[str] = None,
182
- image: Optional[Image] = None,
183
- video: Optional[List[Image]] = None,
184
- height: Optional[int] = None,
185
- width: Optional[int] = None,
186
- num_frames: Optional[int] = None,
187
- frame_rate: Optional[int] = None,
188
- generator: Optional[torch.Generator] = None,
189
- ) -> List[ArtifactType]:
190
- raise NotImplementedError(f"ModelSpecification::validation is not implemented for {self.__class__.__name__}")
191
-
192
- def _save_lora_weights(
193
- self,
194
- directory: str,
195
- transformer: torch.nn.Module,
196
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
197
- scheduler: Optional[SchedulerType] = None,
198
- ) -> None:
199
- r"""
200
- Save the lora state dicts of the model to the given directory.
201
-
202
- This API is not backwards compatible and will be changed in near future.
203
- """
204
- raise NotImplementedError(
205
- f"ModelSpecification::save_lora_weights is not implemented for {self.__class__.__name__}"
206
- )
207
-
208
- def _save_model(
209
- self,
210
- directory: str,
211
- transformer: torch.nn.Module,
212
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
213
- scheduler: Optional[SchedulerType] = None,
214
- ) -> None:
215
- r"""
216
- Save the state dicts to the given directory.
217
-
218
- This API is not backwards compatible and will be changed in near future.
219
- """
220
- raise NotImplementedError(f"ModelSpecification::save_model is not implemented for {self.__class__.__name__}")
221
-
222
- def apply_tensor_parallel(
223
- self,
224
- backend: ParallelBackendEnum,
225
- device_mesh: torch.distributed.DeviceMesh,
226
- text_encoder: torch.nn.Module,
227
- text_encoder_2: torch.nn.Module,
228
- text_encoder_3: torch.nn.Module,
229
- transformer: torch.nn.Module,
230
- vae: torch.nn.Module,
231
- ) -> None:
232
- raise NotImplementedError(
233
- f"ModelSpecification::apply_tensor_parallel is not implemented for {self.__class__.__name__}"
234
- )
235
-
236
- def _load_configs(self) -> None:
237
- self._load_transformer_config()
238
- self._load_vae_config()
239
-
240
- def _load_transformer_config(self) -> None:
241
- if self.transformer_id is not None:
242
- transformer_cls = resolve_component_cls(
243
- self.transformer_id,
244
- component_name="_class_name",
245
- filename="config.json",
246
- revision=self.revision,
247
- cache_dir=self.cache_dir,
248
- )
249
- self.transformer_config = transformer_cls.load_config(
250
- self.transformer_id, revision=self.revision, cache_dir=self.cache_dir
251
- )
252
- else:
253
- transformer_cls = resolve_component_cls(
254
- self.pretrained_model_name_or_path,
255
- component_name="transformer",
256
- filename="model_index.json",
257
- revision=self.revision,
258
- cache_dir=self.cache_dir,
259
- )
260
- self.transformer_config = transformer_cls.load_config(
261
- self.pretrained_model_name_or_path,
262
- subfolder="transformer",
263
- revision=self.revision,
264
- cache_dir=self.cache_dir,
265
- )
266
- self.transformer_config = FrozenDict(**self.transformer_config)
267
-
268
- def _load_vae_config(self) -> None:
269
- if self.vae_id is not None:
270
- vae_cls = resolve_component_cls(
271
- self.vae_id,
272
- component_name="_class_name",
273
- filename="config.json",
274
- revision=self.revision,
275
- cache_dir=self.cache_dir,
276
- )
277
- self.vae_config = vae_cls.load_config(self.vae_id, revision=self.revision, cache_dir=self.cache_dir)
278
- else:
279
- vae_cls = resolve_component_cls(
280
- self.pretrained_model_name_or_path,
281
- component_name="vae",
282
- filename="model_index.json",
283
- revision=self.revision,
284
- cache_dir=self.cache_dir,
285
- )
286
- self.vae_config = vae_cls.load_config(
287
- self.pretrained_model_name_or_path, subfolder="vae", revision=self.revision, cache_dir=self.cache_dir
288
- )
289
- self.vae_config = FrozenDict(**self.vae_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/utils.py DELETED
@@ -1,62 +0,0 @@
1
- from typing import Optional, Tuple
2
-
3
- import numpy as np
4
- import torch
5
- from diffusers.utils.torch_utils import randn_tensor
6
-
7
-
8
- class DiagonalGaussianDistribution(object):
9
- def __init__(self, parameters: torch.Tensor, deterministic: bool = False, _dim: int = 1):
10
- # Note: _dim is the new argument added here after copying from diffusers
11
- self.parameters = parameters
12
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=_dim)
13
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
14
- self.deterministic = deterministic
15
- self.std = torch.exp(0.5 * self.logvar)
16
- self.var = torch.exp(self.logvar)
17
- if self.deterministic:
18
- self.var = self.std = torch.zeros_like(
19
- self.mean, device=self.parameters.device, dtype=self.parameters.dtype
20
- )
21
-
22
- def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
23
- # make sure sample is on the same device as the parameters and has same dtype
24
- sample = randn_tensor(
25
- self.mean.shape,
26
- generator=generator,
27
- device=self.parameters.device,
28
- dtype=self.parameters.dtype,
29
- )
30
- x = self.mean + self.std * sample
31
- return x
32
-
33
- def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
34
- if self.deterministic:
35
- return torch.Tensor([0.0])
36
- else:
37
- if other is None:
38
- return 0.5 * torch.sum(
39
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
40
- dim=[1, 2, 3],
41
- )
42
- else:
43
- return 0.5 * torch.sum(
44
- torch.pow(self.mean - other.mean, 2) / other.var
45
- + self.var / other.var
46
- - 1.0
47
- - self.logvar
48
- + other.logvar,
49
- dim=[1, 2, 3],
50
- )
51
-
52
- def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
53
- if self.deterministic:
54
- return torch.Tensor([0.0])
55
- logtwopi = np.log(2.0 * np.pi)
56
- return 0.5 * torch.sum(
57
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
58
- dim=dims,
59
- )
60
-
61
- def mode(self) -> torch.Tensor:
62
- return self.mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/wan/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .base_specification import WanModelSpecification
 
 
finetrainers/models/wan/base_specification.py DELETED
@@ -1,393 +0,0 @@
1
- import os
2
- from typing import Any, Dict, List, Optional, Tuple
3
-
4
- import torch
5
- from accelerate import init_empty_weights
6
- from diffusers import (
7
- AutoencoderKLWan,
8
- FlowMatchEulerDiscreteScheduler,
9
- WanImageToVideoPipeline,
10
- WanPipeline,
11
- WanTransformer3DModel,
12
- )
13
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
14
- from PIL.Image import Image
15
- from transformers import AutoModel, AutoTokenizer, UMT5EncoderModel
16
-
17
- from ... import data
18
- from ... import functional as FF
19
- from ...logging import get_logger
20
- from ...processors import ProcessorMixin, T5Processor
21
- from ...typing import ArtifactType, SchedulerType
22
- from ...utils import get_non_null_items
23
- from ..modeling_utils import ModelSpecification
24
-
25
-
26
- logger = get_logger()
27
-
28
-
29
- class WanLatentEncodeProcessor(ProcessorMixin):
30
- r"""
31
- Processor to encode image/video into latents using the Wan VAE.
32
-
33
- Args:
34
- output_names (`List[str]`):
35
- The names of the outputs that the processor returns. The outputs are in the following order:
36
- - latents: The latents of the input image/video.
37
- """
38
-
39
- def __init__(self, output_names: List[str]):
40
- super().__init__()
41
- self.output_names = output_names
42
- assert len(self.output_names) == 3
43
-
44
- def forward(
45
- self,
46
- vae: AutoencoderKLWan,
47
- image: Optional[torch.Tensor] = None,
48
- video: Optional[torch.Tensor] = None,
49
- generator: Optional[torch.Generator] = None,
50
- compute_posterior: bool = True,
51
- ) -> Dict[str, torch.Tensor]:
52
- device = vae.device
53
- dtype = vae.dtype
54
-
55
- if image is not None:
56
- video = image.unsqueeze(1)
57
-
58
- assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
59
- video = video.to(device=device, dtype=vae.dtype)
60
- video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
61
-
62
- if compute_posterior:
63
- latents = vae.encode(video).latent_dist.sample(generator=generator)
64
- latents = latents.to(dtype=dtype)
65
- else:
66
- # TODO(aryan): refactor in diffusers to have use_slicing attribute
67
- # if vae.use_slicing and video.shape[0] > 1:
68
- # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
69
- # moments = torch.cat(encoded_slices)
70
- # else:
71
- # moments = vae._encode(video)
72
- moments = vae._encode(video)
73
- latents = moments.to(dtype=dtype)
74
-
75
- latents_mean = torch.tensor(vae.config.latents_mean)
76
- latents_std = 1.0 / torch.tensor(vae.config.latents_std)
77
-
78
- return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std}
79
-
80
-
81
- class WanModelSpecification(ModelSpecification):
82
- def __init__(
83
- self,
84
- pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
85
- tokenizer_id: Optional[str] = None,
86
- text_encoder_id: Optional[str] = None,
87
- transformer_id: Optional[str] = None,
88
- vae_id: Optional[str] = None,
89
- text_encoder_dtype: torch.dtype = torch.bfloat16,
90
- transformer_dtype: torch.dtype = torch.bfloat16,
91
- vae_dtype: torch.dtype = torch.bfloat16,
92
- revision: Optional[str] = None,
93
- cache_dir: Optional[str] = None,
94
- condition_model_processors: List[ProcessorMixin] = None,
95
- latent_model_processors: List[ProcessorMixin] = None,
96
- **kwargs,
97
- ) -> None:
98
- super().__init__(
99
- pretrained_model_name_or_path=pretrained_model_name_or_path,
100
- tokenizer_id=tokenizer_id,
101
- text_encoder_id=text_encoder_id,
102
- transformer_id=transformer_id,
103
- vae_id=vae_id,
104
- text_encoder_dtype=text_encoder_dtype,
105
- transformer_dtype=transformer_dtype,
106
- vae_dtype=vae_dtype,
107
- revision=revision,
108
- cache_dir=cache_dir,
109
- )
110
-
111
- if condition_model_processors is None:
112
- condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
113
- if latent_model_processors is None:
114
- latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])]
115
-
116
- self.condition_model_processors = condition_model_processors
117
- self.latent_model_processors = latent_model_processors
118
-
119
- @property
120
- def _resolution_dim_keys(self):
121
- return {"latents": (2, 3, 4)}
122
-
123
- def load_condition_models(self) -> Dict[str, torch.nn.Module]:
124
- if self.tokenizer_id is not None:
125
- tokenizer = AutoTokenizer.from_pretrained(
126
- self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
127
- )
128
- else:
129
- tokenizer = AutoTokenizer.from_pretrained(
130
- self.pretrained_model_name_or_path,
131
- subfolder="tokenizer",
132
- revision=self.revision,
133
- cache_dir=self.cache_dir,
134
- )
135
-
136
- if self.text_encoder_id is not None:
137
- text_encoder = AutoModel.from_pretrained(
138
- self.text_encoder_id,
139
- torch_dtype=self.text_encoder_dtype,
140
- revision=self.revision,
141
- cache_dir=self.cache_dir,
142
- )
143
- else:
144
- text_encoder = UMT5EncoderModel.from_pretrained(
145
- self.pretrained_model_name_or_path,
146
- subfolder="text_encoder",
147
- torch_dtype=self.text_encoder_dtype,
148
- revision=self.revision,
149
- cache_dir=self.cache_dir,
150
- )
151
-
152
- return {"tokenizer": tokenizer, "text_encoder": text_encoder}
153
-
154
- def load_latent_models(self) -> Dict[str, torch.nn.Module]:
155
- if self.vae_id is not None:
156
- vae = AutoencoderKLWan.from_pretrained(
157
- self.vae_id,
158
- torch_dtype=self.vae_dtype,
159
- revision=self.revision,
160
- cache_dir=self.cache_dir,
161
- )
162
- else:
163
- vae = AutoencoderKLWan.from_pretrained(
164
- self.pretrained_model_name_or_path,
165
- subfolder="vae",
166
- torch_dtype=self.vae_dtype,
167
- revision=self.revision,
168
- cache_dir=self.cache_dir,
169
- )
170
-
171
- return {"vae": vae}
172
-
173
- def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
174
- if self.transformer_id is not None:
175
- transformer = WanTransformer3DModel.from_pretrained(
176
- self.transformer_id,
177
- torch_dtype=self.transformer_dtype,
178
- revision=self.revision,
179
- cache_dir=self.cache_dir,
180
- )
181
- else:
182
- transformer = WanTransformer3DModel.from_pretrained(
183
- self.pretrained_model_name_or_path,
184
- subfolder="transformer",
185
- torch_dtype=self.transformer_dtype,
186
- revision=self.revision,
187
- cache_dir=self.cache_dir,
188
- )
189
-
190
- scheduler = FlowMatchEulerDiscreteScheduler()
191
-
192
- return {"transformer": transformer, "scheduler": scheduler}
193
-
194
- def load_pipeline(
195
- self,
196
- tokenizer: Optional[AutoTokenizer] = None,
197
- text_encoder: Optional[UMT5EncoderModel] = None,
198
- transformer: Optional[WanTransformer3DModel] = None,
199
- vae: Optional[AutoencoderKLWan] = None,
200
- scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
201
- enable_slicing: bool = False,
202
- enable_tiling: bool = False,
203
- enable_model_cpu_offload: bool = False,
204
- training: bool = False,
205
- **kwargs,
206
- ) -> WanPipeline:
207
- components = {
208
- "tokenizer": tokenizer,
209
- "text_encoder": text_encoder,
210
- "transformer": transformer,
211
- "vae": vae,
212
- "scheduler": scheduler,
213
- }
214
- components = get_non_null_items(components)
215
-
216
- pipe = WanPipeline.from_pretrained(
217
- self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
218
- )
219
- pipe.text_encoder.to(self.text_encoder_dtype)
220
- pipe.vae.to(self.vae_dtype)
221
-
222
- if not training:
223
- pipe.transformer.to(self.transformer_dtype)
224
-
225
- # TODO(aryan): add support in diffusers
226
- # if enable_slicing:
227
- # pipe.vae.enable_slicing()
228
- # if enable_tiling:
229
- # pipe.vae.enable_tiling()
230
- if enable_model_cpu_offload:
231
- pipe.enable_model_cpu_offload()
232
-
233
- return pipe
234
-
235
- @torch.no_grad()
236
- def prepare_conditions(
237
- self,
238
- tokenizer: AutoTokenizer,
239
- text_encoder: UMT5EncoderModel,
240
- caption: str,
241
- max_sequence_length: int = 512,
242
- **kwargs,
243
- ) -> Dict[str, Any]:
244
- conditions = {
245
- "tokenizer": tokenizer,
246
- "text_encoder": text_encoder,
247
- "caption": caption,
248
- "max_sequence_length": max_sequence_length,
249
- **kwargs,
250
- }
251
- input_keys = set(conditions.keys())
252
- conditions = super().prepare_conditions(**conditions)
253
- conditions = {k: v for k, v in conditions.items() if k not in input_keys}
254
- conditions.pop("prompt_attention_mask", None)
255
- return conditions
256
-
257
- @torch.no_grad()
258
- def prepare_latents(
259
- self,
260
- vae: AutoencoderKLWan,
261
- image: Optional[torch.Tensor] = None,
262
- video: Optional[torch.Tensor] = None,
263
- generator: Optional[torch.Generator] = None,
264
- compute_posterior: bool = True,
265
- **kwargs,
266
- ) -> Dict[str, torch.Tensor]:
267
- conditions = {
268
- "vae": vae,
269
- "image": image,
270
- "video": video,
271
- "generator": generator,
272
- # We must force this to False because the latent normalization should be done before
273
- # the posterior is computed. The VAE does not handle this any more:
274
- # https://github.com/huggingface/diffusers/pull/10998
275
- "compute_posterior": False,
276
- **kwargs,
277
- }
278
- input_keys = set(conditions.keys())
279
- conditions = super().prepare_latents(**conditions)
280
- conditions = {k: v for k, v in conditions.items() if k not in input_keys}
281
- return conditions
282
-
283
- def forward(
284
- self,
285
- transformer: WanTransformer3DModel,
286
- condition_model_conditions: Dict[str, torch.Tensor],
287
- latent_model_conditions: Dict[str, torch.Tensor],
288
- sigmas: torch.Tensor,
289
- generator: Optional[torch.Generator] = None,
290
- compute_posterior: bool = True,
291
- **kwargs,
292
- ) -> Tuple[torch.Tensor, ...]:
293
- compute_posterior = False # See explanation in prepare_latents
294
- if compute_posterior:
295
- latents = latent_model_conditions.pop("latents")
296
- else:
297
- latents = latent_model_conditions.pop("latents")
298
- latents_mean = latent_model_conditions.pop("latents_mean")
299
- latents_std = latent_model_conditions.pop("latents_std")
300
-
301
- mu, logvar = torch.chunk(latents, 2, dim=1)
302
- mu = self._normalize_latents(mu, latents_mean, latents_std)
303
- logvar = self._normalize_latents(logvar, latents_mean, latents_std)
304
- latents = torch.cat([mu, logvar], dim=1)
305
-
306
- posterior = DiagonalGaussianDistribution(latents)
307
- latents = posterior.sample(generator=generator)
308
- del posterior
309
-
310
- noise = torch.zeros_like(latents).normal_(generator=generator)
311
- noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
312
- timesteps = (sigmas.flatten() * 1000.0).long()
313
-
314
- latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
315
-
316
- pred = transformer(
317
- **latent_model_conditions,
318
- **condition_model_conditions,
319
- timestep=timesteps,
320
- return_dict=False,
321
- )[0]
322
- target = FF.flow_match_target(noise, latents)
323
-
324
- return pred, target, sigmas
325
-
326
- def validation(
327
- self,
328
- pipeline: WanPipeline,
329
- prompt: str,
330
- image: Optional[Image] = None,
331
- height: Optional[int] = None,
332
- width: Optional[int] = None,
333
- num_frames: Optional[int] = None,
334
- num_inference_steps: int = 50,
335
- generator: Optional[torch.Generator] = None,
336
- **kwargs,
337
- ) -> List[ArtifactType]:
338
- if image is not None:
339
- pipeline = WanImageToVideoPipeline.from_pipe(pipeline)
340
-
341
- generation_kwargs = {
342
- "prompt": prompt,
343
- "image": image,
344
- "height": height,
345
- "width": width,
346
- "num_frames": num_frames,
347
- "num_inference_steps": num_inference_steps,
348
- "generator": generator,
349
- "return_dict": True,
350
- "output_type": "pil",
351
- }
352
- generation_kwargs = get_non_null_items(generation_kwargs)
353
- video = pipeline(**generation_kwargs).frames[0]
354
- return [data.VideoArtifact(value=video)]
355
-
356
- def _save_lora_weights(
357
- self,
358
- directory: str,
359
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
360
- scheduler: Optional[SchedulerType] = None,
361
- *args,
362
- **kwargs,
363
- ) -> None:
364
- # TODO(aryan): this needs refactoring
365
- if transformer_state_dict is not None:
366
- WanPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
367
- if scheduler is not None:
368
- scheduler.save_pretrained(os.path.join(directory, "scheduler"))
369
-
370
- def _save_model(
371
- self,
372
- directory: str,
373
- transformer: WanTransformer3DModel,
374
- transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
375
- scheduler: Optional[SchedulerType] = None,
376
- ) -> None:
377
- # TODO(aryan): this needs refactoring
378
- if transformer_state_dict is not None:
379
- with init_empty_weights():
380
- transformer_copy = WanTransformer3DModel.from_config(transformer.config)
381
- transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
382
- transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
383
- if scheduler is not None:
384
- scheduler.save_pretrained(os.path.join(directory, "scheduler"))
385
-
386
- @staticmethod
387
- def _normalize_latents(
388
- latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
389
- ) -> torch.Tensor:
390
- latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
391
- latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
392
- latents = ((latents.float() - latents_mean) * latents_std).to(latents)
393
- return latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/optimizer.py DELETED
@@ -1,449 +0,0 @@
1
- import functools
2
- import math
3
- from typing import Any, Callable, Dict, List, Optional, Type, Union
4
-
5
- import torch
6
- from torch.distributed.checkpoint.state_dict import (
7
- StateDictOptions,
8
- get_optimizer_state_dict,
9
- set_optimizer_state_dict,
10
- )
11
- from torch.distributed.checkpoint.stateful import Stateful
12
-
13
- from .parallel import ParallelBackendEnum
14
- from .utils.import_utils import is_bitsandbytes_available
15
-
16
-
17
- class OptimizerWrapper(Stateful):
18
- r"""
19
- Optimizer wrapper that:
20
- - allows step/zero_grad on multiple optimizers needed for virtual pipeline stages
21
- - saves/loading optimizer state_dict at checkpoint
22
- """
23
-
24
- def __init__(
25
- self,
26
- model_parts: List[torch.nn.Module],
27
- optimizer_cls: Type[torch.optim.Optimizer],
28
- optimizer_kwargs: Dict[str, Any],
29
- ) -> None:
30
- self.optimizer_cls = optimizer_cls
31
- self.optimizer_kwargs = optimizer_kwargs
32
-
33
- self.optimizers = []
34
- self.model_parts = model_parts
35
-
36
- for model in self.model_parts:
37
- optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs)
38
- self.optimizers.append(optimizer)
39
-
40
- def step(self) -> None:
41
- for optimizer in self.optimizers:
42
- optimizer.step()
43
-
44
- def zero_grad(self) -> None:
45
- for optimizer in self.optimizers:
46
- optimizer.zero_grad()
47
-
48
- def state_dict(self) -> Dict[str, Any]:
49
- func = functools.partial(
50
- get_optimizer_state_dict,
51
- options=StateDictOptions(flatten_optimizer_state_dict=True),
52
- )
53
- return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()}
54
-
55
- def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
56
- func = functools.partial(
57
- set_optimizer_state_dict,
58
- optim_state_dict=state_dict,
59
- options=StateDictOptions(flatten_optimizer_state_dict=True),
60
- )
61
- list(map(func, self.model_parts, self.optimizers))
62
-
63
-
64
- class SchedulerWrapper:
65
- def __init__(
66
- self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int
67
- ) -> None:
68
- self.schedulers = []
69
- for optimizer in optimizers:
70
- self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch))
71
-
72
- def step(self) -> None:
73
- for scheduler in self.schedulers:
74
- scheduler.step()
75
-
76
- def get_last_lr(self) -> List[float]:
77
- # TODO(aryan): look into this later. Currently calling it leads to NCCL hang?????
78
- return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)}
79
-
80
- def get_lr_scheduler_state(self) -> Dict[str, Any]:
81
- state_dict = {}
82
- if len(self.schedulers) == 1:
83
- state_dict["lr_scheduler"] = self.schedulers[0]
84
- else:
85
- # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
86
- # It should only support saving and loading a distributed checkpoint with the same number of pp ranks
87
- for idx, lr_scheduler in enumerate(self.schedulers):
88
- state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
89
- return state_dict
90
-
91
-
92
- def get_optimizer(
93
- parallel_backend: ParallelBackendEnum,
94
- name: str,
95
- model_parts: List[torch.nn.Module],
96
- learning_rate: float = 1e-3,
97
- beta1: float = 0.9,
98
- beta2: float = 0.95,
99
- beta3: float = 0.999,
100
- epsilon: float = 1e-8,
101
- weight_decay: float = 1e-4,
102
- fused: bool = False,
103
- ) -> Union[torch.optim.Optimizer, OptimizerWrapper]:
104
- name = name.lower()
105
-
106
- _raise_errors_if_packages_not_available(name)
107
-
108
- if name == "adam":
109
- optimizer_cls = torch.optim.Adam
110
- optimizer_kwargs = {
111
- "lr": learning_rate,
112
- "betas": (beta1, beta2),
113
- "eps": epsilon,
114
- "weight_decay": weight_decay,
115
- "fused": fused,
116
- }
117
- elif name == "adamw":
118
- optimizer_cls = torch.optim.AdamW
119
- optimizer_kwargs = {
120
- "lr": learning_rate,
121
- "betas": (beta1, beta2),
122
- "eps": epsilon,
123
- "weight_decay": weight_decay,
124
- "fused": fused,
125
- }
126
- elif name == "adam-bnb":
127
- from bitsandbytes.optim import Adam
128
-
129
- optimizer_cls = Adam
130
- optimizer_kwargs = {
131
- "lr": learning_rate,
132
- "betas": (beta1, beta2),
133
- "eps": epsilon,
134
- "weight_decay": weight_decay,
135
- }
136
- elif name == "adamw-bnb":
137
- from bitsandbytes.optim import AdamW
138
-
139
- optimizer_cls = AdamW
140
- optimizer_kwargs = {
141
- "lr": learning_rate,
142
- "betas": (beta1, beta2),
143
- "eps": epsilon,
144
- "weight_decay": weight_decay,
145
- }
146
- elif name == "adam-bnb-8bit":
147
- from bitsandbytes.optim import Adam8bit
148
-
149
- optimizer_cls = Adam8bit
150
- optimizer_kwargs = {
151
- "lr": learning_rate,
152
- "betas": (beta1, beta2),
153
- "eps": epsilon,
154
- "weight_decay": weight_decay,
155
- }
156
- elif name == "adamw-bnb-8bit":
157
- from bitsandbytes.optim import AdamW8bit
158
-
159
- optimizer_cls = AdamW8bit
160
- optimizer_kwargs = {
161
- "lr": learning_rate,
162
- "betas": (beta1, beta2),
163
- "eps": epsilon,
164
- "weight_decay": weight_decay,
165
- }
166
-
167
- # TODO(aryan): handle bitsandbytes and torchao
168
- else:
169
- raise ValueError(f"Unsupported optimizer: {name}")
170
-
171
- if parallel_backend == ParallelBackendEnum.ACCELERATE:
172
- return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs)
173
- elif parallel_backend == ParallelBackendEnum.PTD:
174
- return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs)
175
-
176
-
177
- def get_optimizer_accelerate(
178
- model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any]
179
- ) -> torch.optim.Optimizer:
180
- params = [param for model in model_parts for param in model.parameters() if param.requires_grad]
181
- optimizer = optimizer_cls(params, **optimizer_kwargs)
182
- return optimizer
183
-
184
-
185
- def get_optimizer_ptd(
186
- model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any]
187
- ) -> OptimizerWrapper:
188
- return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs)
189
-
190
-
191
- def get_lr_scheduler(
192
- parallel_backend: ParallelBackendEnum,
193
- name: str,
194
- optimizer: Union[torch.optim.Optimizer, OptimizerWrapper],
195
- step_rules: Optional[str] = None,
196
- num_warmup_steps: Optional[int] = None,
197
- num_training_steps: Optional[int] = None,
198
- num_cycles: int = 1,
199
- power: float = 1.0,
200
- lr_init: float = 1e-3,
201
- lr_end: float = 1e-7,
202
- last_epoch: int = -1,
203
- ) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]:
204
- name = name.lower()
205
- if name == "constant":
206
- scheduler_lambda_fn = get_constant_schedule()
207
- elif name == "constant_with_warmup":
208
- scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps)
209
- elif name == "piecewise_constant":
210
- scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules)
211
- elif name == "linear":
212
- scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps)
213
- elif name == "cosine":
214
- scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles)
215
- elif name == "cosine_with_restarts":
216
- scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup(
217
- num_warmup_steps, num_training_steps, num_cycles
218
- )
219
- elif name == "polynomial":
220
- scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup(
221
- num_warmup_steps, num_training_steps, lr_init, lr_end, power
222
- )
223
- else:
224
- raise ValueError(f"Unsupported scheduler: {name}")
225
-
226
- if parallel_backend == ParallelBackendEnum.ACCELERATE:
227
- return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch)
228
- elif parallel_backend == ParallelBackendEnum.PTD:
229
- return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch)
230
-
231
-
232
- def get_lr_scheduler_accelerate(
233
- optimizer: torch.optim.Optimizer,
234
- scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler],
235
- last_epoch: int = -1,
236
- ) -> torch.optim.lr_scheduler.LambdaLR:
237
- scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)
238
- return scheduler
239
-
240
-
241
- def get_lr_scheduler_ptd(
242
- optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1
243
- ) -> SchedulerWrapper:
244
- return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch)
245
-
246
-
247
- # ==============================
248
- # Adapted from https://github.com/huggingface/diffusers/blob/196aef5a6f76e1ad6ba889184860c3633d166910/src/diffusers/optimization.py
249
- # ==============================
250
-
251
-
252
- def get_constant_schedule() -> Callable[[int], float]:
253
- r"""
254
- Create a schedule with a constant learning rate, using the learning rate set in optimizer.
255
- """
256
-
257
- def lr_lambda(current_step: int):
258
- return 1.0
259
-
260
- return lr_lambda
261
-
262
-
263
- def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]:
264
- r"""
265
- Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
266
- increases linearly between 0 and the initial lr set in the optimizer.
267
-
268
- Args:
269
- num_warmup_steps (`int`):
270
- The number of steps for the warmup phase.
271
- """
272
-
273
- def lr_lambda(current_step: int):
274
- if current_step < num_warmup_steps:
275
- return float(current_step) / float(max(1.0, num_warmup_steps))
276
- return 1.0
277
-
278
- return lr_lambda
279
-
280
-
281
- def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]:
282
- r"""
283
- Create a schedule with a constant learning rate, using the learning rate set in optimizer.
284
-
285
- Args:
286
- step_rules (`string`):
287
- The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
288
- if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
289
- steps and multiple 0.005 for the other steps.
290
- """
291
-
292
- rules_dict = {}
293
- rule_list = step_rules.split(",")
294
- for rule_str in rule_list[:-1]:
295
- value_str, steps_str = rule_str.split(":")
296
- steps = int(steps_str)
297
- value = float(value_str)
298
- rules_dict[steps] = value
299
- last_lr_multiple = float(rule_list[-1])
300
-
301
- def create_rules_function(rules_dict, last_lr_multiple):
302
- def rule_func(steps: int) -> float:
303
- sorted_steps = sorted(rules_dict.keys())
304
- for i, sorted_step in enumerate(sorted_steps):
305
- if steps < sorted_step:
306
- return rules_dict[sorted_steps[i]]
307
- return last_lr_multiple
308
-
309
- return rule_func
310
-
311
- rules_func = create_rules_function(rules_dict, last_lr_multiple)
312
- return rules_func
313
-
314
-
315
- def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]:
316
- r"""
317
- Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
318
- a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
319
-
320
- Args:
321
- num_warmup_steps (`int`):
322
- The number of steps for the warmup phase.
323
- num_training_steps (`int`):
324
- The total number of training steps.
325
- """
326
-
327
- def lr_lambda(current_step: int):
328
- if current_step < num_warmup_steps:
329
- return float(current_step) / float(max(1, num_warmup_steps))
330
- return max(
331
- 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
332
- )
333
-
334
- return lr_lambda
335
-
336
-
337
- def get_cosine_schedule_with_warmup(
338
- num_warmup_steps: int,
339
- num_training_steps: int,
340
- num_cycles: float = 0.5,
341
- ) -> Callable[[int], float]:
342
- r"""
343
- Create a schedule with a learning rate that decreases following the values of the cosine function between the
344
- initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
345
- initial lr set in the optimizer.
346
-
347
- Args:
348
- num_warmup_steps (`int`):
349
- The number of steps for the warmup phase.
350
- num_training_steps (`int`):
351
- The total number of training steps.
352
- num_periods (`float`, *optional*, defaults to 0.5):
353
- The number of periods of the cosine function in a schedule (the default is to just decrease from the max
354
- value to 0 following a half-cosine).
355
- """
356
-
357
- def lr_lambda(current_step):
358
- if current_step < num_warmup_steps:
359
- return float(current_step) / float(max(1, num_warmup_steps))
360
- progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
361
- return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
362
-
363
- return lr_lambda
364
-
365
-
366
- def get_cosine_with_hard_restarts_schedule_with_warmup(
367
- num_warmup_steps: int,
368
- num_training_steps: int,
369
- num_cycles: int = 1,
370
- ) -> Callable[[int], float]:
371
- r"""
372
- Create a schedule with a learning rate that decreases following the values of the cosine function between the
373
- initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
374
- linearly between 0 and the initial lr set in the optimizer.
375
-
376
- Args:
377
- num_warmup_steps (`int`):
378
- The number of steps for the warmup phase.
379
- num_training_steps (`int`):
380
- The total number of training steps.
381
- num_cycles (`int`, *optional*, defaults to 1):
382
- The number of hard restarts to use.
383
- """
384
-
385
- def lr_lambda(current_step):
386
- if current_step < num_warmup_steps:
387
- return float(current_step) / float(max(1, num_warmup_steps))
388
- progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
389
- if progress >= 1.0:
390
- return 0.0
391
- return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
392
-
393
- return lr_lambda
394
-
395
-
396
- def get_polynomial_decay_schedule_with_warmup(
397
- num_warmup_steps: int,
398
- num_training_steps: int,
399
- lr_init: float,
400
- lr_end: float = 1e-7,
401
- power: float = 1.0,
402
- ) -> Callable[[int], float]:
403
- r"""
404
- Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
405
- optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
406
- initial lr set in the optimizer.
407
-
408
- Args:
409
- num_warmup_steps (`int`):
410
- The number of steps for the warmup phase.
411
- num_training_steps (`int`):
412
- The total number of training steps.
413
- lr_end (`float`, *optional*, defaults to 1e-7):
414
- The end LR.
415
- power (`float`, *optional*, defaults to 1.0):
416
- Power factor.
417
-
418
- Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at
419
- https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
420
- """
421
-
422
- if not (lr_init > lr_end):
423
- raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
424
-
425
- def lr_lambda(current_step: int):
426
- if current_step < num_warmup_steps:
427
- return float(current_step) / float(max(1, num_warmup_steps))
428
- elif current_step > num_training_steps:
429
- return lr_end / lr_init # as LambdaLR multiplies by lr_init
430
- else:
431
- lr_range = lr_init - lr_end
432
- decay_steps = num_training_steps - num_warmup_steps
433
- pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
434
- decay = lr_range * pct_remaining**power + lr_end
435
- return decay / lr_init # as LambdaLR multiplies by lr_init
436
-
437
- return lr_lambda
438
-
439
-
440
- def _raise_errors_if_packages_not_available(name: str) -> None:
441
- name_split = name.split("-")
442
- if len(name_split) < 2:
443
- return
444
- package_name = name_split[1]
445
- if package_name == "bnb":
446
- if not is_bitsandbytes_available():
447
- raise ImportError(
448
- f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer."
449
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/parallel/__init__.py DELETED
@@ -1,22 +0,0 @@
1
- from enum import Enum
2
- from typing import Union
3
-
4
- from .accelerate import AccelerateParallelBackend
5
- from .ptd import PytorchDTensorParallelBackend
6
- from .utils import apply_ddp_ptd, apply_fsdp2_ptd, dist_max, dist_mean
7
-
8
-
9
- ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend]
10
-
11
-
12
- class ParallelBackendEnum(str, Enum):
13
- ACCELERATE = "accelerate"
14
- PTD = "ptd"
15
-
16
-
17
- def get_parallel_backend_cls(backend: ParallelBackendEnum) -> ParallelBackendType:
18
- if backend == ParallelBackendEnum.ACCELERATE:
19
- return AccelerateParallelBackend
20
- if backend == ParallelBackendEnum.PTD:
21
- return PytorchDTensorParallelBackend
22
- raise ValueError(f"Unknown parallel backend: {backend}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/parallel/accelerate.py DELETED
@@ -1,218 +0,0 @@
1
- import datetime
2
- import pathlib
3
- from typing import Optional
4
-
5
- import torch
6
- from diffusers.utils import is_accelerate_available
7
-
8
- from ..logging import get_logger
9
- from ..utils import get_device_info
10
- from .base import BaseParallelBackend
11
- from .utils import apply_ddp_accelerate
12
-
13
-
14
- if not is_accelerate_available():
15
- raise ImportError(
16
- "Please install the accelerate package using `pip install accelerate` to use the AccelerateParallelBackend."
17
- )
18
-
19
- from accelerate import Accelerator
20
- from accelerate.data_loader import DataLoader
21
- from accelerate.utils import (
22
- DataLoaderConfiguration,
23
- DistributedDataParallelKwargs,
24
- InitProcessGroupKwargs,
25
- ProjectConfiguration,
26
- )
27
-
28
-
29
- logger = get_logger()
30
- _device_type, _device_module = get_device_info()
31
-
32
-
33
- class AccelerateParallelBackend(BaseParallelBackend):
34
- def __init__(
35
- self,
36
- world_size: int,
37
- pp_degree: int = 1,
38
- dp_degree: int = 1,
39
- dp_shards: int = -1,
40
- cp_degree: int = 1,
41
- tp_degree: int = 1,
42
- backend: str = "nccl",
43
- timeout: int = 180,
44
- logging_dir: Optional[str] = None,
45
- output_dir: Optional[str] = None,
46
- gradient_accumulation_steps: Optional[int] = None,
47
- ) -> None:
48
- super().__init__()
49
-
50
- self._world_size = world_size
51
- self._pp_degree = pp_degree
52
- self._dp_degree = dp_degree
53
- self._dp_shards = dp_shards
54
- self._cp_degree = cp_degree
55
- self._tp_degree = tp_degree
56
- self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None
57
- self._logging_dir = (
58
- self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None
59
- )
60
- self._backend = backend
61
- self._timeout = timeout
62
- self._gradient_accumulation_steps = gradient_accumulation_steps
63
-
64
- if pp_degree > 1 or dp_shards > 1 or cp_degree > 1 or tp_degree > 1:
65
- raise ValueError(
66
- "AccelerateParallelBackend does not support anything but Distributed Data Parallelism at the moment."
67
- )
68
- if dp_degree != world_size:
69
- raise ValueError("Data parallel degree must be equal to world size.")
70
-
71
- self._accelerator: Accelerator = None
72
- self._mesh: torch.distributed.DeviceMesh = None
73
-
74
- def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
75
- project_config = None
76
- ddp_kwargs = None
77
- init_process_group_kwargs = None
78
- if self._accelerator is None:
79
- project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir)
80
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
81
- dataloader_config = DataLoaderConfiguration(
82
- split_batches=False, dispatch_batches=False, use_stateful_dataloader=True
83
- )
84
- init_process_group_kwargs = InitProcessGroupKwargs(
85
- backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)
86
- )
87
- self._accelerator, model = apply_ddp_accelerate(
88
- model,
89
- project_config,
90
- ddp_kwargs,
91
- init_process_group_kwargs,
92
- dataloader_config,
93
- self._gradient_accumulation_steps,
94
- accelerator=self._accelerator,
95
- )
96
- logger.debug("Applied AccelerateParallel::apply_ddp to model.")
97
- return model
98
-
99
- def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
100
- logger.debug("AccelerateParallelBackend::prepare_dataset completed!")
101
- return dataset
102
-
103
- def prepare_dataloader(
104
- self,
105
- dataset: torch.utils.data.IterableDataset,
106
- batch_size: int = 1,
107
- num_workers: int = 0,
108
- pin_memory: bool = False,
109
- ) -> DataLoader:
110
- dataloader = torch.utils.data.DataLoader(
111
- dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory
112
- )
113
- dataloader = self._accelerator.prepare_data_loader(dataloader)
114
- logger.debug("AccelerateParallelBackend::prepare_dataloader completed!")
115
- return dataloader
116
-
117
- def prepare_optimizer(self, optimizer, lr_scheduler):
118
- optimizer = self._accelerator.prepare_optimizer(optimizer)
119
- lr_scheduler = self._accelerator.prepare_scheduler(lr_scheduler)
120
- return optimizer, lr_scheduler
121
-
122
- def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
123
- def _get_mesh():
124
- if name is None:
125
- return self._mesh
126
- try:
127
- return self._mesh[name]
128
- except (KeyError, RuntimeError):
129
- return self._mesh
130
-
131
- if self._mesh is not None:
132
- return _get_mesh()
133
-
134
- mesh_list = [("dp_replicate", self._dp_degree), ("dp_shard", self._dp_shards)]
135
- mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1]
136
- names = [x[0] for x in mesh_list]
137
- degrees = [x[1] for x in mesh_list]
138
- mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names)
139
-
140
- dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], []
141
-
142
- if self.data_replication_enabled:
143
- dp_mesh_names.append("dp_replicate")
144
- dp_cp_mesh_names.append("dp_replicate")
145
- if self.data_sharding_enabled:
146
- dp_mesh_names.append("dp_shard")
147
- dp_cp_mesh_names.append("dp_shard")
148
- dp_shard_cp_mesh_names.append("dp_shard")
149
- if self.context_parallel_enabled:
150
- dp_cp_mesh_names.append("cp")
151
- dp_shard_cp_mesh_names.append("cp")
152
-
153
- if len(dp_mesh_names) > 0:
154
- mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp")
155
- if len(dp_cp_mesh_names) > 0:
156
- mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp")
157
- if len(dp_shard_cp_mesh_names) > 0:
158
- mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp")
159
-
160
- logger.debug(f"Device mesh: {mesh}")
161
- self._mesh = mesh
162
- return _get_mesh()
163
-
164
- @property
165
- def world_size(self):
166
- return self._accelerator.num_processes
167
-
168
- @property
169
- def rank(self):
170
- return self._accelerator.process_index
171
-
172
- @property
173
- def local_rank(self):
174
- return self._accelerator.local_process_index
175
-
176
- @property
177
- def is_main_process(self):
178
- r"""Returns `True` if the current process is the main process on the master node."""
179
- return self._accelerator.is_main_process
180
-
181
- @property
182
- def is_local_main_process(self):
183
- r"""Returns `True` if the current process is the main process on local node."""
184
- return self._accelerator.is_local_main_process
185
-
186
- @property
187
- def device(self):
188
- return self._accelerator.device
189
-
190
- def wait_for_everyone(self):
191
- self._accelerator.wait_for_everyone()
192
-
193
- def destroy(self):
194
- self._accelerator.end_training()
195
-
196
- @property
197
- def pipeline_parallel_enabled(self):
198
- return self._pp_degree > 1
199
-
200
- @property
201
- def data_parallel_enabled(self):
202
- return self._dp_degree > 1 or self._dp_shards > 1
203
-
204
- @property
205
- def data_replication_enabled(self):
206
- return self._dp_degree > 1
207
-
208
- @property
209
- def data_sharding_enabled(self):
210
- return self._dp_shards > 1
211
-
212
- @property
213
- def context_parallel_enabled(self):
214
- return self._cp_degree > 1
215
-
216
- @property
217
- def tensor_parallel_enabled(self):
218
- return self._tp_degree > 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/parallel/base.py DELETED
@@ -1,96 +0,0 @@
1
- from contextlib import contextmanager
2
- from typing import Any, Dict, List, Optional
3
-
4
- import torch
5
-
6
- from ..trackers import TrackerType, initialize_trackers
7
-
8
-
9
- class BaseParallelBackend:
10
- r"""
11
- Base class that contains properties and methods that should be implemented by different parallel backends.
12
- """
13
-
14
- def apply_ddp(self, *args, **kwargs) -> torch.nn.Module:
15
- raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.")
16
-
17
- def prepare_dataset(self, *args, **kwargs) -> Any:
18
- raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.")
19
-
20
- def prepare_dataloader(self, *args, **kwargs) -> Any:
21
- raise NotImplementedError("Method `prepare_dataloader` must be implemented by subclass.")
22
-
23
- def prepare_optimizer(self, *args, **kwargs) -> Any:
24
- raise NotImplementedError("Method `prepare_optimizer` must be implemented by subclass.")
25
-
26
- def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
27
- raise NotImplementedError("Method `get_mesh` must be implemented by subclass.")
28
-
29
- def initialize_trackers(
30
- self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
31
- ) -> TrackerType:
32
- self.tracker = None
33
- if self.is_main_process:
34
- self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir)
35
-
36
- def log(self, metrics: Dict[str, Any], step: int) -> None:
37
- if self.is_main_process:
38
- self.tracker.log(metrics, step)
39
-
40
- def wait_for_everyone(self):
41
- raise NotImplementedError("Method `wait_for_everyone` must be implemented by subclass.")
42
-
43
- @contextmanager
44
- def main_process_first(self):
45
- raise NotImplementedError("Method `main_process_first` must be implemented by subclass.")
46
-
47
- def destroy(self):
48
- raise NotImplementedError("Method `destroy` must be implemented by subclass.")
49
-
50
- @property
51
- def world_size(self):
52
- raise NotImplementedError("Method `world_size` must be implemented by subclass.")
53
-
54
- @property
55
- def rank(self):
56
- raise NotImplementedError("Method `rank` must be implemented by subclass.")
57
-
58
- @property
59
- def local_rank(self):
60
- raise NotImplementedError("Method `local_rank` must be implemented by subclass.")
61
-
62
- @property
63
- def is_main_process(self):
64
- raise NotImplementedError("Method `is_main_process` must be implemented by subclass.")
65
-
66
- @property
67
- def is_local_main_process(self):
68
- raise NotImplementedError("Method `is_local_main_process` must be implemented by subclass.")
69
-
70
- @property
71
- def device(self):
72
- raise NotImplementedError("Method `device` must be implemented by subclass.")
73
-
74
- @property
75
- def pipeline_parallel_enabled(self):
76
- raise NotImplementedError("Property `pipeline_parallel_enabled` must be implemented by subclass.")
77
-
78
- @property
79
- def data_parallel_enabled(self):
80
- raise NotImplementedError("Property `data_parallel_enabled` must be implemented by subclass.")
81
-
82
- @property
83
- def data_replication_enabled(self):
84
- raise NotImplementedError("Property `data_replication_enabled` must be implemented by subclass.")
85
-
86
- @property
87
- def data_sharding_enabled(self):
88
- raise NotImplementedError("Property `data_sharding_enabled` must be implemented by subclass.")
89
-
90
- @property
91
- def context_parallel_enabled(self):
92
- raise NotImplementedError("Property `context_parallel_enabled` must be implemented by subclass.")
93
-
94
- @property
95
- def tensor_parallel_enabled(self):
96
- raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/parallel/deepspeed.py DELETED
@@ -1,7 +0,0 @@
1
- from .base import BaseParallelBackend
2
-
3
-
4
- class DeepspeedParallelBackend(BaseParallelBackend):
5
- def __init__(self):
6
- # TODO(aryan)
7
- raise NotImplementedError("DeepspeedParallelBackend is not implemented yet.")
 
 
 
 
 
 
 
 
finetrainers/parallel/ptd.py DELETED
@@ -1,228 +0,0 @@
1
- import datetime
2
- import os
3
- import pathlib
4
- from typing import Optional
5
-
6
- import datasets.distributed
7
- import torch
8
-
9
- from ..data import DPDataLoader
10
- from ..logging import get_logger
11
- from ..utils import get_device_info
12
- from .base import BaseParallelBackend
13
- from .utils import apply_ddp_ptd
14
-
15
-
16
- _device_type, _device_module = get_device_info()
17
- logger = get_logger()
18
-
19
-
20
- class PytorchDTensorParallelBackend(BaseParallelBackend):
21
- def __init__(
22
- self,
23
- world_size: int,
24
- pp_degree: int = 1,
25
- dp_degree: int = 1,
26
- dp_shards: int = -1,
27
- cp_degree: int = 1,
28
- tp_degree: int = 1,
29
- backend: str = "nccl",
30
- timeout: int = 180,
31
- logging_dir: Optional[str] = None,
32
- output_dir: Optional[str] = None,
33
- gradient_accumulation_steps: Optional[int] = None,
34
- ) -> None:
35
- super().__init__()
36
-
37
- self._world_size = world_size
38
- self._pp_degree = pp_degree
39
- self._dp_degree = dp_degree
40
- self._dp_shards = dp_shards
41
- self._cp_degree = cp_degree
42
- self._tp_degree = tp_degree
43
- self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None
44
- self._logging_dir = (
45
- self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None
46
- )
47
- self._backend = backend
48
- self._timeout = timeout
49
-
50
- for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]:
51
- if degree < 1:
52
- raise ValueError(f"Parallel degree must be at least 1, got {degree}.")
53
-
54
- if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size:
55
- raise ValueError(
56
- f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards."
57
- )
58
-
59
- torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout))
60
- _device_module.set_device(self.local_rank)
61
-
62
- logger.info(
63
- f"Initialized parallel state with:\n"
64
- f" - World size: {world_size}\n"
65
- f" - Pipeline parallel degree: {pp_degree}\n"
66
- f" - Data parallel degree: {dp_degree}\n"
67
- f" - Context parallel degree: {cp_degree}\n"
68
- f" - Tensor parallel degree: {tp_degree}\n"
69
- f" - Data parallel shards: {dp_shards}\n"
70
- )
71
-
72
- self._mesh: torch.distributed.DeviceMesh = None
73
-
74
- def apply_ddp(
75
- self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None
76
- ) -> torch.nn.Module:
77
- if device_mesh is None:
78
- device_mesh = self.get_mesh()
79
- apply_ddp_ptd(model, device_mesh)
80
- logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.")
81
- return model
82
-
83
- def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
84
- dp_mesh = self.get_mesh("dp_replicate")
85
- if dp_mesh is None:
86
- dp_mesh = self.get_mesh()
87
- if self.world_size > 1:
88
- dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size()
89
- else:
90
- dp_local_rank, dp_world_size = 0, 1
91
- dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size)
92
- logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!")
93
- return dataset
94
-
95
- def prepare_dataloader(
96
- self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool
97
- ) -> DPDataLoader:
98
- dp_mesh = self.get_mesh("dp_replicate")
99
- if dp_mesh is None:
100
- dp_mesh = self.get_mesh()
101
- if self.world_size > 1:
102
- dp_local_rank = dp_mesh.get_local_rank()
103
- else:
104
- dp_local_rank = 0
105
- dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers)
106
- logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!")
107
- return dataloader
108
-
109
- def prepare_optimizer(self, optimizer, lr_scheduler):
110
- logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!")
111
- return optimizer, lr_scheduler
112
-
113
- def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
114
- def _get_mesh():
115
- if name is None:
116
- return self._mesh
117
- try:
118
- return self._mesh[name]
119
- except (KeyError, RuntimeError):
120
- if self._mesh.ndim == 0:
121
- return None
122
- return self._mesh
123
-
124
- if self._mesh is not None:
125
- return _get_mesh()
126
-
127
- mesh_list = [
128
- ("pp", self._pp_degree),
129
- ("dp_replicate", self._dp_degree),
130
- ("dp_shard", self._dp_shards),
131
- ("cp", self._cp_degree),
132
- ("tp", self._tp_degree),
133
- ]
134
- mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1]
135
- names = [x[0] for x in mesh_list]
136
- degrees = [x[1] for x in mesh_list]
137
- mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names)
138
-
139
- dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], []
140
-
141
- if self.data_replication_enabled:
142
- dp_mesh_names.append("dp_replicate")
143
- dp_cp_mesh_names.append("dp_replicate")
144
- if self.data_sharding_enabled:
145
- dp_mesh_names.append("dp_shard")
146
- dp_cp_mesh_names.append("dp_shard")
147
- dp_shard_cp_mesh_names.append("dp_shard")
148
- if self.context_parallel_enabled:
149
- dp_cp_mesh_names.append("cp")
150
- dp_shard_cp_mesh_names.append("cp")
151
-
152
- if len(dp_mesh_names) > 0:
153
- mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp")
154
- if len(dp_cp_mesh_names) > 0:
155
- mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp")
156
- if len(dp_shard_cp_mesh_names) > 0:
157
- mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp")
158
-
159
- logger.debug(f"Device mesh: {mesh}")
160
- self._mesh = mesh
161
- return _get_mesh()
162
-
163
- @property
164
- def world_size(self):
165
- return torch.distributed.get_world_size()
166
-
167
- @property
168
- def rank(self):
169
- return torch.distributed.get_rank()
170
-
171
- @property
172
- def local_rank(self):
173
- return int(os.environ.get("LOCAL_RANK", 0))
174
-
175
- @property
176
- def is_main_process(self):
177
- r"""Returns `True` if the current process is the main process on the master node."""
178
- return self.rank == 0
179
-
180
- @property
181
- def is_local_main_process(self):
182
- r"""Returns `True` if the current process is the main process on local node."""
183
- return self.local_rank == 0
184
-
185
- @property
186
- def device(self):
187
- return torch.device(_device_type, self.local_rank)
188
-
189
- def wait_for_everyone(self):
190
- return torch.distributed.barrier()
191
-
192
- # @contextmanager
193
- # def main_process_first(self):
194
- # if self.is_main_process:
195
- # yield
196
- # self.wait_for_everyone()
197
- # else:
198
- # self.wait_for_everyone()
199
- # yield
200
-
201
- def destroy(self):
202
- if self.is_main_process:
203
- self.tracker.finish()
204
- return torch.distributed.destroy_process_group()
205
-
206
- @property
207
- def pipeline_parallel_enabled(self):
208
- return self._pp_degree > 1
209
-
210
- @property
211
- def data_parallel_enabled(self):
212
- return self._dp_degree > 1 or self._dp_shards > 1
213
-
214
- @property
215
- def data_replication_enabled(self):
216
- return self._dp_degree > 1
217
-
218
- @property
219
- def data_sharding_enabled(self):
220
- return self._dp_shards > 1
221
-
222
- @property
223
- def context_parallel_enabled(self):
224
- return self._cp_degree > 1
225
-
226
- @property
227
- def tensor_parallel_enabled(self):
228
- return self._tp_degree > 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/parallel/utils.py DELETED
@@ -1,99 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.distributed._functional_collectives as funcol
5
- import torch.distributed.tensor
6
- from diffusers.utils import is_accelerate_available
7
- from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
8
- from torch.distributed._composable.replicate import replicate
9
-
10
- from ..utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES
11
-
12
-
13
- if is_accelerate_available():
14
- from accelerate import Accelerator
15
- from accelerate.utils import (
16
- DataLoaderConfiguration,
17
- DistributedDataParallelKwargs,
18
- InitProcessGroupKwargs,
19
- ProjectConfiguration,
20
- )
21
-
22
-
23
- def apply_fsdp2_ptd(
24
- model: torch.nn.Module,
25
- dp_mesh: torch.distributed.device_mesh.DeviceMesh,
26
- param_dtype: torch.dtype,
27
- reduce_dtype: torch.dtype,
28
- output_dtype: torch.dtype,
29
- pp_enabled: bool = False,
30
- cpu_offload: bool = False,
31
- ) -> None:
32
- r"""Apply FSDP2 on a model."""
33
- mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True)
34
- fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
35
-
36
- if cpu_offload:
37
- fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True)
38
-
39
- def apply_fully_shard(blocks):
40
- for layer_index, block in enumerate(blocks):
41
- if pp_enabled:
42
- # For PP, do not reshard after forward to avoid per-microbatch
43
- # all-gathers, which can be expensive and non-overlapped
44
- reshard_after_forward = False
45
- else:
46
- # As an optimization, do not reshard after forward for the last
47
- # transformer block since FSDP would prefetch it immediately
48
- reshard_after_forward = layer_index < len(blocks) - 1
49
- fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward)
50
-
51
- for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
52
- blocks = getattr(model, transformer_block_name, None)
53
- if blocks is not None:
54
- apply_fully_shard(blocks)
55
-
56
- fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
57
-
58
-
59
- def apply_ddp_accelerate(
60
- model: torch.nn.Module,
61
- project_config: Optional[ProjectConfiguration] = None,
62
- ddp_kwargs: Optional[DistributedDataParallelKwargs] = None,
63
- init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None,
64
- dataloader_config: Optional[DataLoaderConfiguration] = None,
65
- gradient_accumulation_steps: Optional[int] = None,
66
- accelerator: Optional[Accelerator] = None,
67
- ) -> torch.nn.Module:
68
- if accelerator is None:
69
- accelerator = Accelerator(
70
- project_config=project_config,
71
- dataloader_config=dataloader_config,
72
- gradient_accumulation_steps=gradient_accumulation_steps,
73
- log_with=None,
74
- kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
75
- )
76
- if torch.backends.mps.is_available():
77
- accelerator.native_amp = False
78
- accelerator.prepare_model(model)
79
- return accelerator, model
80
-
81
-
82
- def apply_ddp_ptd(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
83
- replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
84
-
85
-
86
- def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
87
- if isinstance(x, torch.distributed.tensor.DTensor):
88
- # functional collectives do not support DTensor inputs
89
- x = x.full_tensor()
90
- assert x.numel() == 1 # required by `.item()`
91
- return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
92
-
93
-
94
- def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
95
- return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh)
96
-
97
-
98
- def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
99
- return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/patches/__init__.py DELETED
@@ -1,28 +0,0 @@
1
- from typing import TYPE_CHECKING
2
-
3
-
4
- if TYPE_CHECKING:
5
- from ..args import BaseArgs
6
- from ..parallel import ParallelBackendType
7
-
8
-
9
- def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBackendType") -> None:
10
- # To avoid circular imports
11
- from ..config import ModelType, TrainingType
12
-
13
- if args.model_name == ModelType.LTX_VIDEO:
14
- from .models.ltx_video import patch
15
-
16
- patch.patch_transformer_forward()
17
- if parallel_backend.tensor_parallel_enabled:
18
- patch.patch_apply_rotary_emb_for_tp_compatibility()
19
-
20
- if args.model_name == ModelType.WAN and "transformer" in args.layerwise_upcasting_modules:
21
- from .models.wan import patch
22
-
23
- patch.patch_time_text_image_embedding_forward()
24
-
25
- if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
26
- from .dependencies.peft import patch
27
-
28
- patch.patch_peft_move_adapter_to_device_of_base_layer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/patches/dependencies/peft/patch.py DELETED
@@ -1,25 +0,0 @@
1
- import functools
2
-
3
- from peft.tuners.tuners_utils import BaseTunerLayer
4
-
5
- from ...utils import DisableTensorToDtype
6
-
7
-
8
- def patch_peft_move_adapter_to_device_of_base_layer() -> None:
9
- _perform_patch_move_adapter_to_device_of_base_layer()
10
-
11
-
12
- def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
13
- BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
14
- BaseTunerLayer._move_adapter_to_device_of_base_layer
15
- )
16
-
17
-
18
- def _patched_move_adapter_to_device_of_base_layer(func) -> None:
19
- # TODO(aryan): This is really unsafe probably and may break things. It works for now, but revisit and refactor.
20
- @functools.wraps(func)
21
- def wrapper(self, *args, **kwargs):
22
- with DisableTensorToDtype():
23
- return func(self, *args, **kwargs)
24
-
25
- return wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/patches/models/ltx_video/patch.py DELETED
@@ -1,127 +0,0 @@
1
- from typing import Any, Dict, Optional, Tuple
2
-
3
- import diffusers
4
- import torch
5
- from diffusers import LTXVideoTransformer3DModel
6
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
7
- from diffusers.utils.import_utils import is_torch_version
8
-
9
-
10
- def patch_transformer_forward() -> None:
11
- _perform_ltx_transformer_forward_patch()
12
-
13
-
14
- def patch_apply_rotary_emb_for_tp_compatibility() -> None:
15
- _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch()
16
-
17
-
18
- def _perform_ltx_transformer_forward_patch() -> None:
19
- LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3D_forward
20
-
21
-
22
- def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
23
- def apply_rotary_emb(x, freqs):
24
- cos, sin = freqs
25
- # ======== THIS IS CHANGED FROM THE ORIGINAL IMPLEMENTATION ========
26
- # The change is made due to unsupported DTensor operation aten.ops.unbind
27
- # FIXME: Once aten.ops.unbind support lands, this will no longer be required
28
- # x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
29
- x_real, x_imag = x.unflatten(2, (-1, 2)).chunk(2, dim=-1) # [B, S, H, D // 2]
30
- # ==================================================================
31
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
32
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
33
- return out
34
-
35
- diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
36
-
37
-
38
- def _patched_LTXVideoTransformer3D_forward(
39
- self,
40
- hidden_states: torch.Tensor,
41
- encoder_hidden_states: torch.Tensor,
42
- timestep: torch.LongTensor,
43
- encoder_attention_mask: torch.Tensor,
44
- num_frames: int,
45
- height: int,
46
- width: int,
47
- rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
48
- return_dict: bool = True,
49
- *args,
50
- **kwargs,
51
- ) -> torch.Tensor:
52
- image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
53
-
54
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
55
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
56
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
57
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
58
-
59
- batch_size = hidden_states.size(0)
60
-
61
- # ===== This is modified compared to Diffusers =====
62
- # This is done because the Diffusers pipeline will pass in a 1D tensor for timestep
63
- if timestep.ndim == 1:
64
- timestep = timestep.view(-1, 1, 1).expand(-1, *hidden_states.shape[1:-1], -1)
65
- # ==================================================
66
-
67
- temb, embedded_timestep = self.time_embed(
68
- timestep.flatten(),
69
- batch_size=batch_size,
70
- hidden_dtype=hidden_states.dtype,
71
- )
72
-
73
- # ===== This is modified compared to Diffusers =====
74
- # temb = temb.view(batch_size, -1, temb.size(-1))
75
- # embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
76
- # ==================================================
77
- # This is done to make it possible to use per-token timestep embedding
78
- temb = temb.view(batch_size, *hidden_states.shape[1:-1], temb.size(-1))
79
- embedded_timestep = embedded_timestep.view(batch_size, *hidden_states.shape[1:-1], embedded_timestep.size(-1))
80
- # ==================================================
81
-
82
- hidden_states = self.proj_in(hidden_states)
83
-
84
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
85
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
86
-
87
- for block in self.transformer_blocks:
88
- if torch.is_grad_enabled() and self.gradient_checkpointing:
89
-
90
- def create_custom_forward(module, return_dict=None):
91
- def custom_forward(*inputs):
92
- if return_dict is not None:
93
- return module(*inputs, return_dict=return_dict)
94
- else:
95
- return module(*inputs)
96
-
97
- return custom_forward
98
-
99
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
100
- hidden_states = torch.utils.checkpoint.checkpoint(
101
- create_custom_forward(block),
102
- hidden_states,
103
- encoder_hidden_states,
104
- temb,
105
- image_rotary_emb,
106
- encoder_attention_mask,
107
- **ckpt_kwargs,
108
- )
109
- else:
110
- hidden_states = block(
111
- hidden_states=hidden_states,
112
- encoder_hidden_states=encoder_hidden_states,
113
- temb=temb,
114
- image_rotary_emb=image_rotary_emb,
115
- encoder_attention_mask=encoder_attention_mask,
116
- )
117
-
118
- scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
119
- shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
120
-
121
- hidden_states = self.norm_out(hidden_states)
122
- hidden_states = hidden_states * (1 + scale) + shift
123
- output = self.proj_out(hidden_states)
124
-
125
- if not return_dict:
126
- return (output,)
127
- return Transformer2DModelOutput(sample=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/patches/models/wan/patch.py DELETED
@@ -1,33 +0,0 @@
1
- from typing import Optional
2
-
3
- import diffusers
4
- import torch
5
-
6
-
7
- def patch_time_text_image_embedding_forward() -> None:
8
- _patch_time_text_image_embedding_forward()
9
-
10
-
11
- def _patch_time_text_image_embedding_forward() -> None:
12
- diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = (
13
- _patched_WanTimeTextImageEmbedding_forward
14
- )
15
-
16
-
17
- def _patched_WanTimeTextImageEmbedding_forward(
18
- self,
19
- timestep: torch.Tensor,
20
- encoder_hidden_states: torch.Tensor,
21
- encoder_hidden_states_image: Optional[torch.Tensor] = None,
22
- ):
23
- # Some code has been removed compared to original implementation in Diffusers
24
- # Also, timestep is typed as that of encoder_hidden_states
25
- timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states)
26
- temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
27
- timestep_proj = self.time_proj(self.act_fn(temb))
28
-
29
- encoder_hidden_states = self.text_embedder(encoder_hidden_states)
30
- if encoder_hidden_states_image is not None:
31
- encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
32
-
33
- return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/patches/utils.py DELETED
@@ -1,18 +0,0 @@
1
- import torch
2
-
3
-
4
- class DisableTensorToDtype:
5
- def __enter__(self):
6
- self.original_to = torch.Tensor.to
7
-
8
- def modified_to(tensor, *args, **kwargs):
9
- # remove dtype from args if present
10
- args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
11
- if "dtype" in kwargs:
12
- kwargs.pop("dtype")
13
- return self.original_to(tensor, *args, **kwargs)
14
-
15
- torch.Tensor.to = modified_to
16
-
17
- def __exit__(self, *args, **kwargs):
18
- torch.Tensor.to = self.original_to
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/processors/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from .base import ProcessorMixin
2
- from .clip import CLIPPooledProcessor
3
- from .glm import CogView4GLMProcessor
4
- from .llama import LlamaProcessor
5
- from .t5 import T5Processor
6
- from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor
 
 
 
 
 
 
 
finetrainers/processors/base.py DELETED
@@ -1,20 +0,0 @@
1
- import inspect
2
- from typing import Any, Dict, List
3
-
4
-
5
- class ProcessorMixin:
6
- def __init__(self) -> None:
7
- self._forward_parameter_names = inspect.signature(self.forward).parameters.keys()
8
- self.output_names: List[str] = None
9
- self.input_names: Dict[str, Any] = None
10
-
11
- def __call__(self, *args, **kwargs) -> Any:
12
- shallow_copy_kwargs = dict(kwargs.items())
13
- if self.input_names is not None:
14
- for k, v in self.input_names.items():
15
- shallow_copy_kwargs[v] = shallow_copy_kwargs.pop(k)
16
- acceptable_kwargs = {k: v for k, v in shallow_copy_kwargs.items() if k in self._forward_parameter_names}
17
- return self.forward(*args, **acceptable_kwargs)
18
-
19
- def forward(self, *args, **kwargs) -> Any:
20
- raise NotImplementedError("ProcessorMixin::forward method should be implemented by the subclass.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/processors/clip.py DELETED
@@ -1,65 +0,0 @@
1
- from typing import Any, Dict, List, Optional, Tuple, Union
2
-
3
- import torch
4
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast
5
-
6
- from .base import ProcessorMixin
7
-
8
-
9
- class CLIPPooledProcessor(ProcessorMixin):
10
- r"""
11
- Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings
12
- and attention masks for the input text.
13
-
14
- Args:
15
- output_names (`List[str]`):
16
- The names of the outputs that the processor should return. The first output is the embeddings of the input
17
- text and the second output is the attention mask for the input text.
18
- """
19
-
20
- def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None) -> None:
21
- super().__init__()
22
-
23
- self.output_names = output_names
24
- self.input_names = input_names
25
-
26
- assert len(output_names) == 1
27
- if input_names is not None:
28
- assert len(input_names) <= 3
29
-
30
- def forward(
31
- self,
32
- tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
33
- text_encoder: CLIPTextModel,
34
- caption: Union[str, List[str]],
35
- ) -> Tuple[torch.Tensor, torch.Tensor]:
36
- r"""
37
- Encode the input text and return the embeddings and attention mask for the input text.
38
-
39
- Args:
40
- tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`):
41
- The tokenizer used to tokenize the input text.
42
- text_encoder (`LlamaModel`):
43
- The text encoder used to encode the input text.
44
- caption (`Union[str, List[str]]`):
45
- The input text to be encoded.
46
- """
47
- if isinstance(caption, str):
48
- caption = [caption]
49
-
50
- device = text_encoder.device
51
- dtype = text_encoder.dtype
52
-
53
- text_inputs = tokenizer(
54
- caption,
55
- padding="max_length",
56
- max_length=77,
57
- truncation=True,
58
- return_tensors="pt",
59
- )
60
- text_input_ids = text_inputs.input_ids.to(device)
61
-
62
- prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False).pooler_output
63
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
64
-
65
- return {self.output_names[0]: prompt_embeds}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/processors/glm.py DELETED
@@ -1,74 +0,0 @@
1
- from typing import List, Tuple, Union
2
-
3
- import torch
4
- from transformers import AutoTokenizer, GlmModel
5
-
6
- from .base import ProcessorMixin
7
-
8
-
9
- class CogView4GLMProcessor(ProcessorMixin):
10
- r"""
11
- Processor for the GLM family of models. This processor is used to encode text inputs and return the embeddings
12
- and attention masks for the input text.
13
-
14
- This processor is specific to CogView4 but can be used with any other model.
15
-
16
- Args:
17
- output_names (`List[str]`):
18
- The names of the outputs that the processor should return. The first output is the embeddings of the input
19
- text and the second output is the attention mask for the input text.
20
- """
21
-
22
- def __init__(self, output_names: List[str]):
23
- super().__init__()
24
-
25
- self.output_names = output_names
26
-
27
- assert len(self.output_names) == 1
28
-
29
- def forward(
30
- self,
31
- tokenizer: AutoTokenizer,
32
- text_encoder: GlmModel,
33
- caption: Union[str, List[str]],
34
- max_sequence_length: int,
35
- ) -> Tuple[torch.Tensor, torch.Tensor]:
36
- r"""
37
- Encode the input text and return the embeddings and attention mask for the input text.
38
-
39
- Args:
40
- tokenizer (`AutoTokenizer`):
41
- The tokenizer used to tokenize the input text.
42
- text_encoder (`GlmModel`):
43
- The text encoder used to encode the input text.
44
- caption (`Union[str, List[str]]`):
45
- The input text to be encoded.
46
- max_sequence_length (`int`):
47
- The maximum sequence length of the input text.
48
- """
49
- if isinstance(caption, str):
50
- caption = [caption]
51
-
52
- device = text_encoder.device
53
- dtype = text_encoder.dtype
54
-
55
- text_inputs = tokenizer(
56
- caption,
57
- padding="longest",
58
- max_length=max_sequence_length,
59
- truncation=True,
60
- add_special_tokens=True,
61
- return_tensors="pt",
62
- )
63
- text_input_ids = text_inputs.input_ids.to(device)
64
-
65
- current_length = text_input_ids.size(1)
66
- pad_length = 16 - current_length % 16
67
- if pad_length > 0:
68
- pad_ids = text_input_ids.new_full((text_input_ids.shape[0], pad_length), fill_value=tokenizer.pad_token_id)
69
- text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
70
-
71
- prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2]
72
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
73
-
74
- return {self.output_names[0]: prompt_embeds}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/processors/llama.py DELETED
@@ -1,118 +0,0 @@
1
- from typing import Any, Dict, List, Optional, Tuple, Union
2
-
3
- import torch
4
- from transformers import LlamaModel, LlamaTokenizer, LlamaTokenizerFast
5
-
6
- from .base import ProcessorMixin
7
-
8
-
9
- DEFAULT_PROMPT_TEMPLATE = {
10
- "template": (
11
- "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
12
- "1. The main content and theme of the video."
13
- "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
14
- "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
15
- "4. background environment, light, style and atmosphere."
16
- "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
17
- "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
18
- ),
19
- "crop_start": 95,
20
- }
21
-
22
-
23
- class LlamaProcessor(ProcessorMixin):
24
- r"""
25
- Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings
26
- and attention masks for the input text.
27
-
28
- Args:
29
- output_names (`List[str]`):
30
- The names of the outputs that the processor should return. The first output is the embeddings of the input
31
- text and the second output is the attention mask for the input text.
32
- """
33
-
34
- def __init__(self, output_names: List[str] = None):
35
- super().__init__()
36
-
37
- self.output_names = output_names
38
-
39
- assert len(output_names) == 2
40
-
41
- def forward(
42
- self,
43
- tokenizer: Union[LlamaTokenizer, LlamaTokenizerFast],
44
- text_encoder: LlamaModel,
45
- caption: Union[str, List[str]],
46
- max_sequence_length: int,
47
- prompt_template: Optional[Dict[str, Any]] = None,
48
- num_layers_to_skip: int = 2,
49
- ) -> Tuple[torch.Tensor, torch.Tensor]:
50
- r"""
51
- Encode the input text and return the embeddings and attention mask for the input text.
52
-
53
- Args:
54
- tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`):
55
- The tokenizer used to tokenize the input text.
56
- text_encoder (`LlamaModel`):
57
- The text encoder used to encode the input text.
58
- caption (`Union[str, List[str]]`):
59
- The input text to be encoded.
60
- max_sequence_length (`int`):
61
- The maximum sequence length of the input text.
62
- prompt_template (`Optional[Dict[str, Any]]`):
63
- The prompt template to be used to encode the input text.
64
- """
65
- if prompt_template is None:
66
- prompt_template = DEFAULT_PROMPT_TEMPLATE
67
- if isinstance(caption, str):
68
- caption = [caption]
69
-
70
- device = text_encoder.device
71
- dtype = text_encoder.dtype
72
-
73
- batch_size = len(caption)
74
- caption = [prompt_template["template"].format(c) for c in caption]
75
-
76
- crop_start = prompt_template.get("crop_start", None)
77
- if crop_start is None:
78
- prompt_template_input = tokenizer(
79
- prompt_template["template"],
80
- padding="max_length",
81
- return_tensors="pt",
82
- return_length=False,
83
- return_overflowing_tokens=False,
84
- return_attention_mask=False,
85
- )
86
- crop_start = prompt_template_input["input_ids"].shape[-1]
87
- # Remove <|eot_id|> token and placeholder {}
88
- crop_start -= 2
89
-
90
- max_sequence_length += crop_start
91
- text_inputs = tokenizer(
92
- caption,
93
- max_length=max_sequence_length,
94
- padding="max_length",
95
- truncation=True,
96
- return_tensors="pt",
97
- return_length=False,
98
- return_overflowing_tokens=False,
99
- return_attention_mask=True,
100
- )
101
- text_input_ids = text_inputs.input_ids.to(device)
102
- prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
103
-
104
- prompt_embeds = text_encoder(
105
- text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
106
- ).hidden_states[-(num_layers_to_skip + 1)]
107
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
108
-
109
- if crop_start is not None and crop_start > 0:
110
- prompt_embeds = prompt_embeds[:, crop_start:]
111
- prompt_attention_mask = prompt_attention_mask[:, crop_start:]
112
-
113
- prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
114
-
115
- return {
116
- self.output_names[0]: prompt_embeds,
117
- self.output_names[1]: prompt_attention_mask,
118
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/processors/t5.py DELETED
@@ -1,73 +0,0 @@
1
- from typing import List, Tuple, Union
2
-
3
- import torch
4
- from transformers import T5EncoderModel, T5Tokenizer, T5TokenizerFast
5
-
6
- from .base import ProcessorMixin
7
-
8
-
9
- class T5Processor(ProcessorMixin):
10
- r"""
11
- Processor for the T5 family of models. This processor is used to encode text inputs and return the embeddings
12
- and attention masks for the input text.
13
-
14
- Args:
15
- output_names (`List[str]`):
16
- The names of the outputs that the processor should return. The first output is the embeddings of the input
17
- text and the second output is the attention mask for the input text.
18
- """
19
-
20
- def __init__(self, output_names: List[str]):
21
- super().__init__()
22
-
23
- self.output_names = output_names
24
-
25
- assert len(self.output_names) == 2
26
-
27
- def forward(
28
- self,
29
- tokenizer: Union[T5Tokenizer, T5TokenizerFast],
30
- text_encoder: T5EncoderModel,
31
- caption: Union[str, List[str]],
32
- max_sequence_length: int,
33
- ) -> Tuple[torch.Tensor, torch.Tensor]:
34
- r"""
35
- Encode the input text and return the embeddings and attention mask for the input text.
36
-
37
- Args:
38
- tokenizer (`Union[T5Tokenizer, T5TokenizerFast]`):
39
- The tokenizer used to tokenize the input text.
40
- text_encoder (`T5EncoderModel`):
41
- The text encoder used to encode the input text.
42
- caption (`Union[str, List[str]]`):
43
- The input text to be encoded.
44
- max_sequence_length (`int`):
45
- The maximum sequence length of the input text.
46
- """
47
- if isinstance(caption, str):
48
- caption = [caption]
49
-
50
- device = text_encoder.device
51
- dtype = text_encoder.dtype
52
-
53
- batch_size = len(caption)
54
- text_inputs = tokenizer(
55
- caption,
56
- padding="max_length",
57
- max_length=max_sequence_length,
58
- truncation=True,
59
- add_special_tokens=True,
60
- return_tensors="pt",
61
- )
62
- text_input_ids = text_inputs.input_ids
63
- prompt_attention_mask = text_inputs.attention_mask
64
- prompt_attention_mask = prompt_attention_mask.bool().to(device)
65
-
66
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
67
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
68
- prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
69
-
70
- return {
71
- self.output_names[0]: prompt_embeds,
72
- self.output_names[1]: prompt_attention_mask,
73
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/processors/text.py DELETED
@@ -1,22 +0,0 @@
1
- from typing import List, Union
2
-
3
- import torch
4
-
5
- from .. import functional as FF
6
- from .base import ProcessorMixin
7
-
8
-
9
- class CaptionTextDropoutProcessor(ProcessorMixin):
10
- def __init__(self, dropout_p: float = 0.0) -> None:
11
- self.dropout_p = dropout_p
12
-
13
- def forward(self, caption: Union[str, List[str]]) -> Union[str, List[str]]:
14
- return FF.dropout_caption(caption, self.dropout_p)
15
-
16
-
17
- class CaptionEmbeddingDropoutProcessor(ProcessorMixin):
18
- def __init__(self, dropout_p: float = 0.0) -> None:
19
- self.dropout_p = dropout_p
20
-
21
- def forward(self, embedding: torch.Tensor) -> torch.Tensor:
22
- return FF.dropout_embeddings_to_zero(embedding, self.dropout_p)