jbilcke-hf HF staff commited on
Commit
d464085
·
1 Parent(s): 80ebcb3

upgrading our code to support the new finetrainers

Browse files
Files changed (45) hide show
  1. docs/finetrainers/documentation_dataset_README.md +165 -0
  2. docs/finetrainers/documentation_environment.md +28 -0
  3. docs/finetrainers/documentation_global_README.md +99 -0
  4. docs/finetrainers/documentation_models_README.md +56 -0
  5. docs/finetrainers/documentation_models_hunyuan_video.md +189 -0
  6. docs/finetrainers/documentation_models_ltx_video.md +196 -0
  7. docs/finetrainers/documentation_models_optimization.md +12 -0
  8. docs/finetrainers/documentation_models_wan.md +36 -0
  9. docs/finetrainers/documentation_optimizers.md +14 -0
  10. docs/finetrainers/documentation_parallel_processing_README.md +59 -0
  11. docs/finetrainers/examples_training_wan__train.sh +162 -0
  12. docs/finetrainers/examples_training_wan__training.json +14 -0
  13. docs/finetrainers/examples_training_wan__validation.json +40 -0
  14. training/README.md +0 -459
  15. training/README_zh.md +0 -455
  16. training/cogvideox/__init__.py +0 -0
  17. training/cogvideox/args.py +0 -484
  18. training/cogvideox/cogvideox_image_to_video_lora.py +0 -1016
  19. training/cogvideox/cogvideox_image_to_video_sft.py +0 -947
  20. training/cogvideox/cogvideox_text_to_video_lora.py +0 -955
  21. training/cogvideox/cogvideox_text_to_video_sft.py +0 -917
  22. training/cogvideox/dataset.py +0 -428
  23. training/cogvideox/prepare_dataset.py +0 -669
  24. training/cogvideox/text_encoder/__init__.py +0 -1
  25. training/cogvideox/text_encoder/text_encoder.py +0 -99
  26. training/cogvideox/utils.py +0 -260
  27. training/mochi-1/README.md +0 -111
  28. training/mochi-1/args.py +0 -268
  29. training/mochi-1/dataset_simple.py +0 -50
  30. training/mochi-1/embed.py +0 -111
  31. training/mochi-1/prepare_dataset.sh +0 -15
  32. training/mochi-1/requirements.txt +0 -8
  33. training/mochi-1/text_to_video_lora.py +0 -592
  34. training/mochi-1/train.sh +0 -37
  35. training/mochi-1/trim_and_crop_videos.py +0 -126
  36. training/mochi-1/utils.py +0 -22
  37. training/prepare_dataset.sh +0 -48
  38. training/train_image_to_video_lora.sh +0 -82
  39. training/train_image_to_video_sft.sh +0 -87
  40. training/train_text_to_video_lora.sh +0 -86
  41. training/train_text_to_video_sft.sh +0 -77
  42. vms/config.py +109 -14
  43. vms/services/trainer.py +75 -22
  44. vms/tabs/train_tab.py +96 -21
  45. vms/ui/video_trainer_ui.py +28 -2
docs/finetrainers/documentation_dataset_README.md ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset
2
+
3
+ ## Dataset preparation
4
+
5
+ Please refer to [video-dataset-scripts](https://github.com/huggingface/video-dataset-scripts) for a collection of scripts to prepare datasets for training. The scripts are designed to work with the HF datasets library and can be used to prepare datasets for training with `finetrainers`.
6
+
7
+ ## Training Dataset Format
8
+
9
+ Dataset loading format support is very limited at the moment. This will be improved in the future. For now, we support the following formats:
10
+
11
+ #### Two file format
12
+
13
+ > [!NOTE]
14
+ > Relevant classes to look for implementation:
15
+ > - ImageFileCaptionFileListDataset
16
+ > - VideoFileCaptionFileListDataset
17
+ >
18
+ > Supports loading directly from the HF Hub.
19
+
20
+ Your dataset structure should look like this. Running the `tree` command, you should see something similar to:
21
+
22
+ ```
23
+ dataset
24
+ ├── prompt.txt
25
+ ├── videos.txt
26
+ ├── videos
27
+ ├── 00000.mp4
28
+ ├── 00001.mp4
29
+ ├── ...
30
+ ```
31
+
32
+ - Make sure that the paths in `videos.txt` is relative to the `dataset` directory. The `prompt.txt` should contain the captions for the videos in the same order as the videos in `videos.txt`.
33
+ - Supported names for caption file: `captions.txt`, `caption.txt`, `prompt.txt`, `prompts.txt` (feel free to send PRs to add more common names).
34
+ - Supported names for video file: `videos.txt`, `video.txt`, (feel free to send PRs to add more common names).
35
+
36
+ #### Caption-Data filename pair format
37
+
38
+ > [!NOTE]
39
+ > Relevant classes to look for implementation:
40
+ > - ImageCaptionFilePairDataset
41
+ > - VideoCaptionFilePairDataset
42
+ >
43
+ > Does not support loading directly from the HF Hub.
44
+
45
+ Your dataset structure should look like this. Running the `tree` command, you should see something similar to:
46
+
47
+ ```
48
+ dataset
49
+ ├── a.txt
50
+ ├── a.mp4
51
+ ├── bkjlaskdjg.txt
52
+ ├── bkjlaskdjg.mp4
53
+ ├── ...
54
+ ```
55
+
56
+ - Each caption file should have a corresponding image/video file with the same name.
57
+
58
+ #### CSV/JSON/JSONL format
59
+
60
+ > [!NOTE]
61
+ > Relevant classes to look for implementation:
62
+ > - ImageFolderDataset
63
+ > - VideoFolderDataset
64
+
65
+ Any dataset loadable via the [🤗 HF datasets] directly should work (not widely tested at the moment):
66
+ - https://huggingface.co/docs/datasets/v3.3.2/en/image_load#webdataset
67
+ - https://huggingface.co/docs/datasets/v3.3.2/en/video_load#webdataset
68
+
69
+ #### Webdataset format
70
+
71
+ > [!NOTE]
72
+ > Relevant classes to look for implementation:
73
+ > - ImageWebDataset
74
+ > - VideoWebDataset
75
+
76
+ Any dataset loadable via the [🤗 HF datasets] directly should work (not widely tested at the moment). We support the [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/image_dataset#webdataset) and [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/video_dataset#webdataset) formats.
77
+
78
+ ## Validation Dataset Format
79
+
80
+ Arguments related to validation are:
81
+ - `--validation_dataset_file`: Path to the validation dataset file. Supported formats are CSV, JSON, JSONL, PARQUET, and ARROW. Note: PARQUET and ARROW have not been tested after a major refactor, but should most likely work. (TODO(aryan): look into this)
82
+ - `--validation_steps`: Interval of training steps after which validation should be performed.
83
+ - `--enable_model_cpu_offload`: If set, CPU offloading will be enabled during validation. Note that this has not been tested for FSDP, TP, or DDP after a major refactor, but should most likely work for single GPU training,
84
+
85
+ > [!IMPORTANT]
86
+ >
87
+ > When using `dp_shards > 1` or `tp_degree > 1`, you must make sure that the number of data samples contained is a multiple of `dp_shards * tp_degree`. If this is not the case, the training will fail due to a NCCL timeout. This will be improved/fixed in the future.
88
+
89
+ - Must contain "caption" as a column. If an image must be provided for validation (for example, image-to-video inference), then the "image_path" field must be provided. If a video must be provided for validation (for example, video-to-video inference), then the "video_path" field must be provided. Other fields like "num_inference_steps", "height", "width", "num_frames", and "frame_rate" can be provided too but are optional.
90
+
91
+ #### CSV Example
92
+
93
+ <details>
94
+ <summary>Click to expand</summary>
95
+
96
+ ```csv
97
+ caption,image_path,video_path,num_inference_steps,height,width,num_frames,frame_rate
98
+ "A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.",,"/raid/aryan/finetrainers-dummy-dataset-disney/a3c275fc2eb0a67168a7c58a6a9adb14.mp4",50,480,768,49,25
99
+ "<SECOND_CAPTION>",,"/path/to/second.mp4",50,512,704,161,25
100
+ ```
101
+
102
+ </details>
103
+
104
+ #### JSON Example
105
+
106
+ Must contain "data" field, which should be a list of dictionaries. Each dictionary corresponds to one validation video that will be generated with the selected configuration of generation parameters.
107
+
108
+ <details>
109
+ <summary>Click to expand</summary>
110
+
111
+ ```json
112
+ {
113
+ "data": [
114
+ {
115
+ "caption": "A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.",
116
+ "image_path": "",
117
+ "video_path": "/raid/aryan/finetrainers-dummy-dataset-disney/a3c275fc2eb0a67168a7c58a6a9adb14.mp4",
118
+ "num_inference_steps": 50,
119
+ "height": 480,
120
+ "width": 768,
121
+ "num_frames": 49,
122
+ "frame_rate": 25
123
+ },
124
+ {
125
+ "caption": "<SECOND_CAPTION>",
126
+ "image_path": "",
127
+ "video_path": "/path/to/second.mp4",
128
+ "num_inference_steps": 50,
129
+ "height": 512,
130
+ "width": 704,
131
+ "num_frames": 161,
132
+ "frame_rate": 25
133
+ }
134
+ ]
135
+ }
136
+ ```
137
+
138
+ </details>
139
+
140
+ ## Understanding how datasets are loaded
141
+
142
+ For memory efficient training, it is important to precompute conditional and latent embeddings. If this is not done, we will need to keep the conditioning models in memory, which can be memory intensive. To avoid this, we implement some abstractions that allow us to do the following efficiently:
143
+ - Loading datasets
144
+ - Chaining multiple datasets together
145
+ - Splitting datasets across data replicas
146
+ - Preprocessing datasets to user-configured resolution buckets
147
+ - Precomputing embeddings without exhaustively using too much disk space
148
+
149
+ The following is a high-level overview of how datasets are loaded and preprocessed:
150
+
151
+ - Initially, the dataset is lazy loaded using the HF `datasets` library. Every dataset is loaded in streaming and infinite mode. This means that the dataset will be loaded indefinitely until some end conditions (e.g. user-configured training steps is completed). Users can chain together multiple datasets too! For example, if you only have high resolution data available, but want to perform multi-resolution training at certain lower resolutions too, you would have to perform the resizing manually and chain the data together. Finetrainers makes this easier by allowing you to specify multiple different, or same, datasets with different resolutions.
152
+ - The dataset is split across data replicas (GPUs groups that perform data parallelism). Each data replica will have a non-overlapping subset of the overall dataset.
153
+ - If multiple datasets have been provided, they will be chained together. Shuffling can also be done to ensure better dataset regularization. This is done by shuffling the iterable datasets in a buffer of user-configured `--dataset_shuffle_buffer_size`. For small datasets, it is recommended to not shuffle and use the default value of `1`. For larger datasets, there is a significant overhead the higher this value is set to, so it is recommended to keep it low (< 1000) [this is because we store the data in memory in a not-so-clever way yet].
154
+ - The dataset is preprocessed to the user-configured resolution buckets. This is done by resizing the images/videos to the specified resolution buckets. This is also necessary for collation when using batch_size > 1.
155
+ - The dataset is precomputed for embeddings and stored to disk. This is done in batches of user-configured `--precompute_batch_size`. This is done to avoid exhausting disk space. The smaller this value, the more number of times conditioning models will be loaded upon precomputation exhaustion. The larger this value, the more disk space will be used.
156
+ - When data points are required for training, they are loaded from disk on the main process and dispatched to data replicas. [TODO: this needs some improvements to speedup training eventually]
157
+
158
+ ## Understanding how datasets are precomputed
159
+
160
+ There are 3 arguments related to precomputation:
161
+ - `--precomputation_items`: The number of data points to precompute and store to disk at a time. This is useful for performing memory-efficient training without exhausting disk space by precomputing embeddings of the entire dataset(s) at once. We default to `512` data points, but configure this to a lower value for smaller datasets. As training progresses, the precomputed data will be read from disk and dispatched to data replicas. Once all precomputed data has been used, the next batch of data points will be precomputed and stored to disk in a rolling fashion.
162
+ - `--precomputation_dir`: The directory where precomputed data will be stored. This is useful for resuming training from a checkpoint, as the precomputed data will be loaded from this directory. If this directory is not provided, the precomputed data will be stored in the `--output_dir/precomputed`.
163
+ - `--precomputation_once`: If you're working with small datasets and want to precompute all embeddings at once, set this flag. This will allow you to train without having to compute embeddings every time the precomputed data is exhausted. Currently, `webdataset` format loading does not support this feature, and it is also disabled for `> 1024` data points due to hard coded logic (can be removed manually by users for now).
164
+
165
+ Batching is not yet supported for precomputation. This will be added in the future.
docs/finetrainers/documentation_environment.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment
2
+
3
+ Finetrainers has only been widely tested with the following environment (output obtained by running `diffusers-cli env`):
4
+
5
+ ```shell
6
+ - 🤗 Diffusers version: 0.33.0.dev0
7
+ - Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
8
+ - Running on Google Colab?: No
9
+ - Python version: 3.10.14
10
+ - PyTorch version (GPU?): 2.5.1+cu124 (True)
11
+ - Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
12
+ - Jax version: 0.4.31
13
+ - JaxLib version: 0.4.31
14
+ - Huggingface_hub version: 0.28.1
15
+ - Transformers version: 4.48.0.dev0
16
+ - Accelerate version: 1.1.0.dev0
17
+ - PEFT version: 0.14.1.dev0
18
+ - Bitsandbytes version: 0.43.3
19
+ - Safetensors version: 0.4.5
20
+ - xFormers version: not installed
21
+ - Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
22
+ NVIDIA A100-SXM4-80GB, 81920 MiB
23
+ NVIDIA A100-SXM4-80GB, 81920 MiB
24
+ NVIDIA DGX Display, 4096 MiB
25
+ NVIDIA A100-SXM4-80GB, 81920 MiB
26
+ ```
27
+
28
+ Other versions of dependencies may or may not work as expected. We would like to make finetrainers work on a wider range of environments, but due to the complexity of testing at the early stages of development, we are unable to do so. The long term goals include compatibility with most pytorch versions on CUDA, MPS, ROCm and XLA devices.
docs/finetrainers/documentation_global_README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # finetrainers 🧪
2
+
3
+ FineTrainers is a work-in-progress library to support (accessible) training of video models. Our first priority is to support LoRA training for all popular video models in [Diffusers](https://github.com/huggingface/diffusers), and eventually other methods like controlnets, control-loras, distillation, etc.
4
+
5
+ `cogvideox-factory` was renamed to `finetrainers`. If you're looking to train CogVideoX or Mochi with the legacy training scripts, please refer to [this](./training/README.md) README instead. Everything in the `training/` directory will be eventually moved and supported under `finetrainers`.
6
+
7
+ <table align="center">
8
+ <tr>
9
+ <td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">Your browser does not support the video tag.</video></td>
10
+ </tr>
11
+ </table>
12
+
13
+ ## News
14
+
15
+ - 🔥 **2025-03-03**: Wan T2V support added!
16
+ - 🔥 **2025-03-03**: We have shipped a complete refactor to support multi-backend distributed training, better precomputation handling for big datasets, model specification format (externally usable for training custom models), FSDP & more.
17
+ - 🔥 **2025-02-12**: We have shipped a set of tooling to curate small and high-quality video datasets for fine-tuning. See [video-dataset-scripts](https://github.com/huggingface/video-dataset-scripts) documentation page for details!
18
+ - 🔥 **2025-02-12**: Check out [eisneim/ltx_lora_training_i2v_t2v](https://github.com/eisneim/ltx_lora_training_i2v_t2v/)! It builds off of `finetrainers` to support image to video training for LTX-Video and STG guidance for inference.
19
+ - 🔥 **2025-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions.
20
+ - 🔥 **2025-01-13**: Support for T2V full-finetuning added! Thanks to [@ArEnSc](https://github.com/ArEnSc) for taking up the initiative!
21
+ - 🔥 **2025-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
22
+ - 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254).
23
+ - 🔥 **2024-12-18**: Support for T2V LoRA finetuning of [LTX Video](https://huggingface.co/docs/diffusers/main/api/pipelines/ltx_video) added!
24
+
25
+ ## Table of Contents
26
+
27
+ - [Quickstart](#quickstart)
28
+ - [Support Matrix](#support-matrix)
29
+ - [Featured Projects](#featured-projects)
30
+ - [Acknowledgements](#acknowledgements)
31
+
32
+ ## Quickstart
33
+
34
+ Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install `diffusers` from source by `pip install git+https://github.com/huggingface/diffusers`. The requirements specify `diffusers>=0.32.1`, but it is always recommended to use the `main` branch of Diffusers for the latest features and bugfixes. Note that the `main` branch for `finetrainers` is also the development branch, and stable support should be expected from the release tags.
35
+
36
+ Checkout to the latest release tag:
37
+
38
+ ```bash
39
+ git fetch --all --tags
40
+ git checkout tags/v0.0.1
41
+ ```
42
+
43
+ Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.0.1) for the release tag.
44
+
45
+ #### Using the main branch
46
+
47
+ To get started quickly with example training scripts on the main development branch, refer to the following:
48
+ - [LTX-Video Pika Effects Crush](./examples/training/sft/ltx_video/crush_smol_lora/)
49
+ - [CogVideoX Pika Effects Crush](./examples/training/sft/cogvideox/crush_smol_lora/)
50
+ - [Wan T2V Pika Effects Crush](./examples/training/sft/wan/crush_smol_lora/)
51
+
52
+ The following are some simple datasets/HF orgs with good datasets to test training with quickly:
53
+ - [Disney Video Generation Dataset](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
54
+ - [bigdatapw Video Dataset Collection](https://huggingface.co/bigdata-pw)
55
+ - [Finetrainers HF Dataset Collection](https://huggingface.co/finetrainers)
56
+
57
+ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./examples/training/) to learn more about supported models for training & example reproducible training launch scripts.
58
+
59
+ > [!IMPORTANT]
60
+ > It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested. For fully reproducible training, please use the same environment as mentioned in [environment.md](./docs/environment.md).
61
+
62
+ ## Support Matrix
63
+
64
+ > [!NOTE]
65
+ > The following numbers were obtained from the [release branch](https://github.com/a-r-r-o-w/finetrainers/tree/v0.0.1). The `main` branch is unstable at the moment and may use higher memory.
66
+
67
+ <div align="center">
68
+
69
+ | **Model Name** | **Tasks** | **Min. LoRA VRAM<sup>*</sup>** | **Min. Full Finetuning VRAM<sup>^</sup>** |
70
+ |:----------------------------------------------:|:-------------:|:----------------------------------:|:---------------------------------------------:|
71
+ | [LTX-Video](./docs/models/ltx_video.md) | Text-to-Video | 5 GB | 21 GB |
72
+ | [HunyuanVideo](./docs/models/hunyuan_video.md) | Text-to-Video | 32 GB | OOM |
73
+ | [CogVideoX-5b](./docs/models/cogvideox.md) | Text-to-Video | 18 GB | 53 GB |
74
+ | [Wan](./docs/models/wan.md) | Text-to-Video | TODO | TODO |
75
+
76
+ </div>
77
+
78
+ <sub><sup>*</sup>Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using **FP8** weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).</sub><br/>
79
+ <sub><sup>^</sup>Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using **BF16** weights & gradient checkpointing.</sub>
80
+
81
+ If you would like to use a custom dataset, refer to the dataset preparation guide [here](./docs/dataset/README.md).
82
+
83
+ ## Featured Projects 🔥
84
+
85
+ Checkout some amazing projects citing `finetrainers`:
86
+ - [Diffusion as Shader](https://github.com/IGL-HKUST/DiffusionAsShader)
87
+ - [SkyworkAI's SkyReels-A1](https://github.com/SkyworkAI/SkyReels-A1)
88
+ - [eisneim's LTX Image-to-Video](https://github.com/eisneim/ltx_lora_training_i2v_t2v/)
89
+ - [wileewang's TransPixar](https://github.com/wileewang/TransPixar)
90
+ - [Feizc's Video-In-Context](https://github.com/feizc/Video-In-Context)
91
+
92
+ Checkout the following UIs built for `finetrainers`:
93
+ - [jbilcke's VideoModelStudio](https://github.com/jbilcke-hf/VideoModelStudio)
94
+ - [neph1's finetrainers-ui](https://github.com/neph1/finetrainers-ui)
95
+
96
+ ## Acknowledgements
97
+
98
+ * `finetrainers` builds on top of & takes inspiration from great open-source libraries - `transformers`, `accelerate`, `torchtune`, `torchtitan`, `peft`, `diffusers`, `bitsandbytes`, `torchao` and `deepspeed` - to name a few.
99
+ * Some of the design choices of `finetrainers` were inspired by [`SimpleTuner`](https://github.com/bghira/SimpleTuner).
docs/finetrainers/documentation_models_README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FineTrainers training documentation
2
+
3
+ This directory contains the training-related specifications for all the models we support in `finetrainers`. Each model page has:
4
+ - an example training command
5
+ - inference example
6
+ - numbers on memory consumption
7
+
8
+ By default, we don't include any validation-related arguments in the example training commands. To enable validation inference every 500 steps, one can add the following arguments:
9
+
10
+ ```diff
11
+ + --validation_dataset_file <Path to a CSV/JSON/PARQUET/ARROW> \
12
+ + --validation_steps 500
13
+ ```
14
+
15
+ Arguments for training are documented in the code. For more information, please run `python train.py --help`.
16
+
17
+ ## Support matrix
18
+
19
+ The following table shows the algorithms supported for training and the models they are supported for:
20
+
21
+ | Model | SFT | Control | ControlNet | Distillation |
22
+ |:-----------------------------------------:|:---:|:-------:|:----------:|:------------:|
23
+ | [CogVideoX](./cogvideox.md) | 🤗 | 😡 | 😡 | 😡 |
24
+ | [LTX-Video](./ltx_video.md) | 🤗 | 😡 | 😡 | 😡 |
25
+ | [HunyuanVideo](./hunyuan_video.md)) | 🤗 | 😡 | 😡 | 😡 |
26
+
27
+ For launching SFT Training:
28
+ - `--training_type lora`: Trains a new set of low-rank weights of the model, yielding a smaller adapter model. Currently, only LoRA is supported from [🤗 PEFT](https://github.com/huggingface/peft)
29
+ - `--training_type full-finetune`: Trains the full-rank weights of the model, yielding a full-parameter trained model.
30
+
31
+ Any model architecture loadable in diffusers/transformers for above models can be used for training. For example, [SkyReels-T2V](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-T2V) is a finetune of HunyuanVideo, which is compatible for continual training out-of-the-box. Custom models can be loaded either by writing your own [ModelSpecification](TODO(aryan): add link) or by using the following set of arguments:
32
+ - `--tokenizer_id`, `--tokenizer_2_id`, `--tokenizer_3_id`: The tokenizers to use for training in conjunction with text encoder conditioning models.
33
+ - `--text_encoder_id`, `--text_encoder_2_id`, `--text_encoder_3_id`: The text encoder conditioning models.
34
+ - `--transformer_id`: The transformer model to use for training.
35
+ - `--vae_id`: The VAE model to use for training.
36
+
37
+ The above arguments should take care of most training scenarios. For any custom training scenarios, please use your own implementation of a `ModelSpecification`. These arguments should be used only if one wants to override the default components loaded from `--pretrained_model_name_or_path`. Similar to each of these arguments, there exists a set of `--<ARG>_dtype` argument to specify the precision of each component.
38
+
39
+ ## Resuming training
40
+
41
+ To resume training, the following arguments can be used:
42
+ - `--checkpointing_steps`: The interval of training steps that should be completed after which the training state should be saved.
43
+ - `--checkpointing_limit`: The maximum number of checkpoints that should be saved at once. If the limit is reached, the oldest checkpoint is purged.
44
+ - `--resume_from_checkpoint <STEP_OR_LATEST>`: Can be an integer or the string `"latest"`. If an integer is provided, training will resume from that step if a checkpoint corresponding to it exists. If `"latest"` is provided, training will resume from the latest checkpoint in the `--output_dir`.
45
+
46
+ > [!IMPORTANT]
47
+ > The `--resume_from_checkpoint` argument is only compatible if the parallel backend and degrees of parallelism are the same from the previous training run. For example, changing `--dp_degree 2 --dp_shards 1` from past run to `--dp_degree 1 --dp_shards 2` in current run will not work.
48
+
49
+ ## How do we handle `mixed_precision`?
50
+
51
+ The accelerate config files (the ones seen [here](../../accelerate_configs/)) that are being supplied while launching training should contain a field called `mixed_precision` and `accelerate` makes use of that if specified. We don't let users explicitly pass that from the CLI args because it can be confusing to have `transformer_dtype` and `mixed_precision` in the codebase.
52
+
53
+ `transformer_dtype` is the ultimate source of truth for the precision to be used when training. It will also most likely always have to be `torch.bfloat16` because:
54
+
55
+ * All models currently supported (except Cog-2b) do not work well in FP16 for inference, so training would be broken as well. This can be revisited if it makes sense to train in FP16 for other models added.
56
+ * The `accelerate` config files default to using "bf16", but modifying that would be at the risk of user and assumes they understand the significance of their changes.
docs/finetrainers/documentation_models_hunyuan_video.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HunyuanVideo
2
+
3
+ ## Training
4
+
5
+ For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
+
7
+ ```bash
8
+ #!/bin/bash
9
+
10
+ export WANDB_MODE="offline"
11
+ export NCCL_P2P_DISABLE=1
12
+ export TORCH_NCCL_ENABLE_MONITORING=0
13
+ export FINETRAINERS_LOG_LEVEL=DEBUG
14
+
15
+ GPU_IDS="0,1"
16
+
17
+ DATA_ROOT="/path/to/dataset"
18
+ CAPTION_COLUMN="prompts.txt"
19
+ VIDEO_COLUMN="videos.txt"
20
+ OUTPUT_DIR="/path/to/models/hunyuan-video/"
21
+
22
+ ID_TOKEN="afkx"
23
+
24
+ # Model arguments
25
+ model_cmd="--model_name hunyuan_video \
26
+ --pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo"
27
+
28
+ # Dataset arguments
29
+ dataset_cmd="--data_root $DATA_ROOT \
30
+ --video_column $VIDEO_COLUMN \
31
+ --caption_column $CAPTION_COLUMN \
32
+ --id_token $ID_TOKEN \
33
+ --video_resolution_buckets 17x512x768 49x512x768 61x512x768 \
34
+ --caption_dropout_p 0.05"
35
+
36
+ # Dataloader arguments
37
+ dataloader_cmd="--dataloader_num_workers 0"
38
+
39
+ # Diffusion arguments
40
+ diffusion_cmd=""
41
+
42
+ # Training arguments
43
+ training_cmd="--training_type lora \
44
+ --seed 42 \
45
+ --batch_size 1 \
46
+ --train_steps 500 \
47
+ --rank 128 \
48
+ --lora_alpha 128 \
49
+ --target_modules to_q to_k to_v to_out.0 \
50
+ --gradient_accumulation_steps 1 \
51
+ --gradient_checkpointing \
52
+ --checkpointing_steps 500 \
53
+ --checkpointing_limit 2 \
54
+ --enable_slicing \
55
+ --enable_tiling"
56
+
57
+ # Optimizer arguments
58
+ optimizer_cmd="--optimizer adamw \
59
+ --lr 2e-5 \
60
+ --lr_scheduler constant_with_warmup \
61
+ --lr_warmup_steps 100 \
62
+ --lr_num_cycles 1 \
63
+ --beta1 0.9 \
64
+ --beta2 0.95 \
65
+ --weight_decay 1e-4 \
66
+ --epsilon 1e-8 \
67
+ --max_grad_norm 1.0"
68
+
69
+ # Miscellaneous arguments
70
+ miscellaneous_cmd="--tracker_name finetrainers-hunyuan-video \
71
+ --output_dir $OUTPUT_DIR \
72
+ --nccl_timeout 1800 \
73
+ --report_to wandb"
74
+
75
+ cmd="accelerate launch --config_file accelerate_configs/uncompiled_8.yaml --gpu_ids $GPU_IDS train.py \
76
+ $model_cmd \
77
+ $dataset_cmd \
78
+ $dataloader_cmd \
79
+ $diffusion_cmd \
80
+ $training_cmd \
81
+ $optimizer_cmd \
82
+ $miscellaneous_cmd"
83
+
84
+ echo "Running command: $cmd"
85
+ eval $cmd
86
+ echo -ne "-------------------- Finished executing script --------------------\n\n"
87
+ ```
88
+
89
+ ## Memory Usage
90
+
91
+ ### LoRA
92
+
93
+ > [!NOTE]
94
+ >
95
+ > The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
96
+
97
+ LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:
98
+
99
+ ```
100
+ Training configuration: {
101
+ "trainable parameters": 163577856,
102
+ "total samples": 69,
103
+ "train epochs": 1,
104
+ "train steps": 10,
105
+ "batches per device": 1,
106
+ "total batches observed per epoch": 69,
107
+ "train batch size": 1,
108
+ "gradient accumulation steps": 1
109
+ }
110
+ ```
111
+
112
+ | stage | memory_allocated | max_memory_reserved |
113
+ |:-----------------------:|:----------------:|:-------------------:|
114
+ | before training start | 38.889 | 39.020 |
115
+ | before validation start | 39.747 | 56.266 |
116
+ | after validation end | 39.748 | 58.385 |
117
+ | after epoch 1 | 39.748 | 40.910 |
118
+ | after training end | 25.288 | 40.910 |
119
+
120
+ Note: requires about `59` GB of VRAM when validation is performed.
121
+
122
+ LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **with precomputation**:
123
+
124
+ ```
125
+ Training configuration: {
126
+ "trainable parameters": 163577856,
127
+ "total samples": 1,
128
+ "train epochs": 10,
129
+ "train steps": 10,
130
+ "batches per device": 1,
131
+ "total batches observed per epoch": 1,
132
+ "train batch size": 1,
133
+ "gradient accumulation steps": 1
134
+ }
135
+ ```
136
+
137
+ | stage | memory_allocated | max_memory_reserved |
138
+ |:-----------------------------:|:----------------:|:-------------------:|
139
+ | after precomputing conditions | 14.232 | 14.461 |
140
+ | after precomputing latents | 14.717 | 17.244 |
141
+ | before training start | 24.195 | 26.039 |
142
+ | after epoch 1 | 24.83 | 42.387 |
143
+ | before validation start | 24.842 | 42.387 |
144
+ | after validation end | 39.558 | 46.947 |
145
+ | after training end | 24.842 | 41.039 |
146
+
147
+ Note: requires about `47` GB of VRAM with validation. If validation is not performed, the memory usage is reduced to about `42` GB.
148
+
149
+ ### Full finetuning
150
+
151
+ Current, full finetuning is not supported for HunyuanVideo. It goes out of memory (OOM) for `49x512x768` resolutions.
152
+
153
+ ## Inference
154
+
155
+ Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
156
+
157
+ ```py
158
+ import torch
159
+ from diffusers import HunyuanVideoPipeline
160
+
161
+ import torch
162
+ from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
163
+ from diffusers.utils import export_to_video
164
+
165
+ model_id = "hunyuanvideo-community/HunyuanVideo"
166
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
167
+ model_id, subfolder="transformer", torch_dtype=torch.bfloat16
168
+ )
169
+ pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
170
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="hunyuanvideo-lora")
171
+ pipe.set_adapters(["hunyuanvideo-lora"], [0.6])
172
+ pipe.vae.enable_tiling()
173
+ pipe.to("cuda")
174
+
175
+ output = pipe(
176
+ prompt="A cat walks on the grass, realistic",
177
+ height=320,
178
+ width=512,
179
+ num_frames=61,
180
+ num_inference_steps=30,
181
+ ).frames[0]
182
+ export_to_video(output, "output.mp4", fps=15)
183
+ ```
184
+
185
+ You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
186
+
187
+ * [Hunyuan-Video in Diffusers](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video)
188
+ * [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
189
+ * [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
docs/finetrainers/documentation_models_ltx_video.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LTX-Video
2
+
3
+ ## Training
4
+
5
+ For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
+
7
+ ```bash
8
+ #!/bin/bash
9
+ export WANDB_MODE="offline"
10
+ export NCCL_P2P_DISABLE=1
11
+ export TORCH_NCCL_ENABLE_MONITORING=0
12
+ export FINETRAINERS_LOG_LEVEL=DEBUG
13
+
14
+ GPU_IDS="0,1"
15
+
16
+ DATA_ROOT="/path/to/dataset"
17
+ CAPTION_COLUMN="prompts.txt"
18
+ VIDEO_COLUMN="videos.txt"
19
+ OUTPUT_DIR="/path/to/models/ltx-video/"
20
+
21
+ ID_TOKEN="BW_STYLE"
22
+
23
+ # Model arguments
24
+ model_cmd="--model_name ltx_video \
25
+ --pretrained_model_name_or_path Lightricks/LTX-Video"
26
+
27
+ # Dataset arguments
28
+ dataset_cmd="--data_root $DATA_ROOT \
29
+ --video_column $VIDEO_COLUMN \
30
+ --caption_column $CAPTION_COLUMN \
31
+ --id_token $ID_TOKEN \
32
+ --video_resolution_buckets 49x512x768 \
33
+ --caption_dropout_p 0.05"
34
+
35
+ # Dataloader arguments
36
+ dataloader_cmd="--dataloader_num_workers 0"
37
+
38
+ # Diffusion arguments
39
+ diffusion_cmd="--flow_weighting_scheme logit_normal"
40
+
41
+ # Training arguments
42
+ training_cmd="--training_type lora \
43
+ --seed 42 \
44
+ --batch_size 1 \
45
+ --train_steps 3000 \
46
+ --rank 128 \
47
+ --lora_alpha 128 \
48
+ --target_modules to_q to_k to_v to_out.0 \
49
+ --gradient_accumulation_steps 4 \
50
+ --gradient_checkpointing \
51
+ --checkpointing_steps 500 \
52
+ --checkpointing_limit 2 \
53
+ --enable_slicing \
54
+ --enable_tiling"
55
+
56
+ # Optimizer arguments
57
+ optimizer_cmd="--optimizer adamw \
58
+ --lr 3e-5 \
59
+ --lr_scheduler constant_with_warmup \
60
+ --lr_warmup_steps 100 \
61
+ --lr_num_cycles 1 \
62
+ --beta1 0.9 \
63
+ --beta2 0.95 \
64
+ --weight_decay 1e-4 \
65
+ --epsilon 1e-8 \
66
+ --max_grad_norm 1.0"
67
+
68
+ # Miscellaneous arguments
69
+ miscellaneous_cmd="--tracker_name finetrainers-ltxv \
70
+ --output_dir $OUTPUT_DIR \
71
+ --nccl_timeout 1800 \
72
+ --report_to wandb"
73
+
74
+ cmd="accelerate launch --config_file accelerate_configs/uncompiled_2.yaml --gpu_ids $GPU_IDS train.py \
75
+ $model_cmd \
76
+ $dataset_cmd \
77
+ $dataloader_cmd \
78
+ $diffusion_cmd \
79
+ $training_cmd \
80
+ $optimizer_cmd \
81
+ $miscellaneous_cmd"
82
+
83
+ echo "Running command: $cmd"
84
+ eval $cmd
85
+ echo -ne "-------------------- Finished executing script --------------------\n\n"
86
+ ```
87
+
88
+ ## Memory Usage
89
+
90
+ ### LoRA
91
+
92
+ > [!NOTE]
93
+ >
94
+ > The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
95
+
96
+ LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:
97
+
98
+ ```
99
+ Training configuration: {
100
+ "trainable parameters": 117440512,
101
+ "total samples": 69,
102
+ "train epochs": 1,
103
+ "train steps": 10,
104
+ "batches per device": 1,
105
+ "total batches observed per epoch": 69,
106
+ "train batch size": 1,
107
+ "gradient accumulation steps": 1
108
+ }
109
+ ```
110
+
111
+ | stage | memory_allocated | max_memory_reserved |
112
+ |:-----------------------:|:----------------:|:-------------------:|
113
+ | before training start | 13.486 | 13.879 |
114
+ | before validation start | 14.146 | 17.623 |
115
+ | after validation end | 14.146 | 17.623 |
116
+ | after epoch 1 | 14.146 | 17.623 |
117
+ | after training end | 4.461 | 17.623 |
118
+
119
+ Note: requires about `18` GB of VRAM without precomputation.
120
+
121
+ LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **with precomputation**:
122
+
123
+ ```
124
+ Training configuration: {
125
+ "trainable parameters": 117440512,
126
+ "total samples": 1,
127
+ "train epochs": 10,
128
+ "train steps": 10,
129
+ "batches per device": 1,
130
+ "total batches observed per epoch": 1,
131
+ "train batch size": 1,
132
+ "gradient accumulation steps": 1
133
+ }
134
+ ```
135
+
136
+ | stage | memory_allocated | max_memory_reserved |
137
+ |:-----------------------------:|:----------------:|:-------------------:|
138
+ | after precomputing conditions | 8.88 | 8.920 |
139
+ | after precomputing latents | 9.684 | 11.613 |
140
+ | before training start | 3.809 | 10.010 |
141
+ | after epoch 1 | 4.26 | 10.916 |
142
+ | before validation start | 4.26 | 10.916 |
143
+ | after validation end | 13.924 | 17.262 |
144
+ | after training end | 4.26 | 14.314 |
145
+
146
+ Note: requires about `17.5` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to `11` GB.
147
+
148
+ ### Full Finetuning
149
+
150
+ ```
151
+ Training configuration: {
152
+ "trainable parameters": 1923385472,
153
+ "total samples": 1,
154
+ "train epochs": 10,
155
+ "train steps": 10,
156
+ "batches per device": 1,
157
+ "total batches observed per epoch": 1,
158
+ "train batch size": 1,
159
+ "gradient accumulation steps": 1
160
+ }
161
+ ```
162
+
163
+ | stage | memory_allocated | max_memory_reserved |
164
+ |:-----------------------------:|:----------------:|:-------------------:|
165
+ | after precomputing conditions | 8.89 | 8.937 |
166
+ | after precomputing latents | 9.701 | 11.615 |
167
+ | before training start | 3.583 | 4.025 |
168
+ | after epoch 1 | 10.769 | 20.357 |
169
+ | before validation start | 10.769 | 20.357 |
170
+ | after validation end | 10.769 | 28.332 |
171
+ | after training end | 10.769 | 12.904 |
172
+
173
+ ## Inference
174
+
175
+ Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
176
+
177
+ ```diff
178
+ import torch
179
+ from diffusers import LTXPipeline
180
+ from diffusers.utils import export_to_video
181
+
182
+ pipe = LTXPipeline.from_pretrained(
183
+ "Lightricks/LTX-Video", torch_dtype=torch.bfloat16
184
+ ).to("cuda")
185
+ + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="ltxv-lora")
186
+ + pipe.set_adapters(["ltxv-lora"], [0.75])
187
+
188
+ video = pipe("<my-awesome-prompt>").frames[0]
189
+ export_to_video(video, "output.mp4", fps=8)
190
+ ```
191
+
192
+ You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
193
+
194
+ * [LTX-Video in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video)
195
+ * [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
196
+ * [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
docs/finetrainers/documentation_models_optimization.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Memory optimizations
2
+
3
+ To lower memory requirements during training:
4
+
5
+ - `--precompute_conditions`: this precomputes the conditions and latents, and loads them as required during training, which saves a significant amount of time and memory.
6
+ - `--gradient_checkpointing`: this saves memory by recomputing activations during the backward pass.
7
+ - `--layerwise_upcasting_modules transformer`: naively casts the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`. This halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`)
8
+ - `--use_8bit_bnb`: this is only applicable to Adam and AdamW optimizers, and makes use of 8-bit precision to store optimizer states.
9
+ - Use a DeepSpeed config to launch training (refer to [`accelerate_configs/deepspeed.yaml`](./accelerate_configs/deepspeed.yaml) as an example).
10
+ - Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
11
+
12
+ We will continue to add more features that help to reduce memory consumption.
docs/finetrainers/documentation_models_wan.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wan
2
+
3
+ ## Training
4
+
5
+ For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
+
7
+ See [this](../../examples/training/sft/wan/crush_smol_lora/) example training script for training Wan with Pika Effects Crush.
8
+
9
+ ## Memory Usage
10
+
11
+ TODO
12
+
13
+ ## Inference
14
+
15
+ Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
16
+
17
+ ```diff
18
+ import torch
19
+ from diffusers import WanPipeline
20
+ from diffusers.utils import export_to_video
21
+
22
+ pipe = WanPipeline.from_pretrained(
23
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
24
+ ).to("cuda")
25
+ + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="wan-lora")
26
+ + pipe.set_adapters(["wan-lora"], [0.75])
27
+
28
+ video = pipe("<my-awesome-prompt>").frames[0]
29
+ export_to_video(video, "output.mp4", fps=8)
30
+ ```
31
+
32
+ You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
33
+
34
+ * [Wan in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan)
35
+ * [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
36
+ * [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
docs/finetrainers/documentation_optimizers.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Optimizers
2
+
3
+ The following optimizers are supported:
4
+ - **torch**:
5
+ - `Adam`
6
+ - `AdamW`
7
+ - **bitsandbytes**:
8
+ - `Adam`
9
+ - `AdamW`
10
+ - `Adam8Bit`
11
+ - `AdamW8Bit`
12
+
13
+ > [!NOTE]
14
+ > Not all optimizers have been tested with all models/parallel settings. They may or may not work, but this will gradually improve over time.
docs/finetrainers/documentation_parallel_processing_README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetrainers Parallel Backends
2
+
3
+ Finetrainers supports parallel training on multiple GPUs & nodes. This is done using the Pytorch DTensor backend. To run parallel training, `torchrun` is utilized.
4
+
5
+ As an experiment for comparing performance of different training backends, Finetrainers has implemented multi-backend support. These backends may or may not fully rely on Pytorch's distributed DTensor solution. Currently, only [🤗 Accelerate](https://github.com/huggingface/accelerate) is supported for backwards-compatibility reasons (as we initially started Finetrainers with only Accelerate). In the near future, there are plans for integrating with:
6
+ - [DeepSpeed](https://github.com/deepspeedai/DeepSpeed)
7
+ - [Nanotron](https://github.com/huggingface/nanotron)
8
+ - [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
9
+
10
+ > [!IMPORTANT]
11
+ > The multi-backend support is completely experimental and only serves to satisfy my curiosity of how much of a tradeoff there is between performance and ease of use. The Pytorch DTensor backend is the only one with stable support, following Accelerate.
12
+ >
13
+ > Users will not have to worry about backwards-breaking changes or dependencies if they stick to the Pytorch DTensor backend.
14
+
15
+ ## Support matrix
16
+
17
+ There are various algorithms for parallel training. Currently, we only support:
18
+ - [DDP](https://pytorch.org/docs/stable/notes/ddp.html)
19
+ - [FSDP2](https://pytorch.org/docs/stable/fsdp.html)
20
+ - [HSDP](https://pytorch.org/docs/stable/fsdp.html)
21
+ - [TP](https://pytorch.org/docs/stable/distributed.tensor.parallel.html)
22
+
23
+ ## Training
24
+
25
+ The following parameters are relevant for launching training:
26
+
27
+ - `parallel_backend`: The backend to use for parallel training. Available options are `ptd` & `accelerate`.
28
+ - `pp_degree`: The degree of pipeline parallelism. Currently unsupported.
29
+ - `dp_degree`: The degree of data parallelis/replicas. Defaults to `1`.
30
+ - `dp_shards`: The number of shards for data parallelism. Defaults to `1`.
31
+ - `cp_degree`: The degree of context parallelism. Currently unsupported.
32
+ - `tp_degree`: The degree of tensor parallelism.
33
+
34
+ For launching training with the Pytorch DTensor backend, use the following:
35
+
36
+ ```bash
37
+ # Single node - 8 GPUs available
38
+ torchrun --standalone --nodes=1 --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint="localhost:0" train.py <YOUR_OTHER_ARGS>
39
+
40
+ # Single node - 8 GPUs but only 4 available
41
+ export CUDA_VISIBLE_DEVICES=0,2,4,5
42
+ torchrun --standalone --nodes=1 --nproc_per_node=4 --rdzv_backend c10d --rdzv_endpoint="localhost:0" train.py <YOUR_OTHER_ARGS>
43
+
44
+ # Multi-node - Nx8 GPUs available
45
+ # TODO(aryan): Add slurm script
46
+ ```
47
+
48
+ For launching training with the Accelerate backend, use the following:
49
+
50
+ ```bash
51
+ # Single node - 8 GPUs available
52
+ accelerate launch --config_file accelerate_configs/uncompiled_8.yaml --gpu_ids 0,1,2,3,4,5,6,7 train.py <YOUR_OTHER_ARGS>
53
+
54
+ # Single node - 8 GPUs but only 4 available
55
+ accelerate launch --config_file accelerate_configs/uncompiled_4.yaml --gpu_ids 0,2,4,5 train.py <YOUR_OTHER_ARGS>
56
+
57
+ # Multi-node - Nx8 GPUs available
58
+ # TODO(aryan): Add slurm script
59
+ ```
docs/finetrainers/examples_training_wan__train.sh ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e -x
4
+
5
+ # export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
6
+ # export TORCHDYNAMO_VERBOSE=1
7
+ export WANDB_MODE="offline"
8
+ export NCCL_P2P_DISABLE=1
9
+ export TORCH_NCCL_ENABLE_MONITORING=0
10
+ export FINETRAINERS_LOG_LEVEL="DEBUG"
11
+
12
+ # Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences!
13
+ # BACKEND="accelerate"
14
+ BACKEND="ptd"
15
+
16
+ # In this setting, I'm using 2 GPUs on a 4-GPU node for training
17
+ NUM_GPUS=2
18
+ CUDA_VISIBLE_DEVICES="2,3"
19
+
20
+ # Check the JSON files for the expected JSON format
21
+ TRAINING_DATASET_CONFIG="examples/training/sft/wan/crush_smol_lora/training.json"
22
+ VALIDATION_DATASET_FILE="examples/training/sft/wan/crush_smol_lora/validation.json"
23
+
24
+ # Depending on how many GPUs you have available, choose your degree of parallelism and technique!
25
+ DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1"
26
+ DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1"
27
+ DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1"
28
+ FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1"
29
+ FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1"
30
+ HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1"
31
+
32
+ # Parallel arguments
33
+ parallel_cmd=(
34
+ $DDP_2
35
+ )
36
+
37
+ # Model arguments
38
+ model_cmd=(
39
+ --model_name "wan"
40
+ --pretrained_model_name_or_path "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
41
+ )
42
+
43
+ # Dataset arguments
44
+ # Here, we know that the dataset size if about ~50 videos. Since we're using 2 GPUs, we precompute
45
+ # embeddings of 25 dataset items per GPU. Also, we're using a very small dataset for finetuning, so
46
+ # we are okay with precomputing embeddings once and re-using them without having to worry about disk
47
+ # space. Currently, however, every new training run performs precomputation even if it's not required
48
+ # (which is something we've to improve [TODO(aryan)])
49
+ dataset_cmd=(
50
+ --dataset_config $TRAINING_DATASET_CONFIG
51
+ --dataset_shuffle_buffer_size 10
52
+ --precomputation_items 25
53
+ --precomputation_once
54
+ )
55
+
56
+ # Dataloader arguments
57
+ dataloader_cmd=(
58
+ --dataloader_num_workers 0
59
+ )
60
+
61
+ # Diffusion arguments
62
+ diffusion_cmd=(
63
+ --flow_weighting_scheme "logit_normal"
64
+ )
65
+
66
+ # Training arguments
67
+ # We target just the attention projections layers for LoRA training here.
68
+ # You can modify as you please and target any layer (regex is supported)
69
+ training_cmd=(
70
+ --training_type "lora"
71
+ --seed 42
72
+ --batch_size 1
73
+ --train_steps 3000
74
+ --rank 32
75
+ --lora_alpha 32
76
+ --target_modules "blocks.*(to_q|to_k|to_v|to_out.0)"
77
+ --gradient_accumulation_steps 1
78
+ --gradient_checkpointing
79
+ --checkpointing_steps 500
80
+ --checkpointing_limit 2
81
+ # --resume_from_checkpoint 3000
82
+ --enable_slicing
83
+ --enable_tiling
84
+ )
85
+
86
+ # Optimizer arguments
87
+ optimizer_cmd=(
88
+ --optimizer "adamw"
89
+ --lr 5e-5
90
+ --lr_scheduler "constant_with_warmup"
91
+ --lr_warmup_steps 1000
92
+ --lr_num_cycles 1
93
+ --beta1 0.9
94
+ --beta2 0.99
95
+ --weight_decay 1e-4
96
+ --epsilon 1e-8
97
+ --max_grad_norm 1.0
98
+ )
99
+
100
+ # Validation arguments
101
+ validation_cmd=(
102
+ --validation_dataset_file "$VALIDATION_DATASET_FILE"
103
+ --validation_steps 500
104
+ )
105
+
106
+ # Miscellaneous arguments
107
+ miscellaneous_cmd=(
108
+ --tracker_name "finetrainers-wan"
109
+ --output_dir "/raid/aryan/wan"
110
+ --init_timeout 600
111
+ --nccl_timeout 600
112
+ --report_to "wandb"
113
+ )
114
+
115
+ # Execute the training script
116
+ if [ "$BACKEND" == "accelerate" ]; then
117
+
118
+ ACCELERATE_CONFIG_FILE=""
119
+ if [ "$NUM_GPUS" == 1 ]; then
120
+ ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
121
+ elif [ "$NUM_GPUS" == 2 ]; then
122
+ ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml"
123
+ elif [ "$NUM_GPUS" == 4 ]; then
124
+ ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml"
125
+ elif [ "$NUM_GPUS" == 8 ]; then
126
+ ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml"
127
+ fi
128
+
129
+ accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \
130
+ "${parallel_cmd[@]}" \
131
+ "${model_cmd[@]}" \
132
+ "${dataset_cmd[@]}" \
133
+ "${dataloader_cmd[@]}" \
134
+ "${diffusion_cmd[@]}" \
135
+ "${training_cmd[@]}" \
136
+ "${optimizer_cmd[@]}" \
137
+ "${validation_cmd[@]}" \
138
+ "${miscellaneous_cmd[@]}"
139
+
140
+ elif [ "$BACKEND" == "ptd" ]; then
141
+
142
+ export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
143
+
144
+ torchrun \
145
+ --standalone \
146
+ --nnodes=1 \
147
+ --nproc_per_node=$NUM_GPUS \
148
+ --rdzv_backend c10d \
149
+ --rdzv_endpoint="localhost:0" \
150
+ train.py \
151
+ "${parallel_cmd[@]}" \
152
+ "${model_cmd[@]}" \
153
+ "${dataset_cmd[@]}" \
154
+ "${dataloader_cmd[@]}" \
155
+ "${diffusion_cmd[@]}" \
156
+ "${training_cmd[@]}" \
157
+ "${optimizer_cmd[@]}" \
158
+ "${validation_cmd[@]}" \
159
+ "${miscellaneous_cmd[@]}"
160
+ fi
161
+
162
+ echo -ne "-------------------- Finished executing script --------------------\n\n"
docs/finetrainers/examples_training_wan__training.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "datasets": [
3
+ {
4
+ "data_root": "finetrainers/crush-smol",
5
+ "dataset_type": "video",
6
+ "id_token": "PIKA_CRUSH",
7
+ "video_resolution_buckets": [
8
+ [49, 480, 832]
9
+ ],
10
+ "reshape_mode": "bicubic",
11
+ "remove_common_llm_caption_prefixes": true
12
+ }
13
+ ]
14
+ }
docs/finetrainers/examples_training_wan__validation.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data": [
3
+ {
4
+ "caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.",
5
+ "image_path": null,
6
+ "video_path": null,
7
+ "num_inference_steps": 50,
8
+ "height": 480,
9
+ "width": 832,
10
+ "num_frames": 49
11
+ },
12
+ {
13
+ "caption": "PIKA_CRUSH A green cube is being compressed by a hydraulic press, which flattens the object as if it were under a hydraulic press. The press is shown in action, with the cube being squeezed into a smaller shape.",
14
+ "image_path": null,
15
+ "video_path": null,
16
+ "num_inference_steps": 50,
17
+ "height": 480,
18
+ "width": 832,
19
+ "num_frames": 49
20
+ },
21
+ {
22
+ "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of colorful jelly beans, flattening them as if they were under a hydraulic press.",
23
+ "image_path": null,
24
+ "video_path": null,
25
+ "num_inference_steps": 50,
26
+ "height": 480,
27
+ "width": 832,
28
+ "num_frames": 49
29
+ },
30
+ {
31
+ "caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.",
32
+ "image_path": null,
33
+ "video_path": null,
34
+ "num_inference_steps": 50,
35
+ "height": 480,
36
+ "width": 832,
37
+ "num_frames": 49
38
+ }
39
+ ]
40
+ }
training/README.md DELETED
@@ -1,459 +0,0 @@
1
- # CogVideoX Factory 🧪
2
-
3
- [中文阅读](./README_zh.md)
4
-
5
- Fine-tune Cog family of video models for custom video generation under 24GB of GPU memory ⚡️📼
6
-
7
- <table align="center">
8
- <tr>
9
- <td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">Your browser does not support the video tag.</video></td>
10
- </tr>
11
- </table>
12
-
13
- **Update 29 Nov 2024**: We have added an experimental memory-efficient trainer for Mochi-1. Check it out [here](https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/mochi-1/)!
14
-
15
- ## Quickstart
16
-
17
- Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install diffusers from source by `pip install git+https://github.com/huggingface/diffusers`.
18
-
19
- Then download a dataset:
20
-
21
- ```bash
22
- # install `huggingface_hub`
23
- huggingface-cli download \
24
- --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset \
25
- --local-dir video-dataset-disney
26
- ```
27
-
28
- Then launch LoRA fine-tuning for text-to-video (modify the different hyperparameters, dataset root, and other configuration options as per your choice):
29
-
30
- ```bash
31
- # For LoRA finetuning of the text-to-video CogVideoX models
32
- ./train_text_to_video_lora.sh
33
-
34
- # For full finetuning of the text-to-video CogVideoX models
35
- ./train_text_to_video_sft.sh
36
-
37
- # For LoRA finetuning of the image-to-video CogVideoX models
38
- ./train_image_to_video_lora.sh
39
- ```
40
-
41
- Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
42
-
43
- ```diff
44
- import torch
45
- from diffusers import CogVideoXPipeline
46
- from diffusers.utils import export_to_video
47
-
48
- pipe = CogVideoXPipeline.from_pretrained(
49
- "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
50
- ).to("cuda")
51
- + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogvideox-lora")
52
- + pipe.set_adapters(["cogvideox-lora"], [1.0])
53
-
54
- video = pipe("<my-awesome-prompt>").frames[0]
55
- export_to_video(video, "output.mp4", fps=8)
56
- ```
57
-
58
- For Image-to-Video LoRAs trained with multiresolution videos, one must also add the following lines (see [this](https://github.com/a-r-r-o-w/cogvideox-factory/issues/26) Issue for more details):
59
-
60
- ```python
61
- from diffusers import CogVideoXImageToVideoPipeline
62
-
63
- pipe = CogVideoXImageToVideoPipeline.from_pretrained(
64
- "THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16
65
- ).to("cuda")
66
-
67
- # ...
68
-
69
- del pipe.transformer.patch_embed.pos_embedding
70
- pipe.transformer.patch_embed.use_learned_positional_embeddings = False
71
- pipe.transformer.config.use_learned_positional_embeddings = False
72
- ```
73
-
74
- You can also check if your LoRA is correctly mounted [here](tests/test_lora_inference.py).
75
-
76
- Below we provide additional sections detailing on more options explored in this repository. They all attempt to make fine-tuning for video models as accessible as possible by reducing memory requirements as much as possible.
77
-
78
- ## Prepare Dataset and Training
79
-
80
- Before starting the training, please check whether the dataset has been prepared according to the [dataset specifications](assets/dataset.md). We provide training scripts suitable for text-to-video and image-to-video generation, compatible with the [CogVideoX model family](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). Training can be started using the `train*.sh` scripts, depending on the task you want to train. Let's take LoRA fine-tuning for text-to-video as an example.
81
-
82
- - Configure environment variables as per your choice:
83
-
84
- ```bash
85
- export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
86
- export TORCHDYNAMO_VERBOSE=1
87
- export WANDB_MODE="offline"
88
- export NCCL_P2P_DISABLE=1
89
- export TORCH_NCCL_ENABLE_MONITORING=0
90
- ```
91
-
92
- - Configure which GPUs to use for training: `GPU_IDS="0,1"`
93
-
94
- - Choose hyperparameters for training. Let's try to do a sweep on learning rate and optimizer type as an example:
95
-
96
- ```bash
97
- LEARNING_RATES=("1e-4" "1e-3")
98
- LR_SCHEDULES=("cosine_with_restarts")
99
- OPTIMIZERS=("adamw" "adam")
100
- MAX_TRAIN_STEPS=("3000")
101
- ```
102
-
103
- - Select which Accelerate configuration you would like to train with: `ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`. We provide some default configurations in the `accelerate_configs/` directory - single GPU uncompiled/compiled, 2x GPU DDP, DeepSpeed, etc. You can create your own config files with custom settings using `accelerate config --config_file my_config.yaml`.
104
-
105
- - Specify the absolute paths and columns/files for captions and videos.
106
-
107
- ```bash
108
- DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
109
- CAPTION_COLUMN="prompt.txt"
110
- VIDEO_COLUMN="videos.txt"
111
- ```
112
-
113
- - Launch experiments sweeping different hyperparameters:
114
- ```
115
- for learning_rate in "${LEARNING_RATES[@]}"; do
116
- for lr_schedule in "${LR_SCHEDULES[@]}"; do
117
- for optimizer in "${OPTIMIZERS[@]}"; do
118
- for steps in "${MAX_TRAIN_STEPS[@]}"; do
119
- output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
120
-
121
- cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_lora.py \
122
- --pretrained_model_name_or_path THUDM/CogVideoX-5b \
123
- --data_root $DATA_ROOT \
124
- --caption_column $CAPTION_COLUMN \
125
- --video_column $VIDEO_COLUMN \
126
- --id_token BW_STYLE \
127
- --height_buckets 480 \
128
- --width_buckets 720 \
129
- --frame_buckets 49 \
130
- --dataloader_num_workers 8 \
131
- --pin_memory \
132
- --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
133
- --validation_prompt_separator ::: \
134
- --num_validation_videos 1 \
135
- --validation_epochs 10 \
136
- --seed 42 \
137
- --rank 128 \
138
- --lora_alpha 128 \
139
- --mixed_precision bf16 \
140
- --output_dir $output_dir \
141
- --max_num_frames 49 \
142
- --train_batch_size 1 \
143
- --max_train_steps $steps \
144
- --checkpointing_steps 1000 \
145
- --gradient_accumulation_steps 1 \
146
- --gradient_checkpointing \
147
- --learning_rate $learning_rate \
148
- --lr_scheduler $lr_schedule \
149
- --lr_warmup_steps 400 \
150
- --lr_num_cycles 1 \
151
- --enable_slicing \
152
- --enable_tiling \
153
- --optimizer $optimizer \
154
- --beta1 0.9 \
155
- --beta2 0.95 \
156
- --weight_decay 0.001 \
157
- --max_grad_norm 1.0 \
158
- --allow_tf32 \
159
- --report_to wandb \
160
- --nccl_timeout 1800"
161
-
162
- echo "Running command: $cmd"
163
- eval $cmd
164
- echo -ne "-------------------- Finished executing script --------------------\n\n"
165
- done
166
- done
167
- done
168
- done
169
- ```
170
-
171
- To understand what the different parameters mean, you could either take a look at the [args](./training/args.py) file or run the training script with `--help`.
172
-
173
- Note: Training scripts are untested on MPS, so performance and memory requirements can differ widely compared to the CUDA reports below.
174
-
175
- ## Memory requirements
176
-
177
- <table align="center">
178
- <tr>
179
- <td align="center" colspan="2"><b>CogVideoX LoRA Finetuning</b></td>
180
- </tr>
181
- <tr>
182
- <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
183
- <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
184
- </tr>
185
- <tr>
186
- <td align="center"><img src="../assets/lora_2b.png" /></td>
187
- <td align="center"><img src="../assets/lora_5b.png" /></td>
188
- </tr>
189
-
190
- <tr>
191
- <td align="center" colspan="2"><b>CogVideoX Full Finetuning</b></td>
192
- </tr>
193
- <tr>
194
- <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
195
- <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
196
- </tr>
197
- <tr>
198
- <td align="center"><img src="../assets/sft_2b.png" /></td>
199
- <td align="center"><img src="../assets/sft_5b.png" /></td>
200
- </tr>
201
- </table>
202
-
203
- Supported and verified memory optimizations for training include:
204
-
205
- - `CPUOffloadOptimizer` from [`torchao`](https://github.com/pytorch/ao). You can read about its capabilities and limitations [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload). In short, it allows you to use the CPU for storing trainable parameters and gradients. This results in the optimizer step happening on the CPU, which requires a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` or applying `torch.compile` on the optimizer step. Additionally, it is recommended not to `torch.compile` your model for training. Gradient clipping and accumulation is not supported yet either.
206
- - Low-bit optimizers from [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers). TODO: to test and make [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) ones work
207
- - DeepSpeed Zero2: Since we rely on `accelerate`, follow [this guide](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) to configure your `accelerate` installation to enable training with DeepSpeed Zero2 optimizations.
208
-
209
- > [!IMPORTANT]
210
- > The memory requirements are reported after running the `training/prepare_dataset.py`, which converts the videos and captions to latents and embeddings. During training, we directly load the latents and embeddings, and do not require the VAE or the T5 text encoder. However, if you perform validation/testing, these must be loaded and increase the amount of required memory. Not performing validation/testing saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
211
- >
212
- > If you choose to run validation/testing, you can save some memory on lower VRAM GPUs by specifying `--enable_model_cpu_offload`.
213
-
214
- ### LoRA finetuning
215
-
216
- > [!NOTE]
217
- > The memory requirements for image-to-video lora finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly.
218
- >
219
- > Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using:
220
- > `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
221
- > or provide a URL to a valid and accessible image.
222
-
223
- <details>
224
- <summary> AdamW </summary>
225
-
226
- **Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
227
-
228
- With `train_batch_size = 1`:
229
-
230
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
231
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
232
- | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 |
233
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 |
234
- | THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 |
235
- | THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 |
236
- | THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 |
237
- | THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 |
238
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 |
239
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 |
240
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 |
241
-
242
- With `train_batch_size = 4`:
243
-
244
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
245
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
246
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 |
247
- | THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 |
248
- | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 |
249
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 |
250
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 |
251
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 |
252
-
253
- </details>
254
-
255
- <details>
256
- <summary> AdamW (8-bit bitsandbytes) </summary>
257
-
258
- **Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
259
-
260
- With `train_batch_size = 1`:
261
-
262
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
263
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
264
- | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 |
265
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 |
266
- | THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 |
267
- | THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 |
268
- | THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 |
269
- | THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 |
270
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 |
271
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 |
272
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 |
273
-
274
- With `train_batch_size = 4`:
275
-
276
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
277
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
278
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 |
279
- | THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 |
280
- | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 |
281
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 |
282
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 |
283
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 |
284
-
285
- </details>
286
-
287
- <details>
288
- <summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
289
-
290
- **Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
291
-
292
- With `train_batch_size = 1`:
293
-
294
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
295
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
296
- | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 |
297
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 |
298
- | THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 |
299
- | THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 |
300
- | THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 |
301
- | THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 |
302
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 |
303
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 |
304
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 |
305
-
306
- With `train_batch_size = 4`:
307
-
308
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
309
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
310
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 |
311
- | THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 |
312
- | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 |
313
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 |
314
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 |
315
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 |
316
-
317
- </details>
318
-
319
- <details>
320
- <summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
321
-
322
- **Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100.
323
-
324
- With `train_batch_size = 1`:
325
-
326
- | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
327
- |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
328
- | THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 |
329
- | THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 |
330
-
331
- With `train_batch_size = 4`:
332
-
333
- | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
334
- |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
335
- | THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 |
336
- | THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 |
337
-
338
- </details>
339
-
340
- ### Full finetuning
341
-
342
- > [!NOTE]
343
- > The memory requirements for image-to-video full finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly.
344
- >
345
- > Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using:
346
- > `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
347
- > or provide a URL to a valid and accessible image.
348
-
349
- > [!NOTE]
350
- > Trying to run full finetuning without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
351
-
352
- <details>
353
- <summary> AdamW </summary>
354
-
355
- With `train_batch_size = 1`:
356
-
357
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
358
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
359
- | THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 |
360
- | THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
361
-
362
- With `train_batch_size = 4`:
363
-
364
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
365
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
366
- | THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 |
367
- | THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
368
-
369
- </details>
370
-
371
- <details>
372
- <summary> AdamW (8-bit bitsandbytes) </summary>
373
-
374
- With `train_batch_size = 1`:
375
-
376
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
377
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
378
- | THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 |
379
- | THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 |
380
-
381
- With `train_batch_size = 4`:
382
-
383
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
384
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
385
- | THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 |
386
- | THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 |
387
-
388
- </details>
389
-
390
- <details>
391
- <summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
392
-
393
- With `train_batch_size = 1`:
394
-
395
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
396
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
397
- | THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 |
398
- | THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 |
399
-
400
- With `train_batch_size = 4`:
401
-
402
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
403
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
404
- | THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 |
405
- | THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 |
406
-
407
- </details>
408
-
409
- <details>
410
- <summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
411
-
412
- **Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100.
413
-
414
- With `train_batch_size = 1`:
415
-
416
- | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
417
- |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
418
- | THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 |
419
- | THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 |
420
-
421
- With `train_batch_size = 4`:
422
-
423
- | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
424
- |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
425
- | THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 |
426
- | THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 |
427
-
428
- </details>
429
-
430
- > [!NOTE]
431
- > - `memory_after_validation` is indicative of the peak memory required for training. This is because apart from the activations, parameters and gradients stored for training, you also need to load the vae and text encoder in memory and spend some memory to perform inference. In order to reduce total memory required to perform training, one can choose not to perform validation/testing as part of the training script.
432
- >
433
- > - `memory_before_validation` is the true indicator of the peak memory required for training if you choose to not perform validation/testing.
434
-
435
- <table align="center">
436
- <tr>
437
- <td align="center"><a href="https://www.youtube.com/watch?v=UvRl4ansfCg"> Slaying OOMs with PyTorch</a></td>
438
- </tr>
439
- <tr>
440
- <td align="center"><img src="assets/slaying-ooms.png" style="width: 480px; height: 480px;"></td>
441
- </tr>
442
- </table>
443
-
444
- ## TODOs
445
-
446
- - [x] Make scripts compatible with DDP
447
- - [ ] Make scripts compatible with FSDP
448
- - [x] Make scripts compatible with DeepSpeed
449
- - [ ] vLLM-powered captioning script
450
- - [x] Multi-resolution/frame support in `prepare_dataset.py`
451
- - [ ] Analyzing traces for potential speedups and removing as many syncs as possible
452
- - [x] Test scripts with memory-efficient optimizer from bitsandbytes
453
- - [x] Test scripts with CPUOffloadOptimizer, etc.
454
- - [ ] Test scripts with torchao quantization, and low bit memory optimizers (Currently errors with AdamW (8/4-bit torchao))
455
- - [ ] Test scripts with AdamW (8-bit bitsandbytes) + CPUOffloadOptimizer (with gradient offloading) (Currently errors out)
456
- - [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (work with the authors to support backward pass, and optimize for A100)
457
-
458
- > [!IMPORTANT]
459
- > Since our goal is to make the scripts as memory-friendly as possible we don't guarantee multi-GPU training.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/README_zh.md DELETED
@@ -1,455 +0,0 @@
1
- # CogVideoX Factory 🧪
2
-
3
- [Read in English](./README.md)
4
-
5
- 在 24GB GPU 内存下对 Cog 系列视频模型进行微调以实现自定义视频生成,支持多分辨率 ⚡️📼
6
-
7
- <table align="center">
8
- <tr>
9
- <td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">您的浏览器不支持视频标签。</video></td>
10
- </tr>
11
- </table>
12
-
13
- ## 快速开始
14
-
15
- 克隆此仓库并确保安装了相关依赖:`pip install -r requirements.txt`。
16
-
17
- 接着下载数据集:
18
-
19
- ```
20
- # 安装 `huggingface_hub`
21
- huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir video-dataset-disney
22
- ```
23
-
24
- 然后启动 LoRA 微调进行文本到视频的生成(根据您的选择修改不同的超参数、数据集根目录以及其他配置选项):
25
-
26
- ```
27
- # 对 CogVideoX 模型进行文本到视频的 LoRA 微调
28
- ./train_text_to_video_lora.sh
29
-
30
- # 对 CogVideoX 模型进行文本到视频的完整微调
31
- ./train_text_to_video_sft.sh
32
-
33
- # 对 CogVideoX 模型进行图像到视频的 LoRA 微调
34
- ./train_image_to_video_lora.sh
35
- ```
36
-
37
- 假设您的 LoRA 已保存并推送到 HF Hub,并命名为 `my-awesome-name/my-awesome-lora`,现在我们可以使用微调模型进行推理:
38
-
39
- ```
40
- import torch
41
- from diffusers import CogVideoXPipeline
42
- from diffusers import export_to_video
43
-
44
- pipe = CogVideoXPipeline.from_pretrained(
45
- "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
46
- ).to("cuda")
47
- + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name=["cogvideox-lora"])
48
- + pipe.set_adapters(["cogvideox-lora"], [1.0])
49
-
50
- video = pipe("<my-awesome-prompt>").frames[0]
51
- export_to_video(video, "output.mp4", fps=8)
52
- ```
53
-
54
- 你也可以在[这里](tests/test_lora_inference.py)来检查你的Lora是否正常挂载。
55
-
56
- **注意:** 对于图像到视频的微调,您必须从 [这个分支](https://github.com/huggingface/diffusers/pull/9482) 安装
57
- diffusers(该分支为 CogVideoX 的图像到视频添加了 LoRA 加载支持)直到它被合并。
58
-
59
- 以下我们提供了更多探索此仓库选项的额外部分。所有这些都旨在尽可能降低内存需求,使视频模型的微调变得更易于访问。
60
-
61
- ## 训练
62
-
63
- 在开始训练之前,请你检查是否按照[数据集规范](assets/dataset_zh.md)准备好了数据集。 我们提供了适用于文本到视频 (text-to-video) 和图像到视频 (image-to-video) 生成的训练脚本,兼容 [CogVideoX 模型家族](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce)。训练可以通过 `train*.sh` 脚本启动,具体取决于你想要训练的任务。让我们以文本到视频的 LoRA 微调为例。
64
-
65
- - 根据你的需求配置环境变量:
66
-
67
- ```
68
- export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
69
- export TORCHDYNAMO_VERBOSE=1
70
- export WANDB_MODE="offline"
71
- export NCCL_P2P_DISABLE=1
72
- export TORCH_NCCL_ENABLE_MONITORING=0
73
- ```
74
-
75
- - 配置用于训练的 GPU:`GPU_IDS="0,1"`
76
-
77
- - 选择训练的超参数。让我们以学习率和优化器类型的超参数遍历为例:
78
-
79
- ```
80
- LEARNING_RATES=("1e-4" "1e-3")
81
- LR_SCHEDULES=("cosine_with_restarts")
82
- OPTIMIZERS=("adamw" "adam")
83
- MAX_TRAIN_STEPS=("3000")
84
- ```
85
-
86
- - 选择用于训练的 Accelerate 配置文件:`ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`
87
- 。我们在 `accelerate_configs/` 目录中提供了一些默认配置 - 单 GPU 编译/未编译、2x GPU DDP、DeepSpeed
88
- 等。你也可以使用 `accelerate config --config_file my_config.yaml` 自定义配置文件。
89
-
90
- - 指定字幕和视频的绝对路径以及列/文件。
91
-
92
- ```
93
- DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
94
- CAPTION_COLUMN="prompt.txt"
95
- VIDEO_COLUMN="videos.txt"
96
- ```
97
-
98
- - 运行实验,遍历不同的超参数:
99
- ```
100
- for learning_rate in "${LEARNING_RATES[@]}"; do
101
- for lr_schedule in "${LR_SCHEDULES[@]}"; do
102
- for optimizer in "${OPTIMIZERS[@]}"; do
103
- for steps in "${MAX_TRAIN_STEPS[@]}"; do
104
- output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
105
-
106
- cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox_text_to_video_lora.py \
107
- --pretrained_model_name_or_path THUDM/CogVideoX-5b \
108
- --data_root $DATA_ROOT \
109
- --caption_column $CAPTION_COLUMN \
110
- --video_column $VIDEO_COLUMN \
111
- --id_token BW_STYLE \
112
- --height_buckets 480 \
113
- --width_buckets 720 \
114
- --frame_buckets 49 \
115
- --dataloader_num_workers 8 \
116
- --pin_memory \
117
- --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
118
- --validation_prompt_separator ::: \
119
- --num_validation_videos 1 \
120
- --validation_epochs 10 \
121
- --seed 42 \
122
- --rank 128 \
123
- --lora_alpha 128 \
124
- --mixed_precision bf16 \
125
- --output_dir $output_dir \
126
- --max_num_frames 49 \
127
- --train_batch_size 1 \
128
- --max_train_steps $steps \
129
- --checkpointing_steps 1000 \
130
- --gradient_accumulation_steps 1 \
131
- --gradient_checkpointing \
132
- --learning_rate $learning_rate \
133
- --lr_scheduler $lr_schedule \
134
- --lr_warmup_steps 400 \
135
- --lr_num_cycles 1 \
136
- --enable_slicing \
137
- --enable_tiling \
138
- --optimizer $optimizer \
139
- --beta1 0.9 \
140
- --beta2 0.95 \
141
- --weight_decay 0.001 \
142
- --max_grad_norm 1.0 \
143
- --allow_tf32 \
144
- --report_to wandb \
145
- --nccl_timeout 1800"
146
-
147
- echo "Running command: $cmd"
148
- eval $cmd
149
- echo -ne "-------------------- Finished executing script --------------------\n\n"
150
- done
151
- done
152
- done
153
- done
154
- ```
155
-
156
- 要了解不同参数的含义,你可以查看 [args](./training/args.py) 文件,或者使用 `--help` 运行训练脚本。
157
-
158
- 注意:训练脚本尚未在 MPS 上测试,因此性能和内存要求可能与下面的 CUDA 报告差异很大。
159
-
160
- ## 内存需求
161
-
162
- <table align="center">
163
- <tr>
164
- <td align="center" colspan="2"><b>CogVideoX LoRA 微调</b></td>
165
- </tr>
166
- <tr>
167
- <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
168
- <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
169
- </tr>
170
- <tr>
171
- <td align="center"><img src="assets/lora_2b.png" /></td>
172
- <td align="center"><img src="assets/lora_5b.png" /></td>
173
- </tr>
174
-
175
- <tr>
176
- <td align="center" colspan="2"><b>CogVideoX 全量微调</b></td>
177
- </tr>
178
- <tr>
179
- <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
180
- <td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
181
- </tr>
182
- <tr>
183
- <td align="center"><img src="assets/sft_2b.png" /></td>
184
- <td align="center"><img src="assets/sft_5b.png" /></td>
185
- </tr>
186
- </table>
187
-
188
- 支持和验证的训练内存优化包括:
189
-
190
- - `CPUOffloadOptimizer` 来自 [`torchao`](https://github.com/pytorch/ao)
191
- 。你可以在[这里](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload)
192
- 阅读它的能力和局限性。简而言之,它允许你将可训练参数和梯度存储在 CPU 中,从而在 CPU 上进行优化步骤。这需要快速的 CPU
193
- 优化器,如 `torch.optim.AdamW(fused=True)`,或者在优化步骤中应用 `torch.compile`
194
- 。此外,建议不要在训练时对模型应用 `torch.compile`。梯度裁剪和累积目前还不支持。
195
- - 来自 [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers)
196
- 的低位优化器。TODO:测试并使 [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) 能正常工作。
197
- - DeepSpeed Zero2:由于我们依赖 `accelerate`
198
- ,请按照[此指南](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) 配置 `accelerate` 以启用 DeepSpeed
199
- Zero2 优化训练。
200
-
201
- > [!重要提示]
202
- > 内存需求是运行 `training/prepare_dataset.py`
203
- >
204
- 后报告的,该脚本将视频和字幕转换为潜在向量和嵌入。在训练期间,我们直接加载这些潜在向量和嵌入,不需要VAE或T5文本编码器。然而,如果执行验证/测试,则必须加载这些模块,并且会增加所需内存的数量。不进行验证/测试可以节省大量内存,这些内存可以用于较小显存的GPU上专注于训练。
205
- >
206
- > 如果选择运行验证/测试,可以通过指定 `--enable_model_cpu_offload` 来为较低显存的GPU节省一些内存。
207
-
208
- ### LoRA微调
209
-
210
- > [!重要提示]
211
- > 图像到视频的LoRA微调的内存需求与文本到视频上的 `THUDM/CogVideoX-5b` 类似,因此没有明确报告。
212
- >
213
- > 此外,为了准备I2V微调的测试图像,可以通过修改脚本实时生成它们,或使用以下命令从训练数据中提取一些帧:
214
- > `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
215
- > 或提供一个有效且可访问的图像URL。
216
-
217
- <details>
218
- <summary> AdamW </summary>
219
-
220
- **注意:** 尝试在没有梯度检查点的情况下运行 CogVideoX-5b 即使在 A100(80 GB)上也会导致 OOM(内存不足)错误,因此内存需求尚未列出。
221
-
222
- 当 `train_batch_size = 1` 时:
223
-
224
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
225
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
226
- | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 |
227
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 |
228
- | THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 |
229
- | THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 |
230
- | THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 |
231
- | THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 |
232
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 |
233
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 |
234
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 |
235
-
236
- 当 `train_batch_size = 4` 时:
237
-
238
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
239
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
240
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 |
241
- | THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 |
242
- | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 |
243
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 |
244
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 |
245
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 |
246
-
247
- </details>
248
-
249
- <details>
250
- <summary> AdamW (8-bit bitsandbytes) </summary>
251
-
252
- **注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。
253
-
254
- 当 `train_batch_size = 1` 时:
255
-
256
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
257
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
258
- | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 |
259
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 |
260
- | THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 |
261
- | THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 |
262
- | THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 |
263
- | THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 |
264
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 |
265
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 |
266
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 |
267
-
268
- 当 `train_batch_size = 4` 时:
269
-
270
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
271
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
272
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 |
273
- | THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 |
274
- | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 |
275
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 |
276
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 |
277
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 |
278
-
279
- </details>
280
-
281
- <details>
282
- <summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
283
-
284
- **注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。
285
-
286
- 当 `train_batch_size = 1` 时:
287
-
288
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
289
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
290
- | THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 |
291
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 |
292
- | THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 |
293
- | THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 |
294
- | THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 |
295
- | THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 |
296
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 |
297
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 |
298
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 |
299
-
300
- 当 `train_batch_size = 4` 时:
301
-
302
- | model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
303
- |:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
304
- | THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 |
305
- | THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 |
306
- | THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 |
307
- | THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 |
308
- | THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 |
309
- | THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 |
310
-
311
- </details>
312
-
313
- <details>
314
- <summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
315
-
316
- **注意:** 结果是在启用梯度检查点的情况下,使用 2x A100 运行时记录的。
317
-
318
- 当 `train_batch_size = 1` 时:
319
-
320
- | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
321
- |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
322
- | THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 |
323
- | THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 |
324
-
325
- 当 `train_batch_size = 4` 时:
326
-
327
- | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
328
- |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
329
- | THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 |
330
- | THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 |
331
-
332
- </details>
333
-
334
- ### Full finetuning
335
-
336
- > [!注意]
337
- > 图像到视频的完整微调内存需求与 `THUDM/CogVideoX-5b` 的文本到视频微调相似,因此没有单独列出。
338
- >
339
- > 此外,要准备用于 I2V 微调的测试图像,你可以通过修改脚本实时生成图像,或者从你的训练数据中提取一些帧:
340
- > `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
341
- > 或提供一个有效且可访问的图像 URL。
342
-
343
- > [!注意]
344
- > 在没有使用梯度检查点的情况下运行完整微调,即使是在 A100(80GB)上,也会出现 OOM(内存不足)错误,因此未列出内存需求。
345
-
346
- <details>
347
- <summary> AdamW </summary>
348
-
349
- 当 `train_batch_size = 1` 时:
350
-
351
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
352
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
353
- | THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 |
354
- | THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
355
-
356
- 当 `train_batch_size = 4` 时:
357
-
358
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
359
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
360
- | THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 |
361
- | THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
362
-
363
- </details>
364
-
365
- <details>
366
- <summary> AdamW (8-bit 量化) </summary>
367
-
368
- 当 `train_batch_size = 1` 时:
369
-
370
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
371
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
372
- | THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 |
373
- | THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 |
374
-
375
- 当 `train_batch_size = 4` 时:
376
-
377
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
378
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
379
- | THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 |
380
- | THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 |
381
-
382
- </details>
383
-
384
- <details>
385
- <summary> AdamW + CPUOffloadOptimizer(带有梯度卸载)</summary>
386
-
387
- 当 `train_batch_size = 1` 时:
388
-
389
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
390
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
391
- | THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 |
392
- | THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 |
393
-
394
- 当 `train_batch_size = 4` 时:
395
-
396
- | model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
397
- |:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
398
- | THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 |
399
- | THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 |
400
-
401
- </details>
402
-
403
- <details>
404
- <summary> DeepSpeed(AdamW + CPU/参数卸载) </summary>
405
-
406
- **注意:** 结果是在启用 `gradient_checkpointing`(梯度检查点)功能,并在 2 台 A100 显卡上运行时报告的。
407
-
408
- 当 `train_batch_size = 1` 时:
409
-
410
- | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
411
- |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
412
- | THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 |
413
- | THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 |
414
-
415
- 当 `train_batch_size = 4` 时:
416
-
417
- | model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
418
- |:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
419
- | THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 |
420
- | THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 |
421
-
422
- </details>
423
-
424
- > [!注意]
425
- > - `memory_after_validation`(验证后内存) 表示训练所需的峰值内存。这是因为除了存储训练过程中需要的激活、参数和梯度之外,还需要加载
426
- VAE 和文本编码器到内存中,并且执行推理操作也会消耗一定内存。为了减少训练所需的总内存,您可以选择在训练脚本中不执行验证/测试。
427
- >
428
- > - 如果选择不进行验证/测试,`memory_before_validation`(验证前内存) 才是训练所需内存的真实指示器。
429
-
430
- <table align="center">
431
- <tr>
432
- <td align="center"><a href="https://www.youtube.com/watch?v=UvRl4ansfCg"> Slaying OOMs with PyTorch</a></td>
433
- </tr>
434
- <tr>
435
- <td align="center"><img src="assets/slaying-ooms.png" style="width: 480px; height: 480px;"></td>
436
- </tr>
437
- </table>
438
-
439
- ## 待办事项
440
-
441
- - [x] 使脚本兼容 DDP
442
- - [ ] 使脚本兼容 FSDP
443
- - [x] 使脚本兼容 DeepSpeed
444
- - [ ] 基于 vLLM 的字幕脚本
445
- - [x] 在 `prepare_dataset.py` 中支持多分辨率/帧数
446
- - [ ] 分析性能瓶颈并尽可能减少同步操作
447
- - [ ] 支持 QLoRA(优先),以及其他高使用率的 LoRA 方法
448
- - [x] 使用 bitsandbytes 的节省内存优化器测试脚本
449
- - [x] 使用 CPUOffloadOptimizer 等测试脚本
450
- - [ ] 使用 torchao 量化和低位内存优化器测试脚本(目前在 AdamW(8/4-bit torchao)上报错)
451
- - [ ] 使用 AdamW(8-bit bitsandbytes)+ CPUOffloadOptimizer(带有梯度卸载)的测试脚本(目前报错)
452
- - [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (与作者合作支持反向传播,并针对 A100 进行优化)
453
-
454
- > [!重要]
455
- > 由于我们的目标是使脚本尽可能节省内存,因此我们不保证支持多 GPU 训练。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/cogvideox/__init__.py DELETED
File without changes
training/cogvideox/args.py DELETED
@@ -1,484 +0,0 @@
1
- import argparse
2
-
3
-
4
- def _get_model_args(parser: argparse.ArgumentParser) -> None:
5
- parser.add_argument(
6
- "--pretrained_model_name_or_path",
7
- type=str,
8
- default=None,
9
- required=True,
10
- help="Path to pretrained model or model identifier from huggingface.co/models.",
11
- )
12
- parser.add_argument(
13
- "--revision",
14
- type=str,
15
- default=None,
16
- required=False,
17
- help="Revision of pretrained model identifier from huggingface.co/models.",
18
- )
19
- parser.add_argument(
20
- "--variant",
21
- type=str,
22
- default=None,
23
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
24
- )
25
- parser.add_argument(
26
- "--cache_dir",
27
- type=str,
28
- default=None,
29
- help="The directory where the downloaded models and datasets will be stored.",
30
- )
31
-
32
-
33
- def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
34
- parser.add_argument(
35
- "--data_root",
36
- type=str,
37
- default=None,
38
- help=("A folder containing the training data."),
39
- )
40
- parser.add_argument(
41
- "--dataset_file",
42
- type=str,
43
- default=None,
44
- help=("Path to a CSV file if loading prompts/video paths using this format."),
45
- )
46
- parser.add_argument(
47
- "--video_column",
48
- type=str,
49
- default="video",
50
- help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.",
51
- )
52
- parser.add_argument(
53
- "--caption_column",
54
- type=str,
55
- default="text",
56
- help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
57
- )
58
- parser.add_argument(
59
- "--id_token",
60
- type=str,
61
- default=None,
62
- help="Identifier token appended to the start of each prompt if provided.",
63
- )
64
- parser.add_argument(
65
- "--height_buckets",
66
- nargs="+",
67
- type=int,
68
- default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
69
- )
70
- parser.add_argument(
71
- "--width_buckets",
72
- nargs="+",
73
- type=int,
74
- default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
75
- )
76
- parser.add_argument(
77
- "--frame_buckets",
78
- nargs="+",
79
- type=int,
80
- default=[49],
81
- help="CogVideoX1.5 need to guarantee that ((num_frames - 1) // self.vae_scale_factor_temporal + 1) % patch_size_t == 0, such as 53"
82
- )
83
- parser.add_argument(
84
- "--load_tensors",
85
- action="store_true",
86
- help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.",
87
- )
88
- parser.add_argument(
89
- "--random_flip",
90
- type=float,
91
- default=None,
92
- help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
93
- )
94
- parser.add_argument(
95
- "--dataloader_num_workers",
96
- type=int,
97
- default=0,
98
- help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
99
- )
100
- parser.add_argument(
101
- "--pin_memory",
102
- action="store_true",
103
- help="Whether or not to use the pinned memory setting in pytorch dataloader.",
104
- )
105
-
106
-
107
- def _get_validation_args(parser: argparse.ArgumentParser) -> None:
108
- parser.add_argument(
109
- "--validation_prompt",
110
- type=str,
111
- default=None,
112
- help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
113
- )
114
- parser.add_argument(
115
- "--validation_images",
116
- type=str,
117
- default=None,
118
- help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
119
- )
120
- parser.add_argument(
121
- "--validation_prompt_separator",
122
- type=str,
123
- default=":::",
124
- help="String that separates multiple validation prompts",
125
- )
126
- parser.add_argument(
127
- "--num_validation_videos",
128
- type=int,
129
- default=1,
130
- help="Number of videos that should be generated during validation per `validation_prompt`.",
131
- )
132
- parser.add_argument(
133
- "--validation_epochs",
134
- type=int,
135
- default=None,
136
- help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.",
137
- )
138
- parser.add_argument(
139
- "--validation_steps",
140
- type=int,
141
- default=None,
142
- help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
143
- )
144
- parser.add_argument(
145
- "--guidance_scale",
146
- type=float,
147
- default=6,
148
- help="The guidance scale to use while sampling validation videos.",
149
- )
150
- parser.add_argument(
151
- "--use_dynamic_cfg",
152
- action="store_true",
153
- default=False,
154
- help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
155
- )
156
- parser.add_argument(
157
- "--enable_model_cpu_offload",
158
- action="store_true",
159
- default=False,
160
- help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
161
- )
162
-
163
-
164
- def _get_training_args(parser: argparse.ArgumentParser) -> None:
165
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
166
- parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.")
167
- parser.add_argument(
168
- "--lora_alpha",
169
- type=int,
170
- default=64,
171
- help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
172
- )
173
- parser.add_argument(
174
- "--mixed_precision",
175
- type=str,
176
- default=None,
177
- choices=["no", "fp16", "bf16"],
178
- help=(
179
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. "
180
- "Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this "
181
- "argument to override the accelerate config."
182
- ),
183
- )
184
- parser.add_argument(
185
- "--output_dir",
186
- type=str,
187
- default="cogvideox-sft",
188
- help="The output directory where the model predictions and checkpoints will be written.",
189
- )
190
- parser.add_argument(
191
- "--height",
192
- type=int,
193
- default=480,
194
- help="All input videos are resized to this height.",
195
- )
196
- parser.add_argument(
197
- "--width",
198
- type=int,
199
- default=720,
200
- help="All input videos are resized to this width.",
201
- )
202
- parser.add_argument(
203
- "--video_reshape_mode",
204
- type=str,
205
- default=None,
206
- help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
207
- )
208
- parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
209
- parser.add_argument(
210
- "--max_num_frames",
211
- type=int,
212
- default=49,
213
- help="All input videos will be truncated to these many frames.",
214
- )
215
- parser.add_argument(
216
- "--skip_frames_start",
217
- type=int,
218
- default=0,
219
- help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
220
- )
221
- parser.add_argument(
222
- "--skip_frames_end",
223
- type=int,
224
- default=0,
225
- help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
226
- )
227
- parser.add_argument(
228
- "--train_batch_size",
229
- type=int,
230
- default=4,
231
- help="Batch size (per device) for the training dataloader.",
232
- )
233
- parser.add_argument("--num_train_epochs", type=int, default=1)
234
- parser.add_argument(
235
- "--max_train_steps",
236
- type=int,
237
- default=None,
238
- help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
239
- )
240
- parser.add_argument(
241
- "--checkpointing_steps",
242
- type=int,
243
- default=500,
244
- help=(
245
- "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
246
- " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
247
- " training using `--resume_from_checkpoint`."
248
- ),
249
- )
250
- parser.add_argument(
251
- "--checkpoints_total_limit",
252
- type=int,
253
- default=None,
254
- help=("Max number of checkpoints to store."),
255
- )
256
- parser.add_argument(
257
- "--resume_from_checkpoint",
258
- type=str,
259
- default=None,
260
- help=(
261
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
262
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
263
- ),
264
- )
265
- parser.add_argument(
266
- "--gradient_accumulation_steps",
267
- type=int,
268
- default=1,
269
- help="Number of updates steps to accumulate before performing a backward/update pass.",
270
- )
271
- parser.add_argument(
272
- "--gradient_checkpointing",
273
- action="store_true",
274
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
275
- )
276
- parser.add_argument(
277
- "--learning_rate",
278
- type=float,
279
- default=1e-4,
280
- help="Initial learning rate (after the potential warmup period) to use.",
281
- )
282
- parser.add_argument(
283
- "--scale_lr",
284
- action="store_true",
285
- default=False,
286
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
287
- )
288
- parser.add_argument(
289
- "--lr_scheduler",
290
- type=str,
291
- default="constant",
292
- help=(
293
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
294
- ' "constant", "constant_with_warmup"]'
295
- ),
296
- )
297
- parser.add_argument(
298
- "--lr_warmup_steps",
299
- type=int,
300
- default=500,
301
- help="Number of steps for the warmup in the lr scheduler.",
302
- )
303
- parser.add_argument(
304
- "--lr_num_cycles",
305
- type=int,
306
- default=1,
307
- help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
308
- )
309
- parser.add_argument(
310
- "--lr_power",
311
- type=float,
312
- default=1.0,
313
- help="Power factor of the polynomial scheduler.",
314
- )
315
- parser.add_argument(
316
- "--enable_slicing",
317
- action="store_true",
318
- default=False,
319
- help="Whether or not to use VAE slicing for saving memory.",
320
- )
321
- parser.add_argument(
322
- "--enable_tiling",
323
- action="store_true",
324
- default=False,
325
- help="Whether or not to use VAE tiling for saving memory.",
326
- )
327
- parser.add_argument(
328
- "--noised_image_dropout",
329
- type=float,
330
- default=0.05,
331
- help="Image condition dropout probability when finetuning image-to-video.",
332
- )
333
- parser.add_argument(
334
- "--ignore_learned_positional_embeddings",
335
- action="store_true",
336
- default=False,
337
- help=(
338
- "Whether to ignore the learned positional embeddings when training CogVideoX Image-to-Video. This setting "
339
- "should be used when performing multi-resolution training, because CogVideoX-I2V does not support it "
340
- "otherwise. Please read the comments in https://github.com/a-r-r-o-w/cogvideox-factory/issues/26 to understand why."
341
- ),
342
- )
343
-
344
-
345
- def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
346
- parser.add_argument(
347
- "--optimizer",
348
- type=lambda s: s.lower(),
349
- default="adam",
350
- choices=["adam", "adamw", "prodigy", "came"],
351
- help=("The optimizer type to use."),
352
- )
353
- parser.add_argument(
354
- "--use_8bit",
355
- action="store_true",
356
- help="Whether or not to use 8-bit optimizers from `bitsandbytes` or `bitsandbytes`.",
357
- )
358
- parser.add_argument(
359
- "--use_4bit",
360
- action="store_true",
361
- help="Whether or not to use 4-bit optimizers from `torchao`.",
362
- )
363
- parser.add_argument(
364
- "--use_torchao", action="store_true", help="Whether or not to use the `torchao` backend for optimizers."
365
- )
366
- parser.add_argument(
367
- "--beta1",
368
- type=float,
369
- default=0.9,
370
- help="The beta1 parameter for the Adam and Prodigy optimizers.",
371
- )
372
- parser.add_argument(
373
- "--beta2",
374
- type=float,
375
- default=0.95,
376
- help="The beta2 parameter for the Adam and Prodigy optimizers.",
377
- )
378
- parser.add_argument(
379
- "--beta3",
380
- type=float,
381
- default=None,
382
- help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
383
- )
384
- parser.add_argument(
385
- "--prodigy_decouple",
386
- action="store_true",
387
- help="Use AdamW style decoupled weight decay.",
388
- )
389
- parser.add_argument(
390
- "--weight_decay",
391
- type=float,
392
- default=1e-04,
393
- help="Weight decay to use for optimizer.",
394
- )
395
- parser.add_argument(
396
- "--epsilon",
397
- type=float,
398
- default=1e-8,
399
- help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
400
- )
401
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
402
- parser.add_argument(
403
- "--prodigy_use_bias_correction",
404
- action="store_true",
405
- help="Turn on Adam's bias correction.",
406
- )
407
- parser.add_argument(
408
- "--prodigy_safeguard_warmup",
409
- action="store_true",
410
- help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
411
- )
412
- parser.add_argument(
413
- "--use_cpu_offload_optimizer",
414
- action="store_true",
415
- help="Whether or not to use the CPUOffloadOptimizer from TorchAO to perform optimization step and maintain parameters on the CPU.",
416
- )
417
- parser.add_argument(
418
- "--offload_gradients",
419
- action="store_true",
420
- help="Whether or not to offload the gradients to CPU when using the CPUOffloadOptimizer from TorchAO.",
421
- )
422
-
423
-
424
- def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
425
- parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
426
- parser.add_argument(
427
- "--push_to_hub",
428
- action="store_true",
429
- help="Whether or not to push the model to the Hub.",
430
- )
431
- parser.add_argument(
432
- "--hub_token",
433
- type=str,
434
- default=None,
435
- help="The token to use to push to the Model Hub.",
436
- )
437
- parser.add_argument(
438
- "--hub_model_id",
439
- type=str,
440
- default=None,
441
- help="The name of the repository to keep in sync with the local `output_dir`.",
442
- )
443
- parser.add_argument(
444
- "--logging_dir",
445
- type=str,
446
- default="logs",
447
- help="Directory where logs are stored.",
448
- )
449
- parser.add_argument(
450
- "--allow_tf32",
451
- action="store_true",
452
- help=(
453
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
454
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
455
- ),
456
- )
457
- parser.add_argument(
458
- "--nccl_timeout",
459
- type=int,
460
- default=600,
461
- help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.",
462
- )
463
- parser.add_argument(
464
- "--report_to",
465
- type=str,
466
- default=None,
467
- help=(
468
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
469
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
470
- ),
471
- )
472
-
473
-
474
- def get_args():
475
- parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
476
-
477
- _get_model_args(parser)
478
- _get_dataset_args(parser)
479
- _get_training_args(parser)
480
- _get_validation_args(parser)
481
- _get_optimizer_args(parser)
482
- _get_configuration_args(parser)
483
-
484
- return parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/cogvideox/cogvideox_image_to_video_lora.py DELETED
@@ -1,1016 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import gc
17
- import logging
18
- import math
19
- import os
20
- import random
21
- import shutil
22
- from datetime import timedelta
23
- from pathlib import Path
24
- from typing import Any, Dict
25
-
26
- import diffusers
27
- import torch
28
- import transformers
29
- import wandb
30
- from accelerate import Accelerator, DistributedType
31
- from accelerate.logging import get_logger
32
- from accelerate.utils import (
33
- DistributedDataParallelKwargs,
34
- InitProcessGroupKwargs,
35
- ProjectConfiguration,
36
- set_seed,
37
- )
38
- from diffusers import (
39
- AutoencoderKLCogVideoX,
40
- CogVideoXDPMScheduler,
41
- CogVideoXImageToVideoPipeline,
42
- CogVideoXTransformer3DModel,
43
- )
44
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
45
- from diffusers.optimization import get_scheduler
46
- from diffusers.training_utils import cast_training_params
47
- from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
48
- from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
49
- from huggingface_hub import create_repo, upload_folder
50
- from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
51
- from torch.utils.data import DataLoader
52
- from tqdm.auto import tqdm
53
- from transformers import AutoTokenizer, T5EncoderModel
54
-
55
-
56
- from args import get_args # isort:skip
57
- from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
58
- from text_encoder import compute_prompt_embeddings # isort:skip
59
- from utils import (
60
- get_gradient_norm,
61
- get_optimizer,
62
- prepare_rotary_positional_embeddings,
63
- print_memory,
64
- reset_memory,
65
- unwrap_model,
66
- )
67
-
68
-
69
- logger = get_logger(__name__)
70
-
71
-
72
- def save_model_card(
73
- repo_id: str,
74
- videos=None,
75
- base_model: str = None,
76
- validation_prompt=None,
77
- repo_folder=None,
78
- fps=8,
79
- ):
80
- widget_dict = []
81
- if videos is not None:
82
- for i, video in enumerate(videos):
83
- export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
84
- widget_dict.append(
85
- {
86
- "text": validation_prompt if validation_prompt else " ",
87
- "output": {"url": f"video_{i}.mp4"},
88
- }
89
- )
90
-
91
- model_description = f"""
92
- # CogVideoX LoRA Finetune
93
-
94
- <Gallery />
95
-
96
- ## Model description
97
-
98
- This is a lora finetune of the CogVideoX model `{base_model}`.
99
-
100
- The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
101
-
102
- ## Download model
103
-
104
- [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
105
-
106
- ## Usage
107
-
108
- Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
109
-
110
- ```py
111
- import torch
112
- from diffusers import CogVideoXImageToVideoPipeline
113
- from diffusers.utils import export_to_video, load_image
114
-
115
- pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16).to("cuda")
116
- pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
117
-
118
- # The LoRA adapter weights are determined by what was used for training.
119
- # In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
120
- # It can be made lower or higher from what was used in training to decrease or amplify the effect
121
- # of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
122
- pipe.set_adapters(["cogvideox-lora"], [32 / 64])
123
-
124
- image = load_image("/path/to/image.png")
125
- video = pipe(image=image, prompt="{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
126
- export_to_video(video, "output.mp4", fps=8)
127
- ```
128
-
129
- For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
130
-
131
- ## License
132
-
133
- Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).
134
- """
135
- model_card = load_or_create_model_card(
136
- repo_id_or_path=repo_id,
137
- from_training=True,
138
- license="other",
139
- base_model=base_model,
140
- prompt=validation_prompt,
141
- model_description=model_description,
142
- widget=widget_dict,
143
- )
144
- tags = [
145
- "text-to-video",
146
- "image-to-video",
147
- "diffusers-training",
148
- "diffusers",
149
- "lora",
150
- "cogvideox",
151
- "cogvideox-diffusers",
152
- "template:sd-lora",
153
- ]
154
-
155
- model_card = populate_model_card(model_card, tags=tags)
156
- model_card.save(os.path.join(repo_folder, "README.md"))
157
-
158
-
159
- def log_validation(
160
- accelerator: Accelerator,
161
- pipe: CogVideoXImageToVideoPipeline,
162
- args: Dict[str, Any],
163
- pipeline_args: Dict[str, Any],
164
- is_final_validation: bool = False,
165
- ):
166
- logger.info(
167
- f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
168
- )
169
-
170
- pipe = pipe.to(accelerator.device)
171
-
172
- # run inference
173
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
174
-
175
- videos = []
176
- for _ in range(args.num_validation_videos):
177
- video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
178
- videos.append(video)
179
-
180
- for tracker in accelerator.trackers:
181
- phase_name = "test" if is_final_validation else "validation"
182
- if tracker.name == "wandb":
183
- video_filenames = []
184
- for i, video in enumerate(videos):
185
- prompt = (
186
- pipeline_args["prompt"][:25]
187
- .replace(" ", "_")
188
- .replace(" ", "_")
189
- .replace("'", "_")
190
- .replace('"', "_")
191
- .replace("/", "_")
192
- )
193
- filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
194
- export_to_video(video, filename, fps=8)
195
- video_filenames.append(filename)
196
-
197
- tracker.log(
198
- {
199
- phase_name: [
200
- wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
201
- for i, filename in enumerate(video_filenames)
202
- ]
203
- }
204
- )
205
-
206
- return videos
207
-
208
-
209
- def run_validation(
210
- args: Dict[str, Any],
211
- accelerator: Accelerator,
212
- transformer,
213
- scheduler,
214
- model_config: Dict[str, Any],
215
- weight_dtype: torch.dtype,
216
- ) -> None:
217
- accelerator.print("===== Memory before validation =====")
218
- print_memory(accelerator.device)
219
- torch.cuda.synchronize(accelerator.device)
220
-
221
- pipe = CogVideoXImageToVideoPipeline.from_pretrained(
222
- args.pretrained_model_name_or_path,
223
- transformer=unwrap_model(accelerator, transformer),
224
- scheduler=scheduler,
225
- revision=args.revision,
226
- variant=args.variant,
227
- torch_dtype=weight_dtype,
228
- )
229
-
230
- if args.enable_slicing:
231
- pipe.vae.enable_slicing()
232
- if args.enable_tiling:
233
- pipe.vae.enable_tiling()
234
- if args.enable_model_cpu_offload:
235
- pipe.enable_model_cpu_offload()
236
-
237
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
238
- validation_images = args.validation_images.split(args.validation_prompt_separator)
239
- for validation_image, validation_prompt in zip(validation_images, validation_prompts):
240
- pipeline_args = {
241
- "image": load_image(validation_image),
242
- "prompt": validation_prompt,
243
- "guidance_scale": args.guidance_scale,
244
- "use_dynamic_cfg": args.use_dynamic_cfg,
245
- "height": args.height,
246
- "width": args.width,
247
- "max_sequence_length": model_config.max_text_seq_length,
248
- }
249
-
250
- log_validation(
251
- pipe=pipe,
252
- args=args,
253
- accelerator=accelerator,
254
- pipeline_args=pipeline_args,
255
- )
256
-
257
- accelerator.print("===== Memory after validation =====")
258
- print_memory(accelerator.device)
259
- reset_memory(accelerator.device)
260
-
261
- del pipe
262
- gc.collect()
263
- torch.cuda.empty_cache()
264
- torch.cuda.synchronize(accelerator.device)
265
-
266
-
267
- class CollateFunction:
268
- def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
269
- self.weight_dtype = weight_dtype
270
- self.load_tensors = load_tensors
271
-
272
- def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
273
- prompts = [x["prompt"] for x in data[0]]
274
-
275
- if self.load_tensors:
276
- prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
277
-
278
- images = [x["image"] for x in data[0]]
279
- images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True)
280
-
281
- videos = [x["video"] for x in data[0]]
282
- videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
283
-
284
- return {
285
- "images": images,
286
- "videos": videos,
287
- "prompts": prompts,
288
- }
289
-
290
-
291
- def main(args):
292
- if args.report_to == "wandb" and args.hub_token is not None:
293
- raise ValueError(
294
- "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
295
- " Please use `huggingface-cli login` to authenticate with the Hub."
296
- )
297
-
298
- if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
299
- # due to pytorch#99272, MPS does not yet support bfloat16.
300
- raise ValueError(
301
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
302
- )
303
-
304
- logging_dir = Path(args.output_dir, args.logging_dir)
305
-
306
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
307
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
308
- init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
309
- accelerator = Accelerator(
310
- gradient_accumulation_steps=args.gradient_accumulation_steps,
311
- mixed_precision=args.mixed_precision,
312
- log_with=args.report_to,
313
- project_config=accelerator_project_config,
314
- kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
315
- )
316
-
317
- # Disable AMP for MPS.
318
- if torch.backends.mps.is_available():
319
- accelerator.native_amp = False
320
-
321
- # Make one log on every process with the configuration for debugging.
322
- logging.basicConfig(
323
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
324
- datefmt="%m/%d/%Y %H:%M:%S",
325
- level=logging.INFO,
326
- )
327
- logger.info(accelerator.state, main_process_only=False)
328
- if accelerator.is_local_main_process:
329
- transformers.utils.logging.set_verbosity_warning()
330
- diffusers.utils.logging.set_verbosity_info()
331
- else:
332
- transformers.utils.logging.set_verbosity_error()
333
- diffusers.utils.logging.set_verbosity_error()
334
-
335
- # If passed along, set the training seed now.
336
- if args.seed is not None:
337
- set_seed(args.seed)
338
-
339
- # Handle the repository creation
340
- if accelerator.is_main_process:
341
- if args.output_dir is not None:
342
- os.makedirs(args.output_dir, exist_ok=True)
343
-
344
- if args.push_to_hub:
345
- repo_id = create_repo(
346
- repo_id=args.hub_model_id or Path(args.output_dir).name,
347
- exist_ok=True,
348
- ).repo_id
349
-
350
- # Prepare models and scheduler
351
- tokenizer = AutoTokenizer.from_pretrained(
352
- args.pretrained_model_name_or_path,
353
- subfolder="tokenizer",
354
- revision=args.revision,
355
- )
356
-
357
- text_encoder = T5EncoderModel.from_pretrained(
358
- args.pretrained_model_name_or_path,
359
- subfolder="text_encoder",
360
- revision=args.revision,
361
- )
362
-
363
- # CogVideoX-2b weights are stored in float16
364
- # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
365
- load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
366
- transformer = CogVideoXTransformer3DModel.from_pretrained(
367
- args.pretrained_model_name_or_path,
368
- subfolder="transformer",
369
- torch_dtype=load_dtype,
370
- revision=args.revision,
371
- variant=args.variant,
372
- )
373
-
374
- # These changes will also be required when trying to run inference with the trained lora
375
- if args.ignore_learned_positional_embeddings:
376
- del transformer.patch_embed.pos_embedding
377
- transformer.patch_embed.use_learned_positional_embeddings = False
378
- transformer.config.use_learned_positional_embeddings = False
379
-
380
- vae = AutoencoderKLCogVideoX.from_pretrained(
381
- args.pretrained_model_name_or_path,
382
- subfolder="vae",
383
- revision=args.revision,
384
- variant=args.variant,
385
- )
386
-
387
- scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
388
-
389
- if args.enable_slicing:
390
- vae.enable_slicing()
391
- if args.enable_tiling:
392
- vae.enable_tiling()
393
-
394
- # We only train the additional adapter LoRA layers
395
- text_encoder.requires_grad_(False)
396
- transformer.requires_grad_(False)
397
- vae.requires_grad_(False)
398
-
399
- VAE_SCALING_FACTOR = vae.config.scaling_factor
400
- VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
401
- RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
402
- RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
403
-
404
- # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
405
- # as these weights are only used for inference, keeping weights in full precision is not required.
406
- weight_dtype = torch.float32
407
- if accelerator.state.deepspeed_plugin:
408
- # DeepSpeed is handling precision, use what's in the DeepSpeed config
409
- if (
410
- "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
411
- and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
412
- ):
413
- weight_dtype = torch.float16
414
- if (
415
- "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
416
- and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
417
- ):
418
- weight_dtype = torch.bfloat16
419
- else:
420
- if accelerator.mixed_precision == "fp16":
421
- weight_dtype = torch.float16
422
- elif accelerator.mixed_precision == "bf16":
423
- weight_dtype = torch.bfloat16
424
-
425
- if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
426
- # due to pytorch#99272, MPS does not yet support bfloat16.
427
- raise ValueError(
428
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
429
- )
430
-
431
- text_encoder.to(accelerator.device, dtype=weight_dtype)
432
- transformer.to(accelerator.device, dtype=weight_dtype)
433
- vae.to(accelerator.device, dtype=weight_dtype)
434
-
435
- if args.gradient_checkpointing:
436
- transformer.enable_gradient_checkpointing()
437
-
438
- # now we will add new LoRA weights to the attention layers
439
- transformer_lora_config = LoraConfig(
440
- r=args.rank,
441
- lora_alpha=args.lora_alpha,
442
- init_lora_weights=True,
443
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
444
- )
445
- transformer.add_adapter(transformer_lora_config)
446
-
447
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
448
- def save_model_hook(models, weights, output_dir):
449
- if accelerator.is_main_process:
450
- transformer_lora_layers_to_save = None
451
-
452
- for model in models:
453
- if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
454
- model = unwrap_model(accelerator, model)
455
- transformer_lora_layers_to_save = get_peft_model_state_dict(model)
456
- else:
457
- raise ValueError(f"Unexpected save model: {model.__class__}")
458
-
459
- # make sure to pop weight so that corresponding model is not saved again
460
- if weights:
461
- weights.pop()
462
-
463
- CogVideoXImageToVideoPipeline.save_lora_weights(
464
- output_dir,
465
- transformer_lora_layers=transformer_lora_layers_to_save,
466
- )
467
-
468
- def load_model_hook(models, input_dir):
469
- transformer_ = None
470
-
471
- # This is a bit of a hack but I don't know any other solution.
472
- if not accelerator.distributed_type == DistributedType.DEEPSPEED:
473
- while len(models) > 0:
474
- model = models.pop()
475
-
476
- if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
477
- transformer_ = unwrap_model(accelerator, model)
478
- else:
479
- raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
480
- else:
481
- transformer_ = CogVideoXTransformer3DModel.from_pretrained(
482
- args.pretrained_model_name_or_path, subfolder="transformer"
483
- )
484
- transformer_.add_adapter(transformer_lora_config)
485
-
486
- lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
487
-
488
- transformer_state_dict = {
489
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
490
- }
491
- transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
492
- incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
493
- if incompatible_keys is not None:
494
- # check only for unexpected keys
495
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
496
- if unexpected_keys:
497
- logger.warning(
498
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
499
- f" {unexpected_keys}. "
500
- )
501
-
502
- # Make sure the trainable params are in float32. This is again needed since the base models
503
- # are in `weight_dtype`. More details:
504
- # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
505
- if args.mixed_precision == "fp16":
506
- # only upcast trainable parameters (LoRA) into fp32
507
- cast_training_params([transformer_])
508
-
509
- accelerator.register_save_state_pre_hook(save_model_hook)
510
- accelerator.register_load_state_pre_hook(load_model_hook)
511
-
512
- # Enable TF32 for faster training on Ampere GPUs,
513
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
514
- if args.allow_tf32 and torch.cuda.is_available():
515
- torch.backends.cuda.matmul.allow_tf32 = True
516
-
517
- if args.scale_lr:
518
- args.learning_rate = (
519
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
520
- )
521
-
522
- # Make sure the trainable params are in float32.
523
- if args.mixed_precision == "fp16":
524
- # only upcast trainable parameters (LoRA) into fp32
525
- cast_training_params([transformer], dtype=torch.float32)
526
-
527
- transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
528
-
529
- # Optimization parameters
530
- transformer_parameters_with_lr = {
531
- "params": transformer_lora_parameters,
532
- "lr": args.learning_rate,
533
- }
534
- params_to_optimize = [transformer_parameters_with_lr]
535
- num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
536
-
537
- use_deepspeed_optimizer = (
538
- accelerator.state.deepspeed_plugin is not None
539
- and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
540
- )
541
- use_deepspeed_scheduler = (
542
- accelerator.state.deepspeed_plugin is not None
543
- and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
544
- )
545
-
546
- optimizer = get_optimizer(
547
- params_to_optimize=params_to_optimize,
548
- optimizer_name=args.optimizer,
549
- learning_rate=args.learning_rate,
550
- beta1=args.beta1,
551
- beta2=args.beta2,
552
- beta3=args.beta3,
553
- epsilon=args.epsilon,
554
- weight_decay=args.weight_decay,
555
- prodigy_decouple=args.prodigy_decouple,
556
- prodigy_use_bias_correction=args.prodigy_use_bias_correction,
557
- prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
558
- use_8bit=args.use_8bit,
559
- use_4bit=args.use_4bit,
560
- use_torchao=args.use_torchao,
561
- use_deepspeed=use_deepspeed_optimizer,
562
- use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
563
- offload_gradients=args.offload_gradients,
564
- )
565
-
566
- # Dataset and DataLoader
567
- dataset_init_kwargs = {
568
- "data_root": args.data_root,
569
- "dataset_file": args.dataset_file,
570
- "caption_column": args.caption_column,
571
- "video_column": args.video_column,
572
- "max_num_frames": args.max_num_frames,
573
- "id_token": args.id_token,
574
- "height_buckets": args.height_buckets,
575
- "width_buckets": args.width_buckets,
576
- "frame_buckets": args.frame_buckets,
577
- "load_tensors": args.load_tensors,
578
- "random_flip": args.random_flip,
579
- "image_to_video": True,
580
- }
581
- if args.video_reshape_mode is None:
582
- train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
583
- else:
584
- train_dataset = VideoDatasetWithResizeAndRectangleCrop(
585
- video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
586
- )
587
-
588
- collate_fn = CollateFunction(weight_dtype, args.load_tensors)
589
-
590
- train_dataloader = DataLoader(
591
- train_dataset,
592
- batch_size=1,
593
- sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
594
- collate_fn=collate_fn,
595
- num_workers=args.dataloader_num_workers,
596
- pin_memory=args.pin_memory,
597
- )
598
-
599
- # Scheduler and math around the number of training steps.
600
- overrode_max_train_steps = False
601
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
602
- if args.max_train_steps is None:
603
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
604
- overrode_max_train_steps = True
605
-
606
- if args.use_cpu_offload_optimizer:
607
- lr_scheduler = None
608
- accelerator.print(
609
- "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
610
- "you are training with those settings, they will be ignored."
611
- )
612
- else:
613
- if use_deepspeed_scheduler:
614
- from accelerate.utils import DummyScheduler
615
-
616
- lr_scheduler = DummyScheduler(
617
- name=args.lr_scheduler,
618
- optimizer=optimizer,
619
- total_num_steps=args.max_train_steps * accelerator.num_processes,
620
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
621
- )
622
- else:
623
- lr_scheduler = get_scheduler(
624
- args.lr_scheduler,
625
- optimizer=optimizer,
626
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
627
- num_training_steps=args.max_train_steps * accelerator.num_processes,
628
- num_cycles=args.lr_num_cycles,
629
- power=args.lr_power,
630
- )
631
-
632
- # Prepare everything with our `accelerator`.
633
- transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
634
- transformer, optimizer, train_dataloader, lr_scheduler
635
- )
636
-
637
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
638
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
639
- if overrode_max_train_steps:
640
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
641
- # Afterwards we recalculate our number of training epochs
642
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
643
-
644
- # We need to initialize the trackers we use, and also store our configuration.
645
- # The trackers initializes automatically on the main process.
646
- if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
647
- tracker_name = args.tracker_name or "cogvideox-lora"
648
- accelerator.init_trackers(tracker_name, config=vars(args))
649
-
650
- accelerator.print("===== Memory before training =====")
651
- reset_memory(accelerator.device)
652
- print_memory(accelerator.device)
653
-
654
- # Train!
655
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
656
-
657
- accelerator.print("***** Running training *****")
658
- accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
659
- accelerator.print(f" Num examples = {len(train_dataset)}")
660
- accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
661
- accelerator.print(f" Num epochs = {args.num_train_epochs}")
662
- accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
663
- accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
664
- accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
665
- accelerator.print(f" Total optimization steps = {args.max_train_steps}")
666
- global_step = 0
667
- first_epoch = 0
668
-
669
- # Potentially load in the weights and states from a previous save
670
- if not args.resume_from_checkpoint:
671
- initial_global_step = 0
672
- else:
673
- if args.resume_from_checkpoint != "latest":
674
- path = os.path.basename(args.resume_from_checkpoint)
675
- else:
676
- # Get the most recent checkpoint
677
- dirs = os.listdir(args.output_dir)
678
- dirs = [d for d in dirs if d.startswith("checkpoint")]
679
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
680
- path = dirs[-1] if len(dirs) > 0 else None
681
-
682
- if path is None:
683
- accelerator.print(
684
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
685
- )
686
- args.resume_from_checkpoint = None
687
- initial_global_step = 0
688
- else:
689
- accelerator.print(f"Resuming from checkpoint {path}")
690
- accelerator.load_state(os.path.join(args.output_dir, path))
691
- global_step = int(path.split("-")[1])
692
-
693
- initial_global_step = global_step
694
- first_epoch = global_step // num_update_steps_per_epoch
695
-
696
- progress_bar = tqdm(
697
- range(0, args.max_train_steps),
698
- initial=initial_global_step,
699
- desc="Steps",
700
- # Only show the progress bar once on each machine.
701
- disable=not accelerator.is_local_main_process,
702
- )
703
-
704
- # For DeepSpeed training
705
- model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
706
-
707
- if args.load_tensors:
708
- del vae, text_encoder
709
- gc.collect()
710
- torch.cuda.empty_cache()
711
- torch.cuda.synchronize(accelerator.device)
712
-
713
- alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
714
-
715
- for epoch in range(first_epoch, args.num_train_epochs):
716
- transformer.train()
717
-
718
- for step, batch in enumerate(train_dataloader):
719
- models_to_accumulate = [transformer]
720
- logs = {}
721
-
722
- with accelerator.accumulate(models_to_accumulate):
723
- images = batch["images"].to(accelerator.device, non_blocking=True)
724
- videos = batch["videos"].to(accelerator.device, non_blocking=True)
725
- prompts = batch["prompts"]
726
-
727
- # Encode videos
728
- if not args.load_tensors:
729
- images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
730
- image_noise_sigma = torch.normal(
731
- mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
732
- )
733
- image_noise_sigma = torch.exp(image_noise_sigma)
734
- noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
735
- image_latent_dist = vae.encode(noisy_images).latent_dist
736
-
737
- videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
738
- latent_dist = vae.encode(videos).latent_dist
739
- else:
740
- image_latent_dist = DiagonalGaussianDistribution(images)
741
- latent_dist = DiagonalGaussianDistribution(videos)
742
-
743
- image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR
744
- image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
745
- image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
746
-
747
- video_latents = latent_dist.sample() * VAE_SCALING_FACTOR
748
- video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
749
- video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
750
-
751
- padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])
752
- latent_padding = image_latents.new_zeros(padding_shape)
753
- image_latents = torch.cat([image_latents, latent_padding], dim=1)
754
-
755
- if random.random() < args.noised_image_dropout:
756
- image_latents = torch.zeros_like(image_latents)
757
-
758
- # Encode prompts
759
- if not args.load_tensors:
760
- prompt_embeds = compute_prompt_embeddings(
761
- tokenizer,
762
- text_encoder,
763
- prompts,
764
- model_config.max_text_seq_length,
765
- accelerator.device,
766
- weight_dtype,
767
- requires_grad=False,
768
- )
769
- else:
770
- prompt_embeds = prompts.to(dtype=weight_dtype)
771
-
772
- # Sample noise that will be added to the latents
773
- noise = torch.randn_like(video_latents)
774
- batch_size, num_frames, num_channels, height, width = video_latents.shape
775
-
776
- # Sample a random timestep for each image
777
- timesteps = torch.randint(
778
- 0,
779
- scheduler.config.num_train_timesteps,
780
- (batch_size,),
781
- dtype=torch.int64,
782
- device=accelerator.device,
783
- )
784
-
785
- # Prepare rotary embeds
786
- image_rotary_emb = (
787
- prepare_rotary_positional_embeddings(
788
- height=height * VAE_SCALE_FACTOR_SPATIAL,
789
- width=width * VAE_SCALE_FACTOR_SPATIAL,
790
- num_frames=num_frames,
791
- vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
792
- patch_size=model_config.patch_size,
793
- patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
794
- attention_head_dim=model_config.attention_head_dim,
795
- device=accelerator.device,
796
- base_height=RoPE_BASE_HEIGHT,
797
- base_width=RoPE_BASE_WIDTH,
798
- )
799
- if model_config.use_rotary_positional_embeddings
800
- else None
801
- )
802
-
803
- # Add noise to the model input according to the noise magnitude at each timestep
804
- # (this is the forward diffusion process)
805
- noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
806
- noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
807
-
808
- ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None,
809
- ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0)
810
- # Predict the noise residual
811
- model_output = transformer(
812
- hidden_states=noisy_model_input,
813
- encoder_hidden_states=prompt_embeds,
814
- timestep=timesteps,
815
- ofs=ofs_emb,
816
- image_rotary_emb=image_rotary_emb,
817
- return_dict=False,
818
- )[0]
819
-
820
- model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
821
-
822
- weights = 1 / (1 - alphas_cumprod[timesteps])
823
- while len(weights.shape) < len(model_pred.shape):
824
- weights = weights.unsqueeze(-1)
825
-
826
- target = video_latents
827
-
828
- loss = torch.mean(
829
- (weights * (model_pred - target) ** 2).reshape(batch_size, -1),
830
- dim=1,
831
- )
832
- loss = loss.mean()
833
- accelerator.backward(loss)
834
-
835
- if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
836
- gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
837
- accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
838
- gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
839
- logs.update(
840
- {
841
- "gradient_norm_before_clip": gradient_norm_before_clip,
842
- "gradient_norm_after_clip": gradient_norm_after_clip,
843
- }
844
- )
845
-
846
- if accelerator.state.deepspeed_plugin is None:
847
- optimizer.step()
848
- optimizer.zero_grad()
849
-
850
- if not args.use_cpu_offload_optimizer:
851
- lr_scheduler.step()
852
-
853
- # Checks if the accelerator has performed an optimization step behind the scenes
854
- if accelerator.sync_gradients:
855
- progress_bar.update(1)
856
- global_step += 1
857
-
858
- # Checkpointing
859
- if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
860
- if global_step % args.checkpointing_steps == 0:
861
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
862
- if args.checkpoints_total_limit is not None:
863
- checkpoints = os.listdir(args.output_dir)
864
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
865
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
866
-
867
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
868
- if len(checkpoints) >= args.checkpoints_total_limit:
869
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
870
- removing_checkpoints = checkpoints[0:num_to_remove]
871
-
872
- logger.info(
873
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
874
- )
875
- logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
876
-
877
- for removing_checkpoint in removing_checkpoints:
878
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
879
- shutil.rmtree(removing_checkpoint)
880
-
881
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
882
- accelerator.save_state(save_path)
883
- logger.info(f"Saved state to {save_path}")
884
-
885
- # Validation
886
- should_run_validation = args.validation_prompt is not None and (
887
- args.validation_steps is not None and global_step % args.validation_steps == 0
888
- )
889
- if should_run_validation:
890
- run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
891
-
892
- last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
893
- logs.update(
894
- {
895
- "loss": loss.detach().item(),
896
- "lr": last_lr,
897
- }
898
- )
899
- progress_bar.set_postfix(**logs)
900
- accelerator.log(logs, step=global_step)
901
-
902
- if global_step >= args.max_train_steps:
903
- break
904
-
905
- if accelerator.is_main_process:
906
- should_run_validation = args.validation_prompt is not None and (
907
- args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0
908
- )
909
- if should_run_validation:
910
- run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
911
-
912
- accelerator.wait_for_everyone()
913
-
914
- if accelerator.is_main_process:
915
- transformer = unwrap_model(accelerator, transformer)
916
- dtype = (
917
- torch.float16
918
- if args.mixed_precision == "fp16"
919
- else torch.bfloat16
920
- if args.mixed_precision == "bf16"
921
- else torch.float32
922
- )
923
- transformer = transformer.to(dtype)
924
- transformer_lora_layers = get_peft_model_state_dict(transformer)
925
-
926
- CogVideoXImageToVideoPipeline.save_lora_weights(
927
- save_directory=args.output_dir,
928
- transformer_lora_layers=transformer_lora_layers,
929
- )
930
-
931
- # Cleanup trained models to save memory
932
- if args.load_tensors:
933
- del transformer
934
- else:
935
- del transformer, text_encoder, vae
936
-
937
- gc.collect()
938
- torch.cuda.empty_cache()
939
- torch.cuda.synchronize(accelerator.device)
940
-
941
- accelerator.print("===== Memory before testing =====")
942
- print_memory(accelerator.device)
943
- reset_memory(accelerator.device)
944
-
945
- # Final test inference
946
- pipe = CogVideoXImageToVideoPipeline.from_pretrained(
947
- args.pretrained_model_name_or_path,
948
- revision=args.revision,
949
- variant=args.variant,
950
- torch_dtype=weight_dtype,
951
- )
952
- pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
953
-
954
- if args.enable_slicing:
955
- pipe.vae.enable_slicing()
956
- if args.enable_tiling:
957
- pipe.vae.enable_tiling()
958
- if args.enable_model_cpu_offload:
959
- pipe.enable_model_cpu_offload()
960
-
961
- # Load LoRA weights
962
- lora_scaling = args.lora_alpha / args.rank
963
- pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
964
- pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
965
-
966
- # Run inference
967
- validation_outputs = []
968
- if args.validation_prompt and args.num_validation_videos > 0:
969
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
970
- validation_images = args.validation_images.split(args.validation_prompt_separator)
971
- for validation_image, validation_prompt in zip(validation_images, validation_prompts):
972
- pipeline_args = {
973
- "image": load_image(validation_image),
974
- "prompt": validation_prompt,
975
- "guidance_scale": args.guidance_scale,
976
- "use_dynamic_cfg": args.use_dynamic_cfg,
977
- "height": args.height,
978
- "width": args.width,
979
- }
980
-
981
- video = log_validation(
982
- accelerator=accelerator,
983
- pipe=pipe,
984
- args=args,
985
- pipeline_args=pipeline_args,
986
- is_final_validation=True,
987
- )
988
- validation_outputs.extend(video)
989
-
990
- accelerator.print("===== Memory after testing =====")
991
- print_memory(accelerator.device)
992
- reset_memory(accelerator.device)
993
- torch.cuda.synchronize(accelerator.device)
994
-
995
- if args.push_to_hub:
996
- save_model_card(
997
- repo_id,
998
- videos=validation_outputs,
999
- base_model=args.pretrained_model_name_or_path,
1000
- validation_prompt=args.validation_prompt,
1001
- repo_folder=args.output_dir,
1002
- fps=args.fps,
1003
- )
1004
- upload_folder(
1005
- repo_id=repo_id,
1006
- folder_path=args.output_dir,
1007
- commit_message="End of training",
1008
- ignore_patterns=["step_*", "epoch_*"],
1009
- )
1010
-
1011
- accelerator.end_training()
1012
-
1013
-
1014
- if __name__ == "__main__":
1015
- args = get_args()
1016
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/cogvideox/cogvideox_image_to_video_sft.py DELETED
@@ -1,947 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import gc
17
- import logging
18
- import math
19
- import os
20
- import random
21
- import shutil
22
- from datetime import timedelta
23
- from pathlib import Path
24
- from typing import Any, Dict
25
-
26
- import diffusers
27
- import torch
28
- import transformers
29
- import wandb
30
- from accelerate import Accelerator, DistributedType, init_empty_weights
31
- from accelerate.logging import get_logger
32
- from accelerate.utils import (
33
- DistributedDataParallelKwargs,
34
- InitProcessGroupKwargs,
35
- ProjectConfiguration,
36
- set_seed,
37
- )
38
- from diffusers import (
39
- AutoencoderKLCogVideoX,
40
- CogVideoXDPMScheduler,
41
- CogVideoXImageToVideoPipeline,
42
- CogVideoXTransformer3DModel,
43
- )
44
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
45
- from diffusers.optimization import get_scheduler
46
- from diffusers.training_utils import cast_training_params
47
- from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
48
- from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
49
- from huggingface_hub import create_repo, upload_folder
50
- from torch.utils.data import DataLoader
51
- from tqdm.auto import tqdm
52
- from transformers import AutoTokenizer, T5EncoderModel
53
-
54
-
55
- from args import get_args # isort:skip
56
- from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
57
- from text_encoder import compute_prompt_embeddings # isort:skip
58
- from utils import (
59
- get_gradient_norm,
60
- get_optimizer,
61
- prepare_rotary_positional_embeddings,
62
- print_memory,
63
- reset_memory,
64
- unwrap_model,
65
- )
66
-
67
-
68
- logger = get_logger(__name__)
69
-
70
-
71
- def save_model_card(
72
- repo_id: str,
73
- videos=None,
74
- base_model: str = None,
75
- validation_prompt=None,
76
- repo_folder=None,
77
- fps=8,
78
- ):
79
- widget_dict = []
80
- if videos is not None:
81
- for i, video in enumerate(videos):
82
- export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
83
- widget_dict.append(
84
- {
85
- "text": validation_prompt if validation_prompt else " ",
86
- "output": {"url": f"video_{i}.mp4"},
87
- }
88
- )
89
-
90
- model_description = f"""
91
- # CogVideoX Full Finetune
92
-
93
- <Gallery />
94
-
95
- ## Model description
96
-
97
- This is a full finetune of the CogVideoX model `{base_model}`.
98
-
99
- ## License
100
-
101
- Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).
102
- """
103
- model_card = load_or_create_model_card(
104
- repo_id_or_path=repo_id,
105
- from_training=True,
106
- license="other",
107
- base_model=base_model,
108
- prompt=validation_prompt,
109
- model_description=model_description,
110
- widget=widget_dict,
111
- )
112
- tags = [
113
- "text-to-video",
114
- "image-to-video",
115
- "diffusers-training",
116
- "diffusers",
117
- "cogvideox",
118
- "cogvideox-diffusers",
119
- ]
120
-
121
- model_card = populate_model_card(model_card, tags=tags)
122
- model_card.save(os.path.join(repo_folder, "README.md"))
123
-
124
-
125
- def log_validation(
126
- accelerator: Accelerator,
127
- pipe: CogVideoXImageToVideoPipeline,
128
- args: Dict[str, Any],
129
- pipeline_args: Dict[str, Any],
130
- is_final_validation: bool = False,
131
- ):
132
- logger.info(
133
- f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
134
- )
135
-
136
- pipe = pipe.to(accelerator.device)
137
-
138
- # run inference
139
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
140
-
141
- videos = []
142
- for _ in range(args.num_validation_videos):
143
- video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
144
- videos.append(video)
145
-
146
- for tracker in accelerator.trackers:
147
- phase_name = "test" if is_final_validation else "validation"
148
- if tracker.name == "wandb":
149
- video_filenames = []
150
- for i, video in enumerate(videos):
151
- prompt = (
152
- pipeline_args["prompt"][:25]
153
- .replace(" ", "_")
154
- .replace(" ", "_")
155
- .replace("'", "_")
156
- .replace('"', "_")
157
- .replace("/", "_")
158
- )
159
- filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
160
- export_to_video(video, filename, fps=8)
161
- video_filenames.append(filename)
162
-
163
- tracker.log(
164
- {
165
- phase_name: [
166
- wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
167
- for i, filename in enumerate(video_filenames)
168
- ]
169
- }
170
- )
171
-
172
- return videos
173
-
174
-
175
- def run_validation(
176
- args: Dict[str, Any],
177
- accelerator: Accelerator,
178
- transformer,
179
- scheduler,
180
- model_config: Dict[str, Any],
181
- weight_dtype: torch.dtype,
182
- ) -> None:
183
- accelerator.print("===== Memory before validation =====")
184
- print_memory(accelerator.device)
185
- torch.cuda.synchronize(accelerator.device)
186
-
187
- pipe = CogVideoXImageToVideoPipeline.from_pretrained(
188
- args.pretrained_model_name_or_path,
189
- transformer=unwrap_model(accelerator, transformer),
190
- scheduler=scheduler,
191
- revision=args.revision,
192
- variant=args.variant,
193
- torch_dtype=weight_dtype,
194
- )
195
-
196
- if args.enable_slicing:
197
- pipe.vae.enable_slicing()
198
- if args.enable_tiling:
199
- pipe.vae.enable_tiling()
200
- if args.enable_model_cpu_offload:
201
- pipe.enable_model_cpu_offload()
202
-
203
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
204
- validation_images = args.validation_images.split(args.validation_prompt_separator)
205
- for validation_image, validation_prompt in zip(validation_images, validation_prompts):
206
- pipeline_args = {
207
- "image": load_image(validation_image),
208
- "prompt": validation_prompt,
209
- "guidance_scale": args.guidance_scale,
210
- "use_dynamic_cfg": args.use_dynamic_cfg,
211
- "height": args.height,
212
- "width": args.width,
213
- "max_sequence_length": model_config.max_text_seq_length,
214
- }
215
-
216
- log_validation(
217
- pipe=pipe,
218
- args=args,
219
- accelerator=accelerator,
220
- pipeline_args=pipeline_args,
221
- )
222
-
223
- accelerator.print("===== Memory after validation =====")
224
- print_memory(accelerator.device)
225
- reset_memory(accelerator.device)
226
-
227
- del pipe
228
- gc.collect()
229
- torch.cuda.empty_cache()
230
- torch.cuda.synchronize(accelerator.device)
231
-
232
-
233
- class CollateFunction:
234
- def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
235
- self.weight_dtype = weight_dtype
236
- self.load_tensors = load_tensors
237
-
238
- def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
239
- prompts = [x["prompt"] for x in data[0]]
240
-
241
- if self.load_tensors:
242
- prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
243
-
244
- images = [x["image"] for x in data[0]]
245
- images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True)
246
-
247
- videos = [x["video"] for x in data[0]]
248
- videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
249
-
250
- return {
251
- "images": images,
252
- "videos": videos,
253
- "prompts": prompts,
254
- }
255
-
256
-
257
- def main(args):
258
- if args.report_to == "wandb" and args.hub_token is not None:
259
- raise ValueError(
260
- "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
261
- " Please use `huggingface-cli login` to authenticate with the Hub."
262
- )
263
-
264
- if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
265
- # due to pytorch#99272, MPS does not yet support bfloat16.
266
- raise ValueError(
267
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
268
- )
269
-
270
- logging_dir = Path(args.output_dir, args.logging_dir)
271
-
272
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
273
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
274
- init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
275
- accelerator = Accelerator(
276
- gradient_accumulation_steps=args.gradient_accumulation_steps,
277
- mixed_precision=args.mixed_precision,
278
- log_with=args.report_to,
279
- project_config=accelerator_project_config,
280
- kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
281
- )
282
-
283
- # Disable AMP for MPS.
284
- if torch.backends.mps.is_available():
285
- accelerator.native_amp = False
286
-
287
- # Make one log on every process with the configuration for debugging.
288
- logging.basicConfig(
289
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
290
- datefmt="%m/%d/%Y %H:%M:%S",
291
- level=logging.INFO,
292
- )
293
- logger.info(accelerator.state, main_process_only=False)
294
- if accelerator.is_local_main_process:
295
- transformers.utils.logging.set_verbosity_warning()
296
- diffusers.utils.logging.set_verbosity_info()
297
- else:
298
- transformers.utils.logging.set_verbosity_error()
299
- diffusers.utils.logging.set_verbosity_error()
300
-
301
- # If passed along, set the training seed now.
302
- if args.seed is not None:
303
- set_seed(args.seed)
304
-
305
- # Handle the repository creation
306
- if accelerator.is_main_process:
307
- if args.output_dir is not None:
308
- os.makedirs(args.output_dir, exist_ok=True)
309
-
310
- if args.push_to_hub:
311
- repo_id = create_repo(
312
- repo_id=args.hub_model_id or Path(args.output_dir).name,
313
- exist_ok=True,
314
- ).repo_id
315
-
316
- # Prepare models and scheduler
317
- tokenizer = AutoTokenizer.from_pretrained(
318
- args.pretrained_model_name_or_path,
319
- subfolder="tokenizer",
320
- revision=args.revision,
321
- )
322
-
323
- text_encoder = T5EncoderModel.from_pretrained(
324
- args.pretrained_model_name_or_path,
325
- subfolder="text_encoder",
326
- revision=args.revision,
327
- )
328
-
329
- # CogVideoX-2b weights are stored in float16
330
- # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
331
- load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
332
- transformer = CogVideoXTransformer3DModel.from_pretrained(
333
- args.pretrained_model_name_or_path,
334
- subfolder="transformer",
335
- torch_dtype=load_dtype,
336
- revision=args.revision,
337
- variant=args.variant,
338
- )
339
-
340
- if args.ignore_learned_positional_embeddings:
341
- del transformer.patch_embed.pos_embedding
342
- transformer.patch_embed.use_learned_positional_embeddings = False
343
- transformer.config.use_learned_positional_embeddings = False
344
-
345
- vae = AutoencoderKLCogVideoX.from_pretrained(
346
- args.pretrained_model_name_or_path,
347
- subfolder="vae",
348
- revision=args.revision,
349
- variant=args.variant,
350
- )
351
-
352
- scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
353
-
354
- if args.enable_slicing:
355
- vae.enable_slicing()
356
- if args.enable_tiling:
357
- vae.enable_tiling()
358
-
359
- text_encoder.requires_grad_(False)
360
- vae.requires_grad_(False)
361
- transformer.requires_grad_(True)
362
-
363
- VAE_SCALING_FACTOR = vae.config.scaling_factor
364
- VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
365
- RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
366
- RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
367
-
368
- # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
369
- # as these weights are only used for inference, keeping weights in full precision is not required.
370
- weight_dtype = torch.float32
371
- if accelerator.state.deepspeed_plugin:
372
- # DeepSpeed is handling precision, use what's in the DeepSpeed config
373
- if (
374
- "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
375
- and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
376
- ):
377
- weight_dtype = torch.float16
378
- if (
379
- "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
380
- and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
381
- ):
382
- weight_dtype = torch.bfloat16
383
- else:
384
- if accelerator.mixed_precision == "fp16":
385
- weight_dtype = torch.float16
386
- elif accelerator.mixed_precision == "bf16":
387
- weight_dtype = torch.bfloat16
388
-
389
- if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
390
- # due to pytorch#99272, MPS does not yet support bfloat16.
391
- raise ValueError(
392
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
393
- )
394
-
395
- text_encoder.to(accelerator.device, dtype=weight_dtype)
396
- transformer.to(accelerator.device, dtype=weight_dtype)
397
- vae.to(accelerator.device, dtype=weight_dtype)
398
-
399
- if args.gradient_checkpointing:
400
- transformer.enable_gradient_checkpointing()
401
-
402
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
403
- def save_model_hook(models, weights, output_dir):
404
- if accelerator.is_main_process:
405
- for model in models:
406
- if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
407
- model = unwrap_model(accelerator, model)
408
- model.save_pretrained(
409
- os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB"
410
- )
411
- else:
412
- raise ValueError(f"Unexpected save model: {model.__class__}")
413
-
414
- # make sure to pop weight so that corresponding model is not saved again
415
- if weights:
416
- weights.pop()
417
-
418
- def load_model_hook(models, input_dir):
419
- transformer_ = None
420
- init_under_meta = False
421
-
422
- # This is a bit of a hack but I don't know any other solution.
423
- if not accelerator.distributed_type == DistributedType.DEEPSPEED:
424
- while len(models) > 0:
425
- model = models.pop()
426
-
427
- if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
428
- transformer_ = unwrap_model(accelerator, model)
429
- else:
430
- raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
431
- else:
432
- with init_empty_weights():
433
- transformer_ = CogVideoXTransformer3DModel.from_config(
434
- args.pretrained_model_name_or_path, subfolder="transformer"
435
- )
436
- init_under_meta = True
437
-
438
- load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer"))
439
- transformer_.register_to_config(**load_model.config)
440
- transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta)
441
- del load_model
442
-
443
- # Make sure the trainable params are in float32. This is again needed since the base models
444
- # are in `weight_dtype`. More details:
445
- # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
446
- if args.mixed_precision == "fp16":
447
- cast_training_params([transformer_])
448
-
449
- accelerator.register_save_state_pre_hook(save_model_hook)
450
- accelerator.register_load_state_pre_hook(load_model_hook)
451
-
452
- # Enable TF32 for faster training on Ampere GPUs,
453
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
454
- if args.allow_tf32 and torch.cuda.is_available():
455
- torch.backends.cuda.matmul.allow_tf32 = True
456
-
457
- if args.scale_lr:
458
- args.learning_rate = (
459
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
460
- )
461
-
462
- # Make sure the trainable params are in float32.
463
- if args.mixed_precision == "fp16":
464
- cast_training_params([transformer], dtype=torch.float32)
465
-
466
- transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
467
-
468
- # Optimization parameters
469
- transformer_parameters_with_lr = {
470
- "params": transformer_parameters,
471
- "lr": args.learning_rate,
472
- }
473
- params_to_optimize = [transformer_parameters_with_lr]
474
- num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
475
-
476
- use_deepspeed_optimizer = (
477
- accelerator.state.deepspeed_plugin is not None
478
- and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
479
- )
480
- use_deepspeed_scheduler = (
481
- accelerator.state.deepspeed_plugin is not None
482
- and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
483
- )
484
-
485
- optimizer = get_optimizer(
486
- params_to_optimize=params_to_optimize,
487
- optimizer_name=args.optimizer,
488
- learning_rate=args.learning_rate,
489
- beta1=args.beta1,
490
- beta2=args.beta2,
491
- beta3=args.beta3,
492
- epsilon=args.epsilon,
493
- weight_decay=args.weight_decay,
494
- prodigy_decouple=args.prodigy_decouple,
495
- prodigy_use_bias_correction=args.prodigy_use_bias_correction,
496
- prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
497
- use_8bit=args.use_8bit,
498
- use_4bit=args.use_4bit,
499
- use_torchao=args.use_torchao,
500
- use_deepspeed=use_deepspeed_optimizer,
501
- use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
502
- offload_gradients=args.offload_gradients,
503
- )
504
-
505
- # Dataset and DataLoader
506
- dataset_init_kwargs = {
507
- "data_root": args.data_root,
508
- "dataset_file": args.dataset_file,
509
- "caption_column": args.caption_column,
510
- "video_column": args.video_column,
511
- "max_num_frames": args.max_num_frames,
512
- "id_token": args.id_token,
513
- "height_buckets": args.height_buckets,
514
- "width_buckets": args.width_buckets,
515
- "frame_buckets": args.frame_buckets,
516
- "load_tensors": args.load_tensors,
517
- "random_flip": args.random_flip,
518
- "image_to_video": True,
519
- }
520
- if args.video_reshape_mode is None:
521
- train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
522
- else:
523
- train_dataset = VideoDatasetWithResizeAndRectangleCrop(
524
- video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
525
- )
526
-
527
- collate_fn = CollateFunction(weight_dtype, args.load_tensors)
528
-
529
- train_dataloader = DataLoader(
530
- train_dataset,
531
- batch_size=1,
532
- sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
533
- collate_fn=collate_fn,
534
- num_workers=args.dataloader_num_workers,
535
- pin_memory=args.pin_memory,
536
- )
537
-
538
- # Scheduler and math around the number of training steps.
539
- overrode_max_train_steps = False
540
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
541
- if args.max_train_steps is None:
542
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
543
- overrode_max_train_steps = True
544
-
545
- if args.use_cpu_offload_optimizer:
546
- lr_scheduler = None
547
- accelerator.print(
548
- "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
549
- "you are training with those settings, they will be ignored."
550
- )
551
- else:
552
- if use_deepspeed_scheduler:
553
- from accelerate.utils import DummyScheduler
554
-
555
- lr_scheduler = DummyScheduler(
556
- name=args.lr_scheduler,
557
- optimizer=optimizer,
558
- total_num_steps=args.max_train_steps * accelerator.num_processes,
559
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
560
- )
561
- else:
562
- lr_scheduler = get_scheduler(
563
- args.lr_scheduler,
564
- optimizer=optimizer,
565
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
566
- num_training_steps=args.max_train_steps * accelerator.num_processes,
567
- num_cycles=args.lr_num_cycles,
568
- power=args.lr_power,
569
- )
570
-
571
- # Prepare everything with our `accelerator`.
572
- transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
573
- transformer, optimizer, train_dataloader, lr_scheduler
574
- )
575
-
576
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
577
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
578
- if overrode_max_train_steps:
579
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
580
- # Afterwards we recalculate our number of training epochs
581
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
582
-
583
- # We need to initialize the trackers we use, and also store our configuration.
584
- # The trackers initializes automatically on the main process.
585
- if accelerator.is_main_process:
586
- tracker_name = args.tracker_name or "cogvideox-sft"
587
- accelerator.init_trackers(tracker_name, config=vars(args))
588
-
589
- accelerator.print("===== Memory before training =====")
590
- reset_memory(accelerator.device)
591
- print_memory(accelerator.device)
592
-
593
- # Train!
594
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
595
-
596
- accelerator.print("***** Running training *****")
597
- accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
598
- accelerator.print(f" Num examples = {len(train_dataset)}")
599
- accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
600
- accelerator.print(f" Num epochs = {args.num_train_epochs}")
601
- accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
602
- accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
603
- accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
604
- accelerator.print(f" Total optimization steps = {args.max_train_steps}")
605
- global_step = 0
606
- first_epoch = 0
607
-
608
- # Potentially load in the weights and states from a previous save
609
- if not args.resume_from_checkpoint:
610
- initial_global_step = 0
611
- else:
612
- if args.resume_from_checkpoint != "latest":
613
- path = os.path.basename(args.resume_from_checkpoint)
614
- else:
615
- # Get the most recent checkpoint
616
- dirs = os.listdir(args.output_dir)
617
- dirs = [d for d in dirs if d.startswith("checkpoint")]
618
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
619
- path = dirs[-1] if len(dirs) > 0 else None
620
-
621
- if path is None:
622
- accelerator.print(
623
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
624
- )
625
- args.resume_from_checkpoint = None
626
- initial_global_step = 0
627
- else:
628
- accelerator.print(f"Resuming from checkpoint {path}")
629
- accelerator.load_state(os.path.join(args.output_dir, path))
630
- global_step = int(path.split("-")[1])
631
-
632
- initial_global_step = global_step
633
- first_epoch = global_step // num_update_steps_per_epoch
634
-
635
- progress_bar = tqdm(
636
- range(0, args.max_train_steps),
637
- initial=initial_global_step,
638
- desc="Steps",
639
- # Only show the progress bar once on each machine.
640
- disable=not accelerator.is_local_main_process,
641
- )
642
-
643
- # For DeepSpeed training
644
- model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
645
-
646
- if args.load_tensors:
647
- del vae, text_encoder
648
- gc.collect()
649
- torch.cuda.empty_cache()
650
- torch.cuda.synchronize(accelerator.device)
651
-
652
- alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
653
-
654
- for epoch in range(first_epoch, args.num_train_epochs):
655
- transformer.train()
656
- for step, batch in enumerate(train_dataloader):
657
- models_to_accumulate = [transformer]
658
- logs = {}
659
-
660
- with accelerator.accumulate(models_to_accumulate):
661
- images = batch["images"].to(accelerator.device, non_blocking=True)
662
- videos = batch["videos"].to(accelerator.device, non_blocking=True)
663
- prompts = batch["prompts"]
664
-
665
- # Encode videos
666
- if not args.load_tensors:
667
- images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
668
- image_noise_sigma = torch.normal(
669
- mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
670
- )
671
- image_noise_sigma = torch.exp(image_noise_sigma)
672
- noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
673
- image_latent_dist = vae.encode(noisy_images).latent_dist
674
-
675
- videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
676
- latent_dist = vae.encode(videos).latent_dist
677
- else:
678
- image_latent_dist = DiagonalGaussianDistribution(images)
679
- latent_dist = DiagonalGaussianDistribution(videos)
680
-
681
- image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR
682
- image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
683
- image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
684
-
685
- video_latents = latent_dist.sample() * VAE_SCALING_FACTOR
686
- video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
687
- video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
688
-
689
- padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])
690
- latent_padding = image_latents.new_zeros(padding_shape)
691
- image_latents = torch.cat([image_latents, latent_padding], dim=1)
692
-
693
- if random.random() < args.noised_image_dropout:
694
- image_latents = torch.zeros_like(image_latents)
695
-
696
- # Encode prompts
697
- if not args.load_tensors:
698
- prompt_embeds = compute_prompt_embeddings(
699
- tokenizer,
700
- text_encoder,
701
- prompts,
702
- model_config.max_text_seq_length,
703
- accelerator.device,
704
- weight_dtype,
705
- requires_grad=False,
706
- )
707
- else:
708
- prompt_embeds = prompts.to(dtype=weight_dtype)
709
-
710
- # Sample noise that will be added to the latents
711
- noise = torch.randn_like(video_latents)
712
- batch_size, num_frames, num_channels, height, width = video_latents.shape
713
-
714
- # Sample a random timestep for each image
715
- timesteps = torch.randint(
716
- 0,
717
- scheduler.config.num_train_timesteps,
718
- (batch_size,),
719
- dtype=torch.int64,
720
- device=accelerator.device,
721
- )
722
-
723
- # Prepare rotary embeds
724
- image_rotary_emb = (
725
- prepare_rotary_positional_embeddings(
726
- height=height * VAE_SCALE_FACTOR_SPATIAL,
727
- width=width * VAE_SCALE_FACTOR_SPATIAL,
728
- num_frames=num_frames,
729
- vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
730
- patch_size=model_config.patch_size,
731
- patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
732
- attention_head_dim=model_config.attention_head_dim,
733
- device=accelerator.device,
734
- base_height=RoPE_BASE_HEIGHT,
735
- base_width=RoPE_BASE_WIDTH,
736
- )
737
- if model_config.use_rotary_positional_embeddings
738
- else None
739
- )
740
-
741
- # Add noise to the model input according to the noise magnitude at each timestep
742
- # (this is the forward diffusion process)
743
- noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
744
- noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
745
- model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
746
- ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None,
747
- ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0)
748
- # Predict the noise residual
749
- model_output = transformer(
750
- hidden_states=noisy_model_input,
751
- encoder_hidden_states=prompt_embeds,
752
- timestep=timesteps,
753
- ofs=ofs_emb,
754
- image_rotary_emb=image_rotary_emb,
755
- return_dict=False,
756
- )[0]
757
-
758
- model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
759
-
760
- weights = 1 / (1 - alphas_cumprod[timesteps])
761
- while len(weights.shape) < len(model_pred.shape):
762
- weights = weights.unsqueeze(-1)
763
-
764
- target = video_latents
765
-
766
- loss = torch.mean(
767
- (weights * (model_pred - target) ** 2).reshape(batch_size, -1),
768
- dim=1,
769
- )
770
- loss = loss.mean()
771
- accelerator.backward(loss)
772
-
773
- if accelerator.sync_gradients:
774
- gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
775
- accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
776
- gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
777
- logs.update(
778
- {
779
- "gradient_norm_before_clip": gradient_norm_before_clip,
780
- "gradient_norm_after_clip": gradient_norm_after_clip,
781
- }
782
- )
783
- if accelerator.state.deepspeed_plugin is None:
784
- optimizer.step()
785
- optimizer.zero_grad()
786
-
787
- if not args.use_cpu_offload_optimizer:
788
- lr_scheduler.step()
789
-
790
- # Checks if the accelerator has performed an optimization step behind the scenes
791
- if accelerator.sync_gradients:
792
- progress_bar.update(1)
793
- global_step += 1
794
-
795
- # Checkpointing
796
- if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
797
- if global_step % args.checkpointing_steps == 0:
798
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
799
- if args.checkpoints_total_limit is not None:
800
- checkpoints = os.listdir(args.output_dir)
801
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
802
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
803
-
804
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
805
- if len(checkpoints) >= args.checkpoints_total_limit:
806
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
807
- removing_checkpoints = checkpoints[0:num_to_remove]
808
-
809
- logger.info(
810
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
811
- )
812
- logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
813
-
814
- for removing_checkpoint in removing_checkpoints:
815
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
816
- shutil.rmtree(removing_checkpoint)
817
-
818
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
819
- accelerator.save_state(save_path)
820
- logger.info(f"Saved state to {save_path}")
821
-
822
- # Validation
823
- should_run_validation = args.validation_prompt is not None and (
824
- args.validation_steps is not None and global_step % args.validation_steps == 0
825
- )
826
- if should_run_validation:
827
- run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
828
-
829
- last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
830
- logs.update(
831
- {
832
- "loss": loss.detach().item(),
833
- "lr": last_lr,
834
- }
835
- )
836
- progress_bar.set_postfix(**logs)
837
- accelerator.log(logs, step=global_step)
838
-
839
- if global_step >= args.max_train_steps:
840
- break
841
-
842
- if accelerator.is_main_process:
843
- should_run_validation = args.validation_prompt is not None and (
844
- args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0
845
- )
846
- if should_run_validation:
847
- run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
848
- accelerator.wait_for_everyone()
849
-
850
- if accelerator.is_main_process:
851
- transformer = unwrap_model(accelerator, transformer)
852
- dtype = (
853
- torch.float16
854
- if args.mixed_precision == "fp16"
855
- else torch.bfloat16
856
- if args.mixed_precision == "bf16"
857
- else torch.float32
858
- )
859
- transformer = transformer.to(dtype)
860
-
861
- transformer.save_pretrained(
862
- os.path.join(args.output_dir, "transformer"),
863
- safe_serialization=True,
864
- max_shard_size="5GB",
865
- )
866
-
867
- # Cleanup trained models to save memory
868
- if args.load_tensors:
869
- del transformer
870
- else:
871
- del transformer, text_encoder, vae
872
-
873
- gc.collect()
874
- torch.cuda.empty_cache()
875
- torch.cuda.synchronize(accelerator.device)
876
-
877
- accelerator.print("===== Memory before testing =====")
878
- print_memory(accelerator.device)
879
- reset_memory(accelerator.device)
880
-
881
- # Final test inference
882
- pipe = CogVideoXImageToVideoPipeline.from_pretrained(
883
- args.pretrained_model_name_or_path,
884
- revision=args.revision,
885
- variant=args.variant,
886
- torch_dtype=weight_dtype,
887
- )
888
- pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
889
-
890
- if args.enable_slicing:
891
- pipe.vae.enable_slicing()
892
- if args.enable_tiling:
893
- pipe.vae.enable_tiling()
894
- if args.enable_model_cpu_offload:
895
- pipe.enable_model_cpu_offload()
896
-
897
- # Run inference
898
- validation_outputs = []
899
- if args.validation_prompt and args.num_validation_videos > 0:
900
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
901
- validation_images = args.validation_images.split(args.validation_prompt_separator)
902
- for validation_image, validation_prompt in zip(validation_images, validation_prompts):
903
- pipeline_args = {
904
- "image": load_image(validation_image),
905
- "prompt": validation_prompt,
906
- "guidance_scale": args.guidance_scale,
907
- "use_dynamic_cfg": args.use_dynamic_cfg,
908
- "height": args.height,
909
- "width": args.width,
910
- }
911
-
912
- video = log_validation(
913
- accelerator=accelerator,
914
- pipe=pipe,
915
- args=args,
916
- pipeline_args=pipeline_args,
917
- is_final_validation=True,
918
- )
919
- validation_outputs.extend(video)
920
-
921
- accelerator.print("===== Memory after testing =====")
922
- print_memory(accelerator.device)
923
- reset_memory(accelerator.device)
924
- torch.cuda.synchronize(accelerator.device)
925
-
926
- if args.push_to_hub:
927
- save_model_card(
928
- repo_id,
929
- videos=validation_outputs,
930
- base_model=args.pretrained_model_name_or_path,
931
- validation_prompt=args.validation_prompt,
932
- repo_folder=args.output_dir,
933
- fps=args.fps,
934
- )
935
- upload_folder(
936
- repo_id=repo_id,
937
- folder_path=args.output_dir,
938
- commit_message="End of training",
939
- ignore_patterns=["step_*", "epoch_*"],
940
- )
941
-
942
- accelerator.end_training()
943
-
944
-
945
- if __name__ == "__main__":
946
- args = get_args()
947
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/cogvideox/cogvideox_text_to_video_lora.py DELETED
@@ -1,955 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import gc
17
- import logging
18
- import math
19
- import os
20
- import shutil
21
- from datetime import timedelta
22
- from pathlib import Path
23
- from typing import Any, Dict
24
-
25
- import diffusers
26
- import torch
27
- import transformers
28
- import wandb
29
- from accelerate import Accelerator, DistributedType
30
- from accelerate.logging import get_logger
31
- from accelerate.utils import (
32
- DistributedDataParallelKwargs,
33
- InitProcessGroupKwargs,
34
- ProjectConfiguration,
35
- set_seed,
36
- )
37
- from diffusers import (
38
- AutoencoderKLCogVideoX,
39
- CogVideoXDPMScheduler,
40
- CogVideoXPipeline,
41
- CogVideoXTransformer3DModel,
42
- )
43
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
44
- from diffusers.optimization import get_scheduler
45
- from diffusers.training_utils import cast_training_params
46
- from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video
47
- from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
48
- from huggingface_hub import create_repo, upload_folder
49
- from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
50
- from torch.utils.data import DataLoader
51
- from tqdm.auto import tqdm
52
- from transformers import AutoTokenizer, T5EncoderModel
53
-
54
-
55
- from args import get_args # isort:skip
56
- from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
57
- from text_encoder import compute_prompt_embeddings # isort:skip
58
- from utils import (
59
- get_gradient_norm,
60
- get_optimizer,
61
- prepare_rotary_positional_embeddings,
62
- print_memory,
63
- reset_memory,
64
- unwrap_model,
65
- ) # isort:skip
66
-
67
-
68
- logger = get_logger(__name__)
69
-
70
-
71
- def save_model_card(
72
- repo_id: str,
73
- videos=None,
74
- base_model: str = None,
75
- validation_prompt=None,
76
- repo_folder=None,
77
- fps=8,
78
- ):
79
- widget_dict = []
80
- if videos is not None:
81
- for i, video in enumerate(videos):
82
- export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
83
- widget_dict.append(
84
- {
85
- "text": validation_prompt if validation_prompt else " ",
86
- "output": {"url": f"video_{i}.mp4"},
87
- }
88
- )
89
-
90
- model_description = f"""
91
- # CogVideoX LoRA Finetune
92
-
93
- <Gallery />
94
-
95
- ## Model description
96
-
97
- This is a lora finetune of the CogVideoX model `{base_model}`.
98
-
99
- The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
100
-
101
- ## Download model
102
-
103
- [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
104
-
105
- ## Usage
106
-
107
- Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
108
-
109
- ```py
110
- import torch
111
- from diffusers import CogVideoXPipeline
112
- from diffusers.utils import export_to_video
113
-
114
- pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
115
- pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
116
-
117
- # The LoRA adapter weights are determined by what was used for training.
118
- # In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
119
- # It can be made lower or higher from what was used in training to decrease or amplify the effect
120
- # of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
121
- pipe.set_adapters(["cogvideox-lora"], [32 / 64])
122
-
123
- video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
124
- export_to_video(video, "output.mp4", fps=8)
125
- ```
126
-
127
- For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
128
-
129
- ## License
130
-
131
- Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
132
- """
133
- model_card = load_or_create_model_card(
134
- repo_id_or_path=repo_id,
135
- from_training=True,
136
- license="other",
137
- base_model=base_model,
138
- prompt=validation_prompt,
139
- model_description=model_description,
140
- widget=widget_dict,
141
- )
142
- tags = [
143
- "text-to-video",
144
- "diffusers-training",
145
- "diffusers",
146
- "lora",
147
- "cogvideox",
148
- "cogvideox-diffusers",
149
- "template:sd-lora",
150
- ]
151
-
152
- model_card = populate_model_card(model_card, tags=tags)
153
- model_card.save(os.path.join(repo_folder, "README.md"))
154
-
155
-
156
- def log_validation(
157
- accelerator: Accelerator,
158
- pipe: CogVideoXPipeline,
159
- args: Dict[str, Any],
160
- pipeline_args: Dict[str, Any],
161
- epoch,
162
- is_final_validation: bool = False,
163
- ):
164
- logger.info(
165
- f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
166
- )
167
-
168
- pipe = pipe.to(accelerator.device)
169
-
170
- # run inference
171
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
172
-
173
- videos = []
174
- for _ in range(args.num_validation_videos):
175
- video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
176
- videos.append(video)
177
-
178
- for tracker in accelerator.trackers:
179
- phase_name = "test" if is_final_validation else "validation"
180
- if tracker.name == "wandb":
181
- video_filenames = []
182
- for i, video in enumerate(videos):
183
- prompt = (
184
- pipeline_args["prompt"][:25]
185
- .replace(" ", "_")
186
- .replace(" ", "_")
187
- .replace("'", "_")
188
- .replace('"', "_")
189
- .replace("/", "_")
190
- )
191
- filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
192
- export_to_video(video, filename, fps=8)
193
- video_filenames.append(filename)
194
-
195
- tracker.log(
196
- {
197
- phase_name: [
198
- wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
199
- for i, filename in enumerate(video_filenames)
200
- ]
201
- }
202
- )
203
-
204
- return videos
205
-
206
-
207
- class CollateFunction:
208
- def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
209
- self.weight_dtype = weight_dtype
210
- self.load_tensors = load_tensors
211
-
212
- def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
213
- prompts = [x["prompt"] for x in data[0]]
214
-
215
- if self.load_tensors:
216
- prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
217
-
218
- videos = [x["video"] for x in data[0]]
219
- videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
220
-
221
- return {
222
- "videos": videos,
223
- "prompts": prompts,
224
- }
225
-
226
-
227
- def main(args):
228
- if args.report_to == "wandb" and args.hub_token is not None:
229
- raise ValueError(
230
- "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
231
- " Please use `huggingface-cli login` to authenticate with the Hub."
232
- )
233
-
234
- if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
235
- # due to pytorch#99272, MPS does not yet support bfloat16.
236
- raise ValueError(
237
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
238
- )
239
-
240
- logging_dir = Path(args.output_dir, args.logging_dir)
241
-
242
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
243
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
244
- init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
245
- accelerator = Accelerator(
246
- gradient_accumulation_steps=args.gradient_accumulation_steps,
247
- mixed_precision=args.mixed_precision,
248
- log_with=args.report_to,
249
- project_config=accelerator_project_config,
250
- kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
251
- )
252
-
253
- # Disable AMP for MPS.
254
- if torch.backends.mps.is_available():
255
- accelerator.native_amp = False
256
-
257
- # Make one log on every process with the configuration for debugging.
258
- logging.basicConfig(
259
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
260
- datefmt="%m/%d/%Y %H:%M:%S",
261
- level=logging.INFO,
262
- )
263
- logger.info(accelerator.state, main_process_only=False)
264
- if accelerator.is_local_main_process:
265
- transformers.utils.logging.set_verbosity_warning()
266
- diffusers.utils.logging.set_verbosity_info()
267
- else:
268
- transformers.utils.logging.set_verbosity_error()
269
- diffusers.utils.logging.set_verbosity_error()
270
-
271
- # If passed along, set the training seed now.
272
- if args.seed is not None:
273
- set_seed(args.seed)
274
-
275
- # Handle the repository creation
276
- if accelerator.is_main_process:
277
- if args.output_dir is not None:
278
- os.makedirs(args.output_dir, exist_ok=True)
279
-
280
- if args.push_to_hub:
281
- repo_id = create_repo(
282
- repo_id=args.hub_model_id or Path(args.output_dir).name,
283
- exist_ok=True,
284
- ).repo_id
285
-
286
- # Prepare models and scheduler
287
- tokenizer = AutoTokenizer.from_pretrained(
288
- args.pretrained_model_name_or_path,
289
- subfolder="tokenizer",
290
- revision=args.revision,
291
- )
292
-
293
- text_encoder = T5EncoderModel.from_pretrained(
294
- args.pretrained_model_name_or_path,
295
- subfolder="text_encoder",
296
- revision=args.revision,
297
- )
298
-
299
- # CogVideoX-2b weights are stored in float16
300
- # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
301
- load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
302
- transformer = CogVideoXTransformer3DModel.from_pretrained(
303
- args.pretrained_model_name_or_path,
304
- subfolder="transformer",
305
- torch_dtype=load_dtype,
306
- revision=args.revision,
307
- variant=args.variant,
308
- )
309
-
310
- vae = AutoencoderKLCogVideoX.from_pretrained(
311
- args.pretrained_model_name_or_path,
312
- subfolder="vae",
313
- revision=args.revision,
314
- variant=args.variant,
315
- )
316
-
317
- scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
318
-
319
- if args.enable_slicing:
320
- vae.enable_slicing()
321
- if args.enable_tiling:
322
- vae.enable_tiling()
323
-
324
- # We only train the additional adapter LoRA layers
325
- text_encoder.requires_grad_(False)
326
- transformer.requires_grad_(False)
327
- vae.requires_grad_(False)
328
-
329
- VAE_SCALING_FACTOR = vae.config.scaling_factor
330
- VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
331
- RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
332
- RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
333
-
334
- # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
335
- # as these weights are only used for inference, keeping weights in full precision is not required.
336
- weight_dtype = torch.float32
337
- if accelerator.state.deepspeed_plugin:
338
- # DeepSpeed is handling precision, use what's in the DeepSpeed config
339
- if (
340
- "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
341
- and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
342
- ):
343
- weight_dtype = torch.float16
344
- if (
345
- "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
346
- and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
347
- ):
348
- weight_dtype = torch.bfloat16
349
- else:
350
- if accelerator.mixed_precision == "fp16":
351
- weight_dtype = torch.float16
352
- elif accelerator.mixed_precision == "bf16":
353
- weight_dtype = torch.bfloat16
354
-
355
- if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
356
- # due to pytorch#99272, MPS does not yet support bfloat16.
357
- raise ValueError(
358
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
359
- )
360
-
361
- text_encoder.to(accelerator.device, dtype=weight_dtype)
362
- transformer.to(accelerator.device, dtype=weight_dtype)
363
- vae.to(accelerator.device, dtype=weight_dtype)
364
-
365
- if args.gradient_checkpointing:
366
- transformer.enable_gradient_checkpointing()
367
-
368
- # now we will add new LoRA weights to the attention layers
369
- transformer_lora_config = LoraConfig(
370
- r=args.rank,
371
- lora_alpha=args.lora_alpha,
372
- init_lora_weights=True,
373
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
374
- )
375
- transformer.add_adapter(transformer_lora_config)
376
-
377
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
378
- def save_model_hook(models, weights, output_dir):
379
- if accelerator.is_main_process:
380
- transformer_lora_layers_to_save = None
381
-
382
- for model in models:
383
- if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
384
- model = unwrap_model(accelerator, model)
385
- transformer_lora_layers_to_save = get_peft_model_state_dict(model)
386
- else:
387
- raise ValueError(f"unexpected save model: {model.__class__}")
388
-
389
- # make sure to pop weight so that corresponding model is not saved again
390
- if weights:
391
- weights.pop()
392
-
393
- CogVideoXPipeline.save_lora_weights(
394
- output_dir,
395
- transformer_lora_layers=transformer_lora_layers_to_save,
396
- )
397
-
398
- def load_model_hook(models, input_dir):
399
- transformer_ = None
400
-
401
- # This is a bit of a hack but I don't know any other solution.
402
- if not accelerator.distributed_type == DistributedType.DEEPSPEED:
403
- while len(models) > 0:
404
- model = models.pop()
405
-
406
- if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
407
- transformer_ = unwrap_model(accelerator, model)
408
- else:
409
- raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
410
- else:
411
- transformer_ = CogVideoXTransformer3DModel.from_pretrained(
412
- args.pretrained_model_name_or_path, subfolder="transformer"
413
- )
414
- transformer_.add_adapter(transformer_lora_config)
415
-
416
- lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
417
-
418
- transformer_state_dict = {
419
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
420
- }
421
- transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
422
- incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
423
- if incompatible_keys is not None:
424
- # check only for unexpected keys
425
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
426
- if unexpected_keys:
427
- logger.warning(
428
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
429
- f" {unexpected_keys}. "
430
- )
431
-
432
- # Make sure the trainable params are in float32. This is again needed since the base models
433
- # are in `weight_dtype`. More details:
434
- # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
435
- if args.mixed_precision == "fp16":
436
- # only upcast trainable parameters (LoRA) into fp32
437
- cast_training_params([transformer_])
438
-
439
- accelerator.register_save_state_pre_hook(save_model_hook)
440
- accelerator.register_load_state_pre_hook(load_model_hook)
441
-
442
- # Enable TF32 for faster training on Ampere GPUs,
443
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
444
- if args.allow_tf32 and torch.cuda.is_available():
445
- torch.backends.cuda.matmul.allow_tf32 = True
446
-
447
- if args.scale_lr:
448
- args.learning_rate = (
449
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
450
- )
451
-
452
- # Make sure the trainable params are in float32.
453
- if args.mixed_precision == "fp16":
454
- # only upcast trainable parameters (LoRA) into fp32
455
- cast_training_params([transformer], dtype=torch.float32)
456
-
457
- transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
458
-
459
- # Optimization parameters
460
- transformer_parameters_with_lr = {
461
- "params": transformer_lora_parameters,
462
- "lr": args.learning_rate,
463
- }
464
- params_to_optimize = [transformer_parameters_with_lr]
465
- num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
466
-
467
- use_deepspeed_optimizer = (
468
- accelerator.state.deepspeed_plugin is not None
469
- and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
470
- )
471
- use_deepspeed_scheduler = (
472
- accelerator.state.deepspeed_plugin is not None
473
- and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
474
- )
475
-
476
- optimizer = get_optimizer(
477
- params_to_optimize=params_to_optimize,
478
- optimizer_name=args.optimizer,
479
- learning_rate=args.learning_rate,
480
- beta1=args.beta1,
481
- beta2=args.beta2,
482
- beta3=args.beta3,
483
- epsilon=args.epsilon,
484
- weight_decay=args.weight_decay,
485
- prodigy_decouple=args.prodigy_decouple,
486
- prodigy_use_bias_correction=args.prodigy_use_bias_correction,
487
- prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
488
- use_8bit=args.use_8bit,
489
- use_4bit=args.use_4bit,
490
- use_torchao=args.use_torchao,
491
- use_deepspeed=use_deepspeed_optimizer,
492
- use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
493
- offload_gradients=args.offload_gradients,
494
- )
495
-
496
- # Dataset and DataLoader
497
- dataset_init_kwargs = {
498
- "data_root": args.data_root,
499
- "dataset_file": args.dataset_file,
500
- "caption_column": args.caption_column,
501
- "video_column": args.video_column,
502
- "max_num_frames": args.max_num_frames,
503
- "id_token": args.id_token,
504
- "height_buckets": args.height_buckets,
505
- "width_buckets": args.width_buckets,
506
- "frame_buckets": args.frame_buckets,
507
- "load_tensors": args.load_tensors,
508
- "random_flip": args.random_flip,
509
- }
510
- if args.video_reshape_mode is None:
511
- train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
512
- else:
513
- train_dataset = VideoDatasetWithResizeAndRectangleCrop(
514
- video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
515
- )
516
-
517
- collate_fn = CollateFunction(weight_dtype, args.load_tensors)
518
-
519
- train_dataloader = DataLoader(
520
- train_dataset,
521
- batch_size=1,
522
- sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
523
- collate_fn=collate_fn,
524
- num_workers=args.dataloader_num_workers,
525
- pin_memory=args.pin_memory,
526
- )
527
-
528
- # Scheduler and math around the number of training steps.
529
- overrode_max_train_steps = False
530
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
531
- if args.max_train_steps is None:
532
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
533
- overrode_max_train_steps = True
534
-
535
- if args.use_cpu_offload_optimizer:
536
- lr_scheduler = None
537
- accelerator.print(
538
- "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
539
- "you are training with those settings, they will be ignored."
540
- )
541
- else:
542
- if use_deepspeed_scheduler:
543
- from accelerate.utils import DummyScheduler
544
-
545
- lr_scheduler = DummyScheduler(
546
- name=args.lr_scheduler,
547
- optimizer=optimizer,
548
- total_num_steps=args.max_train_steps * accelerator.num_processes,
549
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
550
- )
551
- else:
552
- lr_scheduler = get_scheduler(
553
- args.lr_scheduler,
554
- optimizer=optimizer,
555
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
556
- num_training_steps=args.max_train_steps * accelerator.num_processes,
557
- num_cycles=args.lr_num_cycles,
558
- power=args.lr_power,
559
- )
560
-
561
- # Prepare everything with our `accelerator`.
562
- transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
563
- transformer, optimizer, train_dataloader, lr_scheduler
564
- )
565
-
566
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
567
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
568
- if overrode_max_train_steps:
569
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
570
- # Afterwards we recalculate our number of training epochs
571
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
572
-
573
- # We need to initialize the trackers we use, and also store our configuration.
574
- # The trackers initializes automatically on the main process.
575
- if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
576
- tracker_name = args.tracker_name or "cogvideox-lora"
577
- accelerator.init_trackers(tracker_name, config=vars(args))
578
-
579
- accelerator.print("===== Memory before training =====")
580
- reset_memory(accelerator.device)
581
- print_memory(accelerator.device)
582
-
583
- # Train!
584
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
585
-
586
- accelerator.print("***** Running training *****")
587
- accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
588
- accelerator.print(f" Num examples = {len(train_dataset)}")
589
- accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
590
- accelerator.print(f" Num epochs = {args.num_train_epochs}")
591
- accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
592
- accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
593
- accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
594
- accelerator.print(f" Total optimization steps = {args.max_train_steps}")
595
- global_step = 0
596
- first_epoch = 0
597
-
598
- # Potentially load in the weights and states from a previous save
599
- if not args.resume_from_checkpoint:
600
- initial_global_step = 0
601
- else:
602
- if args.resume_from_checkpoint != "latest":
603
- path = os.path.basename(args.resume_from_checkpoint)
604
- else:
605
- # Get the most recent checkpoint
606
- dirs = os.listdir(args.output_dir)
607
- dirs = [d for d in dirs if d.startswith("checkpoint")]
608
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
609
- path = dirs[-1] if len(dirs) > 0 else None
610
-
611
- if path is None:
612
- accelerator.print(
613
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
614
- )
615
- args.resume_from_checkpoint = None
616
- initial_global_step = 0
617
- else:
618
- accelerator.print(f"Resuming from checkpoint {path}")
619
- accelerator.load_state(os.path.join(args.output_dir, path))
620
- global_step = int(path.split("-")[1])
621
-
622
- initial_global_step = global_step
623
- first_epoch = global_step // num_update_steps_per_epoch
624
-
625
- progress_bar = tqdm(
626
- range(0, args.max_train_steps),
627
- initial=initial_global_step,
628
- desc="Steps",
629
- # Only show the progress bar once on each machine.
630
- disable=not accelerator.is_local_main_process,
631
- )
632
-
633
- # For DeepSpeed training
634
- model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
635
-
636
- if args.load_tensors:
637
- del vae, text_encoder
638
- gc.collect()
639
- torch.cuda.empty_cache()
640
- torch.cuda.synchronize(accelerator.device)
641
-
642
- alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
643
-
644
- for epoch in range(first_epoch, args.num_train_epochs):
645
- transformer.train()
646
-
647
- for step, batch in enumerate(train_dataloader):
648
- models_to_accumulate = [transformer]
649
- logs = {}
650
-
651
- with accelerator.accumulate(models_to_accumulate):
652
- videos = batch["videos"].to(accelerator.device, non_blocking=True)
653
- prompts = batch["prompts"]
654
-
655
- # Encode videos
656
- if not args.load_tensors:
657
- videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
658
- latent_dist = vae.encode(videos).latent_dist
659
- else:
660
- latent_dist = DiagonalGaussianDistribution(videos)
661
-
662
- videos = latent_dist.sample() * VAE_SCALING_FACTOR
663
- videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
664
- videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
665
- model_input = videos
666
-
667
- # Encode prompts
668
- if not args.load_tensors:
669
- prompt_embeds = compute_prompt_embeddings(
670
- tokenizer,
671
- text_encoder,
672
- prompts,
673
- model_config.max_text_seq_length,
674
- accelerator.device,
675
- weight_dtype,
676
- requires_grad=False,
677
- )
678
- else:
679
- prompt_embeds = prompts.to(dtype=weight_dtype)
680
-
681
- # Sample noise that will be added to the latents
682
- noise = torch.randn_like(model_input)
683
- batch_size, num_frames, num_channels, height, width = model_input.shape
684
-
685
- # Sample a random timestep for each image
686
- timesteps = torch.randint(
687
- 0,
688
- scheduler.config.num_train_timesteps,
689
- (batch_size,),
690
- dtype=torch.int64,
691
- device=model_input.device,
692
- )
693
-
694
- # Prepare rotary embeds
695
- image_rotary_emb = (
696
- prepare_rotary_positional_embeddings(
697
- height=height * VAE_SCALE_FACTOR_SPATIAL,
698
- width=width * VAE_SCALE_FACTOR_SPATIAL,
699
- num_frames=num_frames,
700
- vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
701
- patch_size=model_config.patch_size,
702
- patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
703
- attention_head_dim=model_config.attention_head_dim,
704
- device=accelerator.device,
705
- base_height=RoPE_BASE_HEIGHT,
706
- base_width=RoPE_BASE_WIDTH,
707
- )
708
- if model_config.use_rotary_positional_embeddings
709
- else None
710
- )
711
-
712
- # Add noise to the model input according to the noise magnitude at each timestep
713
- # (this is the forward diffusion process)
714
- noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
715
-
716
- # Predict the noise residual
717
- model_output = transformer(
718
- hidden_states=noisy_model_input,
719
- encoder_hidden_states=prompt_embeds,
720
- timestep=timesteps,
721
- image_rotary_emb=image_rotary_emb,
722
- return_dict=False,
723
- )[0]
724
-
725
- model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
726
-
727
- weights = 1 / (1 - alphas_cumprod[timesteps])
728
- while len(weights.shape) < len(model_pred.shape):
729
- weights = weights.unsqueeze(-1)
730
-
731
- target = model_input
732
-
733
- loss = torch.mean(
734
- (weights * (model_pred - target) ** 2).reshape(batch_size, -1),
735
- dim=1,
736
- )
737
- loss = loss.mean()
738
- accelerator.backward(loss)
739
-
740
- if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
741
- gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
742
- accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
743
- gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
744
- logs.update(
745
- {
746
- "gradient_norm_before_clip": gradient_norm_before_clip,
747
- "gradient_norm_after_clip": gradient_norm_after_clip,
748
- }
749
- )
750
-
751
- if accelerator.state.deepspeed_plugin is None:
752
- optimizer.step()
753
- optimizer.zero_grad()
754
-
755
- if not args.use_cpu_offload_optimizer:
756
- lr_scheduler.step()
757
-
758
- # Checks if the accelerator has performed an optimization step behind the scenes
759
- if accelerator.sync_gradients:
760
- progress_bar.update(1)
761
- global_step += 1
762
-
763
- if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
764
- if global_step % args.checkpointing_steps == 0:
765
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
766
- if args.checkpoints_total_limit is not None:
767
- checkpoints = os.listdir(args.output_dir)
768
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
769
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
770
-
771
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
772
- if len(checkpoints) >= args.checkpoints_total_limit:
773
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
774
- removing_checkpoints = checkpoints[0:num_to_remove]
775
-
776
- logger.info(
777
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
778
- )
779
- logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
780
-
781
- for removing_checkpoint in removing_checkpoints:
782
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
783
- shutil.rmtree(removing_checkpoint)
784
-
785
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
786
- accelerator.save_state(save_path)
787
- logger.info(f"Saved state to {save_path}")
788
-
789
- last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
790
- logs.update(
791
- {
792
- "loss": loss.detach().item(),
793
- "lr": last_lr,
794
- }
795
- )
796
- progress_bar.set_postfix(**logs)
797
- accelerator.log(logs, step=global_step)
798
-
799
- if global_step >= args.max_train_steps:
800
- break
801
-
802
- if accelerator.is_main_process:
803
- if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
804
- accelerator.print("===== Memory before validation =====")
805
- print_memory(accelerator.device)
806
- torch.cuda.synchronize(accelerator.device)
807
-
808
- pipe = CogVideoXPipeline.from_pretrained(
809
- args.pretrained_model_name_or_path,
810
- transformer=unwrap_model(accelerator, transformer),
811
- scheduler=scheduler,
812
- revision=args.revision,
813
- variant=args.variant,
814
- torch_dtype=weight_dtype,
815
- )
816
-
817
- if args.enable_slicing:
818
- pipe.vae.enable_slicing()
819
- if args.enable_tiling:
820
- pipe.vae.enable_tiling()
821
- if args.enable_model_cpu_offload:
822
- pipe.enable_model_cpu_offload()
823
-
824
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
825
- for validation_prompt in validation_prompts:
826
- pipeline_args = {
827
- "prompt": validation_prompt,
828
- "guidance_scale": args.guidance_scale,
829
- "use_dynamic_cfg": args.use_dynamic_cfg,
830
- "height": args.height,
831
- "width": args.width,
832
- "max_sequence_length": model_config.max_text_seq_length,
833
- }
834
-
835
- log_validation(
836
- pipe=pipe,
837
- args=args,
838
- accelerator=accelerator,
839
- pipeline_args=pipeline_args,
840
- epoch=epoch,
841
- )
842
-
843
- accelerator.print("===== Memory after validation =====")
844
- print_memory(accelerator.device)
845
- reset_memory(accelerator.device)
846
-
847
- del pipe
848
- gc.collect()
849
- torch.cuda.empty_cache()
850
- torch.cuda.synchronize(accelerator.device)
851
-
852
- accelerator.wait_for_everyone()
853
-
854
- if accelerator.is_main_process:
855
- transformer = unwrap_model(accelerator, transformer)
856
- dtype = (
857
- torch.float16
858
- if args.mixed_precision == "fp16"
859
- else torch.bfloat16
860
- if args.mixed_precision == "bf16"
861
- else torch.float32
862
- )
863
- transformer = transformer.to(dtype)
864
- transformer_lora_layers = get_peft_model_state_dict(transformer)
865
-
866
- CogVideoXPipeline.save_lora_weights(
867
- save_directory=args.output_dir,
868
- transformer_lora_layers=transformer_lora_layers,
869
- )
870
-
871
- # Cleanup trained models to save memory
872
- if args.load_tensors:
873
- del transformer
874
- else:
875
- del transformer, text_encoder, vae
876
-
877
- gc.collect()
878
- torch.cuda.empty_cache()
879
- torch.cuda.synchronize(accelerator.device)
880
-
881
- accelerator.print("===== Memory before testing =====")
882
- print_memory(accelerator.device)
883
- reset_memory(accelerator.device)
884
-
885
- # Final test inference
886
- pipe = CogVideoXPipeline.from_pretrained(
887
- args.pretrained_model_name_or_path,
888
- revision=args.revision,
889
- variant=args.variant,
890
- torch_dtype=weight_dtype,
891
- )
892
- pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
893
-
894
- if args.enable_slicing:
895
- pipe.vae.enable_slicing()
896
- if args.enable_tiling:
897
- pipe.vae.enable_tiling()
898
- if args.enable_model_cpu_offload:
899
- pipe.enable_model_cpu_offload()
900
-
901
- # Load LoRA weights
902
- lora_scaling = args.lora_alpha / args.rank
903
- pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
904
- pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
905
-
906
- # Run inference
907
- validation_outputs = []
908
- if args.validation_prompt and args.num_validation_videos > 0:
909
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
910
- for validation_prompt in validation_prompts:
911
- pipeline_args = {
912
- "prompt": validation_prompt,
913
- "guidance_scale": args.guidance_scale,
914
- "use_dynamic_cfg": args.use_dynamic_cfg,
915
- "height": args.height,
916
- "width": args.width,
917
- }
918
-
919
- video = log_validation(
920
- accelerator=accelerator,
921
- pipe=pipe,
922
- args=args,
923
- pipeline_args=pipeline_args,
924
- epoch=epoch,
925
- is_final_validation=True,
926
- )
927
- validation_outputs.extend(video)
928
-
929
- accelerator.print("===== Memory after testing =====")
930
- print_memory(accelerator.device)
931
- reset_memory(accelerator.device)
932
- torch.cuda.synchronize(accelerator.device)
933
-
934
- if args.push_to_hub:
935
- save_model_card(
936
- repo_id,
937
- videos=validation_outputs,
938
- base_model=args.pretrained_model_name_or_path,
939
- validation_prompt=args.validation_prompt,
940
- repo_folder=args.output_dir,
941
- fps=args.fps,
942
- )
943
- upload_folder(
944
- repo_id=repo_id,
945
- folder_path=args.output_dir,
946
- commit_message="End of training",
947
- ignore_patterns=["step_*", "epoch_*"],
948
- )
949
-
950
- accelerator.end_training()
951
-
952
-
953
- if __name__ == "__main__":
954
- args = get_args()
955
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/cogvideox/cogvideox_text_to_video_sft.py DELETED
@@ -1,917 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import gc
17
- import logging
18
- import math
19
- import os
20
- import shutil
21
- from datetime import timedelta
22
- from pathlib import Path
23
- from typing import Any, Dict
24
-
25
- import diffusers
26
- import torch
27
- import transformers
28
- import wandb
29
- from accelerate import Accelerator, DistributedType, init_empty_weights
30
- from accelerate.logging import get_logger
31
- from accelerate.utils import (
32
- DistributedDataParallelKwargs,
33
- InitProcessGroupKwargs,
34
- ProjectConfiguration,
35
- set_seed,
36
- )
37
- from diffusers import (
38
- AutoencoderKLCogVideoX,
39
- CogVideoXDPMScheduler,
40
- CogVideoXPipeline,
41
- CogVideoXTransformer3DModel,
42
- )
43
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
44
- from diffusers.optimization import get_scheduler
45
- from diffusers.training_utils import cast_training_params
46
- from diffusers.utils import export_to_video
47
- from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
48
- from huggingface_hub import create_repo, upload_folder
49
- from torch.utils.data import DataLoader
50
- from tqdm.auto import tqdm
51
- from transformers import AutoTokenizer, T5EncoderModel
52
-
53
-
54
- from args import get_args # isort:skip
55
- from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
56
- from text_encoder import compute_prompt_embeddings # isort:skip
57
- from utils import (
58
- get_gradient_norm,
59
- get_optimizer,
60
- prepare_rotary_positional_embeddings,
61
- print_memory,
62
- reset_memory,
63
- unwrap_model,
64
- ) # isort:skip
65
-
66
-
67
- logger = get_logger(__name__)
68
-
69
-
70
- def save_model_card(
71
- repo_id: str,
72
- videos=None,
73
- base_model: str = None,
74
- validation_prompt=None,
75
- repo_folder=None,
76
- fps=8,
77
- ):
78
- widget_dict = []
79
- if videos is not None:
80
- for i, video in enumerate(videos):
81
- export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
82
- widget_dict.append(
83
- {
84
- "text": validation_prompt if validation_prompt else " ",
85
- "output": {"url": f"video_{i}.mp4"},
86
- }
87
- )
88
-
89
- model_description = f"""
90
- # CogVideoX Full Finetune
91
-
92
- <Gallery />
93
-
94
- ## Model description
95
-
96
- This is a full finetune of the CogVideoX model `{base_model}`.
97
-
98
- The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
99
-
100
- ## Download model
101
-
102
- [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
103
-
104
- ## Usage
105
-
106
- Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
107
-
108
- ```py
109
- import torch
110
- from diffusers import CogVideoXPipeline
111
- from diffusers.utils import export_to_video
112
-
113
- pipe = CogVideoXPipeline.from_pretrained("{repo_id}", torch_dtype=torch.bfloat16).to("cuda")
114
-
115
- video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
116
- export_to_video(video, "output.mp4", fps=8)
117
- ```
118
-
119
- For more details, checkout the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox) for CogVideoX.
120
-
121
- ## License
122
-
123
- Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
124
- """
125
- model_card = load_or_create_model_card(
126
- repo_id_or_path=repo_id,
127
- from_training=True,
128
- license="other",
129
- base_model=base_model,
130
- prompt=validation_prompt,
131
- model_description=model_description,
132
- widget=widget_dict,
133
- )
134
- tags = [
135
- "text-to-video",
136
- "diffusers-training",
137
- "diffusers",
138
- "cogvideox",
139
- "cogvideox-diffusers",
140
- ]
141
-
142
- model_card = populate_model_card(model_card, tags=tags)
143
- model_card.save(os.path.join(repo_folder, "README.md"))
144
-
145
-
146
- def log_validation(
147
- accelerator: Accelerator,
148
- pipe: CogVideoXPipeline,
149
- args: Dict[str, Any],
150
- pipeline_args: Dict[str, Any],
151
- epoch,
152
- is_final_validation: bool = False,
153
- ):
154
- logger.info(
155
- f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
156
- )
157
-
158
- pipe = pipe.to(accelerator.device)
159
-
160
- # run inference
161
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
162
-
163
- videos = []
164
- for _ in range(args.num_validation_videos):
165
- video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
166
- videos.append(video)
167
-
168
- for tracker in accelerator.trackers:
169
- phase_name = "test" if is_final_validation else "validation"
170
- if tracker.name == "wandb":
171
- video_filenames = []
172
- for i, video in enumerate(videos):
173
- prompt = (
174
- pipeline_args["prompt"][:25]
175
- .replace(" ", "_")
176
- .replace(" ", "_")
177
- .replace("'", "_")
178
- .replace('"', "_")
179
- .replace("/", "_")
180
- )
181
- filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
182
- export_to_video(video, filename, fps=8)
183
- video_filenames.append(filename)
184
-
185
- tracker.log(
186
- {
187
- phase_name: [
188
- wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
189
- for i, filename in enumerate(video_filenames)
190
- ]
191
- }
192
- )
193
-
194
- return videos
195
-
196
-
197
- class CollateFunction:
198
- def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
199
- self.weight_dtype = weight_dtype
200
- self.load_tensors = load_tensors
201
-
202
- def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
203
- prompts = [x["prompt"] for x in data[0]]
204
-
205
- if self.load_tensors:
206
- prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
207
-
208
- videos = [x["video"] for x in data[0]]
209
- videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
210
-
211
- return {
212
- "videos": videos,
213
- "prompts": prompts,
214
- }
215
-
216
-
217
- def main(args):
218
- if args.report_to == "wandb" and args.hub_token is not None:
219
- raise ValueError(
220
- "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
221
- " Please use `huggingface-cli login` to authenticate with the Hub."
222
- )
223
-
224
- if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
225
- # due to pytorch#99272, MPS does not yet support bfloat16.
226
- raise ValueError(
227
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
228
- )
229
-
230
- logging_dir = Path(args.output_dir, args.logging_dir)
231
-
232
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
233
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
234
- init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
235
- accelerator = Accelerator(
236
- gradient_accumulation_steps=args.gradient_accumulation_steps,
237
- mixed_precision=args.mixed_precision,
238
- log_with=args.report_to,
239
- project_config=accelerator_project_config,
240
- kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
241
- )
242
-
243
- # Disable AMP for MPS.
244
- if torch.backends.mps.is_available():
245
- accelerator.native_amp = False
246
-
247
- # Make one log on every process with the configuration for debugging.
248
- logging.basicConfig(
249
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
250
- datefmt="%m/%d/%Y %H:%M:%S",
251
- level=logging.INFO,
252
- )
253
- logger.info(accelerator.state, main_process_only=False)
254
- if accelerator.is_local_main_process:
255
- transformers.utils.logging.set_verbosity_warning()
256
- diffusers.utils.logging.set_verbosity_info()
257
- else:
258
- transformers.utils.logging.set_verbosity_error()
259
- diffusers.utils.logging.set_verbosity_error()
260
-
261
- # If passed along, set the training seed now.
262
- if args.seed is not None:
263
- set_seed(args.seed)
264
-
265
- # Handle the repository creation
266
- if accelerator.is_main_process:
267
- if args.output_dir is not None:
268
- os.makedirs(args.output_dir, exist_ok=True)
269
-
270
- if args.push_to_hub:
271
- repo_id = create_repo(
272
- repo_id=args.hub_model_id or Path(args.output_dir).name,
273
- exist_ok=True,
274
- ).repo_id
275
-
276
- # Prepare models and scheduler
277
- tokenizer = AutoTokenizer.from_pretrained(
278
- args.pretrained_model_name_or_path,
279
- subfolder="tokenizer",
280
- revision=args.revision,
281
- )
282
-
283
- text_encoder = T5EncoderModel.from_pretrained(
284
- args.pretrained_model_name_or_path,
285
- subfolder="text_encoder",
286
- revision=args.revision,
287
- )
288
-
289
- # CogVideoX-2b weights are stored in float16
290
- # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
291
- load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
292
- transformer = CogVideoXTransformer3DModel.from_pretrained(
293
- args.pretrained_model_name_or_path,
294
- subfolder="transformer",
295
- torch_dtype=load_dtype,
296
- revision=args.revision,
297
- variant=args.variant,
298
- )
299
-
300
- vae = AutoencoderKLCogVideoX.from_pretrained(
301
- args.pretrained_model_name_or_path,
302
- subfolder="vae",
303
- revision=args.revision,
304
- variant=args.variant,
305
- )
306
-
307
- scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
308
-
309
- if args.enable_slicing:
310
- vae.enable_slicing()
311
- if args.enable_tiling:
312
- vae.enable_tiling()
313
-
314
- text_encoder.requires_grad_(False)
315
- vae.requires_grad_(False)
316
- transformer.requires_grad_(True)
317
-
318
- VAE_SCALING_FACTOR = vae.config.scaling_factor
319
- VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
320
- RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
321
- RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
322
-
323
- # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
324
- # as these weights are only used for inference, keeping weights in full precision is not required.
325
- weight_dtype = torch.float32
326
- if accelerator.state.deepspeed_plugin:
327
- # DeepSpeed is handling precision, use what's in the DeepSpeed config
328
- if (
329
- "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
330
- and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
331
- ):
332
- weight_dtype = torch.float16
333
- if (
334
- "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
335
- and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
336
- ):
337
- weight_dtype = torch.bfloat16
338
- else:
339
- if accelerator.mixed_precision == "fp16":
340
- weight_dtype = torch.float16
341
- elif accelerator.mixed_precision == "bf16":
342
- weight_dtype = torch.bfloat16
343
-
344
- if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
345
- # due to pytorch#99272, MPS does not yet support bfloat16.
346
- raise ValueError(
347
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
348
- )
349
-
350
- text_encoder.to(accelerator.device, dtype=weight_dtype)
351
- transformer.to(accelerator.device, dtype=weight_dtype)
352
- vae.to(accelerator.device, dtype=weight_dtype)
353
-
354
- if args.gradient_checkpointing:
355
- transformer.enable_gradient_checkpointing()
356
-
357
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
358
- def save_model_hook(models, weights, output_dir):
359
- if accelerator.is_main_process:
360
- for model in models:
361
- if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
362
- model: CogVideoXTransformer3DModel
363
- model = unwrap_model(accelerator, model)
364
- model.save_pretrained(
365
- os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB"
366
- )
367
- else:
368
- raise ValueError(f"Unexpected save model: {model.__class__}")
369
-
370
- # make sure to pop weight so that corresponding model is not saved again
371
- if weights:
372
- weights.pop()
373
-
374
- def load_model_hook(models, input_dir):
375
- transformer_ = None
376
- init_under_meta = False
377
-
378
- # This is a bit of a hack but I don't know any other solution.
379
- if not accelerator.distributed_type == DistributedType.DEEPSPEED:
380
- while len(models) > 0:
381
- model = models.pop()
382
-
383
- if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
384
- transformer_ = unwrap_model(accelerator, model)
385
- else:
386
- raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
387
- else:
388
- with init_empty_weights():
389
- transformer_ = CogVideoXTransformer3DModel.from_config(
390
- args.pretrained_model_name_or_path, subfolder="transformer"
391
- )
392
- init_under_meta = True
393
-
394
- load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer"))
395
- transformer_.register_to_config(**load_model.config)
396
- transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta)
397
- del load_model
398
-
399
- # Make sure the trainable params are in float32. This is again needed since the base models
400
- # are in `weight_dtype`. More details:
401
- # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
402
- if args.mixed_precision == "fp16":
403
- cast_training_params([transformer_])
404
-
405
- accelerator.register_save_state_pre_hook(save_model_hook)
406
- accelerator.register_load_state_pre_hook(load_model_hook)
407
-
408
- # Enable TF32 for faster training on Ampere GPUs,
409
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
410
- if args.allow_tf32 and torch.cuda.is_available():
411
- torch.backends.cuda.matmul.allow_tf32 = True
412
-
413
- if args.scale_lr:
414
- args.learning_rate = (
415
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
416
- )
417
-
418
- # Make sure the trainable params are in float32.
419
- if args.mixed_precision == "fp16":
420
- # only upcast trainable parameters (LoRA) into fp32
421
- cast_training_params([transformer], dtype=torch.float32)
422
-
423
- transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
424
-
425
- # Optimization parameters
426
- transformer_parameters_with_lr = {
427
- "params": transformer_parameters,
428
- "lr": args.learning_rate,
429
- }
430
- params_to_optimize = [transformer_parameters_with_lr]
431
- num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
432
-
433
- use_deepspeed_optimizer = (
434
- accelerator.state.deepspeed_plugin is not None
435
- and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
436
- )
437
- use_deepspeed_scheduler = (
438
- accelerator.state.deepspeed_plugin is not None
439
- and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
440
- )
441
-
442
- optimizer = get_optimizer(
443
- params_to_optimize=params_to_optimize,
444
- optimizer_name=args.optimizer,
445
- learning_rate=args.learning_rate,
446
- beta1=args.beta1,
447
- beta2=args.beta2,
448
- beta3=args.beta3,
449
- epsilon=args.epsilon,
450
- weight_decay=args.weight_decay,
451
- prodigy_decouple=args.prodigy_decouple,
452
- prodigy_use_bias_correction=args.prodigy_use_bias_correction,
453
- prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
454
- use_8bit=args.use_8bit,
455
- use_4bit=args.use_4bit,
456
- use_torchao=args.use_torchao,
457
- use_deepspeed=use_deepspeed_optimizer,
458
- use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
459
- offload_gradients=args.offload_gradients,
460
- )
461
-
462
- # Dataset and DataLoader
463
- dataset_init_kwargs = {
464
- "data_root": args.data_root,
465
- "dataset_file": args.dataset_file,
466
- "caption_column": args.caption_column,
467
- "video_column": args.video_column,
468
- "max_num_frames": args.max_num_frames,
469
- "id_token": args.id_token,
470
- "height_buckets": args.height_buckets,
471
- "width_buckets": args.width_buckets,
472
- "frame_buckets": args.frame_buckets,
473
- "load_tensors": args.load_tensors,
474
- "random_flip": args.random_flip,
475
- }
476
- if args.video_reshape_mode is None:
477
- train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
478
- else:
479
- train_dataset = VideoDatasetWithResizeAndRectangleCrop(
480
- video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
481
- )
482
-
483
- collate_fn = CollateFunction(weight_dtype, args.load_tensors)
484
-
485
- train_dataloader = DataLoader(
486
- train_dataset,
487
- batch_size=1,
488
- sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
489
- collate_fn=collate_fn,
490
- num_workers=args.dataloader_num_workers,
491
- pin_memory=args.pin_memory,
492
- )
493
-
494
- # Scheduler and math around the number of training steps.
495
- overrode_max_train_steps = False
496
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
497
- if args.max_train_steps is None:
498
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
499
- overrode_max_train_steps = True
500
-
501
- if args.use_cpu_offload_optimizer:
502
- lr_scheduler = None
503
- accelerator.print(
504
- "CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
505
- "you are training with those settings, they will be ignored."
506
- )
507
- else:
508
- if use_deepspeed_scheduler:
509
- from accelerate.utils import DummyScheduler
510
-
511
- lr_scheduler = DummyScheduler(
512
- name=args.lr_scheduler,
513
- optimizer=optimizer,
514
- total_num_steps=args.max_train_steps * accelerator.num_processes,
515
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
516
- )
517
- else:
518
- lr_scheduler = get_scheduler(
519
- args.lr_scheduler,
520
- optimizer=optimizer,
521
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
522
- num_training_steps=args.max_train_steps * accelerator.num_processes,
523
- num_cycles=args.lr_num_cycles,
524
- power=args.lr_power,
525
- )
526
-
527
- # Prepare everything with our `accelerator`.
528
- transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
529
- transformer, optimizer, train_dataloader, lr_scheduler
530
- )
531
-
532
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
533
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
534
- if overrode_max_train_steps:
535
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
536
- # Afterwards we recalculate our number of training epochs
537
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
538
-
539
- # We need to initialize the trackers we use, and also store our configuration.
540
- # The trackers initializes automatically on the main process.
541
- if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
542
- tracker_name = args.tracker_name or "cogvideox-sft"
543
- accelerator.init_trackers(tracker_name, config=vars(args))
544
-
545
- accelerator.print("===== Memory before training =====")
546
- reset_memory(accelerator.device)
547
- print_memory(accelerator.device)
548
-
549
- # Train!
550
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
551
-
552
- accelerator.print("***** Running training *****")
553
- accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
554
- accelerator.print(f" Num examples = {len(train_dataset)}")
555
- accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
556
- accelerator.print(f" Num epochs = {args.num_train_epochs}")
557
- accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
558
- accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
559
- accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
560
- accelerator.print(f" Total optimization steps = {args.max_train_steps}")
561
- global_step = 0
562
- first_epoch = 0
563
-
564
- # Potentially load in the weights and states from a previous save
565
- if not args.resume_from_checkpoint:
566
- initial_global_step = 0
567
- else:
568
- if args.resume_from_checkpoint != "latest":
569
- path = os.path.basename(args.resume_from_checkpoint)
570
- else:
571
- # Get the most recent checkpoint
572
- dirs = os.listdir(args.output_dir)
573
- dirs = [d for d in dirs if d.startswith("checkpoint")]
574
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
575
- path = dirs[-1] if len(dirs) > 0 else None
576
-
577
- if path is None:
578
- accelerator.print(
579
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
580
- )
581
- args.resume_from_checkpoint = None
582
- initial_global_step = 0
583
- else:
584
- accelerator.print(f"Resuming from checkpoint {path}")
585
- accelerator.load_state(os.path.join(args.output_dir, path))
586
- global_step = int(path.split("-")[1])
587
-
588
- initial_global_step = global_step
589
- first_epoch = global_step // num_update_steps_per_epoch
590
-
591
- progress_bar = tqdm(
592
- range(0, args.max_train_steps),
593
- initial=initial_global_step,
594
- desc="Steps",
595
- # Only show the progress bar once on each machine.
596
- disable=not accelerator.is_local_main_process,
597
- )
598
-
599
- # For DeepSpeed training
600
- model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
601
-
602
- if args.load_tensors:
603
- del vae, text_encoder
604
- gc.collect()
605
- torch.cuda.empty_cache()
606
- torch.cuda.synchronize(accelerator.device)
607
-
608
- alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
609
-
610
- for epoch in range(first_epoch, args.num_train_epochs):
611
- transformer.train()
612
-
613
- for step, batch in enumerate(train_dataloader):
614
- models_to_accumulate = [transformer]
615
- logs = {}
616
-
617
- with accelerator.accumulate(models_to_accumulate):
618
- videos = batch["videos"].to(accelerator.device, non_blocking=True)
619
- prompts = batch["prompts"]
620
-
621
- # Encode videos
622
- if not args.load_tensors:
623
- videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
624
- latent_dist = vae.encode(videos).latent_dist
625
- else:
626
- latent_dist = DiagonalGaussianDistribution(videos)
627
-
628
- videos = latent_dist.sample() * VAE_SCALING_FACTOR
629
- videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
630
- videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
631
- model_input = videos
632
-
633
- # Encode prompts
634
- if not args.load_tensors:
635
- prompt_embeds = compute_prompt_embeddings(
636
- tokenizer,
637
- text_encoder,
638
- prompts,
639
- model_config.max_text_seq_length,
640
- accelerator.device,
641
- weight_dtype,
642
- requires_grad=False,
643
- )
644
- else:
645
- prompt_embeds = prompts.to(dtype=weight_dtype)
646
-
647
- # Sample noise that will be added to the latents
648
- noise = torch.randn_like(model_input)
649
- batch_size, num_frames, num_channels, height, width = model_input.shape
650
-
651
- # Sample a random timestep for each image
652
- timesteps = torch.randint(
653
- 0,
654
- scheduler.config.num_train_timesteps,
655
- (batch_size,),
656
- dtype=torch.int64,
657
- device=model_input.device,
658
- )
659
-
660
- # Prepare rotary embeds
661
- image_rotary_emb = (
662
- prepare_rotary_positional_embeddings(
663
- height=height * VAE_SCALE_FACTOR_SPATIAL,
664
- width=width * VAE_SCALE_FACTOR_SPATIAL,
665
- num_frames=num_frames,
666
- vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
667
- patch_size=model_config.patch_size,
668
- patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
669
- attention_head_dim=model_config.attention_head_dim,
670
- device=accelerator.device,
671
- base_height=RoPE_BASE_HEIGHT,
672
- base_width=RoPE_BASE_WIDTH,
673
- )
674
- if model_config.use_rotary_positional_embeddings
675
- else None
676
- )
677
-
678
- # Add noise to the model input according to the noise magnitude at each timestep
679
- # (this is the forward diffusion process)
680
- noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
681
-
682
- # Predict the noise residual
683
- model_output = transformer(
684
- hidden_states=noisy_model_input,
685
- encoder_hidden_states=prompt_embeds,
686
- timestep=timesteps,
687
- image_rotary_emb=image_rotary_emb,
688
- return_dict=False,
689
- )[0]
690
-
691
- model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
692
-
693
- weights = 1 / (1 - alphas_cumprod[timesteps])
694
- while len(weights.shape) < len(model_pred.shape):
695
- weights = weights.unsqueeze(-1)
696
-
697
- target = model_input
698
-
699
- loss = torch.mean(
700
- (weights * (model_pred - target) ** 2).reshape(batch_size, -1),
701
- dim=1,
702
- )
703
- loss = loss.mean()
704
- accelerator.backward(loss)
705
-
706
- if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
707
- gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
708
- accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
709
- gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
710
- logs.update(
711
- {
712
- "gradient_norm_before_clip": gradient_norm_before_clip,
713
- "gradient_norm_after_clip": gradient_norm_after_clip,
714
- }
715
- )
716
-
717
- if accelerator.state.deepspeed_plugin is None:
718
- optimizer.step()
719
- optimizer.zero_grad()
720
-
721
- if not args.use_cpu_offload_optimizer:
722
- lr_scheduler.step()
723
-
724
- # Checks if the accelerator has performed an optimization step behind the scenes
725
- if accelerator.sync_gradients:
726
- progress_bar.update(1)
727
- global_step += 1
728
-
729
- if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
730
- if global_step % args.checkpointing_steps == 0:
731
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
732
- if args.checkpoints_total_limit is not None:
733
- checkpoints = os.listdir(args.output_dir)
734
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
735
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
736
-
737
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
738
- if len(checkpoints) >= args.checkpoints_total_limit:
739
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
740
- removing_checkpoints = checkpoints[0:num_to_remove]
741
-
742
- logger.info(
743
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
744
- )
745
- logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
746
-
747
- for removing_checkpoint in removing_checkpoints:
748
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
749
- shutil.rmtree(removing_checkpoint)
750
-
751
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
752
- accelerator.save_state(save_path)
753
- logger.info(f"Saved state to {save_path}")
754
-
755
- last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
756
- logs.update(
757
- {
758
- "loss": loss.detach().item(),
759
- "lr": last_lr,
760
- }
761
- )
762
- progress_bar.set_postfix(**logs)
763
- accelerator.log(logs, step=global_step)
764
-
765
- if global_step >= args.max_train_steps:
766
- break
767
-
768
- if accelerator.is_main_process:
769
- if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
770
- accelerator.print("===== Memory before validation =====")
771
- print_memory(accelerator.device)
772
- torch.cuda.synchronize(accelerator.device)
773
-
774
- pipe = CogVideoXPipeline.from_pretrained(
775
- args.pretrained_model_name_or_path,
776
- transformer=unwrap_model(accelerator, transformer),
777
- scheduler=scheduler,
778
- revision=args.revision,
779
- variant=args.variant,
780
- torch_dtype=weight_dtype,
781
- )
782
-
783
- if args.enable_slicing:
784
- pipe.vae.enable_slicing()
785
- if args.enable_tiling:
786
- pipe.vae.enable_tiling()
787
- if args.enable_model_cpu_offload:
788
- pipe.enable_model_cpu_offload()
789
-
790
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
791
- for validation_prompt in validation_prompts:
792
- pipeline_args = {
793
- "prompt": validation_prompt,
794
- "guidance_scale": args.guidance_scale,
795
- "use_dynamic_cfg": args.use_dynamic_cfg,
796
- "height": args.height,
797
- "width": args.width,
798
- "max_sequence_length": model_config.max_text_seq_length,
799
- }
800
-
801
- log_validation(
802
- accelerator=accelerator,
803
- pipe=pipe,
804
- args=args,
805
- pipeline_args=pipeline_args,
806
- epoch=epoch,
807
- is_final_validation=False,
808
- )
809
-
810
- accelerator.print("===== Memory after validation =====")
811
- print_memory(accelerator.device)
812
- reset_memory(accelerator.device)
813
-
814
- del pipe
815
- gc.collect()
816
- torch.cuda.empty_cache()
817
- torch.cuda.synchronize(accelerator.device)
818
-
819
- accelerator.wait_for_everyone()
820
-
821
- if accelerator.is_main_process:
822
- transformer = unwrap_model(accelerator, transformer)
823
- dtype = (
824
- torch.float16
825
- if args.mixed_precision == "fp16"
826
- else torch.bfloat16
827
- if args.mixed_precision == "bf16"
828
- else torch.float32
829
- )
830
- transformer = transformer.to(dtype)
831
-
832
- transformer.save_pretrained(
833
- os.path.join(args.output_dir, "transformer"),
834
- safe_serialization=True,
835
- max_shard_size="5GB",
836
- )
837
-
838
- # Cleanup trained models to save memory
839
- if args.load_tensors:
840
- del transformer
841
- else:
842
- del transformer, text_encoder, vae
843
-
844
- gc.collect()
845
- torch.cuda.empty_cache()
846
- torch.cuda.synchronize(accelerator.device)
847
-
848
- accelerator.print("===== Memory before testing =====")
849
- print_memory(accelerator.device)
850
- reset_memory(accelerator.device)
851
-
852
- # Final test inference
853
- pipe = CogVideoXPipeline.from_pretrained(
854
- args.pretrained_model_name_or_path,
855
- revision=args.revision,
856
- variant=args.variant,
857
- torch_dtype=weight_dtype,
858
- )
859
- pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
860
-
861
- if args.enable_slicing:
862
- pipe.vae.enable_slicing()
863
- if args.enable_tiling:
864
- pipe.vae.enable_tiling()
865
- if args.enable_model_cpu_offload:
866
- pipe.enable_model_cpu_offload()
867
-
868
- # Run inference
869
- validation_outputs = []
870
- if args.validation_prompt and args.num_validation_videos > 0:
871
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
872
- for validation_prompt in validation_prompts:
873
- pipeline_args = {
874
- "prompt": validation_prompt,
875
- "guidance_scale": args.guidance_scale,
876
- "use_dynamic_cfg": args.use_dynamic_cfg,
877
- "height": args.height,
878
- "width": args.width,
879
- }
880
-
881
- video = log_validation(
882
- accelerator=accelerator,
883
- pipe=pipe,
884
- args=args,
885
- pipeline_args=pipeline_args,
886
- epoch=epoch,
887
- is_final_validation=True,
888
- )
889
- validation_outputs.extend(video)
890
-
891
- accelerator.print("===== Memory after testing =====")
892
- print_memory(accelerator.device)
893
- reset_memory(accelerator.device)
894
- torch.cuda.synchronize(accelerator.device)
895
-
896
- if args.push_to_hub:
897
- save_model_card(
898
- repo_id,
899
- videos=validation_outputs,
900
- base_model=args.pretrained_model_name_or_path,
901
- validation_prompt=args.validation_prompt,
902
- repo_folder=args.output_dir,
903
- fps=args.fps,
904
- )
905
- upload_folder(
906
- repo_id=repo_id,
907
- folder_path=args.output_dir,
908
- commit_message="End of training",
909
- ignore_patterns=["step_*", "epoch_*"],
910
- )
911
-
912
- accelerator.end_training()
913
-
914
-
915
- if __name__ == "__main__":
916
- args = get_args()
917
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/cogvideox/dataset.py DELETED
@@ -1,428 +0,0 @@
1
- import random
2
- from pathlib import Path
3
- from typing import Any, Dict, List, Optional, Tuple
4
-
5
- import numpy as np
6
- import pandas as pd
7
- import torch
8
- import torchvision.transforms as TT
9
- from accelerate.logging import get_logger
10
- from torch.utils.data import Dataset, Sampler
11
- from torchvision import transforms
12
- from torchvision.transforms import InterpolationMode
13
- from torchvision.transforms.functional import resize
14
-
15
-
16
- # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
17
- # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
18
- import decord # isort:skip
19
-
20
- decord.bridge.set_bridge("torch")
21
-
22
- logger = get_logger(__name__)
23
-
24
- HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
25
- WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
26
- FRAME_BUCKETS = [16, 24, 32, 48, 64, 80]
27
-
28
-
29
- class VideoDataset(Dataset):
30
- def __init__(
31
- self,
32
- data_root: str,
33
- dataset_file: Optional[str] = None,
34
- caption_column: str = "text",
35
- video_column: str = "video",
36
- max_num_frames: int = 49,
37
- id_token: Optional[str] = None,
38
- height_buckets: List[int] = None,
39
- width_buckets: List[int] = None,
40
- frame_buckets: List[int] = None,
41
- load_tensors: bool = False,
42
- random_flip: Optional[float] = None,
43
- image_to_video: bool = False,
44
- ) -> None:
45
- super().__init__()
46
-
47
- self.data_root = Path(data_root)
48
- self.dataset_file = dataset_file
49
- self.caption_column = caption_column
50
- self.video_column = video_column
51
- self.max_num_frames = max_num_frames
52
- self.id_token = f"{id_token.strip()} " if id_token else ""
53
- self.height_buckets = height_buckets or HEIGHT_BUCKETS
54
- self.width_buckets = width_buckets or WIDTH_BUCKETS
55
- self.frame_buckets = frame_buckets or FRAME_BUCKETS
56
- self.load_tensors = load_tensors
57
- self.random_flip = random_flip
58
- self.image_to_video = image_to_video
59
-
60
- self.resolution_buckets = [
61
- (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets
62
- ]
63
-
64
- # Two methods of loading data are supported.
65
- # - Using a CSV: caption_column and video_column must be some column in the CSV. One could
66
- # make use of other columns too, such as a motion score or aesthetic score, by modifying the
67
- # logic in CSV processing.
68
- # - Using two files containing line-separate captions and relative paths to videos.
69
- # For a more detailed explanation about preparing dataset format, checkout the README.
70
- if dataset_file is None:
71
- (
72
- self.prompts,
73
- self.video_paths,
74
- ) = self._load_dataset_from_local_path()
75
- else:
76
- (
77
- self.prompts,
78
- self.video_paths,
79
- ) = self._load_dataset_from_csv()
80
-
81
- if len(self.video_paths) != len(self.prompts):
82
- raise ValueError(
83
- f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
84
- )
85
-
86
- self.video_transforms = transforms.Compose(
87
- [
88
- transforms.RandomHorizontalFlip(random_flip)
89
- if random_flip
90
- else transforms.Lambda(self.identity_transform),
91
- transforms.Lambda(self.scale_transform),
92
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
93
- ]
94
- )
95
-
96
- @staticmethod
97
- def identity_transform(x):
98
- return x
99
-
100
- @staticmethod
101
- def scale_transform(x):
102
- return x / 255.0
103
-
104
- def __len__(self) -> int:
105
- return len(self.video_paths)
106
-
107
- def __getitem__(self, index: int) -> Dict[str, Any]:
108
- if isinstance(index, list):
109
- # Here, index is actually a list of data objects that we need to return.
110
- # The BucketSampler should ideally return indices. But, in the sampler, we'd like
111
- # to have information about num_frames, height and width. Since this is not stored
112
- # as metadata, we need to read the video to get this information. You could read this
113
- # information without loading the full video in memory, but we do it anyway. In order
114
- # to not load the video twice (once to get the metadata, and once to return the loaded video
115
- # based on sampled indices), we cache it in the BucketSampler. When the sampler is
116
- # to yield, we yield the cache data instead of indices. So, this special check ensures
117
- # that data is not loaded a second time. PRs are welcome for improvements.
118
- return index
119
-
120
- if self.load_tensors:
121
- image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index])
122
-
123
- # This is hardcoded for now.
124
- # The VAE's temporal compression ratio is 4.
125
- # The VAE's spatial compression ratio is 8.
126
- latent_num_frames = video_latents.size(1)
127
- if latent_num_frames % 2 == 0:
128
- num_frames = latent_num_frames * 4
129
- else:
130
- num_frames = (latent_num_frames - 1) * 4 + 1
131
-
132
- height = video_latents.size(2) * 8
133
- width = video_latents.size(3) * 8
134
-
135
- return {
136
- "prompt": prompt_embeds,
137
- "image": image_latents,
138
- "video": video_latents,
139
- "video_metadata": {
140
- "num_frames": num_frames,
141
- "height": height,
142
- "width": width,
143
- },
144
- }
145
- else:
146
- image, video, _ = self._preprocess_video(self.video_paths[index])
147
-
148
- return {
149
- "prompt": self.id_token + self.prompts[index],
150
- "image": image,
151
- "video": video,
152
- "video_metadata": {
153
- "num_frames": video.shape[0],
154
- "height": video.shape[2],
155
- "width": video.shape[3],
156
- },
157
- }
158
-
159
- def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]:
160
- if not self.data_root.exists():
161
- raise ValueError("Root folder for videos does not exist")
162
-
163
- prompt_path = self.data_root.joinpath(self.caption_column)
164
- video_path = self.data_root.joinpath(self.video_column)
165
-
166
- if not prompt_path.exists() or not prompt_path.is_file():
167
- raise ValueError(
168
- "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
169
- )
170
- if not video_path.exists() or not video_path.is_file():
171
- raise ValueError(
172
- "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
173
- )
174
-
175
- with open(prompt_path, "r", encoding="utf-8") as file:
176
- prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
177
- with open(video_path, "r", encoding="utf-8") as file:
178
- video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
179
-
180
- if not self.load_tensors and any(not path.is_file() for path in video_paths):
181
- raise ValueError(
182
- f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
183
- )
184
-
185
- return prompts, video_paths
186
-
187
- def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]:
188
- df = pd.read_csv(self.dataset_file)
189
- prompts = df[self.caption_column].tolist()
190
- video_paths = df[self.video_column].tolist()
191
- video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
192
-
193
- if any(not path.is_file() for path in video_paths):
194
- raise ValueError(
195
- f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
196
- )
197
-
198
- return prompts, video_paths
199
-
200
- def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
201
- r"""
202
- Loads a single video, or latent and prompt embedding, based on initialization parameters.
203
-
204
- If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here,
205
- F, C, H and W are the frames, channels, height and width of the input video.
206
-
207
- If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D].
208
- F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length
209
- and embedding dimension of prompt embeddings.
210
- """
211
- if self.load_tensors:
212
- return self._load_preprocessed_latents_and_embeds(path)
213
- else:
214
- video_reader = decord.VideoReader(uri=path.as_posix())
215
- video_num_frames = len(video_reader)
216
-
217
- indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
218
- frames = video_reader.get_batch(indices)
219
- frames = frames[: self.max_num_frames].float()
220
- frames = frames.permute(0, 3, 1, 2).contiguous()
221
- frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
222
-
223
- image = frames[:1].clone() if self.image_to_video else None
224
-
225
- return image, frames, None
226
-
227
- def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
228
- filename_without_ext = path.name.split(".")[0]
229
- pt_filename = f"{filename_without_ext}.pt"
230
-
231
- # The current path is something like: /a/b/c/d/videos/00001.mp4
232
- # We need to reach: /a/b/c/d/video_latents/00001.pt
233
- image_latents_path = path.parent.parent.joinpath("image_latents")
234
- video_latents_path = path.parent.parent.joinpath("video_latents")
235
- embeds_path = path.parent.parent.joinpath("prompt_embeds")
236
-
237
- if (
238
- not video_latents_path.exists()
239
- or not embeds_path.exists()
240
- or (self.image_to_video and not image_latents_path.exists())
241
- ):
242
- raise ValueError(
243
- f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present."
244
- )
245
-
246
- if self.image_to_video:
247
- image_latent_filepath = image_latents_path.joinpath(pt_filename)
248
- video_latent_filepath = video_latents_path.joinpath(pt_filename)
249
- embeds_filepath = embeds_path.joinpath(pt_filename)
250
-
251
- if not video_latent_filepath.is_file() or not embeds_filepath.is_file():
252
- if self.image_to_video:
253
- image_latent_filepath = image_latent_filepath.as_posix()
254
- video_latent_filepath = video_latent_filepath.as_posix()
255
- embeds_filepath = embeds_filepath.as_posix()
256
- raise ValueError(
257
- f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`."
258
- )
259
-
260
- images = (
261
- torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None
262
- )
263
- latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True)
264
- embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True)
265
-
266
- return images, latents, embeds
267
-
268
-
269
- class VideoDatasetWithResizing(VideoDataset):
270
- def __init__(self, *args, **kwargs) -> None:
271
- super().__init__(*args, **kwargs)
272
-
273
- def _preprocess_video(self, path: Path) -> torch.Tensor:
274
- if self.load_tensors:
275
- return self._load_preprocessed_latents_and_embeds(path)
276
- else:
277
- video_reader = decord.VideoReader(uri=path.as_posix())
278
- video_num_frames = len(video_reader)
279
- nearest_frame_bucket = min(
280
- self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
281
- )
282
-
283
- frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
284
-
285
- frames = video_reader.get_batch(frame_indices)
286
- frames = frames[:nearest_frame_bucket].float()
287
- frames = frames.permute(0, 3, 1, 2).contiguous()
288
-
289
- nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
290
- frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
291
- frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
292
-
293
- image = frames[:1].clone() if self.image_to_video else None
294
-
295
- return image, frames, None
296
-
297
- def _find_nearest_resolution(self, height, width):
298
- nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
299
- return nearest_res[1], nearest_res[2]
300
-
301
-
302
- class VideoDatasetWithResizeAndRectangleCrop(VideoDataset):
303
- def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
304
- super().__init__(*args, **kwargs)
305
- self.video_reshape_mode = video_reshape_mode
306
-
307
- def _resize_for_rectangle_crop(self, arr, image_size):
308
- reshape_mode = self.video_reshape_mode
309
- if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
310
- arr = resize(
311
- arr,
312
- size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
313
- interpolation=InterpolationMode.BICUBIC,
314
- )
315
- else:
316
- arr = resize(
317
- arr,
318
- size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
319
- interpolation=InterpolationMode.BICUBIC,
320
- )
321
-
322
- h, w = arr.shape[2], arr.shape[3]
323
- arr = arr.squeeze(0)
324
-
325
- delta_h = h - image_size[0]
326
- delta_w = w - image_size[1]
327
-
328
- if reshape_mode == "random" or reshape_mode == "none":
329
- top = np.random.randint(0, delta_h + 1)
330
- left = np.random.randint(0, delta_w + 1)
331
- elif reshape_mode == "center":
332
- top, left = delta_h // 2, delta_w // 2
333
- else:
334
- raise NotImplementedError
335
- arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
336
- return arr
337
-
338
- def _preprocess_video(self, path: Path) -> torch.Tensor:
339
- if self.load_tensors:
340
- return self._load_preprocessed_latents_and_embeds(path)
341
- else:
342
- video_reader = decord.VideoReader(uri=path.as_posix())
343
- video_num_frames = len(video_reader)
344
- nearest_frame_bucket = min(
345
- self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
346
- )
347
-
348
- frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
349
-
350
- frames = video_reader.get_batch(frame_indices)
351
- frames = frames[:nearest_frame_bucket].float()
352
- frames = frames.permute(0, 3, 1, 2).contiguous()
353
-
354
- nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
355
- frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
356
- frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
357
-
358
- image = frames[:1].clone() if self.image_to_video else None
359
-
360
- return image, frames, None
361
-
362
- def _find_nearest_resolution(self, height, width):
363
- nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
364
- return nearest_res[1], nearest_res[2]
365
-
366
-
367
- class BucketSampler(Sampler):
368
- r"""
369
- PyTorch Sampler that groups 3D data by height, width and frames.
370
-
371
- Args:
372
- data_source (`VideoDataset`):
373
- A PyTorch dataset object that is an instance of `VideoDataset`.
374
- batch_size (`int`, defaults to `8`):
375
- The batch size to use for training.
376
- shuffle (`bool`, defaults to `True`):
377
- Whether or not to shuffle the data in each batch before dispatching to dataloader.
378
- drop_last (`bool`, defaults to `False`):
379
- Whether or not to drop incomplete buckets of data after completely iterating over all data
380
- in the dataset. If set to True, only batches that have `batch_size` number of entries will
381
- be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
382
- and batches that do not have `batch_size` number of entries will also be yielded.
383
- """
384
-
385
- def __init__(
386
- self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
387
- ) -> None:
388
- self.data_source = data_source
389
- self.batch_size = batch_size
390
- self.shuffle = shuffle
391
- self.drop_last = drop_last
392
-
393
- self.buckets = {resolution: [] for resolution in data_source.resolutions}
394
-
395
- self._raised_warning_for_drop_last = False
396
-
397
- def __len__(self):
398
- if self.drop_last and not self._raised_warning_for_drop_last:
399
- self._raised_warning_for_drop_last = True
400
- logger.warning(
401
- "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
402
- )
403
- return (len(self.data_source) + self.batch_size - 1) // self.batch_size
404
-
405
- def __iter__(self):
406
- for index, data in enumerate(self.data_source):
407
- video_metadata = data["video_metadata"]
408
- f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
409
-
410
- self.buckets[(f, h, w)].append(data)
411
- if len(self.buckets[(f, h, w)]) == self.batch_size:
412
- if self.shuffle:
413
- random.shuffle(self.buckets[(f, h, w)])
414
- yield self.buckets[(f, h, w)]
415
- del self.buckets[(f, h, w)]
416
- self.buckets[(f, h, w)] = []
417
-
418
- if self.drop_last:
419
- return
420
-
421
- for fhw, bucket in list(self.buckets.items()):
422
- if len(bucket) == 0:
423
- continue
424
- if self.shuffle:
425
- random.shuffle(bucket)
426
- yield bucket
427
- del self.buckets[fhw]
428
- self.buckets[fhw] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/cogvideox/prepare_dataset.py DELETED
@@ -1,669 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- import argparse
4
- import functools
5
- import json
6
- import os
7
- import pathlib
8
- import queue
9
- import traceback
10
- import uuid
11
- from concurrent.futures import ThreadPoolExecutor
12
- from typing import Any, Dict, List, Optional, Union
13
-
14
- import torch
15
- import torch.distributed as dist
16
- from diffusers import AutoencoderKLCogVideoX
17
- from diffusers.training_utils import set_seed
18
- from diffusers.utils import export_to_video, get_logger
19
- from torch.utils.data import DataLoader
20
- from torchvision import transforms
21
- from tqdm import tqdm
22
- from transformers import T5EncoderModel, T5Tokenizer
23
-
24
-
25
- import decord # isort:skip
26
-
27
- from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
28
-
29
-
30
- decord.bridge.set_bridge("torch")
31
-
32
- logger = get_logger(__name__)
33
-
34
- DTYPE_MAPPING = {
35
- "fp32": torch.float32,
36
- "fp16": torch.float16,
37
- "bf16": torch.bfloat16,
38
- }
39
-
40
-
41
- def check_height(x: Any) -> int:
42
- x = int(x)
43
- if x % 16 != 0:
44
- raise argparse.ArgumentTypeError(
45
- f"`--height_buckets` must be divisible by 16, but got {x} which does not fit criteria."
46
- )
47
- return x
48
-
49
-
50
- def check_width(x: Any) -> int:
51
- x = int(x)
52
- if x % 16 != 0:
53
- raise argparse.ArgumentTypeError(
54
- f"`--width_buckets` must be divisible by 16, but got {x} which does not fit criteria."
55
- )
56
- return x
57
-
58
-
59
- def check_frames(x: Any) -> int:
60
- x = int(x)
61
- if x % 4 != 0 and x % 4 != 1:
62
- raise argparse.ArgumentTypeError(
63
- f"`--frames_buckets` must be of form `4 * k` or `4 * k + 1`, but got {x} which does not fit criteria."
64
- )
65
- return x
66
-
67
-
68
- def get_args() -> Dict[str, Any]:
69
- parser = argparse.ArgumentParser()
70
- parser.add_argument(
71
- "--model_id",
72
- type=str,
73
- default="THUDM/CogVideoX-2b",
74
- help="Hugging Face model ID to use for tokenizer, text encoder and VAE.",
75
- )
76
- parser.add_argument("--data_root", type=str, required=True, help="Path to where training data is located.")
77
- parser.add_argument(
78
- "--dataset_file", type=str, default=None, help="Path to CSV file containing metadata about training data."
79
- )
80
- parser.add_argument(
81
- "--caption_column",
82
- type=str,
83
- default="caption",
84
- help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the captions. If using the folder structure format for data loading, this should be the name of the file containing line-separated captions (the file should be located in `--data_root`).",
85
- )
86
- parser.add_argument(
87
- "--video_column",
88
- type=str,
89
- default="video",
90
- help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the video paths. If using the folder structure format for data loading, this should be the name of the file containing line-separated video paths (the file should be located in `--data_root`).",
91
- )
92
- parser.add_argument(
93
- "--id_token",
94
- type=str,
95
- default=None,
96
- help="Identifier token appended to the start of each prompt if provided.",
97
- )
98
- parser.add_argument(
99
- "--height_buckets",
100
- nargs="+",
101
- type=check_height,
102
- default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
103
- )
104
- parser.add_argument(
105
- "--width_buckets",
106
- nargs="+",
107
- type=check_width,
108
- default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
109
- )
110
- parser.add_argument(
111
- "--frame_buckets",
112
- nargs="+",
113
- type=check_frames,
114
- default=[49],
115
- )
116
- parser.add_argument(
117
- "--random_flip",
118
- type=float,
119
- default=None,
120
- help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
121
- )
122
- parser.add_argument(
123
- "--dataloader_num_workers",
124
- type=int,
125
- default=0,
126
- help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
127
- )
128
- parser.add_argument(
129
- "--pin_memory",
130
- action="store_true",
131
- help="Whether or not to use the pinned memory setting in pytorch dataloader.",
132
- )
133
- parser.add_argument(
134
- "--video_reshape_mode",
135
- type=str,
136
- default=None,
137
- help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
138
- )
139
- parser.add_argument(
140
- "--save_image_latents",
141
- action="store_true",
142
- help="Whether or not to encode and store image latents, which are required for image-to-video finetuning. The image latents are the first frame of input videos encoded with the VAE.",
143
- )
144
- parser.add_argument(
145
- "--output_dir",
146
- type=str,
147
- required=True,
148
- help="Path to output directory where preprocessed videos/latents/embeddings will be saved.",
149
- )
150
- parser.add_argument("--max_num_frames", type=int, default=49, help="Maximum number of frames in output video.")
151
- parser.add_argument(
152
- "--max_sequence_length", type=int, default=226, help="Max sequence length of prompt embeddings."
153
- )
154
- parser.add_argument("--target_fps", type=int, default=8, help="Frame rate of output videos.")
155
- parser.add_argument(
156
- "--save_latents_and_embeddings",
157
- action="store_true",
158
- help="Whether to encode videos/captions to latents/embeddings and save them in pytorch serializable format.",
159
- )
160
- parser.add_argument(
161
- "--use_slicing",
162
- action="store_true",
163
- help="Whether to enable sliced encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.",
164
- )
165
- parser.add_argument(
166
- "--use_tiling",
167
- action="store_true",
168
- help="Whether to enable tiled encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.",
169
- )
170
- parser.add_argument("--batch_size", type=int, default=1, help="Number of videos to process at once in the VAE.")
171
- parser.add_argument(
172
- "--num_decode_threads",
173
- type=int,
174
- default=0,
175
- help="Number of decoding threads for `decord` to use. The default `0` means to automatically determine required number of threads.",
176
- )
177
- parser.add_argument(
178
- "--dtype",
179
- type=str,
180
- choices=["fp32", "fp16", "bf16"],
181
- default="fp32",
182
- help="Data type to use when generating latents and prompt embeddings.",
183
- )
184
- parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.")
185
- parser.add_argument(
186
- "--num_artifact_workers", type=int, default=4, help="Number of worker threads for serializing artifacts."
187
- )
188
- return parser.parse_args()
189
-
190
-
191
- def _get_t5_prompt_embeds(
192
- tokenizer: T5Tokenizer,
193
- text_encoder: T5EncoderModel,
194
- prompt: Union[str, List[str]],
195
- num_videos_per_prompt: int = 1,
196
- max_sequence_length: int = 226,
197
- device: Optional[torch.device] = None,
198
- dtype: Optional[torch.dtype] = None,
199
- text_input_ids=None,
200
- ):
201
- prompt = [prompt] if isinstance(prompt, str) else prompt
202
- batch_size = len(prompt)
203
-
204
- if tokenizer is not None:
205
- text_inputs = tokenizer(
206
- prompt,
207
- padding="max_length",
208
- max_length=max_sequence_length,
209
- truncation=True,
210
- add_special_tokens=True,
211
- return_tensors="pt",
212
- )
213
- text_input_ids = text_inputs.input_ids
214
- else:
215
- if text_input_ids is None:
216
- raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
217
-
218
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
219
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
220
-
221
- # duplicate text embeddings for each generation per prompt, using mps friendly method
222
- _, seq_len, _ = prompt_embeds.shape
223
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
224
- prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
225
-
226
- return prompt_embeds
227
-
228
-
229
- def encode_prompt(
230
- tokenizer: T5Tokenizer,
231
- text_encoder: T5EncoderModel,
232
- prompt: Union[str, List[str]],
233
- num_videos_per_prompt: int = 1,
234
- max_sequence_length: int = 226,
235
- device: Optional[torch.device] = None,
236
- dtype: Optional[torch.dtype] = None,
237
- text_input_ids=None,
238
- ):
239
- prompt = [prompt] if isinstance(prompt, str) else prompt
240
- prompt_embeds = _get_t5_prompt_embeds(
241
- tokenizer,
242
- text_encoder,
243
- prompt=prompt,
244
- num_videos_per_prompt=num_videos_per_prompt,
245
- max_sequence_length=max_sequence_length,
246
- device=device,
247
- dtype=dtype,
248
- text_input_ids=text_input_ids,
249
- )
250
- return prompt_embeds
251
-
252
-
253
- def compute_prompt_embeddings(
254
- tokenizer: T5Tokenizer,
255
- text_encoder: T5EncoderModel,
256
- prompts: List[str],
257
- max_sequence_length: int,
258
- device: torch.device,
259
- dtype: torch.dtype,
260
- requires_grad: bool = False,
261
- ):
262
- if requires_grad:
263
- prompt_embeds = encode_prompt(
264
- tokenizer,
265
- text_encoder,
266
- prompts,
267
- num_videos_per_prompt=1,
268
- max_sequence_length=max_sequence_length,
269
- device=device,
270
- dtype=dtype,
271
- )
272
- else:
273
- with torch.no_grad():
274
- prompt_embeds = encode_prompt(
275
- tokenizer,
276
- text_encoder,
277
- prompts,
278
- num_videos_per_prompt=1,
279
- max_sequence_length=max_sequence_length,
280
- device=device,
281
- dtype=dtype,
282
- )
283
- return prompt_embeds
284
-
285
-
286
- to_pil_image = transforms.ToPILImage(mode="RGB")
287
-
288
-
289
- def save_image(image: torch.Tensor, path: pathlib.Path) -> None:
290
- image = image.to(dtype=torch.float32).clamp(-1, 1)
291
- image = to_pil_image(image.float())
292
- image.save(path)
293
-
294
-
295
- def save_video(video: torch.Tensor, path: pathlib.Path, fps: int = 8) -> None:
296
- video = video.to(dtype=torch.float32).clamp(-1, 1)
297
- video = [to_pil_image(frame) for frame in video]
298
- export_to_video(video, path, fps=fps)
299
-
300
-
301
- def save_prompt(prompt: str, path: pathlib.Path) -> None:
302
- with open(path, "w", encoding="utf-8") as file:
303
- file.write(prompt)
304
-
305
-
306
- def save_metadata(metadata: Dict[str, Any], path: pathlib.Path) -> None:
307
- with open(path, "w", encoding="utf-8") as file:
308
- file.write(json.dumps(metadata))
309
-
310
-
311
- @torch.no_grad()
312
- def serialize_artifacts(
313
- batch_size: int,
314
- fps: int,
315
- images_dir: Optional[pathlib.Path] = None,
316
- image_latents_dir: Optional[pathlib.Path] = None,
317
- videos_dir: Optional[pathlib.Path] = None,
318
- video_latents_dir: Optional[pathlib.Path] = None,
319
- prompts_dir: Optional[pathlib.Path] = None,
320
- prompt_embeds_dir: Optional[pathlib.Path] = None,
321
- images: Optional[torch.Tensor] = None,
322
- image_latents: Optional[torch.Tensor] = None,
323
- videos: Optional[torch.Tensor] = None,
324
- video_latents: Optional[torch.Tensor] = None,
325
- prompts: Optional[List[str]] = None,
326
- prompt_embeds: Optional[torch.Tensor] = None,
327
- ) -> None:
328
- num_frames, height, width = videos.size(1), videos.size(3), videos.size(4)
329
- metadata = [{"num_frames": num_frames, "height": height, "width": width}]
330
-
331
- data_folder_mapper_list = [
332
- (images, images_dir, lambda img, path: save_image(img[0], path), "png"),
333
- (image_latents, image_latents_dir, torch.save, "pt"),
334
- (videos, videos_dir, functools.partial(save_video, fps=fps), "mp4"),
335
- (video_latents, video_latents_dir, torch.save, "pt"),
336
- (prompts, prompts_dir, save_prompt, "txt"),
337
- (prompt_embeds, prompt_embeds_dir, torch.save, "pt"),
338
- (metadata, videos_dir, save_metadata, "txt"),
339
- ]
340
- filenames = [uuid.uuid4() for _ in range(batch_size)]
341
-
342
- for data, folder, save_fn, extension in data_folder_mapper_list:
343
- if data is None:
344
- continue
345
- for slice, filename in zip(data, filenames):
346
- if isinstance(slice, torch.Tensor):
347
- slice = slice.clone().to("cpu")
348
- path = folder.joinpath(f"{filename}.{extension}")
349
- save_fn(slice, path)
350
-
351
-
352
- def save_intermediates(output_queue: queue.Queue) -> None:
353
- while True:
354
- try:
355
- item = output_queue.get(timeout=30)
356
- if item is None:
357
- break
358
- serialize_artifacts(**item)
359
-
360
- except queue.Empty:
361
- continue
362
-
363
-
364
- @torch.no_grad()
365
- def main():
366
- args = get_args()
367
- set_seed(args.seed)
368
-
369
- output_dir = pathlib.Path(args.output_dir)
370
- tmp_dir = output_dir.joinpath("tmp")
371
-
372
- output_dir.mkdir(parents=True, exist_ok=True)
373
- tmp_dir.mkdir(parents=True, exist_ok=True)
374
-
375
- # Create task queue for non-blocking serializing of artifacts
376
- output_queue = queue.Queue()
377
- save_thread = ThreadPoolExecutor(max_workers=args.num_artifact_workers)
378
- save_future = save_thread.submit(save_intermediates, output_queue)
379
-
380
- # Initialize distributed processing
381
- if "LOCAL_RANK" in os.environ:
382
- local_rank = int(os.environ["LOCAL_RANK"])
383
- torch.cuda.set_device(local_rank)
384
- dist.init_process_group(backend="nccl")
385
- world_size = dist.get_world_size()
386
- rank = dist.get_rank()
387
- else:
388
- # Single GPU
389
- local_rank = 0
390
- world_size = 1
391
- rank = 0
392
- torch.cuda.set_device(rank)
393
-
394
- # Create folders where intermediate tensors from each rank will be saved
395
- images_dir = tmp_dir.joinpath(f"images/{rank}")
396
- image_latents_dir = tmp_dir.joinpath(f"image_latents/{rank}")
397
- videos_dir = tmp_dir.joinpath(f"videos/{rank}")
398
- video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}")
399
- prompts_dir = tmp_dir.joinpath(f"prompts/{rank}")
400
- prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}")
401
-
402
- images_dir.mkdir(parents=True, exist_ok=True)
403
- image_latents_dir.mkdir(parents=True, exist_ok=True)
404
- videos_dir.mkdir(parents=True, exist_ok=True)
405
- video_latents_dir.mkdir(parents=True, exist_ok=True)
406
- prompts_dir.mkdir(parents=True, exist_ok=True)
407
- prompt_embeds_dir.mkdir(parents=True, exist_ok=True)
408
-
409
- weight_dtype = DTYPE_MAPPING[args.dtype]
410
- target_fps = args.target_fps
411
-
412
- # 1. Dataset
413
- dataset_init_kwargs = {
414
- "data_root": args.data_root,
415
- "dataset_file": args.dataset_file,
416
- "caption_column": args.caption_column,
417
- "video_column": args.video_column,
418
- "max_num_frames": args.max_num_frames,
419
- "id_token": args.id_token,
420
- "height_buckets": args.height_buckets,
421
- "width_buckets": args.width_buckets,
422
- "frame_buckets": args.frame_buckets,
423
- "load_tensors": False,
424
- "random_flip": args.random_flip,
425
- "image_to_video": args.save_image_latents,
426
- }
427
- if args.video_reshape_mode is None:
428
- dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
429
- else:
430
- dataset = VideoDatasetWithResizeAndRectangleCrop(
431
- video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
432
- )
433
-
434
- original_dataset_size = len(dataset)
435
-
436
- # Split data among GPUs
437
- if world_size > 1:
438
- samples_per_gpu = original_dataset_size // world_size
439
- start_index = rank * samples_per_gpu
440
- end_index = start_index + samples_per_gpu
441
- if rank == world_size - 1:
442
- end_index = original_dataset_size # Make sure the last GPU gets the remaining data
443
-
444
- # Slice the data
445
- dataset.prompts = dataset.prompts[start_index:end_index]
446
- dataset.video_paths = dataset.video_paths[start_index:end_index]
447
- else:
448
- pass
449
-
450
- rank_dataset_size = len(dataset)
451
-
452
- # 2. Dataloader
453
- def collate_fn(data):
454
- prompts = [x["prompt"] for x in data[0]]
455
-
456
- images = None
457
- if args.save_image_latents:
458
- images = [x["image"] for x in data[0]]
459
- images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True)
460
-
461
- videos = [x["video"] for x in data[0]]
462
- videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True)
463
-
464
- return {
465
- "images": images,
466
- "videos": videos,
467
- "prompts": prompts,
468
- }
469
-
470
- dataloader = DataLoader(
471
- dataset,
472
- batch_size=1,
473
- sampler=BucketSampler(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False),
474
- collate_fn=collate_fn,
475
- num_workers=args.dataloader_num_workers,
476
- pin_memory=args.pin_memory,
477
- )
478
-
479
- # 3. Prepare models
480
- device = f"cuda:{rank}"
481
-
482
- if args.save_latents_and_embeddings:
483
- tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer")
484
- text_encoder = T5EncoderModel.from_pretrained(
485
- args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype
486
- )
487
- text_encoder = text_encoder.to(device)
488
-
489
- vae = AutoencoderKLCogVideoX.from_pretrained(args.model_id, subfolder="vae", torch_dtype=weight_dtype)
490
- vae = vae.to(device)
491
-
492
- if args.use_slicing:
493
- vae.enable_slicing()
494
- if args.use_tiling:
495
- vae.enable_tiling()
496
-
497
- # 4. Compute latents and embeddings and save
498
- if rank == 0:
499
- iterator = tqdm(
500
- dataloader, desc="Encoding", total=(rank_dataset_size + args.batch_size - 1) // args.batch_size
501
- )
502
- else:
503
- iterator = dataloader
504
-
505
- for step, batch in enumerate(iterator):
506
- try:
507
- images = None
508
- image_latents = None
509
- video_latents = None
510
- prompt_embeds = None
511
-
512
- if args.save_image_latents:
513
- images = batch["images"].to(device, non_blocking=True)
514
- images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
515
-
516
- videos = batch["videos"].to(device, non_blocking=True)
517
- videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
518
-
519
- prompts = batch["prompts"]
520
-
521
- # Encode videos & images
522
- if args.save_latents_and_embeddings:
523
- if args.use_slicing:
524
- if args.save_image_latents:
525
- encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)]
526
- image_latents = torch.cat(encoded_slices)
527
- image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
528
-
529
- encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)]
530
- video_latents = torch.cat(encoded_slices)
531
-
532
- else:
533
- if args.save_image_latents:
534
- image_latents = vae._encode(images)
535
- image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
536
-
537
- video_latents = vae._encode(videos)
538
-
539
- video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
540
-
541
- # Encode prompts
542
- prompt_embeds = compute_prompt_embeddings(
543
- tokenizer,
544
- text_encoder,
545
- prompts,
546
- args.max_sequence_length,
547
- device,
548
- weight_dtype,
549
- requires_grad=False,
550
- )
551
-
552
- if images is not None:
553
- images = (images.permute(0, 2, 1, 3, 4) + 1) / 2
554
-
555
- videos = (videos.permute(0, 2, 1, 3, 4) + 1) / 2
556
-
557
- output_queue.put(
558
- {
559
- "batch_size": len(prompts),
560
- "fps": target_fps,
561
- "images_dir": images_dir,
562
- "image_latents_dir": image_latents_dir,
563
- "videos_dir": videos_dir,
564
- "video_latents_dir": video_latents_dir,
565
- "prompts_dir": prompts_dir,
566
- "prompt_embeds_dir": prompt_embeds_dir,
567
- "images": images,
568
- "image_latents": image_latents,
569
- "videos": videos,
570
- "video_latents": video_latents,
571
- "prompts": prompts,
572
- "prompt_embeds": prompt_embeds,
573
- }
574
- )
575
-
576
- except Exception:
577
- print("-------------------------")
578
- print(f"An exception occurred while processing data: {rank=}, {world_size=}, {step=}")
579
- traceback.print_exc()
580
- print("-------------------------")
581
-
582
- # 5. Complete distributed processing
583
- if world_size > 1:
584
- dist.barrier()
585
- dist.destroy_process_group()
586
-
587
- output_queue.put(None)
588
- save_thread.shutdown(wait=True)
589
- save_future.result()
590
-
591
- # 6. Combine results from each rank
592
- if rank == 0:
593
- print(
594
- f"Completed preprocessing latents and embeddings. Temporary files from all ranks saved to `{tmp_dir.as_posix()}`"
595
- )
596
-
597
- # Move files from each rank to common directory
598
- for subfolder, extension in [
599
- ("images", "png"),
600
- ("image_latents", "pt"),
601
- ("videos", "mp4"),
602
- ("video_latents", "pt"),
603
- ("prompts", "txt"),
604
- ("prompt_embeds", "pt"),
605
- ("videos", "txt"),
606
- ]:
607
- tmp_subfolder = tmp_dir.joinpath(subfolder)
608
- combined_subfolder = output_dir.joinpath(subfolder)
609
- combined_subfolder.mkdir(parents=True, exist_ok=True)
610
- pattern = f"*.{extension}"
611
-
612
- for file in tmp_subfolder.rglob(pattern):
613
- file.replace(combined_subfolder / file.name)
614
-
615
- # Remove temporary directories
616
- def rmdir_recursive(dir: pathlib.Path) -> None:
617
- for child in dir.iterdir():
618
- if child.is_file():
619
- child.unlink()
620
- else:
621
- rmdir_recursive(child)
622
- dir.rmdir()
623
-
624
- rmdir_recursive(tmp_dir)
625
-
626
- # Combine prompts and videos into individual text files and single jsonl
627
- prompts_folder = output_dir.joinpath("prompts")
628
- prompts = []
629
- stems = []
630
-
631
- for filename in prompts_folder.rglob("*.txt"):
632
- with open(filename, "r") as file:
633
- prompts.append(file.read().strip())
634
- stems.append(filename.stem)
635
-
636
- prompts_txt = output_dir.joinpath("prompts.txt")
637
- videos_txt = output_dir.joinpath("videos.txt")
638
- data_jsonl = output_dir.joinpath("data.jsonl")
639
-
640
- with open(prompts_txt, "w") as file:
641
- for prompt in prompts:
642
- file.write(f"{prompt}\n")
643
-
644
- with open(videos_txt, "w") as file:
645
- for stem in stems:
646
- file.write(f"videos/{stem}.mp4\n")
647
-
648
- with open(data_jsonl, "w") as file:
649
- for prompt, stem in zip(prompts, stems):
650
- video_metadata_txt = output_dir.joinpath(f"videos/{stem}.txt")
651
- with open(video_metadata_txt, "r", encoding="utf-8") as metadata_file:
652
- metadata = json.loads(metadata_file.read())
653
-
654
- data = {
655
- "prompt": prompt,
656
- "prompt_embed": f"prompt_embeds/{stem}.pt",
657
- "image": f"images/{stem}.png",
658
- "image_latent": f"image_latents/{stem}.pt",
659
- "video": f"videos/{stem}.mp4",
660
- "video_latent": f"video_latents/{stem}.pt",
661
- "metadata": metadata,
662
- }
663
- file.write(json.dumps(data) + "\n")
664
-
665
- print(f"Completed preprocessing. All files saved to `{output_dir.as_posix()}`")
666
-
667
-
668
- if __name__ == "__main__":
669
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/cogvideox/text_encoder/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .text_encoder import compute_prompt_embeddings
 
 
training/cogvideox/text_encoder/text_encoder.py DELETED
@@ -1,99 +0,0 @@
1
- from typing import List, Optional, Union
2
-
3
- import torch
4
- from transformers import T5EncoderModel, T5Tokenizer
5
-
6
-
7
- def _get_t5_prompt_embeds(
8
- tokenizer: T5Tokenizer,
9
- text_encoder: T5EncoderModel,
10
- prompt: Union[str, List[str]],
11
- num_videos_per_prompt: int = 1,
12
- max_sequence_length: int = 226,
13
- device: Optional[torch.device] = None,
14
- dtype: Optional[torch.dtype] = None,
15
- text_input_ids=None,
16
- ):
17
- prompt = [prompt] if isinstance(prompt, str) else prompt
18
- batch_size = len(prompt)
19
-
20
- if tokenizer is not None:
21
- text_inputs = tokenizer(
22
- prompt,
23
- padding="max_length",
24
- max_length=max_sequence_length,
25
- truncation=True,
26
- add_special_tokens=True,
27
- return_tensors="pt",
28
- )
29
- text_input_ids = text_inputs.input_ids
30
- else:
31
- if text_input_ids is None:
32
- raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
33
-
34
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
35
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
36
-
37
- # duplicate text embeddings for each generation per prompt, using mps friendly method
38
- _, seq_len, _ = prompt_embeds.shape
39
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
40
- prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
41
-
42
- return prompt_embeds
43
-
44
-
45
- def encode_prompt(
46
- tokenizer: T5Tokenizer,
47
- text_encoder: T5EncoderModel,
48
- prompt: Union[str, List[str]],
49
- num_videos_per_prompt: int = 1,
50
- max_sequence_length: int = 226,
51
- device: Optional[torch.device] = None,
52
- dtype: Optional[torch.dtype] = None,
53
- text_input_ids=None,
54
- ):
55
- prompt = [prompt] if isinstance(prompt, str) else prompt
56
- prompt_embeds = _get_t5_prompt_embeds(
57
- tokenizer,
58
- text_encoder,
59
- prompt=prompt,
60
- num_videos_per_prompt=num_videos_per_prompt,
61
- max_sequence_length=max_sequence_length,
62
- device=device,
63
- dtype=dtype,
64
- text_input_ids=text_input_ids,
65
- )
66
- return prompt_embeds
67
-
68
-
69
- def compute_prompt_embeddings(
70
- tokenizer: T5Tokenizer,
71
- text_encoder: T5EncoderModel,
72
- prompt: str,
73
- max_sequence_length: int,
74
- device: torch.device,
75
- dtype: torch.dtype,
76
- requires_grad: bool = False,
77
- ):
78
- if requires_grad:
79
- prompt_embeds = encode_prompt(
80
- tokenizer,
81
- text_encoder,
82
- prompt,
83
- num_videos_per_prompt=1,
84
- max_sequence_length=max_sequence_length,
85
- device=device,
86
- dtype=dtype,
87
- )
88
- else:
89
- with torch.no_grad():
90
- prompt_embeds = encode_prompt(
91
- tokenizer,
92
- text_encoder,
93
- prompt,
94
- num_videos_per_prompt=1,
95
- max_sequence_length=max_sequence_length,
96
- device=device,
97
- dtype=dtype,
98
- )
99
- return prompt_embeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/cogvideox/utils.py DELETED
@@ -1,260 +0,0 @@
1
- import gc
2
- import inspect
3
- from typing import Optional, Tuple, Union
4
-
5
- import torch
6
- from accelerate import Accelerator
7
- from accelerate.logging import get_logger
8
- from diffusers.models.embeddings import get_3d_rotary_pos_embed
9
- from diffusers.utils.torch_utils import is_compiled_module
10
-
11
-
12
- logger = get_logger(__name__)
13
-
14
-
15
- def get_optimizer(
16
- params_to_optimize,
17
- optimizer_name: str = "adam",
18
- learning_rate: float = 1e-3,
19
- beta1: float = 0.9,
20
- beta2: float = 0.95,
21
- beta3: float = 0.98,
22
- epsilon: float = 1e-8,
23
- weight_decay: float = 1e-4,
24
- prodigy_decouple: bool = False,
25
- prodigy_use_bias_correction: bool = False,
26
- prodigy_safeguard_warmup: bool = False,
27
- use_8bit: bool = False,
28
- use_4bit: bool = False,
29
- use_torchao: bool = False,
30
- use_deepspeed: bool = False,
31
- use_cpu_offload_optimizer: bool = False,
32
- offload_gradients: bool = False,
33
- ) -> torch.optim.Optimizer:
34
- optimizer_name = optimizer_name.lower()
35
-
36
- # Use DeepSpeed optimzer
37
- if use_deepspeed:
38
- from accelerate.utils import DummyOptim
39
-
40
- return DummyOptim(
41
- params_to_optimize,
42
- lr=learning_rate,
43
- betas=(beta1, beta2),
44
- eps=epsilon,
45
- weight_decay=weight_decay,
46
- )
47
-
48
- if use_8bit and use_4bit:
49
- raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
50
-
51
- if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
52
- try:
53
- import torchao
54
-
55
- torchao.__version__
56
- except ImportError:
57
- raise ImportError(
58
- "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
59
- )
60
-
61
- if not use_torchao and use_4bit:
62
- raise ValueError("4-bit Optimizers are only supported with torchao.")
63
-
64
- # Optimizer creation
65
- supported_optimizers = ["adam", "adamw", "prodigy", "came"]
66
- if optimizer_name not in supported_optimizers:
67
- logger.warning(
68
- f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
69
- )
70
- optimizer_name = "adamw"
71
-
72
- if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
73
- raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
74
-
75
- if use_8bit:
76
- try:
77
- import bitsandbytes as bnb
78
- except ImportError:
79
- raise ImportError(
80
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
81
- )
82
-
83
- if optimizer_name == "adamw":
84
- if use_torchao:
85
- from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
86
-
87
- optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
88
- else:
89
- optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
90
-
91
- init_kwargs = {
92
- "betas": (beta1, beta2),
93
- "eps": epsilon,
94
- "weight_decay": weight_decay,
95
- }
96
-
97
- elif optimizer_name == "adam":
98
- if use_torchao:
99
- from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
100
-
101
- optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
102
- else:
103
- optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
104
-
105
- init_kwargs = {
106
- "betas": (beta1, beta2),
107
- "eps": epsilon,
108
- "weight_decay": weight_decay,
109
- }
110
-
111
- elif optimizer_name == "prodigy":
112
- try:
113
- import prodigyopt
114
- except ImportError:
115
- raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
116
-
117
- optimizer_class = prodigyopt.Prodigy
118
-
119
- if learning_rate <= 0.1:
120
- logger.warning(
121
- "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
122
- )
123
-
124
- init_kwargs = {
125
- "lr": learning_rate,
126
- "betas": (beta1, beta2),
127
- "beta3": beta3,
128
- "eps": epsilon,
129
- "weight_decay": weight_decay,
130
- "decouple": prodigy_decouple,
131
- "use_bias_correction": prodigy_use_bias_correction,
132
- "safeguard_warmup": prodigy_safeguard_warmup,
133
- }
134
-
135
- elif optimizer_name == "came":
136
- try:
137
- import came_pytorch
138
- except ImportError:
139
- raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
140
-
141
- optimizer_class = came_pytorch.CAME
142
-
143
- init_kwargs = {
144
- "lr": learning_rate,
145
- "eps": (1e-30, 1e-16),
146
- "betas": (beta1, beta2, beta3),
147
- "weight_decay": weight_decay,
148
- }
149
-
150
- if use_cpu_offload_optimizer:
151
- from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
152
-
153
- if "fused" in inspect.signature(optimizer_class.__init__).parameters:
154
- init_kwargs.update({"fused": True})
155
-
156
- optimizer = CPUOffloadOptimizer(
157
- params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
158
- )
159
- else:
160
- optimizer = optimizer_class(params_to_optimize, **init_kwargs)
161
-
162
- return optimizer
163
-
164
-
165
- def get_gradient_norm(parameters):
166
- norm = 0
167
- for param in parameters:
168
- if param.grad is None:
169
- continue
170
- local_norm = param.grad.detach().data.norm(2)
171
- norm += local_norm.item() ** 2
172
- norm = norm**0.5
173
- return norm
174
-
175
-
176
- # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
177
- def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
178
- tw = tgt_width
179
- th = tgt_height
180
- h, w = src
181
- r = h / w
182
- if r > (th / tw):
183
- resize_height = th
184
- resize_width = int(round(th / h * w))
185
- else:
186
- resize_width = tw
187
- resize_height = int(round(tw / w * h))
188
-
189
- crop_top = int(round((th - resize_height) / 2.0))
190
- crop_left = int(round((tw - resize_width) / 2.0))
191
-
192
- return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
193
-
194
-
195
- def prepare_rotary_positional_embeddings(
196
- height: int,
197
- width: int,
198
- num_frames: int,
199
- vae_scale_factor_spatial: int = 8,
200
- patch_size: int = 2,
201
- patch_size_t: int = None,
202
- attention_head_dim: int = 64,
203
- device: Optional[torch.device] = None,
204
- base_height: int = 480,
205
- base_width: int = 720,
206
- ) -> Tuple[torch.Tensor, torch.Tensor]:
207
- grid_height = height // (vae_scale_factor_spatial * patch_size)
208
- grid_width = width // (vae_scale_factor_spatial * patch_size)
209
- base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
210
- base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
211
-
212
- if patch_size_t is None:
213
- # CogVideoX 1.0
214
- grid_crops_coords = get_resize_crop_region_for_grid(
215
- (grid_height, grid_width), base_size_width, base_size_height
216
- )
217
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
218
- embed_dim=attention_head_dim,
219
- crops_coords=grid_crops_coords,
220
- grid_size=(grid_height, grid_width),
221
- temporal_size=num_frames,
222
- )
223
- else:
224
- # CogVideoX 1.5
225
- base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
226
-
227
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
228
- embed_dim=attention_head_dim,
229
- crops_coords=None,
230
- grid_size=(grid_height, grid_width),
231
- temporal_size=base_num_frames,
232
- grid_type="slice",
233
- max_size=(base_size_height, base_size_width),
234
- )
235
-
236
- freqs_cos = freqs_cos.to(device=device)
237
- freqs_sin = freqs_sin.to(device=device)
238
- return freqs_cos, freqs_sin
239
-
240
-
241
- def reset_memory(device: Union[str, torch.device]) -> None:
242
- gc.collect()
243
- torch.cuda.empty_cache()
244
- torch.cuda.reset_peak_memory_stats(device)
245
- torch.cuda.reset_accumulated_memory_stats(device)
246
-
247
-
248
- def print_memory(device: Union[str, torch.device]) -> None:
249
- memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
250
- max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
251
- max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
252
- print(f"{memory_allocated=:.3f} GB")
253
- print(f"{max_memory_allocated=:.3f} GB")
254
- print(f"{max_memory_reserved=:.3f} GB")
255
-
256
-
257
- def unwrap_model(accelerator: Accelerator, model):
258
- model = accelerator.unwrap_model(model)
259
- model = model._orig_mod if is_compiled_module(model) else model
260
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/mochi-1/README.md DELETED
@@ -1,111 +0,0 @@
1
- # Simple Mochi-1 finetuner
2
-
3
- <table align=center>
4
- <tr>
5
- <th align=center> Dataset Sample </th>
6
- <th align=center> Test Sample </th>
7
- </tr>
8
- <tr>
9
- <td align=center><video src="https://github.com/user-attachments/assets/6f906a32-b169-493f-a713-07679e87cd91"> Your browser does not support the video tag. </video></td>
10
- <td align=center><video src="https://github.com/user-attachments/assets/d356e70f-ccf4-47f7-be1d-8d21108d8a84"> Your browser does not support the video tag. </video></td>
11
- </tr>
12
- </table>
13
-
14
- Now you can make Mochi-1 your own with `diffusers`, too 🤗 🧨
15
-
16
- We provide a minimal and faithful reimplementation of the [Mochi-1 original fine-tuner](https://github.com/genmoai/mochi/tree/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner). As usual, we leverage `peft` for things LoRA in our implementation.
17
-
18
- **Updates**
19
-
20
- December 1 2024: Support for checkpoint saving and loading.
21
-
22
- ## Getting started
23
-
24
- Install the dependencies: `pip install -r requirements.txt`. Also make sure your `diffusers` installation is from the current `main`.
25
-
26
- Download a demo dataset:
27
-
28
- ```bash
29
- huggingface-cli download \
30
- --repo-type dataset sayakpaul/video-dataset-disney-organized \
31
- --local-dir video-dataset-disney-organized
32
- ```
33
-
34
- The dataset follows the directory structure expected by the subsequent scripts. In particular, it follows what's prescribed [here](https://github.com/genmoai/mochi/tree/main/demos/fine_tuner#1-collect-your-videos-and-captions):
35
-
36
- ```bash
37
- video_1.mp4
38
- video_1.txt -- One-paragraph description of video_1
39
- video_2.mp4
40
- video_2.txt -- One-paragraph description of video_2
41
- ...
42
- ```
43
-
44
- Then run (be sure to check the paths accordingly):
45
-
46
- ```bash
47
- bash prepare_dataset.sh
48
- ```
49
-
50
- We can adjust `num_frames` and `resolution`. By default, in `prepare_dataset.sh`, we use `--force_upsample`. This means if the original video resolution is smaller than the requested resolution, we will upsample the video.
51
-
52
- > [!IMPORTANT]
53
- > It's important to have a resolution of at least 480x848 to satisy Mochi-1's requirements.
54
-
55
- Now, we're ready to fine-tune. To launch, run:
56
-
57
- ```bash
58
- bash train.sh
59
- ```
60
-
61
- You can disable intermediate validation by:
62
-
63
- ```diff
64
- - --validation_prompt "..." \
65
- - --validation_prompt_separator ::: \
66
- - --num_validation_videos 1 \
67
- - --validation_epochs 1 \
68
- ```
69
-
70
- We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM.
71
-
72
- To use the LoRA checkpoint:
73
-
74
- ```py
75
- from diffusers import MochiPipeline
76
- from diffusers.utils import export_to_video
77
- import torch
78
-
79
- pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
80
- pipe.load_lora_weights("path-to-lora")
81
- pipe.enable_model_cpu_offload()
82
-
83
- pipeline_args = {
84
- "prompt": "A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions",
85
- "guidance_scale": 6.0,
86
- "num_inference_steps": 64,
87
- "height": 480,
88
- "width": 848,
89
- "max_sequence_length": 256,
90
- "output_type": "np",
91
- }
92
-
93
- with torch.autocast("cuda", torch.bfloat16)
94
- video = pipe(**pipeline_args).frames[0]
95
- export_to_video(video)
96
- ```
97
-
98
- ## Known limitations
99
-
100
- (Contributions are welcome 🤗)
101
-
102
- Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below:
103
-
104
- * No support for distributed training.
105
- * `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support.
106
- * No support for 8bit optimizers (but should be relatively easy to add).
107
-
108
- **Misc**:
109
-
110
- * We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033).
111
- * `embed.py` script is non-batched.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/mochi-1/args.py DELETED
@@ -1,268 +0,0 @@
1
- """
2
- Default values taken from
3
- https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml
4
- when applicable.
5
- """
6
-
7
- import argparse
8
-
9
-
10
- def _get_model_args(parser: argparse.ArgumentParser) -> None:
11
- parser.add_argument(
12
- "--pretrained_model_name_or_path",
13
- type=str,
14
- default=None,
15
- required=True,
16
- help="Path to pretrained model or model identifier from huggingface.co/models.",
17
- )
18
- parser.add_argument(
19
- "--revision",
20
- type=str,
21
- default=None,
22
- required=False,
23
- help="Revision of pretrained model identifier from huggingface.co/models.",
24
- )
25
- parser.add_argument(
26
- "--variant",
27
- type=str,
28
- default=None,
29
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
30
- )
31
- parser.add_argument(
32
- "--cache_dir",
33
- type=str,
34
- default=None,
35
- help="The directory where the downloaded models and datasets will be stored.",
36
- )
37
- parser.add_argument(
38
- "--cast_dit",
39
- action="store_true",
40
- help="If we should cast DiT params to a lower precision.",
41
- )
42
- parser.add_argument(
43
- "--compile_dit",
44
- action="store_true",
45
- help="If we should compile the DiT.",
46
- )
47
-
48
-
49
- def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
50
- parser.add_argument(
51
- "--data_root",
52
- type=str,
53
- default=None,
54
- help=("A folder containing the training data."),
55
- )
56
- parser.add_argument(
57
- "--caption_dropout",
58
- type=float,
59
- default=None,
60
- help=("Probability to drop out captions randomly."),
61
- )
62
-
63
- parser.add_argument(
64
- "--dataloader_num_workers",
65
- type=int,
66
- default=0,
67
- help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
68
- )
69
- parser.add_argument(
70
- "--pin_memory",
71
- action="store_true",
72
- help="Whether or not to use the pinned memory setting in pytorch dataloader.",
73
- )
74
-
75
-
76
- def _get_validation_args(parser: argparse.ArgumentParser) -> None:
77
- parser.add_argument(
78
- "--validation_prompt",
79
- type=str,
80
- default=None,
81
- help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
82
- )
83
- parser.add_argument(
84
- "--validation_images",
85
- type=str,
86
- default=None,
87
- help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
88
- )
89
- parser.add_argument(
90
- "--validation_prompt_separator",
91
- type=str,
92
- default=":::",
93
- help="String that separates multiple validation prompts",
94
- )
95
- parser.add_argument(
96
- "--num_validation_videos",
97
- type=int,
98
- default=1,
99
- help="Number of videos that should be generated during validation per `validation_prompt`.",
100
- )
101
- parser.add_argument(
102
- "--validation_epochs",
103
- type=int,
104
- default=50,
105
- help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
106
- )
107
- parser.add_argument(
108
- "--enable_slicing",
109
- action="store_true",
110
- default=False,
111
- help="Whether or not to use VAE slicing for saving memory.",
112
- )
113
- parser.add_argument(
114
- "--enable_tiling",
115
- action="store_true",
116
- default=False,
117
- help="Whether or not to use VAE tiling for saving memory.",
118
- )
119
- parser.add_argument(
120
- "--enable_model_cpu_offload",
121
- action="store_true",
122
- default=False,
123
- help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
124
- )
125
- parser.add_argument(
126
- "--fps",
127
- type=int,
128
- default=30,
129
- help="FPS to use when serializing the output videos.",
130
- )
131
- parser.add_argument(
132
- "--height",
133
- type=int,
134
- default=480,
135
- )
136
- parser.add_argument(
137
- "--width",
138
- type=int,
139
- default=848,
140
- )
141
-
142
-
143
- def _get_training_args(parser: argparse.ArgumentParser) -> None:
144
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
145
- parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.")
146
- parser.add_argument(
147
- "--lora_alpha",
148
- type=int,
149
- default=16,
150
- help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
151
- )
152
- parser.add_argument(
153
- "--target_modules",
154
- nargs="+",
155
- type=str,
156
- default=["to_k", "to_q", "to_v", "to_out.0"],
157
- help="Target modules to train LoRA for.",
158
- )
159
- parser.add_argument(
160
- "--output_dir",
161
- type=str,
162
- default="mochi-lora",
163
- help="The output directory where the model predictions and checkpoints will be written.",
164
- )
165
- parser.add_argument(
166
- "--train_batch_size",
167
- type=int,
168
- default=4,
169
- help="Batch size (per device) for the training dataloader.",
170
- )
171
- parser.add_argument("--num_train_epochs", type=int, default=1)
172
- parser.add_argument(
173
- "--max_train_steps",
174
- type=int,
175
- default=None,
176
- help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
177
- )
178
- parser.add_argument(
179
- "--gradient_checkpointing",
180
- action="store_true",
181
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
182
- )
183
- parser.add_argument(
184
- "--learning_rate",
185
- type=float,
186
- default=2e-4,
187
- help="Initial learning rate (after the potential warmup period) to use.",
188
- )
189
- parser.add_argument(
190
- "--scale_lr",
191
- action="store_true",
192
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
193
- )
194
- parser.add_argument(
195
- "--lr_warmup_steps",
196
- type=int,
197
- default=200,
198
- help="Number of steps for the warmup in the lr scheduler.",
199
- )
200
- parser.add_argument(
201
- "--checkpointing_steps",
202
- type=int,
203
- default=None,
204
- )
205
- parser.add_argument(
206
- "--resume_from_checkpoint",
207
- type=str,
208
- default=None,
209
- )
210
-
211
-
212
- def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
213
- parser.add_argument(
214
- "--optimizer",
215
- type=lambda s: s.lower(),
216
- default="adam",
217
- choices=["adam", "adamw"],
218
- help=("The optimizer type to use."),
219
- )
220
- parser.add_argument(
221
- "--weight_decay",
222
- type=float,
223
- default=0.01,
224
- help="Weight decay to use for optimizer.",
225
- )
226
-
227
-
228
- def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
229
- parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
230
- parser.add_argument(
231
- "--push_to_hub",
232
- action="store_true",
233
- help="Whether or not to push the model to the Hub.",
234
- )
235
- parser.add_argument(
236
- "--hub_token",
237
- type=str,
238
- default=None,
239
- help="The token to use to push to the Model Hub.",
240
- )
241
- parser.add_argument(
242
- "--hub_model_id",
243
- type=str,
244
- default=None,
245
- help="The name of the repository to keep in sync with the local `output_dir`.",
246
- )
247
- parser.add_argument(
248
- "--allow_tf32",
249
- action="store_true",
250
- help=(
251
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
252
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
253
- ),
254
- )
255
- parser.add_argument("--report_to", type=str, default=None, help="If logging to wandb.")
256
-
257
-
258
- def get_args():
259
- parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.")
260
-
261
- _get_model_args(parser)
262
- _get_dataset_args(parser)
263
- _get_training_args(parser)
264
- _get_validation_args(parser)
265
- _get_optimizer_args(parser)
266
- _get_configuration_args(parser)
267
-
268
- return parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/mochi-1/dataset_simple.py DELETED
@@ -1,50 +0,0 @@
1
- """
2
- Taken from
3
- https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py
4
- """
5
-
6
- from pathlib import Path
7
-
8
- import click
9
- import torch
10
- from torch.utils.data import DataLoader, Dataset
11
-
12
-
13
- def load_to_cpu(x):
14
- return torch.load(x, map_location=torch.device("cpu"), weights_only=True)
15
-
16
-
17
- class LatentEmbedDataset(Dataset):
18
- def __init__(self, file_paths, repeat=1):
19
- self.items = [
20
- (Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt"))
21
- for p in file_paths
22
- if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file()
23
- ]
24
- self.items = self.items * repeat
25
- print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.")
26
-
27
- def __len__(self):
28
- return len(self.items)
29
-
30
- def __getitem__(self, idx):
31
- latent_path, embed_path = self.items[idx]
32
- return load_to_cpu(latent_path), load_to_cpu(embed_path)
33
-
34
-
35
- @click.command()
36
- @click.argument("directory", type=click.Path(exists=True, file_okay=False))
37
- def process_videos(directory):
38
- dir_path = Path(directory)
39
- mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")]
40
- assert mp4_files, f"No mp4 files found"
41
-
42
- dataset = LatentEmbedDataset(mp4_files)
43
- dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
44
-
45
- for latents, embeds in dataloader:
46
- print([(k, v.shape) for k, v in latents.items()])
47
-
48
-
49
- if __name__ == "__main__":
50
- process_videos()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/mochi-1/embed.py DELETED
@@ -1,111 +0,0 @@
1
- """
2
- Adapted from:
3
- https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py
4
- https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py
5
- """
6
-
7
- import click
8
- import torch
9
- import torchvision
10
- from pathlib import Path
11
- from diffusers import AutoencoderKLMochi, MochiPipeline
12
- from transformers import T5EncoderModel, T5Tokenizer
13
- from tqdm.auto import tqdm
14
-
15
-
16
- def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str):
17
- T, H, W = [int(s) for s in shape.split("x")]
18
- assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6"
19
- video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs")
20
- fps = metadata["video_fps"]
21
- video = video.permute(3, 0, 1, 2)
22
- og_shape = video.shape
23
- assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}"
24
- assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}"
25
- assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}"
26
- if video.shape[1] > T:
27
- video = video[:, :T]
28
- print(f"Trimmed video from {og_shape[1]} to first {T} frames")
29
- video = video.unsqueeze(0)
30
- video = video.float() / 127.5 - 1.0
31
- video = video.to(model.device)
32
-
33
- assert video.ndim == 5
34
-
35
- with torch.inference_mode():
36
- with torch.autocast("cuda", dtype=torch.bfloat16):
37
- ldist = model._encode(video)
38
-
39
- torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt"))
40
-
41
-
42
- @click.command()
43
- @click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
44
- @click.option(
45
- "--model_id",
46
- type=str,
47
- help="Repo id. Should be genmo/mochi-1-preview",
48
- default="genmo/mochi-1-preview",
49
- )
50
- @click.option("--shape", default="163x480x848", help="Shape of the video to encode")
51
- @click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.")
52
- def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None:
53
- """Process all videos and captions in a directory using a single GPU."""
54
- # comment out when running on unsupported hardware
55
- torch.backends.cuda.matmul.allow_tf32 = True
56
- torch.backends.cudnn.allow_tf32 = True
57
-
58
- # Get all video paths
59
- video_paths = list(output_dir.glob("**/*.mp4"))
60
- if not video_paths:
61
- print(f"No MP4 files found in {output_dir}")
62
- return
63
-
64
- text_paths = list(output_dir.glob("**/*.txt"))
65
- if not text_paths:
66
- print(f"No text files found in {output_dir}")
67
- return
68
-
69
- # load the models
70
- vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda")
71
- text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder")
72
- tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
73
- pipeline = MochiPipeline.from_pretrained(
74
- model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None
75
- ).to("cuda")
76
-
77
- for idx, video_path in tqdm(enumerate(sorted(video_paths))):
78
- print(f"Processing {video_path}")
79
- try:
80
- if video_path.with_suffix(".latent.pt").exists() and not overwrite:
81
- print(f"Skipping {video_path}")
82
- continue
83
-
84
- # encode videos.
85
- encode_videos(vae, vid_path=video_path, shape=shape)
86
-
87
- # embed captions.
88
- prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt")
89
- embed_path = prompt_path.with_suffix(".embed.pt")
90
-
91
- if embed_path.exists() and not overwrite:
92
- print(f"Skipping {prompt_path} - embeddings already exist")
93
- continue
94
-
95
- with open(prompt_path) as f:
96
- text = f.read().strip()
97
- with torch.inference_mode():
98
- conditioning = pipeline.encode_prompt(prompt=[text])
99
-
100
- conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]}
101
- torch.save(conditioning, embed_path)
102
-
103
- except Exception as e:
104
- import traceback
105
-
106
- traceback.print_exc()
107
- print(f"Error processing {video_path}: {str(e)}")
108
-
109
-
110
- if __name__ == "__main__":
111
- batch_process()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/mochi-1/prepare_dataset.sh DELETED
@@ -1,15 +0,0 @@
1
- #!/bin/bash
2
-
3
- GPU_ID=0
4
- VIDEO_DIR=video-dataset-disney-organized
5
- OUTPUT_DIR=videos_prepared
6
- NUM_FRAMES=37
7
- RESOLUTION=480x848
8
-
9
- # Extract width and height from RESOLUTION
10
- WIDTH=$(echo $RESOLUTION | cut -dx -f1)
11
- HEIGHT=$(echo $RESOLUTION | cut -dx -f2)
12
-
13
- python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample
14
-
15
- CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/mochi-1/requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- peft
2
- transformers
3
- wandb
4
- torch
5
- torchvision
6
- av==11.0.0
7
- moviepy==1.0.3
8
- click
 
 
 
 
 
 
 
 
 
training/mochi-1/text_to_video_lora.py DELETED
@@ -1,592 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import gc
17
- import random
18
- from glob import glob
19
- import math
20
- import os
21
- import torch.nn.functional as F
22
- import numpy as np
23
- from pathlib import Path
24
- from typing import Any, Dict, Tuple, List
25
-
26
- import torch
27
- import wandb
28
- from diffusers import FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
29
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
30
- from diffusers.training_utils import cast_training_params
31
- from diffusers.utils import export_to_video
32
- from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
33
- from huggingface_hub import create_repo, upload_folder
34
- from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
35
- from torch.utils.data import DataLoader
36
- from tqdm.auto import tqdm
37
-
38
-
39
- from args import get_args # isort:skip
40
- from dataset_simple import LatentEmbedDataset
41
-
42
- import sys
43
- from utils import print_memory, reset_memory # isort:skip
44
-
45
-
46
- # Taken from
47
- # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139
48
- def get_cosine_annealing_lr_scheduler(
49
- optimizer: torch.optim.Optimizer,
50
- warmup_steps: int,
51
- total_steps: int,
52
- ):
53
- def lr_lambda(step):
54
- if step < warmup_steps:
55
- return float(step) / float(max(1, warmup_steps))
56
- else:
57
- return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
58
-
59
- return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
60
-
61
-
62
- def save_model_card(
63
- repo_id: str,
64
- videos=None,
65
- base_model: str = None,
66
- validation_prompt=None,
67
- repo_folder=None,
68
- fps=30,
69
- ):
70
- widget_dict = []
71
- if videos is not None and len(videos) > 0:
72
- for i, video in enumerate(videos):
73
- export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps)
74
- widget_dict.append(
75
- {
76
- "text": validation_prompt if validation_prompt else " ",
77
- "output": {"url": f"final_video_{i}.mp4"},
78
- }
79
- )
80
-
81
- model_description = f"""
82
- # Mochi-1 Preview LoRA Finetune
83
-
84
- <Gallery />
85
-
86
- ## Model description
87
-
88
- This is a lora finetune of the Mochi-1 preview model `{base_model}`.
89
-
90
- The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
91
-
92
- ## Download model
93
-
94
- [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
95
-
96
- ## Usage
97
-
98
- Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
99
-
100
- ```py
101
- from diffusers import MochiPipeline
102
- from diffusers.utils import export_to_video
103
- import torch
104
-
105
- pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
106
- pipe.load_lora_weights("CHANGE_ME")
107
- pipe.enable_model_cpu_offload()
108
-
109
- with torch.autocast("cuda", torch.bfloat16):
110
- video = pipe(
111
- prompt="CHANGE_ME",
112
- guidance_scale=6.0,
113
- num_inference_steps=64,
114
- height=480,
115
- width=848,
116
- max_sequence_length=256,
117
- output_type="np"
118
- ).frames[0]
119
- export_to_video(video)
120
- ```
121
-
122
- For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
123
-
124
- """
125
- model_card = load_or_create_model_card(
126
- repo_id_or_path=repo_id,
127
- from_training=True,
128
- license="apache-2.0",
129
- base_model=base_model,
130
- prompt=validation_prompt,
131
- model_description=model_description,
132
- widget=widget_dict,
133
- )
134
- tags = [
135
- "text-to-video",
136
- "diffusers-training",
137
- "diffusers",
138
- "lora",
139
- "mochi-1-preview",
140
- "mochi-1-preview-diffusers",
141
- "template:sd-lora",
142
- ]
143
-
144
- model_card = populate_model_card(model_card, tags=tags)
145
- model_card.save(os.path.join(repo_folder, "README.md"))
146
-
147
-
148
- def log_validation(
149
- pipe: MochiPipeline,
150
- args: Dict[str, Any],
151
- pipeline_args: Dict[str, Any],
152
- epoch,
153
- wandb_run: str = None,
154
- is_final_validation: bool = False,
155
- ):
156
- print(
157
- f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
158
- )
159
- phase_name = "test" if is_final_validation else "validation"
160
-
161
- if not args.enable_model_cpu_offload:
162
- pipe = pipe.to("cuda")
163
-
164
- # run inference
165
- generator = torch.manual_seed(args.seed) if args.seed else None
166
-
167
- videos = []
168
- with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
169
- for _ in range(args.num_validation_videos):
170
- video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
171
- videos.append(video)
172
-
173
- video_filenames = []
174
- for i, video in enumerate(videos):
175
- prompt = (
176
- pipeline_args["prompt"][:25]
177
- .replace(" ", "_")
178
- .replace(" ", "_")
179
- .replace("'", "_")
180
- .replace('"', "_")
181
- .replace("/", "_")
182
- )
183
- filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
184
- export_to_video(video, filename, fps=30)
185
- video_filenames.append(filename)
186
-
187
- if wandb_run:
188
- wandb.log(
189
- {
190
- phase_name: [
191
- wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30)
192
- for i, filename in enumerate(video_filenames)
193
- ]
194
- }
195
- )
196
-
197
- return videos
198
-
199
-
200
- # Adapted from the original code:
201
- # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578
202
- def cast_dit(model, dtype):
203
- for name, module in model.named_modules():
204
- if isinstance(module, torch.nn.Linear):
205
- assert any(
206
- n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"]
207
- ), f"Unexpected linear layer: {name}"
208
- module.to(dtype=dtype)
209
- elif isinstance(module, torch.nn.Conv2d):
210
- module.to(dtype=dtype)
211
- return model
212
-
213
-
214
- def save_checkpoint(model, optimizer, lr_scheduler, global_step, checkpoint_path):
215
- lora_state_dict = get_peft_model_state_dict(model)
216
- torch.save(
217
- {
218
- "state_dict": lora_state_dict,
219
- "optimizer": optimizer.state_dict(),
220
- "lr_scheduler": lr_scheduler.state_dict(),
221
- "global_step": global_step,
222
- },
223
- checkpoint_path,
224
- )
225
-
226
-
227
- class CollateFunction:
228
- def __init__(self, caption_dropout: float = None) -> None:
229
- self.caption_dropout = caption_dropout
230
-
231
- def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]:
232
- ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0)
233
- z = DiagonalGaussianDistribution(ldists).sample()
234
- assert torch.isfinite(z).all()
235
-
236
- # Sample noise which we will add to the samples.
237
- eps = torch.randn_like(z)
238
- sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32)
239
-
240
- prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0)
241
- prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0)
242
- if self.caption_dropout and random.random() < self.caption_dropout:
243
- prompt_embeds.zero_()
244
- prompt_attention_mask = prompt_attention_mask.long()
245
- prompt_attention_mask.zero_()
246
- prompt_attention_mask = prompt_attention_mask.bool()
247
-
248
- return dict(
249
- z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask
250
- )
251
-
252
-
253
- def main(args):
254
- if not torch.cuda.is_available():
255
- raise ValueError("Not supported without CUDA.")
256
-
257
- if args.report_to == "wandb" and args.hub_token is not None:
258
- raise ValueError(
259
- "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
260
- " Please use `huggingface-cli login` to authenticate with the Hub."
261
- )
262
-
263
- # Handle the repository creation
264
- if args.output_dir is not None:
265
- os.makedirs(args.output_dir, exist_ok=True)
266
-
267
- if args.push_to_hub:
268
- repo_id = create_repo(
269
- repo_id=args.hub_model_id or Path(args.output_dir).name,
270
- exist_ok=True,
271
- ).repo_id
272
-
273
- # Prepare models and scheduler
274
- transformer = MochiTransformer3DModel.from_pretrained(
275
- args.pretrained_model_name_or_path,
276
- subfolder="transformer",
277
- revision=args.revision,
278
- variant=args.variant,
279
- )
280
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
281
- args.pretrained_model_name_or_path, subfolder="scheduler"
282
- )
283
-
284
- transformer.requires_grad_(False)
285
- transformer.to("cuda")
286
- if args.gradient_checkpointing:
287
- transformer.enable_gradient_checkpointing()
288
- if args.cast_dit:
289
- transformer = cast_dit(transformer, torch.bfloat16)
290
- if args.compile_dit:
291
- transformer.compile()
292
-
293
- # now we will add new LoRA weights to the attention layers
294
- transformer_lora_config = LoraConfig(
295
- r=args.rank,
296
- lora_alpha=args.lora_alpha,
297
- init_lora_weights="gaussian",
298
- target_modules=args.target_modules,
299
- )
300
- transformer.add_adapter(transformer_lora_config)
301
-
302
- # Enable TF32 for faster training on Ampere GPUs,
303
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
304
- if args.allow_tf32 and torch.cuda.is_available():
305
- torch.backends.cuda.matmul.allow_tf32 = True
306
-
307
- if args.scale_lr:
308
- args.learning_rate = args.learning_rate * args.train_batch_size
309
- # only upcast trainable parameters (LoRA) into fp32
310
- cast_training_params([transformer], dtype=torch.float32)
311
-
312
- # Prepare optimizer
313
- transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
314
- num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters)
315
- optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
316
-
317
- # Dataset and DataLoader
318
- train_vids = list(sorted(glob(f"{args.data_root}/*.mp4")))
319
- train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")]
320
- print(f"Found {len(train_vids)} training videos in {args.data_root}")
321
- assert len(train_vids) > 0, f"No training data found in {args.data_root}"
322
-
323
- collate_fn = CollateFunction(caption_dropout=args.caption_dropout)
324
- train_dataset = LatentEmbedDataset(train_vids, repeat=1)
325
- train_dataloader = DataLoader(
326
- train_dataset,
327
- collate_fn=collate_fn,
328
- batch_size=args.train_batch_size,
329
- num_workers=args.dataloader_num_workers,
330
- pin_memory=args.pin_memory,
331
- )
332
-
333
- # LR scheduler and math around the number of training steps.
334
- overrode_max_train_steps = False
335
- num_update_steps_per_epoch = len(train_dataloader)
336
- if args.max_train_steps is None:
337
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
338
- overrode_max_train_steps = True
339
-
340
- lr_scheduler = get_cosine_annealing_lr_scheduler(
341
- optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps
342
- )
343
-
344
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
345
- num_update_steps_per_epoch = len(train_dataloader)
346
- if overrode_max_train_steps:
347
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
348
- # Afterwards we recalculate our number of training epochs
349
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
350
-
351
- # We need to initialize the trackers we use, and also store our configuration.
352
- # The trackers initializes automatically on the main process.
353
- wandb_run = None
354
- if args.report_to == "wandb":
355
- tracker_name = args.tracker_name or "mochi-1-lora"
356
- wandb_run = wandb.init(project=tracker_name, config=vars(args))
357
-
358
- # Resume from checkpoint if specified
359
- if args.resume_from_checkpoint:
360
- checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu", weights_only=True)
361
- if "global_step" in checkpoint:
362
- global_step = checkpoint["global_step"]
363
- if "optimizer" in checkpoint:
364
- optimizer.load_state_dict(checkpoint["optimizer"])
365
- if "lr_scheduler" in checkpoint:
366
- lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
367
-
368
- set_peft_model_state_dict(transformer, checkpoint["state_dict"])
369
-
370
- print(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
371
- print(f"Resuming from global step: {global_step}")
372
- else:
373
- global_step = 0
374
-
375
- print("===== Memory before training =====")
376
- reset_memory("cuda")
377
- print_memory("cuda")
378
-
379
- # Train!
380
- total_batch_size = args.train_batch_size
381
- print("***** Running training *****")
382
- print(f" Num trainable parameters = {num_trainable_parameters}")
383
- print(f" Num examples = {len(train_dataset)}")
384
- print(f" Num batches each epoch = {len(train_dataloader)}")
385
- print(f" Num epochs = {args.num_train_epochs}")
386
- print(f" Instantaneous batch size per device = {args.train_batch_size}")
387
- print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
388
- print(f" Total optimization steps = {args.max_train_steps}")
389
-
390
- first_epoch = 0
391
- progress_bar = tqdm(
392
- range(0, args.max_train_steps),
393
- initial=global_step,
394
- desc="Steps",
395
- )
396
- for epoch in range(first_epoch, args.num_train_epochs):
397
- transformer.train()
398
-
399
- for step, batch in enumerate(train_dataloader):
400
- with torch.no_grad():
401
- z = batch["z"].to("cuda")
402
- eps = batch["eps"].to("cuda")
403
- sigma = batch["sigma"].to("cuda")
404
- prompt_embeds = batch["prompt_embeds"].to("cuda")
405
- prompt_attention_mask = batch["prompt_attention_mask"].to("cuda")
406
-
407
- sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1]
408
- # Add noise according to flow matching.
409
- # zt = (1 - texp) * x + texp * z1
410
- z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps
411
- ut = z - eps
412
-
413
- # (1 - sigma) because of
414
- # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656
415
- # Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation.
416
- timesteps = (1 - sigma) * scheduler.config.num_train_timesteps
417
-
418
- with torch.autocast("cuda", torch.bfloat16):
419
- model_pred = transformer(
420
- hidden_states=z_sigma,
421
- encoder_hidden_states=prompt_embeds,
422
- encoder_attention_mask=prompt_attention_mask,
423
- timestep=timesteps,
424
- return_dict=False,
425
- )[0]
426
- assert model_pred.shape == z.shape
427
- loss = F.mse_loss(model_pred.float(), ut.float())
428
- loss.backward()
429
-
430
- optimizer.step()
431
- optimizer.zero_grad()
432
- lr_scheduler.step()
433
-
434
- progress_bar.update(1)
435
- global_step += 1
436
-
437
- last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
438
- logs = {"loss": loss.detach().item(), "lr": last_lr}
439
- progress_bar.set_postfix(**logs)
440
- if wandb_run:
441
- wandb_run.log(logs, step=global_step)
442
-
443
- if args.checkpointing_steps is not None and global_step % args.checkpointing_steps == 0:
444
- print(f"Saving checkpoint at step {global_step}")
445
- checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt")
446
- save_checkpoint(
447
- transformer,
448
- optimizer,
449
- lr_scheduler,
450
- global_step,
451
- checkpoint_path,
452
- )
453
-
454
- if global_step >= args.max_train_steps:
455
- break
456
-
457
- if global_step >= args.max_train_steps:
458
- break
459
-
460
- if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
461
- print("===== Memory before validation =====")
462
- print_memory("cuda")
463
-
464
- transformer.eval()
465
- pipe = MochiPipeline.from_pretrained(
466
- args.pretrained_model_name_or_path,
467
- transformer=transformer,
468
- scheduler=scheduler,
469
- revision=args.revision,
470
- variant=args.variant,
471
- )
472
-
473
- if args.enable_slicing:
474
- pipe.vae.enable_slicing()
475
- if args.enable_tiling:
476
- pipe.vae.enable_tiling()
477
- if args.enable_model_cpu_offload:
478
- pipe.enable_model_cpu_offload()
479
-
480
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
481
- for validation_prompt in validation_prompts:
482
- pipeline_args = {
483
- "prompt": validation_prompt,
484
- "guidance_scale": 6.0,
485
- "num_inference_steps": 64,
486
- "height": args.height,
487
- "width": args.width,
488
- "max_sequence_length": 256,
489
- }
490
- log_validation(
491
- pipe=pipe,
492
- args=args,
493
- pipeline_args=pipeline_args,
494
- epoch=epoch,
495
- wandb_run=wandb_run,
496
- )
497
-
498
- print("===== Memory after validation =====")
499
- print_memory("cuda")
500
- reset_memory("cuda")
501
-
502
- del pipe.text_encoder
503
- del pipe.vae
504
- del pipe
505
- gc.collect()
506
- torch.cuda.empty_cache()
507
-
508
- transformer.train()
509
-
510
- transformer.eval()
511
- transformer_lora_layers = get_peft_model_state_dict(transformer)
512
- MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers)
513
-
514
- # Cleanup trained models to save memory
515
- del transformer
516
-
517
- gc.collect()
518
- torch.cuda.empty_cache()
519
-
520
- # Final test inference
521
- validation_outputs = []
522
- if args.validation_prompt and args.num_validation_videos > 0:
523
- print("===== Memory before testing =====")
524
- print_memory("cuda")
525
- reset_memory("cuda")
526
-
527
- pipe = MochiPipeline.from_pretrained(
528
- args.pretrained_model_name_or_path,
529
- revision=args.revision,
530
- variant=args.variant,
531
- )
532
-
533
- if args.enable_slicing:
534
- pipe.vae.enable_slicing()
535
- if args.enable_tiling:
536
- pipe.vae.enable_tiling()
537
- if args.enable_model_cpu_offload:
538
- pipe.enable_model_cpu_offload()
539
-
540
- # Load LoRA weights
541
- lora_scaling = args.lora_alpha / args.rank
542
- pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora")
543
- pipe.set_adapters(["mochi-lora"], [lora_scaling])
544
-
545
- # Run inference
546
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
547
- for validation_prompt in validation_prompts:
548
- pipeline_args = {
549
- "prompt": validation_prompt,
550
- "guidance_scale": 6.0,
551
- "num_inference_steps": 64,
552
- "height": args.height,
553
- "width": args.width,
554
- "max_sequence_length": 256,
555
- }
556
-
557
- video = log_validation(
558
- pipe=pipe,
559
- args=args,
560
- pipeline_args=pipeline_args,
561
- epoch=epoch,
562
- wandb_run=wandb_run,
563
- is_final_validation=True,
564
- )
565
- validation_outputs.extend(video)
566
-
567
- print("===== Memory after testing =====")
568
- print_memory("cuda")
569
- reset_memory("cuda")
570
- torch.cuda.synchronize("cuda")
571
-
572
- if args.push_to_hub:
573
- save_model_card(
574
- repo_id,
575
- videos=validation_outputs,
576
- base_model=args.pretrained_model_name_or_path,
577
- validation_prompt=args.validation_prompt,
578
- repo_folder=args.output_dir,
579
- fps=args.fps,
580
- )
581
- upload_folder(
582
- repo_id=repo_id,
583
- folder_path=args.output_dir,
584
- commit_message="End of training",
585
- ignore_patterns=["*.bin"],
586
- )
587
- print(f"Params pushed to {repo_id}.")
588
-
589
-
590
- if __name__ == "__main__":
591
- args = get_args()
592
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/mochi-1/train.sh DELETED
@@ -1,37 +0,0 @@
1
- #!/bin/bash
2
- export NCCL_P2P_DISABLE=1
3
- export TORCH_NCCL_ENABLE_MONITORING=0
4
-
5
- GPU_IDS="0"
6
-
7
- DATA_ROOT="videos_prepared"
8
- MODEL="genmo/mochi-1-preview"
9
- OUTPUT_PATH="mochi-lora"
10
-
11
- cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python text_to_video_lora.py \
12
- --pretrained_model_name_or_path $MODEL \
13
- --cast_dit \
14
- --data_root $DATA_ROOT \
15
- --seed 42 \
16
- --output_dir $OUTPUT_PATH \
17
- --train_batch_size 1 \
18
- --dataloader_num_workers 4 \
19
- --pin_memory \
20
- --caption_dropout 0.1 \
21
- --max_train_steps 2000 \
22
- --gradient_checkpointing \
23
- --enable_slicing \
24
- --enable_tiling \
25
- --enable_model_cpu_offload \
26
- --optimizer adamw \
27
- --validation_prompt \"A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \
28
- --validation_prompt_separator ::: \
29
- --num_validation_videos 1 \
30
- --validation_epochs 1 \
31
- --allow_tf32 \
32
- --report_to wandb \
33
- --push_to_hub"
34
-
35
- echo "Running command: $cmd"
36
- eval $cmd
37
- echo -ne "-------------------- Finished executing script --------------------\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/mochi-1/trim_and_crop_videos.py DELETED
@@ -1,126 +0,0 @@
1
- """
2
- Adapted from:
3
- https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py
4
- """
5
-
6
- from pathlib import Path
7
- import shutil
8
-
9
- import click
10
- from moviepy.editor import VideoFileClip
11
- from tqdm import tqdm
12
-
13
-
14
- @click.command()
15
- @click.argument("folder", type=click.Path(exists=True, dir_okay=True))
16
- @click.argument("output_folder", type=click.Path(dir_okay=True))
17
- @click.option("--num_frames", "-f", type=float, default=30, help="Number of frames")
18
- @click.option("--resolution", "-r", type=str, default="480x848", help="Video resolution")
19
- @click.option("--force_upsample", is_flag=True, help="Force upsample.")
20
- def truncate_videos(folder, output_folder, num_frames, resolution, force_upsample):
21
- """Truncate all MP4 and MOV files in FOLDER to specified number of frames and resolution"""
22
- input_path = Path(folder)
23
- output_path = Path(output_folder)
24
- output_path.mkdir(parents=True, exist_ok=True)
25
-
26
- # Parse target resolution
27
- target_height, target_width = map(int, resolution.split("x"))
28
-
29
- # Calculate duration
30
- duration = (num_frames / 30) + 0.09
31
-
32
- # Find all MP4 and MOV files
33
- video_files = (
34
- list(input_path.rglob("*.mp4"))
35
- + list(input_path.rglob("*.MOV"))
36
- + list(input_path.rglob("*.mov"))
37
- + list(input_path.rglob("*.MP4"))
38
- )
39
-
40
- for file_path in tqdm(video_files):
41
- try:
42
- relative_path = file_path.relative_to(input_path)
43
- output_file = output_path / relative_path.with_suffix(".mp4")
44
- output_file.parent.mkdir(parents=True, exist_ok=True)
45
-
46
- click.echo(f"Processing: {file_path}")
47
- video = VideoFileClip(str(file_path))
48
-
49
- # Skip if video is too short
50
- if video.duration < duration:
51
- click.echo(f"Skipping {file_path} as it is too short")
52
- continue
53
-
54
- # Skip if target resolution is larger than input
55
- if target_width > video.w or target_height > video.h:
56
- if force_upsample:
57
- click.echo(
58
- f"{file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}. So, upsampling the video."
59
- )
60
- video = video.resize(width=target_width, height=target_height)
61
- else:
62
- click.echo(
63
- f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}"
64
- )
65
- continue
66
-
67
- # First truncate duration
68
- truncated = video.subclip(0, duration)
69
-
70
- # Calculate crop dimensions to maintain aspect ratio
71
- target_ratio = target_width / target_height
72
- current_ratio = truncated.w / truncated.h
73
-
74
- if current_ratio > target_ratio:
75
- # Video is wider than target ratio - crop width
76
- new_width = int(truncated.h * target_ratio)
77
- x1 = (truncated.w - new_width) // 2
78
- final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height))
79
- else:
80
- # Video is taller than target ratio - crop height
81
- new_height = int(truncated.w / target_ratio)
82
- y1 = (truncated.h - new_height) // 2
83
- final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height))
84
-
85
- # Set output parameters for consistent MP4 encoding
86
- output_params = {
87
- "codec": "libx264",
88
- "audio": False, # Disable audio
89
- "preset": "medium", # Balance between speed and quality
90
- "bitrate": "5000k", # Adjust as needed
91
- }
92
-
93
- # Set FPS to 30
94
- final = final.set_fps(30)
95
-
96
- # Check for a corresponding .txt file
97
- txt_file_path = file_path.with_suffix(".txt")
98
- if txt_file_path.exists():
99
- output_txt_file = output_path / relative_path.with_suffix(".txt")
100
- output_txt_file.parent.mkdir(parents=True, exist_ok=True)
101
- shutil.copy(txt_file_path, output_txt_file)
102
- click.echo(f"Copied {txt_file_path} to {output_txt_file}")
103
- else:
104
- # Print warning in bold yellow with a warning emoji
105
- click.echo(
106
- f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m"
107
- )
108
- output_txt_file = output_path / relative_path.with_suffix(".txt")
109
- output_txt_file.parent.mkdir(parents=True, exist_ok=True)
110
- output_txt_file.touch()
111
-
112
- # Write the output file
113
- final.write_videofile(str(output_file), **output_params)
114
-
115
- # Clean up
116
- video.close()
117
- truncated.close()
118
- final.close()
119
-
120
- except Exception as e:
121
- click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True)
122
- raise
123
-
124
-
125
- if __name__ == "__main__":
126
- truncate_videos()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/mochi-1/utils.py DELETED
@@ -1,22 +0,0 @@
1
- import gc
2
- import inspect
3
- from typing import Optional, Tuple, Union
4
-
5
- import torch
6
-
7
- logger = get_logger(__name__)
8
-
9
- def reset_memory(device: Union[str, torch.device]) -> None:
10
- gc.collect()
11
- torch.cuda.empty_cache()
12
- torch.cuda.reset_peak_memory_stats(device)
13
- torch.cuda.reset_accumulated_memory_stats(device)
14
-
15
-
16
- def print_memory(device: Union[str, torch.device]) -> None:
17
- memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
18
- max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
19
- max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
20
- print(f"{memory_allocated=:.3f} GB")
21
- print(f"{max_memory_allocated=:.3f} GB")
22
- print(f"{max_memory_reserved=:.3f} GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/prepare_dataset.sh DELETED
@@ -1,48 +0,0 @@
1
- #!/bin/bash
2
-
3
- MODEL_ID="THUDM/CogVideoX-2b"
4
-
5
- NUM_GPUS=8
6
-
7
- # For more details on the expected data format, please refer to the README.
8
- DATA_ROOT="/path/to/my/datasets/video-dataset" # This needs to be the path to the base directory where your videos are located.
9
- CAPTION_COLUMN="prompt.txt"
10
- VIDEO_COLUMN="videos.txt"
11
- OUTPUT_DIR="/path/to/my/datasets/preprocessed-dataset"
12
- HEIGHT_BUCKETS="480 720"
13
- WIDTH_BUCKETS="720 960"
14
- FRAME_BUCKETS="49"
15
- MAX_NUM_FRAMES="49"
16
- MAX_SEQUENCE_LENGTH=226
17
- TARGET_FPS=8
18
- BATCH_SIZE=1
19
- DTYPE=fp32
20
-
21
- # To create a folder-style dataset structure without pre-encoding videos and captions
22
- # For Image-to-Video finetuning, make sure to pass `--save_image_latents`
23
- CMD_WITHOUT_PRE_ENCODING="\
24
- torchrun --nproc_per_node=$NUM_GPUS \
25
- training/prepare_dataset.py \
26
- --model_id $MODEL_ID \
27
- --data_root $DATA_ROOT \
28
- --caption_column $CAPTION_COLUMN \
29
- --video_column $VIDEO_COLUMN \
30
- --output_dir $OUTPUT_DIR \
31
- --height_buckets $HEIGHT_BUCKETS \
32
- --width_buckets $WIDTH_BUCKETS \
33
- --frame_buckets $FRAME_BUCKETS \
34
- --max_num_frames $MAX_NUM_FRAMES \
35
- --max_sequence_length $MAX_SEQUENCE_LENGTH \
36
- --target_fps $TARGET_FPS \
37
- --batch_size $BATCH_SIZE \
38
- --dtype $DTYPE
39
- "
40
-
41
- CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_latents_and_embeddings"
42
-
43
- # Select which you'd like to run
44
- CMD=$CMD_WITH_PRE_ENCODING
45
-
46
- echo "===== Running \`$CMD\` ====="
47
- eval $CMD
48
- echo -ne "===== Finished running script =====\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/train_image_to_video_lora.sh DELETED
@@ -1,82 +0,0 @@
1
- export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
2
- export TORCHDYNAMO_VERBOSE=1
3
- export WANDB_MODE="offline"
4
- export NCCL_P2P_DISABLE=1
5
- export TORCH_NCCL_ENABLE_MONITORING=0
6
-
7
- GPU_IDS="0"
8
-
9
- # Training Configurations
10
- # Experiment with as many hyperparameters as you want!
11
- LEARNING_RATES=("1e-4" "1e-3")
12
- LR_SCHEDULES=("cosine_with_restarts")
13
- OPTIMIZERS=("adamw" "adam")
14
- MAX_TRAIN_STEPS=("3000")
15
-
16
- # Single GPU uncompiled training
17
- ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
18
-
19
- # Absolute path to where the data is located. Make sure to have read the README for how to prepare data.
20
- # This example assumes you downloaded an already prepared dataset from HF CLI as follows:
21
- # huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset
22
- DATA_ROOT="/path/to/my/datasets/disney-dataset"
23
- CAPTION_COLUMN="prompt.txt"
24
- VIDEO_COLUMN="videos.txt"
25
-
26
- # Launch experiments with different hyperparameters
27
- for learning_rate in "${LEARNING_RATES[@]}"; do
28
- for lr_schedule in "${LR_SCHEDULES[@]}"; do
29
- for optimizer in "${OPTIMIZERS[@]}"; do
30
- for steps in "${MAX_TRAIN_STEPS[@]}"; do
31
- output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
32
-
33
- cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_image_to_video_lora.py \
34
- --pretrained_model_name_or_path THUDM/CogVideoX-5b-I2V \
35
- --data_root $DATA_ROOT \
36
- --caption_column $CAPTION_COLUMN \
37
- --video_column $VIDEO_COLUMN \
38
- --id_token BW_STYLE \
39
- --height_buckets 480 \
40
- --width_buckets 720 \
41
- --frame_buckets 49 \
42
- --dataloader_num_workers 8 \
43
- --pin_memory \
44
- --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
45
- --validation_images \"/path/to/image1.png:::/path/to/image2.png\"
46
- --validation_prompt_separator ::: \
47
- --num_validation_videos 1 \
48
- --validation_epochs 10 \
49
- --seed 42 \
50
- --rank 128 \
51
- --lora_alpha 128 \
52
- --mixed_precision bf16 \
53
- --output_dir $output_dir \
54
- --max_num_frames 49 \
55
- --train_batch_size 1 \
56
- --max_train_steps $steps \
57
- --checkpointing_steps 1000 \
58
- --gradient_accumulation_steps 1 \
59
- --gradient_checkpointing \
60
- --learning_rate $learning_rate \
61
- --lr_scheduler $lr_schedule \
62
- --lr_warmup_steps 400 \
63
- --lr_num_cycles 1 \
64
- --enable_slicing \
65
- --enable_tiling \
66
- --noised_image_dropout 0.05 \
67
- --optimizer $optimizer \
68
- --beta1 0.9 \
69
- --beta2 0.95 \
70
- --weight_decay 0.001 \
71
- --max_grad_norm 1.0 \
72
- --allow_tf32 \
73
- --report_to wandb \
74
- --nccl_timeout 1800"
75
-
76
- echo "Running command: $cmd"
77
- eval $cmd
78
- echo -ne "-------------------- Finished executing script --------------------\n\n"
79
- done
80
- done
81
- done
82
- done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/train_image_to_video_sft.sh DELETED
@@ -1,87 +0,0 @@
1
- export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
2
- export TORCHDYNAMO_VERBOSE=1
3
- export WANDB_MODE="offline"
4
- # export NCCL_P2P_DISABLE=1
5
- export TORCH_NCCL_ENABLE_MONITORING=0
6
- export TOKENIZERS_PARALLELISM=true
7
- export OMP_NUM_THREADS=16
8
-
9
- GPU_IDS="0,1"
10
-
11
- # Training Configurations
12
- # Experiment with as many hyperparameters as you want!
13
- LEARNING_RATES=("1e-4")
14
- LR_SCHEDULES=("cosine_with_restarts")
15
- OPTIMIZERS=("adamw")
16
- MAX_TRAIN_STEPS=("20000")
17
-
18
- # Single GPU uncompiled training
19
- ACCELERATE_CONFIG_FILE="accelerate_configs/deepspeed.yaml"
20
-
21
- # Absolute path to where the data is located. Make sure to have read the README for how to prepare data.
22
- # This example assumes you downloaded an already prepared dataset from HF CLI as follows:
23
- # huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset
24
- DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
25
- CAPTION_COLUMN="prompt.txt"
26
- VIDEO_COLUMN="videos.txt"
27
- MODEL_PATH="THUDM/CogVideoX1.5-5B-I2V"
28
-
29
- # Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process.
30
- # Launch experiments with different hyperparameters
31
-
32
- for learning_rate in "${LEARNING_RATES[@]}"; do
33
- for lr_schedule in "${LR_SCHEDULES[@]}"; do
34
- for optimizer in "${OPTIMIZERS[@]}"; do
35
- for steps in "${MAX_TRAIN_STEPS[@]}"; do
36
- output_dir="./cogvideox-sft__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
37
-
38
- cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE \
39
- --gpu_ids $GPU_IDS \
40
- training/cogvideox/cogvideox_image_to_video_sft.py \
41
- --pretrained_model_name_or_path $MODEL_PATH \
42
- --data_root $DATA_ROOT \
43
- --caption_column $CAPTION_COLUMN \
44
- --video_column $VIDEO_COLUMN \
45
- --id_token BW_STYLE \
46
- --height_buckets 480 \
47
- --width_buckets 720 \
48
- --frame_buckets 77 \
49
- --dataloader_num_workers 8 \
50
- --pin_memory \
51
- --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
52
- --validation_images \"/path/to/image1.png:::/path/to/image2.png\" \
53
- --validation_prompt_separator ::: \
54
- --num_validation_videos 1 \
55
- --validation_epochs 1 \
56
- --seed 42 \
57
- --mixed_precision bf16 \
58
- --output_dir $output_dir \
59
- --max_num_frames 77 \
60
- --train_batch_size 1 \
61
- --max_train_steps $steps \
62
- --checkpointing_steps 2000 \
63
- --gradient_accumulation_steps 4 \
64
- --gradient_checkpointing \
65
- --learning_rate $learning_rate \
66
- --lr_scheduler $lr_schedule \
67
- --lr_warmup_steps 800 \
68
- --lr_num_cycles 1 \
69
- --enable_slicing \
70
- --enable_tiling \
71
- --noised_image_dropout 0.05 \
72
- --optimizer $optimizer \
73
- --beta1 0.9 \
74
- --beta2 0.95 \
75
- --weight_decay 0.001 \
76
- --max_grad_norm 1.0 \
77
- --allow_tf32 \
78
- --report_to wandb \
79
- --nccl_timeout 1800"
80
-
81
- echo "Running command: $cmd"
82
- eval $cmd
83
- echo -ne "-------------------- Finished executing script --------------------\n\n"
84
- done
85
- done
86
- done
87
- done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/train_text_to_video_lora.sh DELETED
@@ -1,86 +0,0 @@
1
- export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
2
- export TORCHDYNAMO_VERBOSE=1
3
- export WANDB_MODE="offline"
4
- export NCCL_P2P_DISABLE=1
5
- export TORCH_NCCL_ENABLE_MONITORING=0
6
-
7
- GPU_IDS="0"
8
-
9
- # Training Configurations
10
- # Experiment with as many hyperparameters as you want!
11
- LEARNING_RATES=("1e-4" "1e-3")
12
- LR_SCHEDULES=("cosine_with_restarts")
13
- OPTIMIZERS=("adamw" "adam")
14
- MAX_TRAIN_STEPS=("3000")
15
-
16
- # Single GPU uncompiled training
17
- ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
18
-
19
- # Absolute path to where the data is located. Make sure to have read the README for how to prepare data.
20
- # This example assumes you downloaded an already prepared dataset from HF CLI as follows:
21
- # huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset
22
- DATA_ROOT="/path/to/my/datasets/disney-dataset"
23
-
24
- CAPTION_COLUMN="prompt.txt"
25
- VIDEO_COLUMN="videos.txt"
26
- MODEL_PATH="THUDM/CogVideoX-5b"
27
-
28
- # Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process.
29
- # Launch experiments with different hyperparameters
30
-
31
- for learning_rate in "${LEARNING_RATES[@]}"; do
32
- for lr_schedule in "${LR_SCHEDULES[@]}"; do
33
- for optimizer in "${OPTIMIZERS[@]}"; do
34
- for steps in "${MAX_TRAIN_STEPS[@]}"; do
35
- output_dir="./cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
36
-
37
- cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_lora.py \
38
- --pretrained_model_name_or_path $MODEL_PATH \
39
- --data_root $DATA_ROOT \
40
- --caption_column $CAPTION_COLUMN \
41
- --video_column $VIDEO_COLUMN \
42
- --id_token BW_STYLE \
43
- --height_buckets 480 \
44
- --width_buckets 720 \
45
- --frame_buckets 49 \
46
- --dataloader_num_workers 8 \
47
- --pin_memory \
48
- --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
49
- --validation_prompt_separator ::: \
50
- --num_validation_videos 1 \
51
- --validation_epochs 10 \
52
- --seed 42 \
53
- --rank 128 \
54
- --lora_alpha 128 \
55
- --mixed_precision bf16 \
56
- --output_dir $output_dir \
57
- --max_num_frames 49 \
58
- --train_batch_size 1 \
59
- --max_train_steps $steps \
60
- --checkpointing_steps 1000 \
61
- --gradient_accumulation_steps 1 \
62
- --gradient_checkpointing \
63
- --learning_rate $learning_rate \
64
- --lr_scheduler $lr_schedule \
65
- --lr_warmup_steps 400 \
66
- --lr_num_cycles 1 \
67
- --enable_slicing \
68
- --enable_tiling \
69
- --enable_model_cpu_offload \
70
- --load_tensors \
71
- --optimizer $optimizer \
72
- --beta1 0.9 \
73
- --beta2 0.95 \
74
- --weight_decay 0.001 \
75
- --max_grad_norm 1.0 \
76
- --allow_tf32 \
77
- --report_to wandb \
78
- --nccl_timeout 1800"
79
-
80
- echo "Running command: $cmd"
81
- eval $cmd
82
- echo -ne "-------------------- Finished executing script --------------------\n\n"
83
- done
84
- done
85
- done
86
- done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/train_text_to_video_sft.sh DELETED
@@ -1,77 +0,0 @@
1
- export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
2
- export TORCHDYNAMO_VERBOSE=1
3
- export WANDB_MODE="offline"
4
- export NCCL_P2P_DISABLE=1
5
- export TORCH_NCCL_ENABLE_MONITORING=0
6
-
7
- GPU_IDS="0"
8
-
9
- # Training Configurations
10
- # Experiment with as many hyperparameters as you want!
11
- LEARNING_RATES=("1e-4")
12
- LR_SCHEDULES=("cosine_with_restarts")
13
- OPTIMIZERS=("adamw")
14
- MAX_TRAIN_STEPS=("20000")
15
-
16
- # Single GPU uncompiled training
17
- ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
18
-
19
- # Absolute path to where the data is located. Make sure to have read the README for how to prepare data.
20
- # This example assumes you downloaded an already prepared dataset from HF CLI as follows:
21
- # huggingface-cli download --repo-type dataset Wild-Heart/Tom-and-Jerry-VideoGeneration-Dataset --local-dir /path/to/my/datasets/tom-and-jerry-dataset
22
- DATA_ROOT="/path/to/my/datasets/tom-and-jerry-dataset"
23
- CAPTION_COLUMN="captions.txt"
24
- VIDEO_COLUMN="videos.txt"
25
-
26
- # Launch experiments with different hyperparameters
27
- for learning_rate in "${LEARNING_RATES[@]}"; do
28
- for lr_schedule in "${LR_SCHEDULES[@]}"; do
29
- for optimizer in "${OPTIMIZERS[@]}"; do
30
- for steps in "${MAX_TRAIN_STEPS[@]}"; do
31
- output_dir="/path/to/my/models/cogvideox-sft__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
32
-
33
- cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_sft.py \
34
- --pretrained_model_name_or_path THUDM/CogVideoX-5b \
35
- --data_root $DATA_ROOT \
36
- --caption_column $CAPTION_COLUMN \
37
- --video_column $VIDEO_COLUMN \
38
- --height_buckets 480 \
39
- --width_buckets 720 \
40
- --frame_buckets 49 \
41
- --dataloader_num_workers 8 \
42
- --pin_memory \
43
- --validation_prompt \"Tom, the mischievous gray cat, is sprawled out on a vibrant red pillow, his body relaxed and his eyes half-closed, as if he's just woken up or is about to doze off. His white paws are stretched out in front of him, and his tail is casually draped over the edge of the pillow. The setting appears to be a cozy corner of a room, with a warm yellow wall in the background and a hint of a wooden floor. The scene captures a rare moment of tranquility for Tom, contrasting with his usual energetic and playful demeanor:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
44
- --validation_prompt_separator ::: \
45
- --num_validation_videos 1 \
46
- --validation_epochs 1 \
47
- --seed 42 \
48
- --mixed_precision bf16 \
49
- --output_dir $output_dir \
50
- --max_num_frames 49 \
51
- --train_batch_size 1 \
52
- --max_train_steps $steps \
53
- --checkpointing_steps 2000 \
54
- --gradient_accumulation_steps 4 \
55
- --gradient_checkpointing \
56
- --learning_rate $learning_rate \
57
- --lr_scheduler $lr_schedule \
58
- --lr_warmup_steps 800 \
59
- --lr_num_cycles 1 \
60
- --enable_slicing \
61
- --enable_tiling \
62
- --optimizer $optimizer \
63
- --beta1 0.9 \
64
- --beta2 0.95 \
65
- --weight_decay 0.001 \
66
- --max_grad_norm 1.0 \
67
- --allow_tf32 \
68
- --report_to wandb \
69
- --nccl_timeout 1800"
70
-
71
- echo "Running command: $cmd"
72
- eval $cmd
73
- echo -ne "-------------------- Finished executing script --------------------\n\n"
74
- done
75
- done
76
- done
77
- done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vms/config.py CHANGED
@@ -56,9 +56,17 @@ if NORMALIZE_IMAGES_TO not in ['png', 'jpg']:
56
  raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'")
57
  JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
58
 
 
59
  MODEL_TYPES = {
60
  "HunyuanVideo (LoRA)": "hunyuan_video",
61
- "LTX-Video (LoRA)": "ltx_video"
 
 
 
 
 
 
 
62
  }
63
 
64
 
@@ -136,9 +144,11 @@ MEDIUM_19_9_RATIO_BUCKETS = [
136
  (NB_FRAMES_256, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
137
  ]
138
 
 
139
  TRAINING_PRESETS = {
140
  "HunyuanVideo (normal)": {
141
  "model_type": "hunyuan_video",
 
142
  "lora_rank": "128",
143
  "lora_alpha": "128",
144
  "num_epochs": 70,
@@ -146,9 +156,11 @@ TRAINING_PRESETS = {
146
  "learning_rate": 2e-5,
147
  "save_iterations": 500,
148
  "training_buckets": SMALL_TRAINING_BUCKETS,
 
149
  },
150
  "LTX-Video (normal)": {
151
  "model_type": "ltx_video",
 
152
  "lora_rank": "128",
153
  "lora_alpha": "128",
154
  "num_epochs": 70,
@@ -156,9 +168,11 @@ TRAINING_PRESETS = {
156
  "learning_rate": 3e-5,
157
  "save_iterations": 500,
158
  "training_buckets": SMALL_TRAINING_BUCKETS,
 
159
  },
160
  "LTX-Video (16:9, HQ)": {
161
  "model_type": "ltx_video",
 
162
  "lora_rank": "256",
163
  "lora_alpha": "128",
164
  "num_epochs": 50,
@@ -166,6 +180,41 @@ TRAINING_PRESETS = {
166
  "learning_rate": 3e-5,
167
  "save_iterations": 200,
168
  "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  }
170
  }
171
 
@@ -260,7 +309,8 @@ class TrainingConfig:
260
  lora_alpha=128,
261
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
262
  caption_dropout_p=0.05,
263
- flow_weighting_scheme="none" # Hunyuan specific
 
264
  )
265
 
266
  @classmethod
@@ -281,7 +331,51 @@ class TrainingConfig:
281
  lora_alpha=128,
282
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
283
  caption_dropout_p=0.05,
284
- flow_weighting_scheme="logit_normal" # LTX specific
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  )
286
 
287
  def to_args_list(self) -> List[str]:
@@ -302,9 +396,9 @@ class TrainingConfig:
302
  args.extend(["--cache_dir", self.cache_dir])
303
 
304
  # Dataset arguments
305
- args.extend(["--data_root", self.data_root])
306
- args.extend(["--video_column", self.video_column])
307
- args.extend(["--caption_column", self.caption_column])
308
  if self.id_token:
309
  args.extend(["--id_token", self.id_token])
310
 
@@ -313,9 +407,6 @@ class TrainingConfig:
313
  bucket_strs = [f"{f}x{h}x{w}" for f, h, w in self.video_resolution_buckets]
314
  args.extend(["--video_resolution_buckets"] + bucket_strs)
315
 
316
- if self.video_reshape_mode:
317
- args.extend(["--video_reshape_mode", self.video_reshape_mode])
318
-
319
  args.extend(["--caption_dropout_p", str(self.caption_dropout_p)])
320
  args.extend(["--caption_dropout_technique", self.caption_dropout_technique])
321
  if self.precompute_conditions:
@@ -333,14 +424,18 @@ class TrainingConfig:
333
  args.extend(["--training_type", self.training_type])
334
  args.extend(["--seed", str(self.seed)])
335
 
336
- # we don't use this, because mixed precision is handled by accelerate launch, not by the training script itself.
337
  #args.extend(["--mixed_precision", self.mixed_precision])
338
 
339
  args.extend(["--batch_size", str(self.batch_size)])
340
- args.extend(["--train_epochs", str(self.train_epochs)])
341
- args.extend(["--rank", str(self.lora_rank)])
342
- args.extend(["--lora_alpha", str(self.lora_alpha)])
343
- args.extend(["--target_modules"] + self.target_modules)
 
 
 
 
344
  args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)])
345
  if self.gradient_checkpointing:
346
  args.append("--gradient_checkpointing")
 
56
  raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'")
57
  JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
58
 
59
+ # Expanded model types to include Wan-2.1-T2V
60
  MODEL_TYPES = {
61
  "HunyuanVideo (LoRA)": "hunyuan_video",
62
+ "LTX-Video (LoRA)": "ltx_video",
63
+ "Wan-2.1-T2V (LoRA)": "wan"
64
+ }
65
+
66
+ # Training types
67
+ TRAINING_TYPES = {
68
+ "LoRA Finetune": "lora",
69
+ "Full Finetune": "full-finetune"
70
  }
71
 
72
 
 
144
  (NB_FRAMES_256, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
145
  ]
146
 
147
+ # Updated training presets to include Wan-2.1-T2V and support both LoRA and full-finetune
148
  TRAINING_PRESETS = {
149
  "HunyuanVideo (normal)": {
150
  "model_type": "hunyuan_video",
151
+ "training_type": "lora",
152
  "lora_rank": "128",
153
  "lora_alpha": "128",
154
  "num_epochs": 70,
 
156
  "learning_rate": 2e-5,
157
  "save_iterations": 500,
158
  "training_buckets": SMALL_TRAINING_BUCKETS,
159
+ "flow_weighting_scheme": "none"
160
  },
161
  "LTX-Video (normal)": {
162
  "model_type": "ltx_video",
163
+ "training_type": "lora",
164
  "lora_rank": "128",
165
  "lora_alpha": "128",
166
  "num_epochs": 70,
 
168
  "learning_rate": 3e-5,
169
  "save_iterations": 500,
170
  "training_buckets": SMALL_TRAINING_BUCKETS,
171
+ "flow_weighting_scheme": "logit_normal"
172
  },
173
  "LTX-Video (16:9, HQ)": {
174
  "model_type": "ltx_video",
175
+ "training_type": "lora",
176
  "lora_rank": "256",
177
  "lora_alpha": "128",
178
  "num_epochs": 50,
 
180
  "learning_rate": 3e-5,
181
  "save_iterations": 200,
182
  "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
183
+ "flow_weighting_scheme": "logit_normal"
184
+ },
185
+ "LTX-Video (Full Finetune)": {
186
+ "model_type": "ltx_video",
187
+ "training_type": "full-finetune",
188
+ "num_epochs": 30,
189
+ "batch_size": 1,
190
+ "learning_rate": 1e-5,
191
+ "save_iterations": 300,
192
+ "training_buckets": SMALL_TRAINING_BUCKETS,
193
+ "flow_weighting_scheme": "logit_normal"
194
+ },
195
+ "Wan-2.1-T2V (normal)": {
196
+ "model_type": "wan",
197
+ "training_type": "lora",
198
+ "lora_rank": "32",
199
+ "lora_alpha": "32",
200
+ "num_epochs": 70,
201
+ "batch_size": 1,
202
+ "learning_rate": 5e-5,
203
+ "save_iterations": 500,
204
+ "training_buckets": SMALL_TRAINING_BUCKETS,
205
+ "flow_weighting_scheme": "logit_normal"
206
+ },
207
+ "Wan-2.1-T2V (HQ)": {
208
+ "model_type": "wan",
209
+ "training_type": "lora",
210
+ "lora_rank": "64",
211
+ "lora_alpha": "64",
212
+ "num_epochs": 50,
213
+ "batch_size": 1,
214
+ "learning_rate": 3e-5,
215
+ "save_iterations": 200,
216
+ "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
217
+ "flow_weighting_scheme": "logit_normal"
218
  }
219
  }
220
 
 
309
  lora_alpha=128,
310
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
311
  caption_dropout_p=0.05,
312
+ flow_weighting_scheme="none", # Hunyuan specific
313
+ training_type="lora"
314
  )
315
 
316
  @classmethod
 
331
  lora_alpha=128,
332
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
333
  caption_dropout_p=0.05,
334
+ flow_weighting_scheme="logit_normal", # LTX specific
335
+ training_type="lora"
336
+ )
337
+
338
+ @classmethod
339
+ def ltx_video_full_finetune(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig':
340
+ """Configuration for LTX-Video full finetune training"""
341
+ return cls(
342
+ model_name="ltx_video",
343
+ pretrained_model_name_or_path="Lightricks/LTX-Video",
344
+ data_root=data_path,
345
+ output_dir=output_path,
346
+ batch_size=1,
347
+ train_epochs=30,
348
+ lr=1e-5,
349
+ gradient_checkpointing=True,
350
+ id_token="BW_STYLE",
351
+ gradient_accumulation_steps=1,
352
+ video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
353
+ caption_dropout_p=0.05,
354
+ flow_weighting_scheme="logit_normal", # LTX specific
355
+ training_type="full-finetune"
356
+ )
357
+
358
+ @classmethod
359
+ def wan_lora(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig':
360
+ """Configuration for Wan T2V LoRA training"""
361
+ return cls(
362
+ model_name="wan",
363
+ pretrained_model_name_or_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
364
+ data_root=data_path,
365
+ output_dir=output_path,
366
+ batch_size=1,
367
+ train_epochs=70,
368
+ lr=5e-5,
369
+ gradient_checkpointing=True,
370
+ id_token=None, # Default is no ID token for Wan
371
+ gradient_accumulation_steps=1,
372
+ lora_rank=32,
373
+ lora_alpha=32,
374
+ target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
375
+ video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
376
+ caption_dropout_p=0.05,
377
+ flow_weighting_scheme="logit_normal", # Wan specific
378
+ training_type="lora"
379
  )
380
 
381
  def to_args_list(self) -> List[str]:
 
396
  args.extend(["--cache_dir", self.cache_dir])
397
 
398
  # Dataset arguments
399
+ args.extend(["--dataset_config", self.data_root])
400
+
401
+ # Add ID token if specified
402
  if self.id_token:
403
  args.extend(["--id_token", self.id_token])
404
 
 
407
  bucket_strs = [f"{f}x{h}x{w}" for f, h, w in self.video_resolution_buckets]
408
  args.extend(["--video_resolution_buckets"] + bucket_strs)
409
 
 
 
 
410
  args.extend(["--caption_dropout_p", str(self.caption_dropout_p)])
411
  args.extend(["--caption_dropout_technique", self.caption_dropout_technique])
412
  if self.precompute_conditions:
 
424
  args.extend(["--training_type", self.training_type])
425
  args.extend(["--seed", str(self.seed)])
426
 
427
+ # We don't use this, because mixed precision is handled by accelerate launch, not by the training script itself.
428
  #args.extend(["--mixed_precision", self.mixed_precision])
429
 
430
  args.extend(["--batch_size", str(self.batch_size)])
431
+ args.extend(["--train_steps", str(self.train_epochs * 1000)]) # Convert epochs to steps for compatibility
432
+
433
+ # LoRA specific arguments
434
+ if self.training_type == "lora":
435
+ args.extend(["--rank", str(self.lora_rank)])
436
+ args.extend(["--lora_alpha", str(self.lora_alpha)])
437
+ args.extend(["--target_modules"] + self.target_modules)
438
+
439
  args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)])
440
  if self.gradient_checkpointing:
441
  args.append("--gradient_checkpointing")
vms/services/trainer.py CHANGED
@@ -20,7 +20,11 @@ from typing import Any, Optional, Dict, List, Union, Tuple
20
 
21
  from huggingface_hub import upload_folder, create_repo
22
 
23
- from ..config import TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
 
 
 
 
24
  from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
25
 
26
  logger = logging.getLogger(__name__)
@@ -112,6 +116,7 @@ class TrainingService:
112
  ui_state_file = OUTPUT_PATH / "ui_state.json"
113
  default_state = {
114
  "model_type": list(MODEL_TYPES.keys())[0],
 
115
  "lora_rank": "128",
116
  "lora_alpha": "128",
117
  "num_epochs": 50,
@@ -153,7 +158,6 @@ class TrainingService:
153
  # Make sure we have all keys (in case structure changed)
154
  merged_state = default_state.copy()
155
  merged_state.update(saved_state)
156
- #logger.info(f"Successfully loaded UI state from {ui_state_file}")
157
  return merged_state
158
  except json.JSONDecodeError as e:
159
  logger.error(f"Error parsing UI state JSON: {str(e)}")
@@ -171,6 +175,7 @@ class TrainingService:
171
  logger.info("Creating new UI state file with default values")
172
  default_state = {
173
  "model_type": list(MODEL_TYPES.keys())[0],
 
174
  "lora_rank": "128",
175
  "lora_alpha": "128",
176
  "num_epochs": 50,
@@ -203,6 +208,7 @@ class TrainingService:
203
  # Create a new file with default values
204
  default_state = {
205
  "model_type": list(MODEL_TYPES.keys())[0],
 
206
  "lora_rank": "128",
207
  "lora_alpha": "128",
208
  "num_epochs": 50,
@@ -331,6 +337,9 @@ class TrainingService:
331
  elif model_type == "ltx_video":
332
  if config.batch_size > 4:
333
  return "LTX model recommended batch size is 1-4"
 
 
 
334
 
335
  logger.info(f"Config validation passed with {len(video_lines)} training files")
336
  return None
@@ -350,6 +359,7 @@ class TrainingService:
350
  save_iterations: int,
351
  repo_id: str,
352
  preset_name: str,
 
353
  resume_from_checkpoint: Optional[str] = None,
354
  ) -> Tuple[str, str]:
355
  """Start training with finetrainers"""
@@ -360,11 +370,13 @@ class TrainingService:
360
  raise ValueError("model_type cannot be empty")
361
  if model_type not in MODEL_TYPES.values():
362
  raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}")
 
 
363
 
364
  # Check if we're resuming or starting new
365
  is_resuming = resume_from_checkpoint is not None
366
  log_prefix = "Resuming" if is_resuming else "Initializing"
367
- logger.info(f"{log_prefix} training with model_type={model_type}")
368
 
369
  try:
370
  # Get absolute paths - FIXED to look in project root instead of within vms directory
@@ -409,32 +421,66 @@ class TrainingService:
409
  logger.error(error_msg)
410
  return error_msg, "No training data available"
411
 
412
-
413
  # Get preset configuration
414
  preset = TRAINING_PRESETS[preset_name]
415
  training_buckets = preset["training_buckets"]
 
 
416
 
417
  # Get config for selected model type with preset buckets
418
  if model_type == "hunyuan_video":
419
- config = TrainingConfig.hunyuan_video_lora(
420
- data_path=str(TRAINING_PATH),
421
- output_path=str(OUTPUT_PATH),
422
- buckets=training_buckets
423
- )
424
- else: # ltx_video
425
- config = TrainingConfig.ltx_video_lora(
426
- data_path=str(TRAINING_PATH),
427
- output_path=str(OUTPUT_PATH),
428
- buckets=training_buckets
429
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  # Update with UI parameters
432
  config.train_epochs = int(num_epochs)
433
- config.lora_rank = int(lora_rank)
434
- config.lora_alpha = int(lora_alpha)
435
  config.batch_size = int(batch_size)
436
  config.lr = float(learning_rate)
437
  config.checkpointing_steps = int(save_iterations)
 
 
 
 
 
 
 
438
 
439
  # Update with resume_from_checkpoint if provided
440
  if resume_from_checkpoint:
@@ -469,7 +515,6 @@ class TrainingService:
469
  # Convert config to command line arguments
470
  config_args = config.to_args_list()
471
 
472
-
473
  logger.debug("Generated args list: %s", config_args)
474
 
475
  # Log the full command for debugging
@@ -505,6 +550,7 @@ class TrainingService:
505
  # Save session info including repo_id for later hub upload
506
  self.save_session({
507
  "model_type": model_type,
 
508
  "lora_rank": lora_rank,
509
  "lora_alpha": lora_alpha,
510
  "num_epochs": num_epochs,
@@ -526,13 +572,14 @@ class TrainingService:
526
  total_epochs=num_epochs,
527
  message='Training started',
528
  repo_id=repo_id,
529
- model_type=model_type
 
530
  )
531
 
532
  # Start monitoring process output
533
  self._start_log_monitor(process)
534
 
535
- success_msg = f"Started training {model_type} model"
536
  self.append_log(success_msg)
537
  logger.info(success_msg)
538
 
@@ -668,6 +715,7 @@ class TrainingService:
668
  last_session = {
669
  "params": {
670
  "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
 
671
  "lora_rank": ui_state.get("lora_rank", "128"),
672
  "lora_alpha": ui_state.get("lora_alpha", "128"),
673
  "num_epochs": ui_state.get("num_epochs", 70),
@@ -724,10 +772,15 @@ class TrainingService:
724
  logger.info(f"Mapped internal model type '{model_type_internal}' to display name '{model_type_display}'")
725
  break
726
 
 
 
 
 
727
  # Add UI updates to restore the training parameters in the UI
728
  # This shows the user what values are being used for the resumed training
729
  ui_updates.update({
730
  "model_type": model_type_display, # Use the display name for the UI dropdown
 
731
  "lora_rank": params.get('lora_rank', "128"),
732
  "lora_alpha": params.get('lora_alpha', "128"),
733
  "num_epochs": params.get('num_epochs', 70),
@@ -755,6 +808,7 @@ class TrainingService:
755
  save_iterations=params.get('save_iterations', 500),
756
  repo_id=params.get('repo_id', ''),
757
  preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
 
758
  resume_from_checkpoint=str(latest_checkpoint)
759
  )
760
 
@@ -1029,5 +1083,4 @@ class TrainingService:
1029
  return temp_zip_path
1030
  except Exception as e:
1031
  print(f"Failed to create zip: {str(e)}")
1032
- raise gr.Error(f"Failed to create zip: {str(e)}")
1033
-
 
20
 
21
  from huggingface_hub import upload_folder, create_repo
22
 
23
+ from ..config import (
24
+ TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
25
+ STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
26
+ MODEL_TYPES, TRAINING_TYPES
27
+ )
28
  from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
29
 
30
  logger = logging.getLogger(__name__)
 
116
  ui_state_file = OUTPUT_PATH / "ui_state.json"
117
  default_state = {
118
  "model_type": list(MODEL_TYPES.keys())[0],
119
+ "training_type": list(TRAINING_TYPES.keys())[0],
120
  "lora_rank": "128",
121
  "lora_alpha": "128",
122
  "num_epochs": 50,
 
158
  # Make sure we have all keys (in case structure changed)
159
  merged_state = default_state.copy()
160
  merged_state.update(saved_state)
 
161
  return merged_state
162
  except json.JSONDecodeError as e:
163
  logger.error(f"Error parsing UI state JSON: {str(e)}")
 
175
  logger.info("Creating new UI state file with default values")
176
  default_state = {
177
  "model_type": list(MODEL_TYPES.keys())[0],
178
+ "training_type": list(TRAINING_TYPES.keys())[0],
179
  "lora_rank": "128",
180
  "lora_alpha": "128",
181
  "num_epochs": 50,
 
208
  # Create a new file with default values
209
  default_state = {
210
  "model_type": list(MODEL_TYPES.keys())[0],
211
+ "training_type": list(TRAINING_TYPES.keys())[0],
212
  "lora_rank": "128",
213
  "lora_alpha": "128",
214
  "num_epochs": 50,
 
337
  elif model_type == "ltx_video":
338
  if config.batch_size > 4:
339
  return "LTX model recommended batch size is 1-4"
340
+ elif model_type == "wan":
341
+ if config.batch_size > 4:
342
+ return "Wan model recommended batch size is 1-4"
343
 
344
  logger.info(f"Config validation passed with {len(video_lines)} training files")
345
  return None
 
359
  save_iterations: int,
360
  repo_id: str,
361
  preset_name: str,
362
+ training_type: str = "lora",
363
  resume_from_checkpoint: Optional[str] = None,
364
  ) -> Tuple[str, str]:
365
  """Start training with finetrainers"""
 
370
  raise ValueError("model_type cannot be empty")
371
  if model_type not in MODEL_TYPES.values():
372
  raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}")
373
+ if training_type not in TRAINING_TYPES.values():
374
+ raise ValueError(f"Invalid training_type: {training_type}. Must be one of {list(TRAINING_TYPES.values())}")
375
 
376
  # Check if we're resuming or starting new
377
  is_resuming = resume_from_checkpoint is not None
378
  log_prefix = "Resuming" if is_resuming else "Initializing"
379
+ logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}")
380
 
381
  try:
382
  # Get absolute paths - FIXED to look in project root instead of within vms directory
 
421
  logger.error(error_msg)
422
  return error_msg, "No training data available"
423
 
 
424
  # Get preset configuration
425
  preset = TRAINING_PRESETS[preset_name]
426
  training_buckets = preset["training_buckets"]
427
+ flow_weighting_scheme = preset.get("flow_weighting_scheme", "none")
428
+ preset_training_type = preset.get("training_type", "lora")
429
 
430
  # Get config for selected model type with preset buckets
431
  if model_type == "hunyuan_video":
432
+ if training_type == "lora":
433
+ config = TrainingConfig.hunyuan_video_lora(
434
+ data_path=str(TRAINING_PATH),
435
+ output_path=str(OUTPUT_PATH),
436
+ buckets=training_buckets
437
+ )
438
+ else:
439
+ # Hunyuan doesn't support full finetune in our UI yet
440
+ error_msg = "Full finetune is not supported for Hunyuan Video due to memory limitations"
441
+ logger.error(error_msg)
442
+ return error_msg, "Training configuration error"
443
+ elif model_type == "ltx_video":
444
+ if training_type == "lora":
445
+ config = TrainingConfig.ltx_video_lora(
446
+ data_path=str(TRAINING_PATH),
447
+ output_path=str(OUTPUT_PATH),
448
+ buckets=training_buckets
449
+ )
450
+ else:
451
+ config = TrainingConfig.ltx_video_full_finetune(
452
+ data_path=str(TRAINING_PATH),
453
+ output_path=str(OUTPUT_PATH),
454
+ buckets=training_buckets
455
+ )
456
+ elif model_type == "wan":
457
+ if training_type == "lora":
458
+ config = TrainingConfig.wan_lora(
459
+ data_path=str(TRAINING_PATH),
460
+ output_path=str(OUTPUT_PATH),
461
+ buckets=training_buckets
462
+ )
463
+ else:
464
+ error_msg = "Full finetune for Wan is not yet supported in this UI"
465
+ logger.error(error_msg)
466
+ return error_msg, "Training configuration error"
467
+ else:
468
+ error_msg = f"Unsupported model type: {model_type}"
469
+ logger.error(error_msg)
470
+ return error_msg, "Unsupported model"
471
 
472
  # Update with UI parameters
473
  config.train_epochs = int(num_epochs)
 
 
474
  config.batch_size = int(batch_size)
475
  config.lr = float(learning_rate)
476
  config.checkpointing_steps = int(save_iterations)
477
+ config.training_type = training_type
478
+ config.flow_weighting_scheme = flow_weighting_scheme
479
+
480
+ # Update LoRA parameters if using LoRA training type
481
+ if training_type == "lora":
482
+ config.lora_rank = int(lora_rank)
483
+ config.lora_alpha = int(lora_alpha)
484
 
485
  # Update with resume_from_checkpoint if provided
486
  if resume_from_checkpoint:
 
515
  # Convert config to command line arguments
516
  config_args = config.to_args_list()
517
 
 
518
  logger.debug("Generated args list: %s", config_args)
519
 
520
  # Log the full command for debugging
 
550
  # Save session info including repo_id for later hub upload
551
  self.save_session({
552
  "model_type": model_type,
553
+ "training_type": training_type,
554
  "lora_rank": lora_rank,
555
  "lora_alpha": lora_alpha,
556
  "num_epochs": num_epochs,
 
572
  total_epochs=num_epochs,
573
  message='Training started',
574
  repo_id=repo_id,
575
+ model_type=model_type,
576
+ training_type=training_type
577
  )
578
 
579
  # Start monitoring process output
580
  self._start_log_monitor(process)
581
 
582
+ success_msg = f"Started {training_type} training for {model_type} model"
583
  self.append_log(success_msg)
584
  logger.info(success_msg)
585
 
 
715
  last_session = {
716
  "params": {
717
  "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
718
+ "training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])),
719
  "lora_rank": ui_state.get("lora_rank", "128"),
720
  "lora_alpha": ui_state.get("lora_alpha", "128"),
721
  "num_epochs": ui_state.get("num_epochs", 70),
 
772
  logger.info(f"Mapped internal model type '{model_type_internal}' to display name '{model_type_display}'")
773
  break
774
 
775
+ # Get training type (default to LoRA if not present in saved session)
776
+ training_type_internal = params.get('training_type', 'lora')
777
+ training_type_display = next((disp for disp, val in TRAINING_TYPES.items() if val == training_type_internal), list(TRAINING_TYPES.keys())[0])
778
+
779
  # Add UI updates to restore the training parameters in the UI
780
  # This shows the user what values are being used for the resumed training
781
  ui_updates.update({
782
  "model_type": model_type_display, # Use the display name for the UI dropdown
783
+ "training_type": training_type_display, # Use the display name for training type
784
  "lora_rank": params.get('lora_rank', "128"),
785
  "lora_alpha": params.get('lora_alpha', "128"),
786
  "num_epochs": params.get('num_epochs', 70),
 
808
  save_iterations=params.get('save_iterations', 500),
809
  repo_id=params.get('repo_id', ''),
810
  preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
811
+ training_type=training_type_internal,
812
  resume_from_checkpoint=str(latest_checkpoint)
813
  )
814
 
 
1083
  return temp_zip_path
1084
  except Exception as e:
1085
  print(f"Failed to create zip: {str(e)}")
1086
+ raise gr.Error(f"Failed to create zip: {str(e)}")
 
vms/tabs/train_tab.py CHANGED
@@ -45,11 +45,21 @@ class TrainTab(BaseTab):
45
  label="Model Type",
46
  value=list(MODEL_TYPES.keys())[0]
47
  )
 
 
 
 
 
 
 
 
48
  self.components["model_info"] = gr.Markdown(
49
- value=self.get_model_info(list(MODEL_TYPES.keys())[0])
50
  )
51
 
52
- with gr.Row():
 
 
53
  self.components["lora_rank"] = gr.Dropdown(
54
  label="LoRA Rank",
55
  choices=["16", "32", "64", "128", "256", "512", "1024"],
@@ -62,6 +72,7 @@ class TrainTab(BaseTab):
62
  value="128",
63
  type="value"
64
  )
 
65
  with gr.Row():
66
  self.components["num_epochs"] = gr.Number(
67
  label="Number of Epochs",
@@ -143,15 +154,18 @@ class TrainTab(BaseTab):
143
  def connect_events(self) -> None:
144
  """Connect event handlers to UI components"""
145
  # Model type change event
146
- def update_model_info(model):
147
- params = self.get_default_params(MODEL_TYPES[model])
148
- info = self.get_model_info(MODEL_TYPES[model])
 
 
149
  return {
150
  self.components["model_info"]: info,
151
  self.components["num_epochs"]: params["num_epochs"],
152
  self.components["batch_size"]: params["batch_size"],
153
  self.components["learning_rate"]: params["learning_rate"],
154
- self.components["save_iterations"]: params["save_iterations"]
 
155
  }
156
 
157
  self.components["model_type"].change(
@@ -160,13 +174,32 @@ class TrainTab(BaseTab):
160
  outputs=[]
161
  ).then(
162
  fn=update_model_info,
163
- inputs=[self.components["model_type"]],
164
  outputs=[
165
  self.components["model_info"],
166
  self.components["num_epochs"],
167
  self.components["batch_size"],
168
  self.components["learning_rate"],
169
- self.components["save_iterations"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  ]
171
  )
172
 
@@ -217,13 +250,15 @@ class TrainTab(BaseTab):
217
  inputs=[self.components["training_preset"]],
218
  outputs=[
219
  self.components["model_type"],
 
220
  self.components["lora_rank"],
221
  self.components["lora_alpha"],
222
  self.components["num_epochs"],
223
  self.components["batch_size"],
224
  self.components["learning_rate"],
225
  self.components["save_iterations"],
226
- self.components["preset_info"]
 
227
  ]
228
  )
229
 
@@ -233,6 +268,7 @@ class TrainTab(BaseTab):
233
  inputs=[
234
  self.components["training_preset"],
235
  self.components["model_type"],
 
236
  self.components["lora_rank"],
237
  self.components["lora_alpha"],
238
  self.components["num_epochs"],
@@ -278,7 +314,7 @@ class TrainTab(BaseTab):
278
  ]
279
  )
280
 
281
- def handle_training_start(self, preset, model_type, *args):
282
  """Handle training start with proper log parser reset and checkpoint detection"""
283
  # Safely reset log parser if it exists
284
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
@@ -305,12 +341,20 @@ class TrainTab(BaseTab):
305
  logger.error(f"Invalid model type: {model_type}")
306
  return f"Error: Invalid model type '{model_type}'", "Model type not recognized"
307
 
 
 
 
 
 
 
 
308
  # Start training (it will automatically use the checkpoint if provided)
309
  try:
310
  return self.app.trainer.start_training(
311
  model_internal_type, # Use internal model type
312
  *args,
313
  preset_name=preset,
 
314
  resume_from_checkpoint=resume_from
315
  )
316
  except Exception as e:
@@ -318,25 +362,56 @@ class TrainTab(BaseTab):
318
  return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
319
 
320
 
321
- def get_model_info(self, model_type: str) -> str:
322
- """Get information about the selected model type"""
 
 
323
  if model_type == "hunyuan_video":
324
- return """### HunyuanVideo (LoRA)
325
  - Required VRAM: ~48GB minimum
326
  - Recommended batch size: 1-2
327
  - Typical training time: 2-4 hours
328
- - Default resolution: 49x512x768
329
- - Default LoRA rank: 128 (~600 MB)"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  elif model_type == "ltx_video":
332
- return """### LTX-Video (LoRA)
333
- - Required VRAM: ~18GB minimum
334
  - Recommended batch size: 1-4
335
  - Typical training time: 1-3 hours
336
- - Default resolution: 49x512x768
337
- - Default LoRA rank: 128"""
338
-
339
- return ""
340
 
341
  def get_default_params(self, model_type: str) -> Dict[str, Any]:
342
  """Get default training parameters for model type"""
 
45
  label="Model Type",
46
  value=list(MODEL_TYPES.keys())[0]
47
  )
48
+ with gr.Column():
49
+ self.components["training_type"] = gr.Dropdown(
50
+ choices=list(TRAINING_TYPES.keys()),
51
+ label="Training Type",
52
+ value=list(TRAINING_TYPES.keys())[0]
53
+ )
54
+
55
+ with gr.Row():
56
  self.components["model_info"] = gr.Markdown(
57
+ value=self.get_model_info(list(MODEL_TYPES.keys())[0], list(TRAINING_TYPES.keys())[0])
58
  )
59
 
60
+ # LoRA specific parameters (will show/hide based on training type)
61
+ with gr.Row(visible=True) as lora_params_row:
62
+ self.components["lora_params_row"] = lora_params_row
63
  self.components["lora_rank"] = gr.Dropdown(
64
  label="LoRA Rank",
65
  choices=["16", "32", "64", "128", "256", "512", "1024"],
 
72
  value="128",
73
  type="value"
74
  )
75
+
76
  with gr.Row():
77
  self.components["num_epochs"] = gr.Number(
78
  label="Number of Epochs",
 
154
  def connect_events(self) -> None:
155
  """Connect event handlers to UI components"""
156
  # Model type change event
157
+ def update_model_info(model, training_type):
158
+ params = self.get_default_params(MODEL_TYPES[model], TRAINING_TYPES[training_type])
159
+ info = self.get_model_info(MODEL_TYPES[model], TRAINING_TYPES[training_type])
160
+ show_lora_params = training_type == list(TRAINING_TYPES.keys())[0] # Show if LoRA Finetune
161
+
162
  return {
163
  self.components["model_info"]: info,
164
  self.components["num_epochs"]: params["num_epochs"],
165
  self.components["batch_size"]: params["batch_size"],
166
  self.components["learning_rate"]: params["learning_rate"],
167
+ self.components["save_iterations"]: params["save_iterations"],
168
+ self.components["lora_params_row"]: gr.Row(visible=show_lora_params)
169
  }
170
 
171
  self.components["model_type"].change(
 
174
  outputs=[]
175
  ).then(
176
  fn=update_model_info,
177
+ inputs=[self.components["model_type"], self.components["training_type"]],
178
  outputs=[
179
  self.components["model_info"],
180
  self.components["num_epochs"],
181
  self.components["batch_size"],
182
  self.components["learning_rate"],
183
+ self.components["save_iterations"],
184
+ self.components["lora_params_row"]
185
+ ]
186
+ )
187
+
188
+ # Training type change event
189
+ self.components["training_type"].change(
190
+ fn=lambda v: self.app.update_ui_state(training_type=v),
191
+ inputs=[self.components["training_type"]],
192
+ outputs=[]
193
+ ).then(
194
+ fn=update_model_info,
195
+ inputs=[self.components["model_type"], self.components["training_type"]],
196
+ outputs=[
197
+ self.components["model_info"],
198
+ self.components["num_epochs"],
199
+ self.components["batch_size"],
200
+ self.components["learning_rate"],
201
+ self.components["save_iterations"],
202
+ self.components["lora_params_row"]
203
  ]
204
  )
205
 
 
250
  inputs=[self.components["training_preset"]],
251
  outputs=[
252
  self.components["model_type"],
253
+ self.components["training_type"],
254
  self.components["lora_rank"],
255
  self.components["lora_alpha"],
256
  self.components["num_epochs"],
257
  self.components["batch_size"],
258
  self.components["learning_rate"],
259
  self.components["save_iterations"],
260
+ self.components["preset_info"],
261
+ self.components["lora_params_row"]
262
  ]
263
  )
264
 
 
268
  inputs=[
269
  self.components["training_preset"],
270
  self.components["model_type"],
271
+ self.components["training_type"],
272
  self.components["lora_rank"],
273
  self.components["lora_alpha"],
274
  self.components["num_epochs"],
 
314
  ]
315
  )
316
 
317
+ def handle_training_start(self, preset, model_type, training_type, *args):
318
  """Handle training start with proper log parser reset and checkpoint detection"""
319
  # Safely reset log parser if it exists
320
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
 
341
  logger.error(f"Invalid model type: {model_type}")
342
  return f"Error: Invalid model type '{model_type}'", "Model type not recognized"
343
 
344
+ # Convert training_type display name to internal name
345
+ training_internal_type = TRAINING_TYPES.get(training_type)
346
+
347
+ if not training_internal_type:
348
+ logger.error(f"Invalid training type: {training_type}")
349
+ return f"Error: Invalid training type '{training_type}'", "Training type not recognized"
350
+
351
  # Start training (it will automatically use the checkpoint if provided)
352
  try:
353
  return self.app.trainer.start_training(
354
  model_internal_type, # Use internal model type
355
  *args,
356
  preset_name=preset,
357
+ training_type=training_internal_type, # Pass the internal training type
358
  resume_from_checkpoint=resume_from
359
  )
360
  except Exception as e:
 
362
  return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
363
 
364
 
365
+ def get_model_info(self, model_type: str, training_type: str) -> str:
366
+ """Get information about the selected model type and training method"""
367
+ training_method = "LoRA finetune" if training_type == "lora" else "Full finetune"
368
+
369
  if model_type == "hunyuan_video":
370
+ base_info = """### HunyuanVideo
371
  - Required VRAM: ~48GB minimum
372
  - Recommended batch size: 1-2
373
  - Typical training time: 2-4 hours
374
+ - Default resolution: 49x512x768"""
375
+
376
+ if training_type == "lora":
377
+ return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
378
+ else:
379
+ return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
380
+
381
+ elif model_type == "wan":
382
+ base_info = """### Wan-2.1-T2V
383
+ - Recommended batch size: 1-2
384
+ - Typical training time: 1-3 hours
385
+ - Default resolution: 49x512x768"""
386
+
387
+ if training_type == "lora":
388
+ return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
389
+ else:
390
+ return base_info + "\n- **Full finetune not supported in this UI**" + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
391
+ else:
392
+ return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
393
+
394
+ elif model_type == "wan":
395
+ base_info = """### Wan-2.1-T2V
396
+ - Recommended batch size: 1-2
397
+ - Typical training time: 1-3 hours
398
+ - Default resolution: 49x512x768"""
399
+
400
+ if training_type == "lora":
401
+ return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
402
+ else:
403
+ return base_info + "\n- **Full finetune not supported in this UI**" + "\n- Default LoRA rank: 128 (~600 MB)"
404
+ else:
405
+ return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
406
 
407
  elif model_type == "ltx_video":
408
+ base_info = """### LTX-Video
 
409
  - Recommended batch size: 1-4
410
  - Typical training time: 1-3 hours
411
+ - Default resolution: 49x512x768"""
412
+
413
+ if training_type == "lora":
414
+ return base_
415
 
416
  def get_default_params(self, model_type: str) -> Dict[str, Any]:
417
  """Get default training parameters for model type"""
vms/ui/video_trainer_ui.py CHANGED
@@ -75,7 +75,7 @@ class VideoTrainerUI:
75
  for tab_id, tab_obj in self.tabs.items():
76
  tab_obj.connect_events()
77
 
78
- # Add app-level timers for auto-refresh functionality
79
  self._add_timers()
80
 
81
  # Initialize app state on load
@@ -89,6 +89,7 @@ class VideoTrainerUI:
89
  self.tabs["train_tab"].components["pause_resume_btn"],
90
  self.tabs["train_tab"].components["training_preset"],
91
  self.tabs["train_tab"].components["model_type"],
 
92
  self.tabs["train_tab"].components["lora_rank"],
93
  self.tabs["train_tab"].components["lora_alpha"],
94
  self.tabs["train_tab"].components["num_epochs"],
@@ -180,9 +181,23 @@ class VideoTrainerUI:
180
 
181
  ui_state["model_type"] = model_type_value
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # Copy other parameters
184
  for param in ["lora_rank", "lora_alpha", "num_epochs",
185
- "batch_size", "learning_rate", "save_iterations", "training_preset"]:
186
  if param in recovery_ui:
187
  ui_state[param] = recovery_ui[param]
188
 
@@ -205,6 +220,15 @@ class VideoTrainerUI:
205
  model_type_val = display_name
206
  break
207
 
 
 
 
 
 
 
 
 
 
208
  training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
209
  lora_rank_val = ui_state.get("lora_rank", "128")
210
  lora_alpha_val = ui_state.get("lora_alpha", "128")
@@ -222,6 +246,7 @@ class VideoTrainerUI:
222
  delete_checkpoints_btn,
223
  training_preset,
224
  model_type_val,
 
225
  lora_rank_val,
226
  lora_alpha_val,
227
  num_epochs_val,
@@ -238,6 +263,7 @@ class VideoTrainerUI:
238
  return (
239
  ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
240
  ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
 
241
  ui_state.get("lora_rank", "128"),
242
  ui_state.get("lora_alpha", "128"),
243
  ui_state.get("num_epochs", 70),
 
75
  for tab_id, tab_obj in self.tabs.items():
76
  tab_obj.connect_events()
77
 
78
+ # app-level timers for auto-refresh functionality
79
  self._add_timers()
80
 
81
  # Initialize app state on load
 
89
  self.tabs["train_tab"].components["pause_resume_btn"],
90
  self.tabs["train_tab"].components["training_preset"],
91
  self.tabs["train_tab"].components["model_type"],
92
+ self.tabs["train_tab"].components["training_type"], # Add the new training_type component to outputs
93
  self.tabs["train_tab"].components["lora_rank"],
94
  self.tabs["train_tab"].components["lora_alpha"],
95
  self.tabs["train_tab"].components["num_epochs"],
 
181
 
182
  ui_state["model_type"] = model_type_value
183
 
184
+ # Handle training_type
185
+ if "training_type" in recovery_ui:
186
+ training_type_value = recovery_ui["training_type"]
187
+
188
+ # If it's an internal name, convert to display name
189
+ if training_type_value not in TRAINING_TYPES:
190
+ for display_name, internal_name in TRAINING_TYPES.items():
191
+ if internal_name == training_type_value:
192
+ training_type_value = display_name
193
+ logger.info(f"Converted internal training type '{recovery_ui['training_type']}' to display name '{training_type_value}'")
194
+ break
195
+
196
+ ui_state["training_type"] = training_type_value
197
+
198
  # Copy other parameters
199
  for param in ["lora_rank", "lora_alpha", "num_epochs",
200
+ "batch_size", "learning_rate", "save_iterations", "training_preset"]:
201
  if param in recovery_ui:
202
  ui_state[param] = recovery_ui[param]
203
 
 
220
  model_type_val = display_name
221
  break
222
 
223
+ # Ensure training_type is a display name, not internal name
224
+ training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
225
+ if training_type_val not in TRAINING_TYPES:
226
+ # Convert from internal to display name
227
+ for display_name, internal_name in TRAINING_TYPES.items():
228
+ if internal_name == training_type_val:
229
+ training_type_val = display_name
230
+ break
231
+
232
  training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
233
  lora_rank_val = ui_state.get("lora_rank", "128")
234
  lora_alpha_val = ui_state.get("lora_alpha", "128")
 
246
  delete_checkpoints_btn,
247
  training_preset,
248
  model_type_val,
249
+ training_type_val,
250
  lora_rank_val,
251
  lora_alpha_val,
252
  num_epochs_val,
 
263
  return (
264
  ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
265
  ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
266
+ ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
267
  ui_state.get("lora_rank", "128"),
268
  ui_state.get("lora_alpha", "128"),
269
  ui_state.get("num_epochs", 70),