Spaces:
Running
Running
Commit
·
d464085
1
Parent(s):
80ebcb3
upgrading our code to support the new finetrainers
Browse files- docs/finetrainers/documentation_dataset_README.md +165 -0
- docs/finetrainers/documentation_environment.md +28 -0
- docs/finetrainers/documentation_global_README.md +99 -0
- docs/finetrainers/documentation_models_README.md +56 -0
- docs/finetrainers/documentation_models_hunyuan_video.md +189 -0
- docs/finetrainers/documentation_models_ltx_video.md +196 -0
- docs/finetrainers/documentation_models_optimization.md +12 -0
- docs/finetrainers/documentation_models_wan.md +36 -0
- docs/finetrainers/documentation_optimizers.md +14 -0
- docs/finetrainers/documentation_parallel_processing_README.md +59 -0
- docs/finetrainers/examples_training_wan__train.sh +162 -0
- docs/finetrainers/examples_training_wan__training.json +14 -0
- docs/finetrainers/examples_training_wan__validation.json +40 -0
- training/README.md +0 -459
- training/README_zh.md +0 -455
- training/cogvideox/__init__.py +0 -0
- training/cogvideox/args.py +0 -484
- training/cogvideox/cogvideox_image_to_video_lora.py +0 -1016
- training/cogvideox/cogvideox_image_to_video_sft.py +0 -947
- training/cogvideox/cogvideox_text_to_video_lora.py +0 -955
- training/cogvideox/cogvideox_text_to_video_sft.py +0 -917
- training/cogvideox/dataset.py +0 -428
- training/cogvideox/prepare_dataset.py +0 -669
- training/cogvideox/text_encoder/__init__.py +0 -1
- training/cogvideox/text_encoder/text_encoder.py +0 -99
- training/cogvideox/utils.py +0 -260
- training/mochi-1/README.md +0 -111
- training/mochi-1/args.py +0 -268
- training/mochi-1/dataset_simple.py +0 -50
- training/mochi-1/embed.py +0 -111
- training/mochi-1/prepare_dataset.sh +0 -15
- training/mochi-1/requirements.txt +0 -8
- training/mochi-1/text_to_video_lora.py +0 -592
- training/mochi-1/train.sh +0 -37
- training/mochi-1/trim_and_crop_videos.py +0 -126
- training/mochi-1/utils.py +0 -22
- training/prepare_dataset.sh +0 -48
- training/train_image_to_video_lora.sh +0 -82
- training/train_image_to_video_sft.sh +0 -87
- training/train_text_to_video_lora.sh +0 -86
- training/train_text_to_video_sft.sh +0 -77
- vms/config.py +109 -14
- vms/services/trainer.py +75 -22
- vms/tabs/train_tab.py +96 -21
- vms/ui/video_trainer_ui.py +28 -2
docs/finetrainers/documentation_dataset_README.md
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset
|
2 |
+
|
3 |
+
## Dataset preparation
|
4 |
+
|
5 |
+
Please refer to [video-dataset-scripts](https://github.com/huggingface/video-dataset-scripts) for a collection of scripts to prepare datasets for training. The scripts are designed to work with the HF datasets library and can be used to prepare datasets for training with `finetrainers`.
|
6 |
+
|
7 |
+
## Training Dataset Format
|
8 |
+
|
9 |
+
Dataset loading format support is very limited at the moment. This will be improved in the future. For now, we support the following formats:
|
10 |
+
|
11 |
+
#### Two file format
|
12 |
+
|
13 |
+
> [!NOTE]
|
14 |
+
> Relevant classes to look for implementation:
|
15 |
+
> - ImageFileCaptionFileListDataset
|
16 |
+
> - VideoFileCaptionFileListDataset
|
17 |
+
>
|
18 |
+
> Supports loading directly from the HF Hub.
|
19 |
+
|
20 |
+
Your dataset structure should look like this. Running the `tree` command, you should see something similar to:
|
21 |
+
|
22 |
+
```
|
23 |
+
dataset
|
24 |
+
├── prompt.txt
|
25 |
+
├── videos.txt
|
26 |
+
├── videos
|
27 |
+
├── 00000.mp4
|
28 |
+
├── 00001.mp4
|
29 |
+
├── ...
|
30 |
+
```
|
31 |
+
|
32 |
+
- Make sure that the paths in `videos.txt` is relative to the `dataset` directory. The `prompt.txt` should contain the captions for the videos in the same order as the videos in `videos.txt`.
|
33 |
+
- Supported names for caption file: `captions.txt`, `caption.txt`, `prompt.txt`, `prompts.txt` (feel free to send PRs to add more common names).
|
34 |
+
- Supported names for video file: `videos.txt`, `video.txt`, (feel free to send PRs to add more common names).
|
35 |
+
|
36 |
+
#### Caption-Data filename pair format
|
37 |
+
|
38 |
+
> [!NOTE]
|
39 |
+
> Relevant classes to look for implementation:
|
40 |
+
> - ImageCaptionFilePairDataset
|
41 |
+
> - VideoCaptionFilePairDataset
|
42 |
+
>
|
43 |
+
> Does not support loading directly from the HF Hub.
|
44 |
+
|
45 |
+
Your dataset structure should look like this. Running the `tree` command, you should see something similar to:
|
46 |
+
|
47 |
+
```
|
48 |
+
dataset
|
49 |
+
├── a.txt
|
50 |
+
├── a.mp4
|
51 |
+
├── bkjlaskdjg.txt
|
52 |
+
├── bkjlaskdjg.mp4
|
53 |
+
├── ...
|
54 |
+
```
|
55 |
+
|
56 |
+
- Each caption file should have a corresponding image/video file with the same name.
|
57 |
+
|
58 |
+
#### CSV/JSON/JSONL format
|
59 |
+
|
60 |
+
> [!NOTE]
|
61 |
+
> Relevant classes to look for implementation:
|
62 |
+
> - ImageFolderDataset
|
63 |
+
> - VideoFolderDataset
|
64 |
+
|
65 |
+
Any dataset loadable via the [🤗 HF datasets] directly should work (not widely tested at the moment):
|
66 |
+
- https://huggingface.co/docs/datasets/v3.3.2/en/image_load#webdataset
|
67 |
+
- https://huggingface.co/docs/datasets/v3.3.2/en/video_load#webdataset
|
68 |
+
|
69 |
+
#### Webdataset format
|
70 |
+
|
71 |
+
> [!NOTE]
|
72 |
+
> Relevant classes to look for implementation:
|
73 |
+
> - ImageWebDataset
|
74 |
+
> - VideoWebDataset
|
75 |
+
|
76 |
+
Any dataset loadable via the [🤗 HF datasets] directly should work (not widely tested at the moment). We support the [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/image_dataset#webdataset) and [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/video_dataset#webdataset) formats.
|
77 |
+
|
78 |
+
## Validation Dataset Format
|
79 |
+
|
80 |
+
Arguments related to validation are:
|
81 |
+
- `--validation_dataset_file`: Path to the validation dataset file. Supported formats are CSV, JSON, JSONL, PARQUET, and ARROW. Note: PARQUET and ARROW have not been tested after a major refactor, but should most likely work. (TODO(aryan): look into this)
|
82 |
+
- `--validation_steps`: Interval of training steps after which validation should be performed.
|
83 |
+
- `--enable_model_cpu_offload`: If set, CPU offloading will be enabled during validation. Note that this has not been tested for FSDP, TP, or DDP after a major refactor, but should most likely work for single GPU training,
|
84 |
+
|
85 |
+
> [!IMPORTANT]
|
86 |
+
>
|
87 |
+
> When using `dp_shards > 1` or `tp_degree > 1`, you must make sure that the number of data samples contained is a multiple of `dp_shards * tp_degree`. If this is not the case, the training will fail due to a NCCL timeout. This will be improved/fixed in the future.
|
88 |
+
|
89 |
+
- Must contain "caption" as a column. If an image must be provided for validation (for example, image-to-video inference), then the "image_path" field must be provided. If a video must be provided for validation (for example, video-to-video inference), then the "video_path" field must be provided. Other fields like "num_inference_steps", "height", "width", "num_frames", and "frame_rate" can be provided too but are optional.
|
90 |
+
|
91 |
+
#### CSV Example
|
92 |
+
|
93 |
+
<details>
|
94 |
+
<summary>Click to expand</summary>
|
95 |
+
|
96 |
+
```csv
|
97 |
+
caption,image_path,video_path,num_inference_steps,height,width,num_frames,frame_rate
|
98 |
+
"A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.",,"/raid/aryan/finetrainers-dummy-dataset-disney/a3c275fc2eb0a67168a7c58a6a9adb14.mp4",50,480,768,49,25
|
99 |
+
"<SECOND_CAPTION>",,"/path/to/second.mp4",50,512,704,161,25
|
100 |
+
```
|
101 |
+
|
102 |
+
</details>
|
103 |
+
|
104 |
+
#### JSON Example
|
105 |
+
|
106 |
+
Must contain "data" field, which should be a list of dictionaries. Each dictionary corresponds to one validation video that will be generated with the selected configuration of generation parameters.
|
107 |
+
|
108 |
+
<details>
|
109 |
+
<summary>Click to expand</summary>
|
110 |
+
|
111 |
+
```json
|
112 |
+
{
|
113 |
+
"data": [
|
114 |
+
{
|
115 |
+
"caption": "A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.",
|
116 |
+
"image_path": "",
|
117 |
+
"video_path": "/raid/aryan/finetrainers-dummy-dataset-disney/a3c275fc2eb0a67168a7c58a6a9adb14.mp4",
|
118 |
+
"num_inference_steps": 50,
|
119 |
+
"height": 480,
|
120 |
+
"width": 768,
|
121 |
+
"num_frames": 49,
|
122 |
+
"frame_rate": 25
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"caption": "<SECOND_CAPTION>",
|
126 |
+
"image_path": "",
|
127 |
+
"video_path": "/path/to/second.mp4",
|
128 |
+
"num_inference_steps": 50,
|
129 |
+
"height": 512,
|
130 |
+
"width": 704,
|
131 |
+
"num_frames": 161,
|
132 |
+
"frame_rate": 25
|
133 |
+
}
|
134 |
+
]
|
135 |
+
}
|
136 |
+
```
|
137 |
+
|
138 |
+
</details>
|
139 |
+
|
140 |
+
## Understanding how datasets are loaded
|
141 |
+
|
142 |
+
For memory efficient training, it is important to precompute conditional and latent embeddings. If this is not done, we will need to keep the conditioning models in memory, which can be memory intensive. To avoid this, we implement some abstractions that allow us to do the following efficiently:
|
143 |
+
- Loading datasets
|
144 |
+
- Chaining multiple datasets together
|
145 |
+
- Splitting datasets across data replicas
|
146 |
+
- Preprocessing datasets to user-configured resolution buckets
|
147 |
+
- Precomputing embeddings without exhaustively using too much disk space
|
148 |
+
|
149 |
+
The following is a high-level overview of how datasets are loaded and preprocessed:
|
150 |
+
|
151 |
+
- Initially, the dataset is lazy loaded using the HF `datasets` library. Every dataset is loaded in streaming and infinite mode. This means that the dataset will be loaded indefinitely until some end conditions (e.g. user-configured training steps is completed). Users can chain together multiple datasets too! For example, if you only have high resolution data available, but want to perform multi-resolution training at certain lower resolutions too, you would have to perform the resizing manually and chain the data together. Finetrainers makes this easier by allowing you to specify multiple different, or same, datasets with different resolutions.
|
152 |
+
- The dataset is split across data replicas (GPUs groups that perform data parallelism). Each data replica will have a non-overlapping subset of the overall dataset.
|
153 |
+
- If multiple datasets have been provided, they will be chained together. Shuffling can also be done to ensure better dataset regularization. This is done by shuffling the iterable datasets in a buffer of user-configured `--dataset_shuffle_buffer_size`. For small datasets, it is recommended to not shuffle and use the default value of `1`. For larger datasets, there is a significant overhead the higher this value is set to, so it is recommended to keep it low (< 1000) [this is because we store the data in memory in a not-so-clever way yet].
|
154 |
+
- The dataset is preprocessed to the user-configured resolution buckets. This is done by resizing the images/videos to the specified resolution buckets. This is also necessary for collation when using batch_size > 1.
|
155 |
+
- The dataset is precomputed for embeddings and stored to disk. This is done in batches of user-configured `--precompute_batch_size`. This is done to avoid exhausting disk space. The smaller this value, the more number of times conditioning models will be loaded upon precomputation exhaustion. The larger this value, the more disk space will be used.
|
156 |
+
- When data points are required for training, they are loaded from disk on the main process and dispatched to data replicas. [TODO: this needs some improvements to speedup training eventually]
|
157 |
+
|
158 |
+
## Understanding how datasets are precomputed
|
159 |
+
|
160 |
+
There are 3 arguments related to precomputation:
|
161 |
+
- `--precomputation_items`: The number of data points to precompute and store to disk at a time. This is useful for performing memory-efficient training without exhausting disk space by precomputing embeddings of the entire dataset(s) at once. We default to `512` data points, but configure this to a lower value for smaller datasets. As training progresses, the precomputed data will be read from disk and dispatched to data replicas. Once all precomputed data has been used, the next batch of data points will be precomputed and stored to disk in a rolling fashion.
|
162 |
+
- `--precomputation_dir`: The directory where precomputed data will be stored. This is useful for resuming training from a checkpoint, as the precomputed data will be loaded from this directory. If this directory is not provided, the precomputed data will be stored in the `--output_dir/precomputed`.
|
163 |
+
- `--precomputation_once`: If you're working with small datasets and want to precompute all embeddings at once, set this flag. This will allow you to train without having to compute embeddings every time the precomputed data is exhausted. Currently, `webdataset` format loading does not support this feature, and it is also disabled for `> 1024` data points due to hard coded logic (can be removed manually by users for now).
|
164 |
+
|
165 |
+
Batching is not yet supported for precomputation. This will be added in the future.
|
docs/finetrainers/documentation_environment.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Environment
|
2 |
+
|
3 |
+
Finetrainers has only been widely tested with the following environment (output obtained by running `diffusers-cli env`):
|
4 |
+
|
5 |
+
```shell
|
6 |
+
- 🤗 Diffusers version: 0.33.0.dev0
|
7 |
+
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
|
8 |
+
- Running on Google Colab?: No
|
9 |
+
- Python version: 3.10.14
|
10 |
+
- PyTorch version (GPU?): 2.5.1+cu124 (True)
|
11 |
+
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
|
12 |
+
- Jax version: 0.4.31
|
13 |
+
- JaxLib version: 0.4.31
|
14 |
+
- Huggingface_hub version: 0.28.1
|
15 |
+
- Transformers version: 4.48.0.dev0
|
16 |
+
- Accelerate version: 1.1.0.dev0
|
17 |
+
- PEFT version: 0.14.1.dev0
|
18 |
+
- Bitsandbytes version: 0.43.3
|
19 |
+
- Safetensors version: 0.4.5
|
20 |
+
- xFormers version: not installed
|
21 |
+
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
|
22 |
+
NVIDIA A100-SXM4-80GB, 81920 MiB
|
23 |
+
NVIDIA A100-SXM4-80GB, 81920 MiB
|
24 |
+
NVIDIA DGX Display, 4096 MiB
|
25 |
+
NVIDIA A100-SXM4-80GB, 81920 MiB
|
26 |
+
```
|
27 |
+
|
28 |
+
Other versions of dependencies may or may not work as expected. We would like to make finetrainers work on a wider range of environments, but due to the complexity of testing at the early stages of development, we are unable to do so. The long term goals include compatibility with most pytorch versions on CUDA, MPS, ROCm and XLA devices.
|
docs/finetrainers/documentation_global_README.md
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# finetrainers 🧪
|
2 |
+
|
3 |
+
FineTrainers is a work-in-progress library to support (accessible) training of video models. Our first priority is to support LoRA training for all popular video models in [Diffusers](https://github.com/huggingface/diffusers), and eventually other methods like controlnets, control-loras, distillation, etc.
|
4 |
+
|
5 |
+
`cogvideox-factory` was renamed to `finetrainers`. If you're looking to train CogVideoX or Mochi with the legacy training scripts, please refer to [this](./training/README.md) README instead. Everything in the `training/` directory will be eventually moved and supported under `finetrainers`.
|
6 |
+
|
7 |
+
<table align="center">
|
8 |
+
<tr>
|
9 |
+
<td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">Your browser does not support the video tag.</video></td>
|
10 |
+
</tr>
|
11 |
+
</table>
|
12 |
+
|
13 |
+
## News
|
14 |
+
|
15 |
+
- 🔥 **2025-03-03**: Wan T2V support added!
|
16 |
+
- 🔥 **2025-03-03**: We have shipped a complete refactor to support multi-backend distributed training, better precomputation handling for big datasets, model specification format (externally usable for training custom models), FSDP & more.
|
17 |
+
- 🔥 **2025-02-12**: We have shipped a set of tooling to curate small and high-quality video datasets for fine-tuning. See [video-dataset-scripts](https://github.com/huggingface/video-dataset-scripts) documentation page for details!
|
18 |
+
- 🔥 **2025-02-12**: Check out [eisneim/ltx_lora_training_i2v_t2v](https://github.com/eisneim/ltx_lora_training_i2v_t2v/)! It builds off of `finetrainers` to support image to video training for LTX-Video and STG guidance for inference.
|
19 |
+
- 🔥 **2025-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions.
|
20 |
+
- 🔥 **2025-01-13**: Support for T2V full-finetuning added! Thanks to [@ArEnSc](https://github.com/ArEnSc) for taking up the initiative!
|
21 |
+
- 🔥 **2025-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
|
22 |
+
- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254).
|
23 |
+
- 🔥 **2024-12-18**: Support for T2V LoRA finetuning of [LTX Video](https://huggingface.co/docs/diffusers/main/api/pipelines/ltx_video) added!
|
24 |
+
|
25 |
+
## Table of Contents
|
26 |
+
|
27 |
+
- [Quickstart](#quickstart)
|
28 |
+
- [Support Matrix](#support-matrix)
|
29 |
+
- [Featured Projects](#featured-projects)
|
30 |
+
- [Acknowledgements](#acknowledgements)
|
31 |
+
|
32 |
+
## Quickstart
|
33 |
+
|
34 |
+
Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install `diffusers` from source by `pip install git+https://github.com/huggingface/diffusers`. The requirements specify `diffusers>=0.32.1`, but it is always recommended to use the `main` branch of Diffusers for the latest features and bugfixes. Note that the `main` branch for `finetrainers` is also the development branch, and stable support should be expected from the release tags.
|
35 |
+
|
36 |
+
Checkout to the latest release tag:
|
37 |
+
|
38 |
+
```bash
|
39 |
+
git fetch --all --tags
|
40 |
+
git checkout tags/v0.0.1
|
41 |
+
```
|
42 |
+
|
43 |
+
Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.0.1) for the release tag.
|
44 |
+
|
45 |
+
#### Using the main branch
|
46 |
+
|
47 |
+
To get started quickly with example training scripts on the main development branch, refer to the following:
|
48 |
+
- [LTX-Video Pika Effects Crush](./examples/training/sft/ltx_video/crush_smol_lora/)
|
49 |
+
- [CogVideoX Pika Effects Crush](./examples/training/sft/cogvideox/crush_smol_lora/)
|
50 |
+
- [Wan T2V Pika Effects Crush](./examples/training/sft/wan/crush_smol_lora/)
|
51 |
+
|
52 |
+
The following are some simple datasets/HF orgs with good datasets to test training with quickly:
|
53 |
+
- [Disney Video Generation Dataset](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
|
54 |
+
- [bigdatapw Video Dataset Collection](https://huggingface.co/bigdata-pw)
|
55 |
+
- [Finetrainers HF Dataset Collection](https://huggingface.co/finetrainers)
|
56 |
+
|
57 |
+
Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./examples/training/) to learn more about supported models for training & example reproducible training launch scripts.
|
58 |
+
|
59 |
+
> [!IMPORTANT]
|
60 |
+
> It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested. For fully reproducible training, please use the same environment as mentioned in [environment.md](./docs/environment.md).
|
61 |
+
|
62 |
+
## Support Matrix
|
63 |
+
|
64 |
+
> [!NOTE]
|
65 |
+
> The following numbers were obtained from the [release branch](https://github.com/a-r-r-o-w/finetrainers/tree/v0.0.1). The `main` branch is unstable at the moment and may use higher memory.
|
66 |
+
|
67 |
+
<div align="center">
|
68 |
+
|
69 |
+
| **Model Name** | **Tasks** | **Min. LoRA VRAM<sup>*</sup>** | **Min. Full Finetuning VRAM<sup>^</sup>** |
|
70 |
+
|:----------------------------------------------:|:-------------:|:----------------------------------:|:---------------------------------------------:|
|
71 |
+
| [LTX-Video](./docs/models/ltx_video.md) | Text-to-Video | 5 GB | 21 GB |
|
72 |
+
| [HunyuanVideo](./docs/models/hunyuan_video.md) | Text-to-Video | 32 GB | OOM |
|
73 |
+
| [CogVideoX-5b](./docs/models/cogvideox.md) | Text-to-Video | 18 GB | 53 GB |
|
74 |
+
| [Wan](./docs/models/wan.md) | Text-to-Video | TODO | TODO |
|
75 |
+
|
76 |
+
</div>
|
77 |
+
|
78 |
+
<sub><sup>*</sup>Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using **FP8** weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).</sub><br/>
|
79 |
+
<sub><sup>^</sup>Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using **BF16** weights & gradient checkpointing.</sub>
|
80 |
+
|
81 |
+
If you would like to use a custom dataset, refer to the dataset preparation guide [here](./docs/dataset/README.md).
|
82 |
+
|
83 |
+
## Featured Projects 🔥
|
84 |
+
|
85 |
+
Checkout some amazing projects citing `finetrainers`:
|
86 |
+
- [Diffusion as Shader](https://github.com/IGL-HKUST/DiffusionAsShader)
|
87 |
+
- [SkyworkAI's SkyReels-A1](https://github.com/SkyworkAI/SkyReels-A1)
|
88 |
+
- [eisneim's LTX Image-to-Video](https://github.com/eisneim/ltx_lora_training_i2v_t2v/)
|
89 |
+
- [wileewang's TransPixar](https://github.com/wileewang/TransPixar)
|
90 |
+
- [Feizc's Video-In-Context](https://github.com/feizc/Video-In-Context)
|
91 |
+
|
92 |
+
Checkout the following UIs built for `finetrainers`:
|
93 |
+
- [jbilcke's VideoModelStudio](https://github.com/jbilcke-hf/VideoModelStudio)
|
94 |
+
- [neph1's finetrainers-ui](https://github.com/neph1/finetrainers-ui)
|
95 |
+
|
96 |
+
## Acknowledgements
|
97 |
+
|
98 |
+
* `finetrainers` builds on top of & takes inspiration from great open-source libraries - `transformers`, `accelerate`, `torchtune`, `torchtitan`, `peft`, `diffusers`, `bitsandbytes`, `torchao` and `deepspeed` - to name a few.
|
99 |
+
* Some of the design choices of `finetrainers` were inspired by [`SimpleTuner`](https://github.com/bghira/SimpleTuner).
|
docs/finetrainers/documentation_models_README.md
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FineTrainers training documentation
|
2 |
+
|
3 |
+
This directory contains the training-related specifications for all the models we support in `finetrainers`. Each model page has:
|
4 |
+
- an example training command
|
5 |
+
- inference example
|
6 |
+
- numbers on memory consumption
|
7 |
+
|
8 |
+
By default, we don't include any validation-related arguments in the example training commands. To enable validation inference every 500 steps, one can add the following arguments:
|
9 |
+
|
10 |
+
```diff
|
11 |
+
+ --validation_dataset_file <Path to a CSV/JSON/PARQUET/ARROW> \
|
12 |
+
+ --validation_steps 500
|
13 |
+
```
|
14 |
+
|
15 |
+
Arguments for training are documented in the code. For more information, please run `python train.py --help`.
|
16 |
+
|
17 |
+
## Support matrix
|
18 |
+
|
19 |
+
The following table shows the algorithms supported for training and the models they are supported for:
|
20 |
+
|
21 |
+
| Model | SFT | Control | ControlNet | Distillation |
|
22 |
+
|:-----------------------------------------:|:---:|:-------:|:----------:|:------------:|
|
23 |
+
| [CogVideoX](./cogvideox.md) | 🤗 | 😡 | 😡 | 😡 |
|
24 |
+
| [LTX-Video](./ltx_video.md) | 🤗 | 😡 | 😡 | 😡 |
|
25 |
+
| [HunyuanVideo](./hunyuan_video.md)) | 🤗 | 😡 | 😡 | 😡 |
|
26 |
+
|
27 |
+
For launching SFT Training:
|
28 |
+
- `--training_type lora`: Trains a new set of low-rank weights of the model, yielding a smaller adapter model. Currently, only LoRA is supported from [🤗 PEFT](https://github.com/huggingface/peft)
|
29 |
+
- `--training_type full-finetune`: Trains the full-rank weights of the model, yielding a full-parameter trained model.
|
30 |
+
|
31 |
+
Any model architecture loadable in diffusers/transformers for above models can be used for training. For example, [SkyReels-T2V](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-T2V) is a finetune of HunyuanVideo, which is compatible for continual training out-of-the-box. Custom models can be loaded either by writing your own [ModelSpecification](TODO(aryan): add link) or by using the following set of arguments:
|
32 |
+
- `--tokenizer_id`, `--tokenizer_2_id`, `--tokenizer_3_id`: The tokenizers to use for training in conjunction with text encoder conditioning models.
|
33 |
+
- `--text_encoder_id`, `--text_encoder_2_id`, `--text_encoder_3_id`: The text encoder conditioning models.
|
34 |
+
- `--transformer_id`: The transformer model to use for training.
|
35 |
+
- `--vae_id`: The VAE model to use for training.
|
36 |
+
|
37 |
+
The above arguments should take care of most training scenarios. For any custom training scenarios, please use your own implementation of a `ModelSpecification`. These arguments should be used only if one wants to override the default components loaded from `--pretrained_model_name_or_path`. Similar to each of these arguments, there exists a set of `--<ARG>_dtype` argument to specify the precision of each component.
|
38 |
+
|
39 |
+
## Resuming training
|
40 |
+
|
41 |
+
To resume training, the following arguments can be used:
|
42 |
+
- `--checkpointing_steps`: The interval of training steps that should be completed after which the training state should be saved.
|
43 |
+
- `--checkpointing_limit`: The maximum number of checkpoints that should be saved at once. If the limit is reached, the oldest checkpoint is purged.
|
44 |
+
- `--resume_from_checkpoint <STEP_OR_LATEST>`: Can be an integer or the string `"latest"`. If an integer is provided, training will resume from that step if a checkpoint corresponding to it exists. If `"latest"` is provided, training will resume from the latest checkpoint in the `--output_dir`.
|
45 |
+
|
46 |
+
> [!IMPORTANT]
|
47 |
+
> The `--resume_from_checkpoint` argument is only compatible if the parallel backend and degrees of parallelism are the same from the previous training run. For example, changing `--dp_degree 2 --dp_shards 1` from past run to `--dp_degree 1 --dp_shards 2` in current run will not work.
|
48 |
+
|
49 |
+
## How do we handle `mixed_precision`?
|
50 |
+
|
51 |
+
The accelerate config files (the ones seen [here](../../accelerate_configs/)) that are being supplied while launching training should contain a field called `mixed_precision` and `accelerate` makes use of that if specified. We don't let users explicitly pass that from the CLI args because it can be confusing to have `transformer_dtype` and `mixed_precision` in the codebase.
|
52 |
+
|
53 |
+
`transformer_dtype` is the ultimate source of truth for the precision to be used when training. It will also most likely always have to be `torch.bfloat16` because:
|
54 |
+
|
55 |
+
* All models currently supported (except Cog-2b) do not work well in FP16 for inference, so training would be broken as well. This can be revisited if it makes sense to train in FP16 for other models added.
|
56 |
+
* The `accelerate` config files default to using "bf16", but modifying that would be at the risk of user and assumes they understand the significance of their changes.
|
docs/finetrainers/documentation_models_hunyuan_video.md
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HunyuanVideo
|
2 |
+
|
3 |
+
## Training
|
4 |
+
|
5 |
+
For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
|
6 |
+
|
7 |
+
```bash
|
8 |
+
#!/bin/bash
|
9 |
+
|
10 |
+
export WANDB_MODE="offline"
|
11 |
+
export NCCL_P2P_DISABLE=1
|
12 |
+
export TORCH_NCCL_ENABLE_MONITORING=0
|
13 |
+
export FINETRAINERS_LOG_LEVEL=DEBUG
|
14 |
+
|
15 |
+
GPU_IDS="0,1"
|
16 |
+
|
17 |
+
DATA_ROOT="/path/to/dataset"
|
18 |
+
CAPTION_COLUMN="prompts.txt"
|
19 |
+
VIDEO_COLUMN="videos.txt"
|
20 |
+
OUTPUT_DIR="/path/to/models/hunyuan-video/"
|
21 |
+
|
22 |
+
ID_TOKEN="afkx"
|
23 |
+
|
24 |
+
# Model arguments
|
25 |
+
model_cmd="--model_name hunyuan_video \
|
26 |
+
--pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo"
|
27 |
+
|
28 |
+
# Dataset arguments
|
29 |
+
dataset_cmd="--data_root $DATA_ROOT \
|
30 |
+
--video_column $VIDEO_COLUMN \
|
31 |
+
--caption_column $CAPTION_COLUMN \
|
32 |
+
--id_token $ID_TOKEN \
|
33 |
+
--video_resolution_buckets 17x512x768 49x512x768 61x512x768 \
|
34 |
+
--caption_dropout_p 0.05"
|
35 |
+
|
36 |
+
# Dataloader arguments
|
37 |
+
dataloader_cmd="--dataloader_num_workers 0"
|
38 |
+
|
39 |
+
# Diffusion arguments
|
40 |
+
diffusion_cmd=""
|
41 |
+
|
42 |
+
# Training arguments
|
43 |
+
training_cmd="--training_type lora \
|
44 |
+
--seed 42 \
|
45 |
+
--batch_size 1 \
|
46 |
+
--train_steps 500 \
|
47 |
+
--rank 128 \
|
48 |
+
--lora_alpha 128 \
|
49 |
+
--target_modules to_q to_k to_v to_out.0 \
|
50 |
+
--gradient_accumulation_steps 1 \
|
51 |
+
--gradient_checkpointing \
|
52 |
+
--checkpointing_steps 500 \
|
53 |
+
--checkpointing_limit 2 \
|
54 |
+
--enable_slicing \
|
55 |
+
--enable_tiling"
|
56 |
+
|
57 |
+
# Optimizer arguments
|
58 |
+
optimizer_cmd="--optimizer adamw \
|
59 |
+
--lr 2e-5 \
|
60 |
+
--lr_scheduler constant_with_warmup \
|
61 |
+
--lr_warmup_steps 100 \
|
62 |
+
--lr_num_cycles 1 \
|
63 |
+
--beta1 0.9 \
|
64 |
+
--beta2 0.95 \
|
65 |
+
--weight_decay 1e-4 \
|
66 |
+
--epsilon 1e-8 \
|
67 |
+
--max_grad_norm 1.0"
|
68 |
+
|
69 |
+
# Miscellaneous arguments
|
70 |
+
miscellaneous_cmd="--tracker_name finetrainers-hunyuan-video \
|
71 |
+
--output_dir $OUTPUT_DIR \
|
72 |
+
--nccl_timeout 1800 \
|
73 |
+
--report_to wandb"
|
74 |
+
|
75 |
+
cmd="accelerate launch --config_file accelerate_configs/uncompiled_8.yaml --gpu_ids $GPU_IDS train.py \
|
76 |
+
$model_cmd \
|
77 |
+
$dataset_cmd \
|
78 |
+
$dataloader_cmd \
|
79 |
+
$diffusion_cmd \
|
80 |
+
$training_cmd \
|
81 |
+
$optimizer_cmd \
|
82 |
+
$miscellaneous_cmd"
|
83 |
+
|
84 |
+
echo "Running command: $cmd"
|
85 |
+
eval $cmd
|
86 |
+
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
87 |
+
```
|
88 |
+
|
89 |
+
## Memory Usage
|
90 |
+
|
91 |
+
### LoRA
|
92 |
+
|
93 |
+
> [!NOTE]
|
94 |
+
>
|
95 |
+
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
|
96 |
+
|
97 |
+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:
|
98 |
+
|
99 |
+
```
|
100 |
+
Training configuration: {
|
101 |
+
"trainable parameters": 163577856,
|
102 |
+
"total samples": 69,
|
103 |
+
"train epochs": 1,
|
104 |
+
"train steps": 10,
|
105 |
+
"batches per device": 1,
|
106 |
+
"total batches observed per epoch": 69,
|
107 |
+
"train batch size": 1,
|
108 |
+
"gradient accumulation steps": 1
|
109 |
+
}
|
110 |
+
```
|
111 |
+
|
112 |
+
| stage | memory_allocated | max_memory_reserved |
|
113 |
+
|:-----------------------:|:----------------:|:-------------------:|
|
114 |
+
| before training start | 38.889 | 39.020 |
|
115 |
+
| before validation start | 39.747 | 56.266 |
|
116 |
+
| after validation end | 39.748 | 58.385 |
|
117 |
+
| after epoch 1 | 39.748 | 40.910 |
|
118 |
+
| after training end | 25.288 | 40.910 |
|
119 |
+
|
120 |
+
Note: requires about `59` GB of VRAM when validation is performed.
|
121 |
+
|
122 |
+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **with precomputation**:
|
123 |
+
|
124 |
+
```
|
125 |
+
Training configuration: {
|
126 |
+
"trainable parameters": 163577856,
|
127 |
+
"total samples": 1,
|
128 |
+
"train epochs": 10,
|
129 |
+
"train steps": 10,
|
130 |
+
"batches per device": 1,
|
131 |
+
"total batches observed per epoch": 1,
|
132 |
+
"train batch size": 1,
|
133 |
+
"gradient accumulation steps": 1
|
134 |
+
}
|
135 |
+
```
|
136 |
+
|
137 |
+
| stage | memory_allocated | max_memory_reserved |
|
138 |
+
|:-----------------------------:|:----------------:|:-------------------:|
|
139 |
+
| after precomputing conditions | 14.232 | 14.461 |
|
140 |
+
| after precomputing latents | 14.717 | 17.244 |
|
141 |
+
| before training start | 24.195 | 26.039 |
|
142 |
+
| after epoch 1 | 24.83 | 42.387 |
|
143 |
+
| before validation start | 24.842 | 42.387 |
|
144 |
+
| after validation end | 39.558 | 46.947 |
|
145 |
+
| after training end | 24.842 | 41.039 |
|
146 |
+
|
147 |
+
Note: requires about `47` GB of VRAM with validation. If validation is not performed, the memory usage is reduced to about `42` GB.
|
148 |
+
|
149 |
+
### Full finetuning
|
150 |
+
|
151 |
+
Current, full finetuning is not supported for HunyuanVideo. It goes out of memory (OOM) for `49x512x768` resolutions.
|
152 |
+
|
153 |
+
## Inference
|
154 |
+
|
155 |
+
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
|
156 |
+
|
157 |
+
```py
|
158 |
+
import torch
|
159 |
+
from diffusers import HunyuanVideoPipeline
|
160 |
+
|
161 |
+
import torch
|
162 |
+
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
163 |
+
from diffusers.utils import export_to_video
|
164 |
+
|
165 |
+
model_id = "hunyuanvideo-community/HunyuanVideo"
|
166 |
+
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
167 |
+
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
168 |
+
)
|
169 |
+
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
|
170 |
+
pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="hunyuanvideo-lora")
|
171 |
+
pipe.set_adapters(["hunyuanvideo-lora"], [0.6])
|
172 |
+
pipe.vae.enable_tiling()
|
173 |
+
pipe.to("cuda")
|
174 |
+
|
175 |
+
output = pipe(
|
176 |
+
prompt="A cat walks on the grass, realistic",
|
177 |
+
height=320,
|
178 |
+
width=512,
|
179 |
+
num_frames=61,
|
180 |
+
num_inference_steps=30,
|
181 |
+
).frames[0]
|
182 |
+
export_to_video(output, "output.mp4", fps=15)
|
183 |
+
```
|
184 |
+
|
185 |
+
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
186 |
+
|
187 |
+
* [Hunyuan-Video in Diffusers](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video)
|
188 |
+
* [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
|
189 |
+
* [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
|
docs/finetrainers/documentation_models_ltx_video.md
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LTX-Video
|
2 |
+
|
3 |
+
## Training
|
4 |
+
|
5 |
+
For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
|
6 |
+
|
7 |
+
```bash
|
8 |
+
#!/bin/bash
|
9 |
+
export WANDB_MODE="offline"
|
10 |
+
export NCCL_P2P_DISABLE=1
|
11 |
+
export TORCH_NCCL_ENABLE_MONITORING=0
|
12 |
+
export FINETRAINERS_LOG_LEVEL=DEBUG
|
13 |
+
|
14 |
+
GPU_IDS="0,1"
|
15 |
+
|
16 |
+
DATA_ROOT="/path/to/dataset"
|
17 |
+
CAPTION_COLUMN="prompts.txt"
|
18 |
+
VIDEO_COLUMN="videos.txt"
|
19 |
+
OUTPUT_DIR="/path/to/models/ltx-video/"
|
20 |
+
|
21 |
+
ID_TOKEN="BW_STYLE"
|
22 |
+
|
23 |
+
# Model arguments
|
24 |
+
model_cmd="--model_name ltx_video \
|
25 |
+
--pretrained_model_name_or_path Lightricks/LTX-Video"
|
26 |
+
|
27 |
+
# Dataset arguments
|
28 |
+
dataset_cmd="--data_root $DATA_ROOT \
|
29 |
+
--video_column $VIDEO_COLUMN \
|
30 |
+
--caption_column $CAPTION_COLUMN \
|
31 |
+
--id_token $ID_TOKEN \
|
32 |
+
--video_resolution_buckets 49x512x768 \
|
33 |
+
--caption_dropout_p 0.05"
|
34 |
+
|
35 |
+
# Dataloader arguments
|
36 |
+
dataloader_cmd="--dataloader_num_workers 0"
|
37 |
+
|
38 |
+
# Diffusion arguments
|
39 |
+
diffusion_cmd="--flow_weighting_scheme logit_normal"
|
40 |
+
|
41 |
+
# Training arguments
|
42 |
+
training_cmd="--training_type lora \
|
43 |
+
--seed 42 \
|
44 |
+
--batch_size 1 \
|
45 |
+
--train_steps 3000 \
|
46 |
+
--rank 128 \
|
47 |
+
--lora_alpha 128 \
|
48 |
+
--target_modules to_q to_k to_v to_out.0 \
|
49 |
+
--gradient_accumulation_steps 4 \
|
50 |
+
--gradient_checkpointing \
|
51 |
+
--checkpointing_steps 500 \
|
52 |
+
--checkpointing_limit 2 \
|
53 |
+
--enable_slicing \
|
54 |
+
--enable_tiling"
|
55 |
+
|
56 |
+
# Optimizer arguments
|
57 |
+
optimizer_cmd="--optimizer adamw \
|
58 |
+
--lr 3e-5 \
|
59 |
+
--lr_scheduler constant_with_warmup \
|
60 |
+
--lr_warmup_steps 100 \
|
61 |
+
--lr_num_cycles 1 \
|
62 |
+
--beta1 0.9 \
|
63 |
+
--beta2 0.95 \
|
64 |
+
--weight_decay 1e-4 \
|
65 |
+
--epsilon 1e-8 \
|
66 |
+
--max_grad_norm 1.0"
|
67 |
+
|
68 |
+
# Miscellaneous arguments
|
69 |
+
miscellaneous_cmd="--tracker_name finetrainers-ltxv \
|
70 |
+
--output_dir $OUTPUT_DIR \
|
71 |
+
--nccl_timeout 1800 \
|
72 |
+
--report_to wandb"
|
73 |
+
|
74 |
+
cmd="accelerate launch --config_file accelerate_configs/uncompiled_2.yaml --gpu_ids $GPU_IDS train.py \
|
75 |
+
$model_cmd \
|
76 |
+
$dataset_cmd \
|
77 |
+
$dataloader_cmd \
|
78 |
+
$diffusion_cmd \
|
79 |
+
$training_cmd \
|
80 |
+
$optimizer_cmd \
|
81 |
+
$miscellaneous_cmd"
|
82 |
+
|
83 |
+
echo "Running command: $cmd"
|
84 |
+
eval $cmd
|
85 |
+
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
86 |
+
```
|
87 |
+
|
88 |
+
## Memory Usage
|
89 |
+
|
90 |
+
### LoRA
|
91 |
+
|
92 |
+
> [!NOTE]
|
93 |
+
>
|
94 |
+
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
|
95 |
+
|
96 |
+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:
|
97 |
+
|
98 |
+
```
|
99 |
+
Training configuration: {
|
100 |
+
"trainable parameters": 117440512,
|
101 |
+
"total samples": 69,
|
102 |
+
"train epochs": 1,
|
103 |
+
"train steps": 10,
|
104 |
+
"batches per device": 1,
|
105 |
+
"total batches observed per epoch": 69,
|
106 |
+
"train batch size": 1,
|
107 |
+
"gradient accumulation steps": 1
|
108 |
+
}
|
109 |
+
```
|
110 |
+
|
111 |
+
| stage | memory_allocated | max_memory_reserved |
|
112 |
+
|:-----------------------:|:----------------:|:-------------------:|
|
113 |
+
| before training start | 13.486 | 13.879 |
|
114 |
+
| before validation start | 14.146 | 17.623 |
|
115 |
+
| after validation end | 14.146 | 17.623 |
|
116 |
+
| after epoch 1 | 14.146 | 17.623 |
|
117 |
+
| after training end | 4.461 | 17.623 |
|
118 |
+
|
119 |
+
Note: requires about `18` GB of VRAM without precomputation.
|
120 |
+
|
121 |
+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **with precomputation**:
|
122 |
+
|
123 |
+
```
|
124 |
+
Training configuration: {
|
125 |
+
"trainable parameters": 117440512,
|
126 |
+
"total samples": 1,
|
127 |
+
"train epochs": 10,
|
128 |
+
"train steps": 10,
|
129 |
+
"batches per device": 1,
|
130 |
+
"total batches observed per epoch": 1,
|
131 |
+
"train batch size": 1,
|
132 |
+
"gradient accumulation steps": 1
|
133 |
+
}
|
134 |
+
```
|
135 |
+
|
136 |
+
| stage | memory_allocated | max_memory_reserved |
|
137 |
+
|:-----------------------------:|:----------------:|:-------------------:|
|
138 |
+
| after precomputing conditions | 8.88 | 8.920 |
|
139 |
+
| after precomputing latents | 9.684 | 11.613 |
|
140 |
+
| before training start | 3.809 | 10.010 |
|
141 |
+
| after epoch 1 | 4.26 | 10.916 |
|
142 |
+
| before validation start | 4.26 | 10.916 |
|
143 |
+
| after validation end | 13.924 | 17.262 |
|
144 |
+
| after training end | 4.26 | 14.314 |
|
145 |
+
|
146 |
+
Note: requires about `17.5` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to `11` GB.
|
147 |
+
|
148 |
+
### Full Finetuning
|
149 |
+
|
150 |
+
```
|
151 |
+
Training configuration: {
|
152 |
+
"trainable parameters": 1923385472,
|
153 |
+
"total samples": 1,
|
154 |
+
"train epochs": 10,
|
155 |
+
"train steps": 10,
|
156 |
+
"batches per device": 1,
|
157 |
+
"total batches observed per epoch": 1,
|
158 |
+
"train batch size": 1,
|
159 |
+
"gradient accumulation steps": 1
|
160 |
+
}
|
161 |
+
```
|
162 |
+
|
163 |
+
| stage | memory_allocated | max_memory_reserved |
|
164 |
+
|:-----------------------------:|:----------------:|:-------------------:|
|
165 |
+
| after precomputing conditions | 8.89 | 8.937 |
|
166 |
+
| after precomputing latents | 9.701 | 11.615 |
|
167 |
+
| before training start | 3.583 | 4.025 |
|
168 |
+
| after epoch 1 | 10.769 | 20.357 |
|
169 |
+
| before validation start | 10.769 | 20.357 |
|
170 |
+
| after validation end | 10.769 | 28.332 |
|
171 |
+
| after training end | 10.769 | 12.904 |
|
172 |
+
|
173 |
+
## Inference
|
174 |
+
|
175 |
+
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
|
176 |
+
|
177 |
+
```diff
|
178 |
+
import torch
|
179 |
+
from diffusers import LTXPipeline
|
180 |
+
from diffusers.utils import export_to_video
|
181 |
+
|
182 |
+
pipe = LTXPipeline.from_pretrained(
|
183 |
+
"Lightricks/LTX-Video", torch_dtype=torch.bfloat16
|
184 |
+
).to("cuda")
|
185 |
+
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="ltxv-lora")
|
186 |
+
+ pipe.set_adapters(["ltxv-lora"], [0.75])
|
187 |
+
|
188 |
+
video = pipe("<my-awesome-prompt>").frames[0]
|
189 |
+
export_to_video(video, "output.mp4", fps=8)
|
190 |
+
```
|
191 |
+
|
192 |
+
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
193 |
+
|
194 |
+
* [LTX-Video in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video)
|
195 |
+
* [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
|
196 |
+
* [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
|
docs/finetrainers/documentation_models_optimization.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Memory optimizations
|
2 |
+
|
3 |
+
To lower memory requirements during training:
|
4 |
+
|
5 |
+
- `--precompute_conditions`: this precomputes the conditions and latents, and loads them as required during training, which saves a significant amount of time and memory.
|
6 |
+
- `--gradient_checkpointing`: this saves memory by recomputing activations during the backward pass.
|
7 |
+
- `--layerwise_upcasting_modules transformer`: naively casts the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`. This halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`)
|
8 |
+
- `--use_8bit_bnb`: this is only applicable to Adam and AdamW optimizers, and makes use of 8-bit precision to store optimizer states.
|
9 |
+
- Use a DeepSpeed config to launch training (refer to [`accelerate_configs/deepspeed.yaml`](./accelerate_configs/deepspeed.yaml) as an example).
|
10 |
+
- Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
|
11 |
+
|
12 |
+
We will continue to add more features that help to reduce memory consumption.
|
docs/finetrainers/documentation_models_wan.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Wan
|
2 |
+
|
3 |
+
## Training
|
4 |
+
|
5 |
+
For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
|
6 |
+
|
7 |
+
See [this](../../examples/training/sft/wan/crush_smol_lora/) example training script for training Wan with Pika Effects Crush.
|
8 |
+
|
9 |
+
## Memory Usage
|
10 |
+
|
11 |
+
TODO
|
12 |
+
|
13 |
+
## Inference
|
14 |
+
|
15 |
+
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
|
16 |
+
|
17 |
+
```diff
|
18 |
+
import torch
|
19 |
+
from diffusers import WanPipeline
|
20 |
+
from diffusers.utils import export_to_video
|
21 |
+
|
22 |
+
pipe = WanPipeline.from_pretrained(
|
23 |
+
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
|
24 |
+
).to("cuda")
|
25 |
+
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="wan-lora")
|
26 |
+
+ pipe.set_adapters(["wan-lora"], [0.75])
|
27 |
+
|
28 |
+
video = pipe("<my-awesome-prompt>").frames[0]
|
29 |
+
export_to_video(video, "output.mp4", fps=8)
|
30 |
+
```
|
31 |
+
|
32 |
+
You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
|
33 |
+
|
34 |
+
* [Wan in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan)
|
35 |
+
* [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
|
36 |
+
* [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
|
docs/finetrainers/documentation_optimizers.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Optimizers
|
2 |
+
|
3 |
+
The following optimizers are supported:
|
4 |
+
- **torch**:
|
5 |
+
- `Adam`
|
6 |
+
- `AdamW`
|
7 |
+
- **bitsandbytes**:
|
8 |
+
- `Adam`
|
9 |
+
- `AdamW`
|
10 |
+
- `Adam8Bit`
|
11 |
+
- `AdamW8Bit`
|
12 |
+
|
13 |
+
> [!NOTE]
|
14 |
+
> Not all optimizers have been tested with all models/parallel settings. They may or may not work, but this will gradually improve over time.
|
docs/finetrainers/documentation_parallel_processing_README.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Finetrainers Parallel Backends
|
2 |
+
|
3 |
+
Finetrainers supports parallel training on multiple GPUs & nodes. This is done using the Pytorch DTensor backend. To run parallel training, `torchrun` is utilized.
|
4 |
+
|
5 |
+
As an experiment for comparing performance of different training backends, Finetrainers has implemented multi-backend support. These backends may or may not fully rely on Pytorch's distributed DTensor solution. Currently, only [🤗 Accelerate](https://github.com/huggingface/accelerate) is supported for backwards-compatibility reasons (as we initially started Finetrainers with only Accelerate). In the near future, there are plans for integrating with:
|
6 |
+
- [DeepSpeed](https://github.com/deepspeedai/DeepSpeed)
|
7 |
+
- [Nanotron](https://github.com/huggingface/nanotron)
|
8 |
+
- [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
|
9 |
+
|
10 |
+
> [!IMPORTANT]
|
11 |
+
> The multi-backend support is completely experimental and only serves to satisfy my curiosity of how much of a tradeoff there is between performance and ease of use. The Pytorch DTensor backend is the only one with stable support, following Accelerate.
|
12 |
+
>
|
13 |
+
> Users will not have to worry about backwards-breaking changes or dependencies if they stick to the Pytorch DTensor backend.
|
14 |
+
|
15 |
+
## Support matrix
|
16 |
+
|
17 |
+
There are various algorithms for parallel training. Currently, we only support:
|
18 |
+
- [DDP](https://pytorch.org/docs/stable/notes/ddp.html)
|
19 |
+
- [FSDP2](https://pytorch.org/docs/stable/fsdp.html)
|
20 |
+
- [HSDP](https://pytorch.org/docs/stable/fsdp.html)
|
21 |
+
- [TP](https://pytorch.org/docs/stable/distributed.tensor.parallel.html)
|
22 |
+
|
23 |
+
## Training
|
24 |
+
|
25 |
+
The following parameters are relevant for launching training:
|
26 |
+
|
27 |
+
- `parallel_backend`: The backend to use for parallel training. Available options are `ptd` & `accelerate`.
|
28 |
+
- `pp_degree`: The degree of pipeline parallelism. Currently unsupported.
|
29 |
+
- `dp_degree`: The degree of data parallelis/replicas. Defaults to `1`.
|
30 |
+
- `dp_shards`: The number of shards for data parallelism. Defaults to `1`.
|
31 |
+
- `cp_degree`: The degree of context parallelism. Currently unsupported.
|
32 |
+
- `tp_degree`: The degree of tensor parallelism.
|
33 |
+
|
34 |
+
For launching training with the Pytorch DTensor backend, use the following:
|
35 |
+
|
36 |
+
```bash
|
37 |
+
# Single node - 8 GPUs available
|
38 |
+
torchrun --standalone --nodes=1 --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint="localhost:0" train.py <YOUR_OTHER_ARGS>
|
39 |
+
|
40 |
+
# Single node - 8 GPUs but only 4 available
|
41 |
+
export CUDA_VISIBLE_DEVICES=0,2,4,5
|
42 |
+
torchrun --standalone --nodes=1 --nproc_per_node=4 --rdzv_backend c10d --rdzv_endpoint="localhost:0" train.py <YOUR_OTHER_ARGS>
|
43 |
+
|
44 |
+
# Multi-node - Nx8 GPUs available
|
45 |
+
# TODO(aryan): Add slurm script
|
46 |
+
```
|
47 |
+
|
48 |
+
For launching training with the Accelerate backend, use the following:
|
49 |
+
|
50 |
+
```bash
|
51 |
+
# Single node - 8 GPUs available
|
52 |
+
accelerate launch --config_file accelerate_configs/uncompiled_8.yaml --gpu_ids 0,1,2,3,4,5,6,7 train.py <YOUR_OTHER_ARGS>
|
53 |
+
|
54 |
+
# Single node - 8 GPUs but only 4 available
|
55 |
+
accelerate launch --config_file accelerate_configs/uncompiled_4.yaml --gpu_ids 0,2,4,5 train.py <YOUR_OTHER_ARGS>
|
56 |
+
|
57 |
+
# Multi-node - Nx8 GPUs available
|
58 |
+
# TODO(aryan): Add slurm script
|
59 |
+
```
|
docs/finetrainers/examples_training_wan__train.sh
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e -x
|
4 |
+
|
5 |
+
# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
|
6 |
+
# export TORCHDYNAMO_VERBOSE=1
|
7 |
+
export WANDB_MODE="offline"
|
8 |
+
export NCCL_P2P_DISABLE=1
|
9 |
+
export TORCH_NCCL_ENABLE_MONITORING=0
|
10 |
+
export FINETRAINERS_LOG_LEVEL="DEBUG"
|
11 |
+
|
12 |
+
# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences!
|
13 |
+
# BACKEND="accelerate"
|
14 |
+
BACKEND="ptd"
|
15 |
+
|
16 |
+
# In this setting, I'm using 2 GPUs on a 4-GPU node for training
|
17 |
+
NUM_GPUS=2
|
18 |
+
CUDA_VISIBLE_DEVICES="2,3"
|
19 |
+
|
20 |
+
# Check the JSON files for the expected JSON format
|
21 |
+
TRAINING_DATASET_CONFIG="examples/training/sft/wan/crush_smol_lora/training.json"
|
22 |
+
VALIDATION_DATASET_FILE="examples/training/sft/wan/crush_smol_lora/validation.json"
|
23 |
+
|
24 |
+
# Depending on how many GPUs you have available, choose your degree of parallelism and technique!
|
25 |
+
DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1"
|
26 |
+
DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1"
|
27 |
+
DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1"
|
28 |
+
FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1"
|
29 |
+
FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1"
|
30 |
+
HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1"
|
31 |
+
|
32 |
+
# Parallel arguments
|
33 |
+
parallel_cmd=(
|
34 |
+
$DDP_2
|
35 |
+
)
|
36 |
+
|
37 |
+
# Model arguments
|
38 |
+
model_cmd=(
|
39 |
+
--model_name "wan"
|
40 |
+
--pretrained_model_name_or_path "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
41 |
+
)
|
42 |
+
|
43 |
+
# Dataset arguments
|
44 |
+
# Here, we know that the dataset size if about ~50 videos. Since we're using 2 GPUs, we precompute
|
45 |
+
# embeddings of 25 dataset items per GPU. Also, we're using a very small dataset for finetuning, so
|
46 |
+
# we are okay with precomputing embeddings once and re-using them without having to worry about disk
|
47 |
+
# space. Currently, however, every new training run performs precomputation even if it's not required
|
48 |
+
# (which is something we've to improve [TODO(aryan)])
|
49 |
+
dataset_cmd=(
|
50 |
+
--dataset_config $TRAINING_DATASET_CONFIG
|
51 |
+
--dataset_shuffle_buffer_size 10
|
52 |
+
--precomputation_items 25
|
53 |
+
--precomputation_once
|
54 |
+
)
|
55 |
+
|
56 |
+
# Dataloader arguments
|
57 |
+
dataloader_cmd=(
|
58 |
+
--dataloader_num_workers 0
|
59 |
+
)
|
60 |
+
|
61 |
+
# Diffusion arguments
|
62 |
+
diffusion_cmd=(
|
63 |
+
--flow_weighting_scheme "logit_normal"
|
64 |
+
)
|
65 |
+
|
66 |
+
# Training arguments
|
67 |
+
# We target just the attention projections layers for LoRA training here.
|
68 |
+
# You can modify as you please and target any layer (regex is supported)
|
69 |
+
training_cmd=(
|
70 |
+
--training_type "lora"
|
71 |
+
--seed 42
|
72 |
+
--batch_size 1
|
73 |
+
--train_steps 3000
|
74 |
+
--rank 32
|
75 |
+
--lora_alpha 32
|
76 |
+
--target_modules "blocks.*(to_q|to_k|to_v|to_out.0)"
|
77 |
+
--gradient_accumulation_steps 1
|
78 |
+
--gradient_checkpointing
|
79 |
+
--checkpointing_steps 500
|
80 |
+
--checkpointing_limit 2
|
81 |
+
# --resume_from_checkpoint 3000
|
82 |
+
--enable_slicing
|
83 |
+
--enable_tiling
|
84 |
+
)
|
85 |
+
|
86 |
+
# Optimizer arguments
|
87 |
+
optimizer_cmd=(
|
88 |
+
--optimizer "adamw"
|
89 |
+
--lr 5e-5
|
90 |
+
--lr_scheduler "constant_with_warmup"
|
91 |
+
--lr_warmup_steps 1000
|
92 |
+
--lr_num_cycles 1
|
93 |
+
--beta1 0.9
|
94 |
+
--beta2 0.99
|
95 |
+
--weight_decay 1e-4
|
96 |
+
--epsilon 1e-8
|
97 |
+
--max_grad_norm 1.0
|
98 |
+
)
|
99 |
+
|
100 |
+
# Validation arguments
|
101 |
+
validation_cmd=(
|
102 |
+
--validation_dataset_file "$VALIDATION_DATASET_FILE"
|
103 |
+
--validation_steps 500
|
104 |
+
)
|
105 |
+
|
106 |
+
# Miscellaneous arguments
|
107 |
+
miscellaneous_cmd=(
|
108 |
+
--tracker_name "finetrainers-wan"
|
109 |
+
--output_dir "/raid/aryan/wan"
|
110 |
+
--init_timeout 600
|
111 |
+
--nccl_timeout 600
|
112 |
+
--report_to "wandb"
|
113 |
+
)
|
114 |
+
|
115 |
+
# Execute the training script
|
116 |
+
if [ "$BACKEND" == "accelerate" ]; then
|
117 |
+
|
118 |
+
ACCELERATE_CONFIG_FILE=""
|
119 |
+
if [ "$NUM_GPUS" == 1 ]; then
|
120 |
+
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
|
121 |
+
elif [ "$NUM_GPUS" == 2 ]; then
|
122 |
+
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml"
|
123 |
+
elif [ "$NUM_GPUS" == 4 ]; then
|
124 |
+
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml"
|
125 |
+
elif [ "$NUM_GPUS" == 8 ]; then
|
126 |
+
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml"
|
127 |
+
fi
|
128 |
+
|
129 |
+
accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \
|
130 |
+
"${parallel_cmd[@]}" \
|
131 |
+
"${model_cmd[@]}" \
|
132 |
+
"${dataset_cmd[@]}" \
|
133 |
+
"${dataloader_cmd[@]}" \
|
134 |
+
"${diffusion_cmd[@]}" \
|
135 |
+
"${training_cmd[@]}" \
|
136 |
+
"${optimizer_cmd[@]}" \
|
137 |
+
"${validation_cmd[@]}" \
|
138 |
+
"${miscellaneous_cmd[@]}"
|
139 |
+
|
140 |
+
elif [ "$BACKEND" == "ptd" ]; then
|
141 |
+
|
142 |
+
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
|
143 |
+
|
144 |
+
torchrun \
|
145 |
+
--standalone \
|
146 |
+
--nnodes=1 \
|
147 |
+
--nproc_per_node=$NUM_GPUS \
|
148 |
+
--rdzv_backend c10d \
|
149 |
+
--rdzv_endpoint="localhost:0" \
|
150 |
+
train.py \
|
151 |
+
"${parallel_cmd[@]}" \
|
152 |
+
"${model_cmd[@]}" \
|
153 |
+
"${dataset_cmd[@]}" \
|
154 |
+
"${dataloader_cmd[@]}" \
|
155 |
+
"${diffusion_cmd[@]}" \
|
156 |
+
"${training_cmd[@]}" \
|
157 |
+
"${optimizer_cmd[@]}" \
|
158 |
+
"${validation_cmd[@]}" \
|
159 |
+
"${miscellaneous_cmd[@]}"
|
160 |
+
fi
|
161 |
+
|
162 |
+
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
docs/finetrainers/examples_training_wan__training.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"datasets": [
|
3 |
+
{
|
4 |
+
"data_root": "finetrainers/crush-smol",
|
5 |
+
"dataset_type": "video",
|
6 |
+
"id_token": "PIKA_CRUSH",
|
7 |
+
"video_resolution_buckets": [
|
8 |
+
[49, 480, 832]
|
9 |
+
],
|
10 |
+
"reshape_mode": "bicubic",
|
11 |
+
"remove_common_llm_caption_prefixes": true
|
12 |
+
}
|
13 |
+
]
|
14 |
+
}
|
docs/finetrainers/examples_training_wan__validation.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"data": [
|
3 |
+
{
|
4 |
+
"caption": "PIKA_CRUSH A red toy car is being crushed by a large hydraulic press, which is flattening objects as if they were under a hydraulic press.",
|
5 |
+
"image_path": null,
|
6 |
+
"video_path": null,
|
7 |
+
"num_inference_steps": 50,
|
8 |
+
"height": 480,
|
9 |
+
"width": 832,
|
10 |
+
"num_frames": 49
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"caption": "PIKA_CRUSH A green cube is being compressed by a hydraulic press, which flattens the object as if it were under a hydraulic press. The press is shown in action, with the cube being squeezed into a smaller shape.",
|
14 |
+
"image_path": null,
|
15 |
+
"video_path": null,
|
16 |
+
"num_inference_steps": 50,
|
17 |
+
"height": 480,
|
18 |
+
"width": 832,
|
19 |
+
"num_frames": 49
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of colorful jelly beans, flattening them as if they were under a hydraulic press.",
|
23 |
+
"image_path": null,
|
24 |
+
"video_path": null,
|
25 |
+
"num_inference_steps": 50,
|
26 |
+
"height": 480,
|
27 |
+
"width": 832,
|
28 |
+
"num_frames": 49
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"caption": "PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.",
|
32 |
+
"image_path": null,
|
33 |
+
"video_path": null,
|
34 |
+
"num_inference_steps": 50,
|
35 |
+
"height": 480,
|
36 |
+
"width": 832,
|
37 |
+
"num_frames": 49
|
38 |
+
}
|
39 |
+
]
|
40 |
+
}
|
training/README.md
DELETED
@@ -1,459 +0,0 @@
|
|
1 |
-
# CogVideoX Factory 🧪
|
2 |
-
|
3 |
-
[中文阅读](./README_zh.md)
|
4 |
-
|
5 |
-
Fine-tune Cog family of video models for custom video generation under 24GB of GPU memory ⚡️📼
|
6 |
-
|
7 |
-
<table align="center">
|
8 |
-
<tr>
|
9 |
-
<td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">Your browser does not support the video tag.</video></td>
|
10 |
-
</tr>
|
11 |
-
</table>
|
12 |
-
|
13 |
-
**Update 29 Nov 2024**: We have added an experimental memory-efficient trainer for Mochi-1. Check it out [here](https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/mochi-1/)!
|
14 |
-
|
15 |
-
## Quickstart
|
16 |
-
|
17 |
-
Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install diffusers from source by `pip install git+https://github.com/huggingface/diffusers`.
|
18 |
-
|
19 |
-
Then download a dataset:
|
20 |
-
|
21 |
-
```bash
|
22 |
-
# install `huggingface_hub`
|
23 |
-
huggingface-cli download \
|
24 |
-
--repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset \
|
25 |
-
--local-dir video-dataset-disney
|
26 |
-
```
|
27 |
-
|
28 |
-
Then launch LoRA fine-tuning for text-to-video (modify the different hyperparameters, dataset root, and other configuration options as per your choice):
|
29 |
-
|
30 |
-
```bash
|
31 |
-
# For LoRA finetuning of the text-to-video CogVideoX models
|
32 |
-
./train_text_to_video_lora.sh
|
33 |
-
|
34 |
-
# For full finetuning of the text-to-video CogVideoX models
|
35 |
-
./train_text_to_video_sft.sh
|
36 |
-
|
37 |
-
# For LoRA finetuning of the image-to-video CogVideoX models
|
38 |
-
./train_image_to_video_lora.sh
|
39 |
-
```
|
40 |
-
|
41 |
-
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
|
42 |
-
|
43 |
-
```diff
|
44 |
-
import torch
|
45 |
-
from diffusers import CogVideoXPipeline
|
46 |
-
from diffusers.utils import export_to_video
|
47 |
-
|
48 |
-
pipe = CogVideoXPipeline.from_pretrained(
|
49 |
-
"THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
|
50 |
-
).to("cuda")
|
51 |
-
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogvideox-lora")
|
52 |
-
+ pipe.set_adapters(["cogvideox-lora"], [1.0])
|
53 |
-
|
54 |
-
video = pipe("<my-awesome-prompt>").frames[0]
|
55 |
-
export_to_video(video, "output.mp4", fps=8)
|
56 |
-
```
|
57 |
-
|
58 |
-
For Image-to-Video LoRAs trained with multiresolution videos, one must also add the following lines (see [this](https://github.com/a-r-r-o-w/cogvideox-factory/issues/26) Issue for more details):
|
59 |
-
|
60 |
-
```python
|
61 |
-
from diffusers import CogVideoXImageToVideoPipeline
|
62 |
-
|
63 |
-
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
64 |
-
"THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16
|
65 |
-
).to("cuda")
|
66 |
-
|
67 |
-
# ...
|
68 |
-
|
69 |
-
del pipe.transformer.patch_embed.pos_embedding
|
70 |
-
pipe.transformer.patch_embed.use_learned_positional_embeddings = False
|
71 |
-
pipe.transformer.config.use_learned_positional_embeddings = False
|
72 |
-
```
|
73 |
-
|
74 |
-
You can also check if your LoRA is correctly mounted [here](tests/test_lora_inference.py).
|
75 |
-
|
76 |
-
Below we provide additional sections detailing on more options explored in this repository. They all attempt to make fine-tuning for video models as accessible as possible by reducing memory requirements as much as possible.
|
77 |
-
|
78 |
-
## Prepare Dataset and Training
|
79 |
-
|
80 |
-
Before starting the training, please check whether the dataset has been prepared according to the [dataset specifications](assets/dataset.md). We provide training scripts suitable for text-to-video and image-to-video generation, compatible with the [CogVideoX model family](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). Training can be started using the `train*.sh` scripts, depending on the task you want to train. Let's take LoRA fine-tuning for text-to-video as an example.
|
81 |
-
|
82 |
-
- Configure environment variables as per your choice:
|
83 |
-
|
84 |
-
```bash
|
85 |
-
export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
|
86 |
-
export TORCHDYNAMO_VERBOSE=1
|
87 |
-
export WANDB_MODE="offline"
|
88 |
-
export NCCL_P2P_DISABLE=1
|
89 |
-
export TORCH_NCCL_ENABLE_MONITORING=0
|
90 |
-
```
|
91 |
-
|
92 |
-
- Configure which GPUs to use for training: `GPU_IDS="0,1"`
|
93 |
-
|
94 |
-
- Choose hyperparameters for training. Let's try to do a sweep on learning rate and optimizer type as an example:
|
95 |
-
|
96 |
-
```bash
|
97 |
-
LEARNING_RATES=("1e-4" "1e-3")
|
98 |
-
LR_SCHEDULES=("cosine_with_restarts")
|
99 |
-
OPTIMIZERS=("adamw" "adam")
|
100 |
-
MAX_TRAIN_STEPS=("3000")
|
101 |
-
```
|
102 |
-
|
103 |
-
- Select which Accelerate configuration you would like to train with: `ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`. We provide some default configurations in the `accelerate_configs/` directory - single GPU uncompiled/compiled, 2x GPU DDP, DeepSpeed, etc. You can create your own config files with custom settings using `accelerate config --config_file my_config.yaml`.
|
104 |
-
|
105 |
-
- Specify the absolute paths and columns/files for captions and videos.
|
106 |
-
|
107 |
-
```bash
|
108 |
-
DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
|
109 |
-
CAPTION_COLUMN="prompt.txt"
|
110 |
-
VIDEO_COLUMN="videos.txt"
|
111 |
-
```
|
112 |
-
|
113 |
-
- Launch experiments sweeping different hyperparameters:
|
114 |
-
```
|
115 |
-
for learning_rate in "${LEARNING_RATES[@]}"; do
|
116 |
-
for lr_schedule in "${LR_SCHEDULES[@]}"; do
|
117 |
-
for optimizer in "${OPTIMIZERS[@]}"; do
|
118 |
-
for steps in "${MAX_TRAIN_STEPS[@]}"; do
|
119 |
-
output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
|
120 |
-
|
121 |
-
cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_lora.py \
|
122 |
-
--pretrained_model_name_or_path THUDM/CogVideoX-5b \
|
123 |
-
--data_root $DATA_ROOT \
|
124 |
-
--caption_column $CAPTION_COLUMN \
|
125 |
-
--video_column $VIDEO_COLUMN \
|
126 |
-
--id_token BW_STYLE \
|
127 |
-
--height_buckets 480 \
|
128 |
-
--width_buckets 720 \
|
129 |
-
--frame_buckets 49 \
|
130 |
-
--dataloader_num_workers 8 \
|
131 |
-
--pin_memory \
|
132 |
-
--validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
|
133 |
-
--validation_prompt_separator ::: \
|
134 |
-
--num_validation_videos 1 \
|
135 |
-
--validation_epochs 10 \
|
136 |
-
--seed 42 \
|
137 |
-
--rank 128 \
|
138 |
-
--lora_alpha 128 \
|
139 |
-
--mixed_precision bf16 \
|
140 |
-
--output_dir $output_dir \
|
141 |
-
--max_num_frames 49 \
|
142 |
-
--train_batch_size 1 \
|
143 |
-
--max_train_steps $steps \
|
144 |
-
--checkpointing_steps 1000 \
|
145 |
-
--gradient_accumulation_steps 1 \
|
146 |
-
--gradient_checkpointing \
|
147 |
-
--learning_rate $learning_rate \
|
148 |
-
--lr_scheduler $lr_schedule \
|
149 |
-
--lr_warmup_steps 400 \
|
150 |
-
--lr_num_cycles 1 \
|
151 |
-
--enable_slicing \
|
152 |
-
--enable_tiling \
|
153 |
-
--optimizer $optimizer \
|
154 |
-
--beta1 0.9 \
|
155 |
-
--beta2 0.95 \
|
156 |
-
--weight_decay 0.001 \
|
157 |
-
--max_grad_norm 1.0 \
|
158 |
-
--allow_tf32 \
|
159 |
-
--report_to wandb \
|
160 |
-
--nccl_timeout 1800"
|
161 |
-
|
162 |
-
echo "Running command: $cmd"
|
163 |
-
eval $cmd
|
164 |
-
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
165 |
-
done
|
166 |
-
done
|
167 |
-
done
|
168 |
-
done
|
169 |
-
```
|
170 |
-
|
171 |
-
To understand what the different parameters mean, you could either take a look at the [args](./training/args.py) file or run the training script with `--help`.
|
172 |
-
|
173 |
-
Note: Training scripts are untested on MPS, so performance and memory requirements can differ widely compared to the CUDA reports below.
|
174 |
-
|
175 |
-
## Memory requirements
|
176 |
-
|
177 |
-
<table align="center">
|
178 |
-
<tr>
|
179 |
-
<td align="center" colspan="2"><b>CogVideoX LoRA Finetuning</b></td>
|
180 |
-
</tr>
|
181 |
-
<tr>
|
182 |
-
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
|
183 |
-
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
|
184 |
-
</tr>
|
185 |
-
<tr>
|
186 |
-
<td align="center"><img src="../assets/lora_2b.png" /></td>
|
187 |
-
<td align="center"><img src="../assets/lora_5b.png" /></td>
|
188 |
-
</tr>
|
189 |
-
|
190 |
-
<tr>
|
191 |
-
<td align="center" colspan="2"><b>CogVideoX Full Finetuning</b></td>
|
192 |
-
</tr>
|
193 |
-
<tr>
|
194 |
-
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
|
195 |
-
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
|
196 |
-
</tr>
|
197 |
-
<tr>
|
198 |
-
<td align="center"><img src="../assets/sft_2b.png" /></td>
|
199 |
-
<td align="center"><img src="../assets/sft_5b.png" /></td>
|
200 |
-
</tr>
|
201 |
-
</table>
|
202 |
-
|
203 |
-
Supported and verified memory optimizations for training include:
|
204 |
-
|
205 |
-
- `CPUOffloadOptimizer` from [`torchao`](https://github.com/pytorch/ao). You can read about its capabilities and limitations [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload). In short, it allows you to use the CPU for storing trainable parameters and gradients. This results in the optimizer step happening on the CPU, which requires a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` or applying `torch.compile` on the optimizer step. Additionally, it is recommended not to `torch.compile` your model for training. Gradient clipping and accumulation is not supported yet either.
|
206 |
-
- Low-bit optimizers from [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers). TODO: to test and make [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) ones work
|
207 |
-
- DeepSpeed Zero2: Since we rely on `accelerate`, follow [this guide](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) to configure your `accelerate` installation to enable training with DeepSpeed Zero2 optimizations.
|
208 |
-
|
209 |
-
> [!IMPORTANT]
|
210 |
-
> The memory requirements are reported after running the `training/prepare_dataset.py`, which converts the videos and captions to latents and embeddings. During training, we directly load the latents and embeddings, and do not require the VAE or the T5 text encoder. However, if you perform validation/testing, these must be loaded and increase the amount of required memory. Not performing validation/testing saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
|
211 |
-
>
|
212 |
-
> If you choose to run validation/testing, you can save some memory on lower VRAM GPUs by specifying `--enable_model_cpu_offload`.
|
213 |
-
|
214 |
-
### LoRA finetuning
|
215 |
-
|
216 |
-
> [!NOTE]
|
217 |
-
> The memory requirements for image-to-video lora finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly.
|
218 |
-
>
|
219 |
-
> Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using:
|
220 |
-
> `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
|
221 |
-
> or provide a URL to a valid and accessible image.
|
222 |
-
|
223 |
-
<details>
|
224 |
-
<summary> AdamW </summary>
|
225 |
-
|
226 |
-
**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
|
227 |
-
|
228 |
-
With `train_batch_size = 1`:
|
229 |
-
|
230 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
231 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
232 |
-
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 |
|
233 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 |
|
234 |
-
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 |
|
235 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 |
|
236 |
-
| THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 |
|
237 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 |
|
238 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 |
|
239 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 |
|
240 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 |
|
241 |
-
|
242 |
-
With `train_batch_size = 4`:
|
243 |
-
|
244 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
245 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
246 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 |
|
247 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 |
|
248 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 |
|
249 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 |
|
250 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 |
|
251 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 |
|
252 |
-
|
253 |
-
</details>
|
254 |
-
|
255 |
-
<details>
|
256 |
-
<summary> AdamW (8-bit bitsandbytes) </summary>
|
257 |
-
|
258 |
-
**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
|
259 |
-
|
260 |
-
With `train_batch_size = 1`:
|
261 |
-
|
262 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
263 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
264 |
-
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 |
|
265 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 |
|
266 |
-
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 |
|
267 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 |
|
268 |
-
| THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 |
|
269 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 |
|
270 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 |
|
271 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 |
|
272 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 |
|
273 |
-
|
274 |
-
With `train_batch_size = 4`:
|
275 |
-
|
276 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
277 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
278 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 |
|
279 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 |
|
280 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 |
|
281 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 |
|
282 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 |
|
283 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 |
|
284 |
-
|
285 |
-
</details>
|
286 |
-
|
287 |
-
<details>
|
288 |
-
<summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
|
289 |
-
|
290 |
-
**Note:** Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
|
291 |
-
|
292 |
-
With `train_batch_size = 1`:
|
293 |
-
|
294 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
295 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
296 |
-
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 |
|
297 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 |
|
298 |
-
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 |
|
299 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 |
|
300 |
-
| THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 |
|
301 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 |
|
302 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 |
|
303 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 |
|
304 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 |
|
305 |
-
|
306 |
-
With `train_batch_size = 4`:
|
307 |
-
|
308 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
309 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
310 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 |
|
311 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 |
|
312 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 |
|
313 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 |
|
314 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 |
|
315 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 |
|
316 |
-
|
317 |
-
</details>
|
318 |
-
|
319 |
-
<details>
|
320 |
-
<summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
|
321 |
-
|
322 |
-
**Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100.
|
323 |
-
|
324 |
-
With `train_batch_size = 1`:
|
325 |
-
|
326 |
-
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
327 |
-
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
328 |
-
| THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 |
|
329 |
-
| THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 |
|
330 |
-
|
331 |
-
With `train_batch_size = 4`:
|
332 |
-
|
333 |
-
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
334 |
-
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
335 |
-
| THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 |
|
336 |
-
| THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 |
|
337 |
-
|
338 |
-
</details>
|
339 |
-
|
340 |
-
### Full finetuning
|
341 |
-
|
342 |
-
> [!NOTE]
|
343 |
-
> The memory requirements for image-to-video full finetuning are similar to that of text-to-video on `THUDM/CogVideoX-5b`, so it hasn't been reported explicitly.
|
344 |
-
>
|
345 |
-
> Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using:
|
346 |
-
> `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
|
347 |
-
> or provide a URL to a valid and accessible image.
|
348 |
-
|
349 |
-
> [!NOTE]
|
350 |
-
> Trying to run full finetuning without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.
|
351 |
-
|
352 |
-
<details>
|
353 |
-
<summary> AdamW </summary>
|
354 |
-
|
355 |
-
With `train_batch_size = 1`:
|
356 |
-
|
357 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
358 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
359 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 |
|
360 |
-
| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
|
361 |
-
|
362 |
-
With `train_batch_size = 4`:
|
363 |
-
|
364 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
365 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
366 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 |
|
367 |
-
| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
|
368 |
-
|
369 |
-
</details>
|
370 |
-
|
371 |
-
<details>
|
372 |
-
<summary> AdamW (8-bit bitsandbytes) </summary>
|
373 |
-
|
374 |
-
With `train_batch_size = 1`:
|
375 |
-
|
376 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
377 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
378 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 |
|
379 |
-
| THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 |
|
380 |
-
|
381 |
-
With `train_batch_size = 4`:
|
382 |
-
|
383 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
384 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
385 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 |
|
386 |
-
| THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 |
|
387 |
-
|
388 |
-
</details>
|
389 |
-
|
390 |
-
<details>
|
391 |
-
<summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
|
392 |
-
|
393 |
-
With `train_batch_size = 1`:
|
394 |
-
|
395 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
396 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
397 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 |
|
398 |
-
| THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 |
|
399 |
-
|
400 |
-
With `train_batch_size = 4`:
|
401 |
-
|
402 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
403 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
404 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 |
|
405 |
-
| THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 |
|
406 |
-
|
407 |
-
</details>
|
408 |
-
|
409 |
-
<details>
|
410 |
-
<summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
|
411 |
-
|
412 |
-
**Note:** Results are reported with `gradient_checkpointing` enabled, running on a 2x A100.
|
413 |
-
|
414 |
-
With `train_batch_size = 1`:
|
415 |
-
|
416 |
-
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
417 |
-
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
418 |
-
| THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 |
|
419 |
-
| THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 |
|
420 |
-
|
421 |
-
With `train_batch_size = 4`:
|
422 |
-
|
423 |
-
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
424 |
-
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
425 |
-
| THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 |
|
426 |
-
| THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 |
|
427 |
-
|
428 |
-
</details>
|
429 |
-
|
430 |
-
> [!NOTE]
|
431 |
-
> - `memory_after_validation` is indicative of the peak memory required for training. This is because apart from the activations, parameters and gradients stored for training, you also need to load the vae and text encoder in memory and spend some memory to perform inference. In order to reduce total memory required to perform training, one can choose not to perform validation/testing as part of the training script.
|
432 |
-
>
|
433 |
-
> - `memory_before_validation` is the true indicator of the peak memory required for training if you choose to not perform validation/testing.
|
434 |
-
|
435 |
-
<table align="center">
|
436 |
-
<tr>
|
437 |
-
<td align="center"><a href="https://www.youtube.com/watch?v=UvRl4ansfCg"> Slaying OOMs with PyTorch</a></td>
|
438 |
-
</tr>
|
439 |
-
<tr>
|
440 |
-
<td align="center"><img src="assets/slaying-ooms.png" style="width: 480px; height: 480px;"></td>
|
441 |
-
</tr>
|
442 |
-
</table>
|
443 |
-
|
444 |
-
## TODOs
|
445 |
-
|
446 |
-
- [x] Make scripts compatible with DDP
|
447 |
-
- [ ] Make scripts compatible with FSDP
|
448 |
-
- [x] Make scripts compatible with DeepSpeed
|
449 |
-
- [ ] vLLM-powered captioning script
|
450 |
-
- [x] Multi-resolution/frame support in `prepare_dataset.py`
|
451 |
-
- [ ] Analyzing traces for potential speedups and removing as many syncs as possible
|
452 |
-
- [x] Test scripts with memory-efficient optimizer from bitsandbytes
|
453 |
-
- [x] Test scripts with CPUOffloadOptimizer, etc.
|
454 |
-
- [ ] Test scripts with torchao quantization, and low bit memory optimizers (Currently errors with AdamW (8/4-bit torchao))
|
455 |
-
- [ ] Test scripts with AdamW (8-bit bitsandbytes) + CPUOffloadOptimizer (with gradient offloading) (Currently errors out)
|
456 |
-
- [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (work with the authors to support backward pass, and optimize for A100)
|
457 |
-
|
458 |
-
> [!IMPORTANT]
|
459 |
-
> Since our goal is to make the scripts as memory-friendly as possible we don't guarantee multi-GPU training.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/README_zh.md
DELETED
@@ -1,455 +0,0 @@
|
|
1 |
-
# CogVideoX Factory 🧪
|
2 |
-
|
3 |
-
[Read in English](./README.md)
|
4 |
-
|
5 |
-
在 24GB GPU 内存下对 Cog 系列视频模型进行微调以实现自定义视频生成,支持多分辨率 ⚡️📼
|
6 |
-
|
7 |
-
<table align="center">
|
8 |
-
<tr>
|
9 |
-
<td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">您的浏览器不支持视频标签。</video></td>
|
10 |
-
</tr>
|
11 |
-
</table>
|
12 |
-
|
13 |
-
## 快速开始
|
14 |
-
|
15 |
-
克隆此仓库并确保安装了相关依赖:`pip install -r requirements.txt`。
|
16 |
-
|
17 |
-
接着下载数据集:
|
18 |
-
|
19 |
-
```
|
20 |
-
# 安装 `huggingface_hub`
|
21 |
-
huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir video-dataset-disney
|
22 |
-
```
|
23 |
-
|
24 |
-
然后启动 LoRA 微调进行文本到视频的生成(根据您的选择修改不同的超参数、数据集根目录以及其他配置选项):
|
25 |
-
|
26 |
-
```
|
27 |
-
# 对 CogVideoX 模型进行文本到视频的 LoRA 微调
|
28 |
-
./train_text_to_video_lora.sh
|
29 |
-
|
30 |
-
# 对 CogVideoX 模型进行文本到视频的完整微调
|
31 |
-
./train_text_to_video_sft.sh
|
32 |
-
|
33 |
-
# 对 CogVideoX 模型进行图像到视频的 LoRA 微调
|
34 |
-
./train_image_to_video_lora.sh
|
35 |
-
```
|
36 |
-
|
37 |
-
假设您的 LoRA 已保存并推送到 HF Hub,并命名为 `my-awesome-name/my-awesome-lora`,现在我们可以使用微调模型进行推理:
|
38 |
-
|
39 |
-
```
|
40 |
-
import torch
|
41 |
-
from diffusers import CogVideoXPipeline
|
42 |
-
from diffusers import export_to_video
|
43 |
-
|
44 |
-
pipe = CogVideoXPipeline.from_pretrained(
|
45 |
-
"THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
|
46 |
-
).to("cuda")
|
47 |
-
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name=["cogvideox-lora"])
|
48 |
-
+ pipe.set_adapters(["cogvideox-lora"], [1.0])
|
49 |
-
|
50 |
-
video = pipe("<my-awesome-prompt>").frames[0]
|
51 |
-
export_to_video(video, "output.mp4", fps=8)
|
52 |
-
```
|
53 |
-
|
54 |
-
你也可以在[这里](tests/test_lora_inference.py)来检查你的Lora是否正常挂载。
|
55 |
-
|
56 |
-
**注意:** 对于图像到视频的微调,您必须从 [这个分支](https://github.com/huggingface/diffusers/pull/9482) 安装
|
57 |
-
diffusers(该分支为 CogVideoX 的图像到视频添加了 LoRA 加载支持)直到它被合并。
|
58 |
-
|
59 |
-
以下我们提供了更多探索此仓库选项的额外部分。所有这些都旨在尽可能降低内存需求,使视频模型的微调变得更易于访问。
|
60 |
-
|
61 |
-
## 训练
|
62 |
-
|
63 |
-
在开始训练之前,请你检查是否按照[数据集规范](assets/dataset_zh.md)准备好了数据集。 我们提供了适用于文本到视频 (text-to-video) 和图像到视频 (image-to-video) 生成的训练脚本,兼容 [CogVideoX 模型家族](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce)。训练可以通过 `train*.sh` 脚本启动,具体取决于你想要训练的任务。让我们以文本到视频的 LoRA 微调为例。
|
64 |
-
|
65 |
-
- 根据你的需求配置环境变量:
|
66 |
-
|
67 |
-
```
|
68 |
-
export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
|
69 |
-
export TORCHDYNAMO_VERBOSE=1
|
70 |
-
export WANDB_MODE="offline"
|
71 |
-
export NCCL_P2P_DISABLE=1
|
72 |
-
export TORCH_NCCL_ENABLE_MONITORING=0
|
73 |
-
```
|
74 |
-
|
75 |
-
- 配置用于训练的 GPU:`GPU_IDS="0,1"`
|
76 |
-
|
77 |
-
- 选择训练的超参数。让我们以学习率和优化器类型的超参数遍历为例:
|
78 |
-
|
79 |
-
```
|
80 |
-
LEARNING_RATES=("1e-4" "1e-3")
|
81 |
-
LR_SCHEDULES=("cosine_with_restarts")
|
82 |
-
OPTIMIZERS=("adamw" "adam")
|
83 |
-
MAX_TRAIN_STEPS=("3000")
|
84 |
-
```
|
85 |
-
|
86 |
-
- 选择用于训练的 Accelerate 配置文件:`ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`
|
87 |
-
。我们在 `accelerate_configs/` 目录中提供了一些默认配置 - 单 GPU 编译/未编译、2x GPU DDP、DeepSpeed
|
88 |
-
等。你也可以使用 `accelerate config --config_file my_config.yaml` 自定义配置文件。
|
89 |
-
|
90 |
-
- 指定字幕和视频的绝对路径以及列/文件。
|
91 |
-
|
92 |
-
```
|
93 |
-
DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
|
94 |
-
CAPTION_COLUMN="prompt.txt"
|
95 |
-
VIDEO_COLUMN="videos.txt"
|
96 |
-
```
|
97 |
-
|
98 |
-
- 运行实验,遍历不同的超参数:
|
99 |
-
```
|
100 |
-
for learning_rate in "${LEARNING_RATES[@]}"; do
|
101 |
-
for lr_schedule in "${LR_SCHEDULES[@]}"; do
|
102 |
-
for optimizer in "${OPTIMIZERS[@]}"; do
|
103 |
-
for steps in "${MAX_TRAIN_STEPS[@]}"; do
|
104 |
-
output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
|
105 |
-
|
106 |
-
cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox_text_to_video_lora.py \
|
107 |
-
--pretrained_model_name_or_path THUDM/CogVideoX-5b \
|
108 |
-
--data_root $DATA_ROOT \
|
109 |
-
--caption_column $CAPTION_COLUMN \
|
110 |
-
--video_column $VIDEO_COLUMN \
|
111 |
-
--id_token BW_STYLE \
|
112 |
-
--height_buckets 480 \
|
113 |
-
--width_buckets 720 \
|
114 |
-
--frame_buckets 49 \
|
115 |
-
--dataloader_num_workers 8 \
|
116 |
-
--pin_memory \
|
117 |
-
--validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
|
118 |
-
--validation_prompt_separator ::: \
|
119 |
-
--num_validation_videos 1 \
|
120 |
-
--validation_epochs 10 \
|
121 |
-
--seed 42 \
|
122 |
-
--rank 128 \
|
123 |
-
--lora_alpha 128 \
|
124 |
-
--mixed_precision bf16 \
|
125 |
-
--output_dir $output_dir \
|
126 |
-
--max_num_frames 49 \
|
127 |
-
--train_batch_size 1 \
|
128 |
-
--max_train_steps $steps \
|
129 |
-
--checkpointing_steps 1000 \
|
130 |
-
--gradient_accumulation_steps 1 \
|
131 |
-
--gradient_checkpointing \
|
132 |
-
--learning_rate $learning_rate \
|
133 |
-
--lr_scheduler $lr_schedule \
|
134 |
-
--lr_warmup_steps 400 \
|
135 |
-
--lr_num_cycles 1 \
|
136 |
-
--enable_slicing \
|
137 |
-
--enable_tiling \
|
138 |
-
--optimizer $optimizer \
|
139 |
-
--beta1 0.9 \
|
140 |
-
--beta2 0.95 \
|
141 |
-
--weight_decay 0.001 \
|
142 |
-
--max_grad_norm 1.0 \
|
143 |
-
--allow_tf32 \
|
144 |
-
--report_to wandb \
|
145 |
-
--nccl_timeout 1800"
|
146 |
-
|
147 |
-
echo "Running command: $cmd"
|
148 |
-
eval $cmd
|
149 |
-
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
150 |
-
done
|
151 |
-
done
|
152 |
-
done
|
153 |
-
done
|
154 |
-
```
|
155 |
-
|
156 |
-
要了解不同参数的含义,你可以查看 [args](./training/args.py) 文件,或者使用 `--help` 运行训练脚本。
|
157 |
-
|
158 |
-
注意:训练脚本尚未在 MPS 上测试,因此性能和内存要求可能与下面的 CUDA 报告差异很大。
|
159 |
-
|
160 |
-
## 内存需求
|
161 |
-
|
162 |
-
<table align="center">
|
163 |
-
<tr>
|
164 |
-
<td align="center" colspan="2"><b>CogVideoX LoRA 微调</b></td>
|
165 |
-
</tr>
|
166 |
-
<tr>
|
167 |
-
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
|
168 |
-
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
|
169 |
-
</tr>
|
170 |
-
<tr>
|
171 |
-
<td align="center"><img src="assets/lora_2b.png" /></td>
|
172 |
-
<td align="center"><img src="assets/lora_5b.png" /></td>
|
173 |
-
</tr>
|
174 |
-
|
175 |
-
<tr>
|
176 |
-
<td align="center" colspan="2"><b>CogVideoX 全量微调</b></td>
|
177 |
-
</tr>
|
178 |
-
<tr>
|
179 |
-
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-2b">THUDM/CogVideoX-2b</a></td>
|
180 |
-
<td align="center"><a href="https://huggingface.co/THUDM/CogVideoX-5b">THUDM/CogVideoX-5b</a></td>
|
181 |
-
</tr>
|
182 |
-
<tr>
|
183 |
-
<td align="center"><img src="assets/sft_2b.png" /></td>
|
184 |
-
<td align="center"><img src="assets/sft_5b.png" /></td>
|
185 |
-
</tr>
|
186 |
-
</table>
|
187 |
-
|
188 |
-
支持和验证的训练内存优化包括:
|
189 |
-
|
190 |
-
- `CPUOffloadOptimizer` 来自 [`torchao`](https://github.com/pytorch/ao)
|
191 |
-
。你可以在[这里](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload)
|
192 |
-
阅读它的能力和局限性。简而言之,它允许你将可训练参数和梯度存储在 CPU 中,从而在 CPU 上进行优化步骤。这需要快速的 CPU
|
193 |
-
优化器,如 `torch.optim.AdamW(fused=True)`,或者在优化步骤中应用 `torch.compile`
|
194 |
-
。此外,建议不要在训练时对模型应用 `torch.compile`。梯度裁剪和累积目前还不支持。
|
195 |
-
- 来自 [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/optimizers)
|
196 |
-
的低位优化器。TODO:测试并使 [`torchao`](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim) 能正常工作。
|
197 |
-
- DeepSpeed Zero2:由于我们依赖 `accelerate`
|
198 |
-
,请按照[此指南](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) 配置 `accelerate` 以启用 DeepSpeed
|
199 |
-
Zero2 优化训练。
|
200 |
-
|
201 |
-
> [!重要提示]
|
202 |
-
> 内存需求是运行 `training/prepare_dataset.py`
|
203 |
-
>
|
204 |
-
后报告的,该脚本将视频和字幕转换为潜在向量和嵌入。在训练期间,我们直接加载这些潜在向量和嵌入,不需要VAE或T5文本编码器。然而,如果执行验证/测试,则必须加载这些模块,并且会增加所需内存的数量。不进行验证/测试可以节省大量内存,这些内存可以用于较小显存的GPU上专注于训练。
|
205 |
-
>
|
206 |
-
> 如果选择运行验证/测试,可以通过指定 `--enable_model_cpu_offload` 来为较低显存的GPU节省一些内存。
|
207 |
-
|
208 |
-
### LoRA微调
|
209 |
-
|
210 |
-
> [!重要提示]
|
211 |
-
> 图像到视频的LoRA微调的内存需求与文本到视频上的 `THUDM/CogVideoX-5b` 类似,因此没有明确报告。
|
212 |
-
>
|
213 |
-
> 此外,为了准备I2V微调的测试图像,可以通过修改脚本实时生成它们,或使用以下命令从训练数据中提取一些帧:
|
214 |
-
> `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
|
215 |
-
> 或提供一个有效且可访问的图像URL。
|
216 |
-
|
217 |
-
<details>
|
218 |
-
<summary> AdamW </summary>
|
219 |
-
|
220 |
-
**注意:** 尝试在没有梯度检查点的情况下运行 CogVideoX-5b 即使在 A100(80 GB)上也会导致 OOM(内存不足)错误,因此内存需求尚未列出。
|
221 |
-
|
222 |
-
当 `train_batch_size = 1` 时:
|
223 |
-
|
224 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
225 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
226 |
-
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.764 | 46.918 | 24.234 |
|
227 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.121 | 24.234 |
|
228 |
-
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.314 | 47.469 | 24.469 |
|
229 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.036 | 13.035 | 21.564 | 24.500 |
|
230 |
-
| THUDM/CogVideoX-2b | 256 | False | 13.095 | 45.826 | 48.990 | 25.543 |
|
231 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 13.095 | 22.344 | 25.537 |
|
232 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.746 | 38.123 |
|
233 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 30.338 | 38.738 |
|
234 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 22.119 | 31.939 | 41.537 |
|
235 |
-
|
236 |
-
当 `train_batch_size = 4` 时:
|
237 |
-
|
238 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
239 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
240 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.803 | 21.814 | 24.322 |
|
241 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 22.254 | 22.254 | 24.572 |
|
242 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.033 | 25.574 |
|
243 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.492 | 46.492 | 38.197 |
|
244 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 47.805 | 47.805 | 39.365 |
|
245 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 41.008 |
|
246 |
-
|
247 |
-
</details>
|
248 |
-
|
249 |
-
<details>
|
250 |
-
<summary> AdamW (8-bit bitsandbytes) </summary>
|
251 |
-
|
252 |
-
**注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。
|
253 |
-
|
254 |
-
当 `train_batch_size = 1` 时:
|
255 |
-
|
256 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
257 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
258 |
-
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.732 | 46.887 | 24.195 |
|
259 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.430 | 24.195 |
|
260 |
-
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 44.004 | 47.158 | 24.369 |
|
261 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 21.297 | 24.357 |
|
262 |
-
| THUDM/CogVideoX-2b | 256 | False | 13.035 | 45.291 | 48.455 | 24.836 |
|
263 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.035 | 13.035 | 21.625 | 24.869 |
|
264 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.602 | 38.049 |
|
265 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.818 | 29.359 | 38.520 |
|
266 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 21.352 | 30.727 | 39.596 |
|
267 |
-
|
268 |
-
当 `train_batch_size = 4` 时:
|
269 |
-
|
270 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
271 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
272 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.734 | 21.775 | 24.281 |
|
273 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.036 | 21.941 | 21.941 | 24.445 |
|
274 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.266 | 24.943 |
|
275 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.320 | 46.326 | 38.104 |
|
276 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.820 | 46.820 | 38.588 |
|
277 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.920 | 47.980 | 40.002 |
|
278 |
-
|
279 |
-
</details>
|
280 |
-
|
281 |
-
<details>
|
282 |
-
<summary> AdamW + CPUOffloadOptimizer (with gradient offloading) </summary>
|
283 |
-
|
284 |
-
**注意:** 在没有启用梯度检查点的情况下,尝试运行 CogVideoX-5b 模型即使在 A100(80 GB)上也会导致 OOM(内存不足),因此未列出内存测量数据。
|
285 |
-
|
286 |
-
当 `train_batch_size = 1` 时:
|
287 |
-
|
288 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
289 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
290 |
-
| THUDM/CogVideoX-2b | 16 | False | 12.945 | 43.705 | 46.859 | 24.180 |
|
291 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 12.945 | 21.395 | 24.180 |
|
292 |
-
| THUDM/CogVideoX-2b | 64 | False | 13.035 | 43.916 | 47.070 | 24.234 |
|
293 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 13.035 | 20.887 | 24.266 |
|
294 |
-
| THUDM/CogVideoX-2b | 256 | False | 13.095 | 44.947 | 48.111 | 24.607 |
|
295 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.095 | 13.095 | 21.391 | 24.635 |
|
296 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 19.742 | 28.533 | 38.002 |
|
297 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 20.006 | 29.107 | 38.785 |
|
298 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 20.771 | 30.078 | 39.559 |
|
299 |
-
|
300 |
-
当 `train_batch_size = 4` 时:
|
301 |
-
|
302 |
-
| model | lora rank | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
303 |
-
|:------------------:|:---------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
304 |
-
| THUDM/CogVideoX-2b | 16 | True | 12.945 | 21.709 | 21.762 | 24.254 |
|
305 |
-
| THUDM/CogVideoX-2b | 64 | True | 13.035 | 21.844 | 21.855 | 24.338 |
|
306 |
-
| THUDM/CogVideoX-2b | 256 | True | 13.094 | 22.020 | 22.031 | 24.709 |
|
307 |
-
| THUDM/CogVideoX-5b | 16 | True | 19.742 | 46.262 | 46.297 | 38.400 |
|
308 |
-
| THUDM/CogVideoX-5b | 64 | True | 20.006 | 46.561 | 46.574 | 38.840 |
|
309 |
-
| THUDM/CogVideoX-5b | 256 | True | 20.771 | 47.268 | 47.332 | 39.623 |
|
310 |
-
|
311 |
-
</details>
|
312 |
-
|
313 |
-
<details>
|
314 |
-
<summary> DeepSpeed (AdamW + CPU/Parameter offloading) </summary>
|
315 |
-
|
316 |
-
**注意:** 结果是在启用梯度检查点的情况下,使用 2x A100 运行时记录的。
|
317 |
-
|
318 |
-
当 `train_batch_size = 1` 时:
|
319 |
-
|
320 |
-
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
321 |
-
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
322 |
-
| THUDM/CogVideoX-2b | 13.141 | 13.141 | 21.070 | 24.602 |
|
323 |
-
| THUDM/CogVideoX-5b | 20.170 | 20.170 | 28.662 | 38.957 |
|
324 |
-
|
325 |
-
当 `train_batch_size = 4` 时:
|
326 |
-
|
327 |
-
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
328 |
-
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
329 |
-
| THUDM/CogVideoX-2b | 13.141 | 19.854 | 20.836 | 24.709 |
|
330 |
-
| THUDM/CogVideoX-5b | 20.170 | 40.635 | 40.699 | 39.027 |
|
331 |
-
|
332 |
-
</details>
|
333 |
-
|
334 |
-
### Full finetuning
|
335 |
-
|
336 |
-
> [!注意]
|
337 |
-
> 图像到视频的完整微调内存需求与 `THUDM/CogVideoX-5b` 的文本到视频微调相似,因此没有单独列出。
|
338 |
-
>
|
339 |
-
> 此外,要准备用于 I2V 微调的测试图像,你可以通过修改脚本实时生成图像,或者从你的训练数据中提取一些帧:
|
340 |
-
> `ffmpeg -i input.mp4 -frames:v 1 frame.png`,
|
341 |
-
> 或提供一个有效且可访问的图像 URL。
|
342 |
-
|
343 |
-
> [!注意]
|
344 |
-
> 在没有使用梯度检查点的情况下运行完整微调,即使是在 A100(80GB)上,也会出现 OOM(内存不足)错误,因此未列出内存需求。
|
345 |
-
|
346 |
-
<details>
|
347 |
-
<summary> AdamW </summary>
|
348 |
-
|
349 |
-
当 `train_batch_size = 1` 时:
|
350 |
-
|
351 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
352 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
353 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 33.934 | 43.848 | 37.520 |
|
354 |
-
| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
|
355 |
-
|
356 |
-
当 `train_batch_size = 4` 时:
|
357 |
-
|
358 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
359 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
360 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 38.281 | 48.341 | 37.544 |
|
361 |
-
| THUDM/CogVideoX-5b | True | 30.061 | OOM | OOM | OOM |
|
362 |
-
|
363 |
-
</details>
|
364 |
-
|
365 |
-
<details>
|
366 |
-
<summary> AdamW (8-bit 量化) </summary>
|
367 |
-
|
368 |
-
当 `train_batch_size = 1` 时:
|
369 |
-
|
370 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
371 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
372 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 16.447 | 27.555 | 27.156 |
|
373 |
-
| THUDM/CogVideoX-5b | True | 30.061 | 52.826 | 58.570 | 49.541 |
|
374 |
-
|
375 |
-
当 `train_batch_size = 4` 时:
|
376 |
-
|
377 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
378 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
379 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 27.930 | 27.990 | 27.326 |
|
380 |
-
| THUDM/CogVideoX-5b | True | 16.396 | 66.648 | 66.705 | 48.828 |
|
381 |
-
|
382 |
-
</details>
|
383 |
-
|
384 |
-
<details>
|
385 |
-
<summary> AdamW + CPUOffloadOptimizer(带有梯度卸载)</summary>
|
386 |
-
|
387 |
-
当 `train_batch_size = 1` 时:
|
388 |
-
|
389 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
390 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
391 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 16.396 | 26.100 | 23.832 |
|
392 |
-
| THUDM/CogVideoX-5b | True | 30.061 | 39.359 | 48.307 | 37.947 |
|
393 |
-
|
394 |
-
当 `train_batch_size = 4` 时:
|
395 |
-
|
396 |
-
| model | gradient_checkpointing | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
397 |
-
|:------------------:|:----------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
398 |
-
| THUDM/CogVideoX-2b | True | 16.396 | 27.916 | 27.975 | 23.936 |
|
399 |
-
| THUDM/CogVideoX-5b | True | 30.061 | 66.607 | 66.668 | 38.061 |
|
400 |
-
|
401 |
-
</details>
|
402 |
-
|
403 |
-
<details>
|
404 |
-
<summary> DeepSpeed(AdamW + CPU/参数卸载) </summary>
|
405 |
-
|
406 |
-
**注意:** 结果是在启用 `gradient_checkpointing`(梯度检查点)功能,并在 2 台 A100 显卡上运行时报告的。
|
407 |
-
|
408 |
-
当 `train_batch_size = 1` 时:
|
409 |
-
|
410 |
-
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
411 |
-
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
412 |
-
| THUDM/CogVideoX-2b | 13.111 | 13.111 | 20.328 | 23.867 |
|
413 |
-
| THUDM/CogVideoX-5b | 19.762 | 19.998 | 27.697 | 38.018 |
|
414 |
-
|
415 |
-
当 `train_batch_size = 4` 时:
|
416 |
-
|
417 |
-
| model | memory_before_training | memory_before_validation | memory_after_validation | memory_after_testing |
|
418 |
-
|:------------------:|:----------------------:|:------------------------:|:-----------------------:|:--------------------:|
|
419 |
-
| THUDM/CogVideoX-2b | 13.111 | 21.188 | 21.254 | 23.869 |
|
420 |
-
| THUDM/CogVideoX-5b | 19.762 | 43.465 | 43.531 | 38.082 |
|
421 |
-
|
422 |
-
</details>
|
423 |
-
|
424 |
-
> [!注意]
|
425 |
-
> - `memory_after_validation`(验证后内存) 表示训练所需的峰值内存。这是因为除了存储训练过程中需要的激活、参数和梯度之外,还需要加载
|
426 |
-
VAE 和文本编码器到内存中,并且执行推理操作也会消耗一定内存。为了减少训练所需的总内存,您可以选择在训练脚本中不执行验证/测试。
|
427 |
-
>
|
428 |
-
> - 如果选择不进行验证/测试,`memory_before_validation`(验证前内存) 才是训练所需内存的真实指示器。
|
429 |
-
|
430 |
-
<table align="center">
|
431 |
-
<tr>
|
432 |
-
<td align="center"><a href="https://www.youtube.com/watch?v=UvRl4ansfCg"> Slaying OOMs with PyTorch</a></td>
|
433 |
-
</tr>
|
434 |
-
<tr>
|
435 |
-
<td align="center"><img src="assets/slaying-ooms.png" style="width: 480px; height: 480px;"></td>
|
436 |
-
</tr>
|
437 |
-
</table>
|
438 |
-
|
439 |
-
## 待办事项
|
440 |
-
|
441 |
-
- [x] 使脚本兼容 DDP
|
442 |
-
- [ ] 使脚本兼容 FSDP
|
443 |
-
- [x] 使脚本兼容 DeepSpeed
|
444 |
-
- [ ] 基于 vLLM 的字幕脚本
|
445 |
-
- [x] 在 `prepare_dataset.py` 中支持多分辨率/帧数
|
446 |
-
- [ ] 分析性能瓶颈并尽可能减少同步操作
|
447 |
-
- [ ] 支持 QLoRA(优先),以及其他高使用率的 LoRA 方法
|
448 |
-
- [x] 使用 bitsandbytes 的节省内存优化器测试脚本
|
449 |
-
- [x] 使用 CPUOffloadOptimizer 等测试脚本
|
450 |
-
- [ ] 使用 torchao 量化和低位内存优化器测试脚本(目前在 AdamW(8/4-bit torchao)上报错)
|
451 |
-
- [ ] 使用 AdamW(8-bit bitsandbytes)+ CPUOffloadOptimizer(带有梯度卸载)的测试脚本(目前报错)
|
452 |
-
- [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (与作者合作支持反向传播,并针对 A100 进行优化)
|
453 |
-
|
454 |
-
> [!重要]
|
455 |
-
> 由于我们的目标是使脚本尽可能节省内存,因此我们不保证支持多 GPU 训练。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/cogvideox/__init__.py
DELETED
File without changes
|
training/cogvideox/args.py
DELETED
@@ -1,484 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
|
4 |
-
def _get_model_args(parser: argparse.ArgumentParser) -> None:
|
5 |
-
parser.add_argument(
|
6 |
-
"--pretrained_model_name_or_path",
|
7 |
-
type=str,
|
8 |
-
default=None,
|
9 |
-
required=True,
|
10 |
-
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
11 |
-
)
|
12 |
-
parser.add_argument(
|
13 |
-
"--revision",
|
14 |
-
type=str,
|
15 |
-
default=None,
|
16 |
-
required=False,
|
17 |
-
help="Revision of pretrained model identifier from huggingface.co/models.",
|
18 |
-
)
|
19 |
-
parser.add_argument(
|
20 |
-
"--variant",
|
21 |
-
type=str,
|
22 |
-
default=None,
|
23 |
-
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
24 |
-
)
|
25 |
-
parser.add_argument(
|
26 |
-
"--cache_dir",
|
27 |
-
type=str,
|
28 |
-
default=None,
|
29 |
-
help="The directory where the downloaded models and datasets will be stored.",
|
30 |
-
)
|
31 |
-
|
32 |
-
|
33 |
-
def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
|
34 |
-
parser.add_argument(
|
35 |
-
"--data_root",
|
36 |
-
type=str,
|
37 |
-
default=None,
|
38 |
-
help=("A folder containing the training data."),
|
39 |
-
)
|
40 |
-
parser.add_argument(
|
41 |
-
"--dataset_file",
|
42 |
-
type=str,
|
43 |
-
default=None,
|
44 |
-
help=("Path to a CSV file if loading prompts/video paths using this format."),
|
45 |
-
)
|
46 |
-
parser.add_argument(
|
47 |
-
"--video_column",
|
48 |
-
type=str,
|
49 |
-
default="video",
|
50 |
-
help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.",
|
51 |
-
)
|
52 |
-
parser.add_argument(
|
53 |
-
"--caption_column",
|
54 |
-
type=str,
|
55 |
-
default="text",
|
56 |
-
help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
|
57 |
-
)
|
58 |
-
parser.add_argument(
|
59 |
-
"--id_token",
|
60 |
-
type=str,
|
61 |
-
default=None,
|
62 |
-
help="Identifier token appended to the start of each prompt if provided.",
|
63 |
-
)
|
64 |
-
parser.add_argument(
|
65 |
-
"--height_buckets",
|
66 |
-
nargs="+",
|
67 |
-
type=int,
|
68 |
-
default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
|
69 |
-
)
|
70 |
-
parser.add_argument(
|
71 |
-
"--width_buckets",
|
72 |
-
nargs="+",
|
73 |
-
type=int,
|
74 |
-
default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
|
75 |
-
)
|
76 |
-
parser.add_argument(
|
77 |
-
"--frame_buckets",
|
78 |
-
nargs="+",
|
79 |
-
type=int,
|
80 |
-
default=[49],
|
81 |
-
help="CogVideoX1.5 need to guarantee that ((num_frames - 1) // self.vae_scale_factor_temporal + 1) % patch_size_t == 0, such as 53"
|
82 |
-
)
|
83 |
-
parser.add_argument(
|
84 |
-
"--load_tensors",
|
85 |
-
action="store_true",
|
86 |
-
help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.",
|
87 |
-
)
|
88 |
-
parser.add_argument(
|
89 |
-
"--random_flip",
|
90 |
-
type=float,
|
91 |
-
default=None,
|
92 |
-
help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
|
93 |
-
)
|
94 |
-
parser.add_argument(
|
95 |
-
"--dataloader_num_workers",
|
96 |
-
type=int,
|
97 |
-
default=0,
|
98 |
-
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
99 |
-
)
|
100 |
-
parser.add_argument(
|
101 |
-
"--pin_memory",
|
102 |
-
action="store_true",
|
103 |
-
help="Whether or not to use the pinned memory setting in pytorch dataloader.",
|
104 |
-
)
|
105 |
-
|
106 |
-
|
107 |
-
def _get_validation_args(parser: argparse.ArgumentParser) -> None:
|
108 |
-
parser.add_argument(
|
109 |
-
"--validation_prompt",
|
110 |
-
type=str,
|
111 |
-
default=None,
|
112 |
-
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
113 |
-
)
|
114 |
-
parser.add_argument(
|
115 |
-
"--validation_images",
|
116 |
-
type=str,
|
117 |
-
default=None,
|
118 |
-
help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
|
119 |
-
)
|
120 |
-
parser.add_argument(
|
121 |
-
"--validation_prompt_separator",
|
122 |
-
type=str,
|
123 |
-
default=":::",
|
124 |
-
help="String that separates multiple validation prompts",
|
125 |
-
)
|
126 |
-
parser.add_argument(
|
127 |
-
"--num_validation_videos",
|
128 |
-
type=int,
|
129 |
-
default=1,
|
130 |
-
help="Number of videos that should be generated during validation per `validation_prompt`.",
|
131 |
-
)
|
132 |
-
parser.add_argument(
|
133 |
-
"--validation_epochs",
|
134 |
-
type=int,
|
135 |
-
default=None,
|
136 |
-
help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.",
|
137 |
-
)
|
138 |
-
parser.add_argument(
|
139 |
-
"--validation_steps",
|
140 |
-
type=int,
|
141 |
-
default=None,
|
142 |
-
help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
|
143 |
-
)
|
144 |
-
parser.add_argument(
|
145 |
-
"--guidance_scale",
|
146 |
-
type=float,
|
147 |
-
default=6,
|
148 |
-
help="The guidance scale to use while sampling validation videos.",
|
149 |
-
)
|
150 |
-
parser.add_argument(
|
151 |
-
"--use_dynamic_cfg",
|
152 |
-
action="store_true",
|
153 |
-
default=False,
|
154 |
-
help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
|
155 |
-
)
|
156 |
-
parser.add_argument(
|
157 |
-
"--enable_model_cpu_offload",
|
158 |
-
action="store_true",
|
159 |
-
default=False,
|
160 |
-
help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
|
161 |
-
)
|
162 |
-
|
163 |
-
|
164 |
-
def _get_training_args(parser: argparse.ArgumentParser) -> None:
|
165 |
-
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
166 |
-
parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.")
|
167 |
-
parser.add_argument(
|
168 |
-
"--lora_alpha",
|
169 |
-
type=int,
|
170 |
-
default=64,
|
171 |
-
help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
|
172 |
-
)
|
173 |
-
parser.add_argument(
|
174 |
-
"--mixed_precision",
|
175 |
-
type=str,
|
176 |
-
default=None,
|
177 |
-
choices=["no", "fp16", "bf16"],
|
178 |
-
help=(
|
179 |
-
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. "
|
180 |
-
"Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this "
|
181 |
-
"argument to override the accelerate config."
|
182 |
-
),
|
183 |
-
)
|
184 |
-
parser.add_argument(
|
185 |
-
"--output_dir",
|
186 |
-
type=str,
|
187 |
-
default="cogvideox-sft",
|
188 |
-
help="The output directory where the model predictions and checkpoints will be written.",
|
189 |
-
)
|
190 |
-
parser.add_argument(
|
191 |
-
"--height",
|
192 |
-
type=int,
|
193 |
-
default=480,
|
194 |
-
help="All input videos are resized to this height.",
|
195 |
-
)
|
196 |
-
parser.add_argument(
|
197 |
-
"--width",
|
198 |
-
type=int,
|
199 |
-
default=720,
|
200 |
-
help="All input videos are resized to this width.",
|
201 |
-
)
|
202 |
-
parser.add_argument(
|
203 |
-
"--video_reshape_mode",
|
204 |
-
type=str,
|
205 |
-
default=None,
|
206 |
-
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
|
207 |
-
)
|
208 |
-
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
|
209 |
-
parser.add_argument(
|
210 |
-
"--max_num_frames",
|
211 |
-
type=int,
|
212 |
-
default=49,
|
213 |
-
help="All input videos will be truncated to these many frames.",
|
214 |
-
)
|
215 |
-
parser.add_argument(
|
216 |
-
"--skip_frames_start",
|
217 |
-
type=int,
|
218 |
-
default=0,
|
219 |
-
help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
|
220 |
-
)
|
221 |
-
parser.add_argument(
|
222 |
-
"--skip_frames_end",
|
223 |
-
type=int,
|
224 |
-
default=0,
|
225 |
-
help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
|
226 |
-
)
|
227 |
-
parser.add_argument(
|
228 |
-
"--train_batch_size",
|
229 |
-
type=int,
|
230 |
-
default=4,
|
231 |
-
help="Batch size (per device) for the training dataloader.",
|
232 |
-
)
|
233 |
-
parser.add_argument("--num_train_epochs", type=int, default=1)
|
234 |
-
parser.add_argument(
|
235 |
-
"--max_train_steps",
|
236 |
-
type=int,
|
237 |
-
default=None,
|
238 |
-
help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
|
239 |
-
)
|
240 |
-
parser.add_argument(
|
241 |
-
"--checkpointing_steps",
|
242 |
-
type=int,
|
243 |
-
default=500,
|
244 |
-
help=(
|
245 |
-
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
246 |
-
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
247 |
-
" training using `--resume_from_checkpoint`."
|
248 |
-
),
|
249 |
-
)
|
250 |
-
parser.add_argument(
|
251 |
-
"--checkpoints_total_limit",
|
252 |
-
type=int,
|
253 |
-
default=None,
|
254 |
-
help=("Max number of checkpoints to store."),
|
255 |
-
)
|
256 |
-
parser.add_argument(
|
257 |
-
"--resume_from_checkpoint",
|
258 |
-
type=str,
|
259 |
-
default=None,
|
260 |
-
help=(
|
261 |
-
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
262 |
-
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
263 |
-
),
|
264 |
-
)
|
265 |
-
parser.add_argument(
|
266 |
-
"--gradient_accumulation_steps",
|
267 |
-
type=int,
|
268 |
-
default=1,
|
269 |
-
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
270 |
-
)
|
271 |
-
parser.add_argument(
|
272 |
-
"--gradient_checkpointing",
|
273 |
-
action="store_true",
|
274 |
-
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
275 |
-
)
|
276 |
-
parser.add_argument(
|
277 |
-
"--learning_rate",
|
278 |
-
type=float,
|
279 |
-
default=1e-4,
|
280 |
-
help="Initial learning rate (after the potential warmup period) to use.",
|
281 |
-
)
|
282 |
-
parser.add_argument(
|
283 |
-
"--scale_lr",
|
284 |
-
action="store_true",
|
285 |
-
default=False,
|
286 |
-
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
287 |
-
)
|
288 |
-
parser.add_argument(
|
289 |
-
"--lr_scheduler",
|
290 |
-
type=str,
|
291 |
-
default="constant",
|
292 |
-
help=(
|
293 |
-
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
294 |
-
' "constant", "constant_with_warmup"]'
|
295 |
-
),
|
296 |
-
)
|
297 |
-
parser.add_argument(
|
298 |
-
"--lr_warmup_steps",
|
299 |
-
type=int,
|
300 |
-
default=500,
|
301 |
-
help="Number of steps for the warmup in the lr scheduler.",
|
302 |
-
)
|
303 |
-
parser.add_argument(
|
304 |
-
"--lr_num_cycles",
|
305 |
-
type=int,
|
306 |
-
default=1,
|
307 |
-
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
308 |
-
)
|
309 |
-
parser.add_argument(
|
310 |
-
"--lr_power",
|
311 |
-
type=float,
|
312 |
-
default=1.0,
|
313 |
-
help="Power factor of the polynomial scheduler.",
|
314 |
-
)
|
315 |
-
parser.add_argument(
|
316 |
-
"--enable_slicing",
|
317 |
-
action="store_true",
|
318 |
-
default=False,
|
319 |
-
help="Whether or not to use VAE slicing for saving memory.",
|
320 |
-
)
|
321 |
-
parser.add_argument(
|
322 |
-
"--enable_tiling",
|
323 |
-
action="store_true",
|
324 |
-
default=False,
|
325 |
-
help="Whether or not to use VAE tiling for saving memory.",
|
326 |
-
)
|
327 |
-
parser.add_argument(
|
328 |
-
"--noised_image_dropout",
|
329 |
-
type=float,
|
330 |
-
default=0.05,
|
331 |
-
help="Image condition dropout probability when finetuning image-to-video.",
|
332 |
-
)
|
333 |
-
parser.add_argument(
|
334 |
-
"--ignore_learned_positional_embeddings",
|
335 |
-
action="store_true",
|
336 |
-
default=False,
|
337 |
-
help=(
|
338 |
-
"Whether to ignore the learned positional embeddings when training CogVideoX Image-to-Video. This setting "
|
339 |
-
"should be used when performing multi-resolution training, because CogVideoX-I2V does not support it "
|
340 |
-
"otherwise. Please read the comments in https://github.com/a-r-r-o-w/cogvideox-factory/issues/26 to understand why."
|
341 |
-
),
|
342 |
-
)
|
343 |
-
|
344 |
-
|
345 |
-
def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
|
346 |
-
parser.add_argument(
|
347 |
-
"--optimizer",
|
348 |
-
type=lambda s: s.lower(),
|
349 |
-
default="adam",
|
350 |
-
choices=["adam", "adamw", "prodigy", "came"],
|
351 |
-
help=("The optimizer type to use."),
|
352 |
-
)
|
353 |
-
parser.add_argument(
|
354 |
-
"--use_8bit",
|
355 |
-
action="store_true",
|
356 |
-
help="Whether or not to use 8-bit optimizers from `bitsandbytes` or `bitsandbytes`.",
|
357 |
-
)
|
358 |
-
parser.add_argument(
|
359 |
-
"--use_4bit",
|
360 |
-
action="store_true",
|
361 |
-
help="Whether or not to use 4-bit optimizers from `torchao`.",
|
362 |
-
)
|
363 |
-
parser.add_argument(
|
364 |
-
"--use_torchao", action="store_true", help="Whether or not to use the `torchao` backend for optimizers."
|
365 |
-
)
|
366 |
-
parser.add_argument(
|
367 |
-
"--beta1",
|
368 |
-
type=float,
|
369 |
-
default=0.9,
|
370 |
-
help="The beta1 parameter for the Adam and Prodigy optimizers.",
|
371 |
-
)
|
372 |
-
parser.add_argument(
|
373 |
-
"--beta2",
|
374 |
-
type=float,
|
375 |
-
default=0.95,
|
376 |
-
help="The beta2 parameter for the Adam and Prodigy optimizers.",
|
377 |
-
)
|
378 |
-
parser.add_argument(
|
379 |
-
"--beta3",
|
380 |
-
type=float,
|
381 |
-
default=None,
|
382 |
-
help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
|
383 |
-
)
|
384 |
-
parser.add_argument(
|
385 |
-
"--prodigy_decouple",
|
386 |
-
action="store_true",
|
387 |
-
help="Use AdamW style decoupled weight decay.",
|
388 |
-
)
|
389 |
-
parser.add_argument(
|
390 |
-
"--weight_decay",
|
391 |
-
type=float,
|
392 |
-
default=1e-04,
|
393 |
-
help="Weight decay to use for optimizer.",
|
394 |
-
)
|
395 |
-
parser.add_argument(
|
396 |
-
"--epsilon",
|
397 |
-
type=float,
|
398 |
-
default=1e-8,
|
399 |
-
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
|
400 |
-
)
|
401 |
-
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
402 |
-
parser.add_argument(
|
403 |
-
"--prodigy_use_bias_correction",
|
404 |
-
action="store_true",
|
405 |
-
help="Turn on Adam's bias correction.",
|
406 |
-
)
|
407 |
-
parser.add_argument(
|
408 |
-
"--prodigy_safeguard_warmup",
|
409 |
-
action="store_true",
|
410 |
-
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
|
411 |
-
)
|
412 |
-
parser.add_argument(
|
413 |
-
"--use_cpu_offload_optimizer",
|
414 |
-
action="store_true",
|
415 |
-
help="Whether or not to use the CPUOffloadOptimizer from TorchAO to perform optimization step and maintain parameters on the CPU.",
|
416 |
-
)
|
417 |
-
parser.add_argument(
|
418 |
-
"--offload_gradients",
|
419 |
-
action="store_true",
|
420 |
-
help="Whether or not to offload the gradients to CPU when using the CPUOffloadOptimizer from TorchAO.",
|
421 |
-
)
|
422 |
-
|
423 |
-
|
424 |
-
def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
|
425 |
-
parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
|
426 |
-
parser.add_argument(
|
427 |
-
"--push_to_hub",
|
428 |
-
action="store_true",
|
429 |
-
help="Whether or not to push the model to the Hub.",
|
430 |
-
)
|
431 |
-
parser.add_argument(
|
432 |
-
"--hub_token",
|
433 |
-
type=str,
|
434 |
-
default=None,
|
435 |
-
help="The token to use to push to the Model Hub.",
|
436 |
-
)
|
437 |
-
parser.add_argument(
|
438 |
-
"--hub_model_id",
|
439 |
-
type=str,
|
440 |
-
default=None,
|
441 |
-
help="The name of the repository to keep in sync with the local `output_dir`.",
|
442 |
-
)
|
443 |
-
parser.add_argument(
|
444 |
-
"--logging_dir",
|
445 |
-
type=str,
|
446 |
-
default="logs",
|
447 |
-
help="Directory where logs are stored.",
|
448 |
-
)
|
449 |
-
parser.add_argument(
|
450 |
-
"--allow_tf32",
|
451 |
-
action="store_true",
|
452 |
-
help=(
|
453 |
-
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
454 |
-
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
455 |
-
),
|
456 |
-
)
|
457 |
-
parser.add_argument(
|
458 |
-
"--nccl_timeout",
|
459 |
-
type=int,
|
460 |
-
default=600,
|
461 |
-
help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.",
|
462 |
-
)
|
463 |
-
parser.add_argument(
|
464 |
-
"--report_to",
|
465 |
-
type=str,
|
466 |
-
default=None,
|
467 |
-
help=(
|
468 |
-
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
469 |
-
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
470 |
-
),
|
471 |
-
)
|
472 |
-
|
473 |
-
|
474 |
-
def get_args():
|
475 |
-
parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
|
476 |
-
|
477 |
-
_get_model_args(parser)
|
478 |
-
_get_dataset_args(parser)
|
479 |
-
_get_training_args(parser)
|
480 |
-
_get_validation_args(parser)
|
481 |
-
_get_optimizer_args(parser)
|
482 |
-
_get_configuration_args(parser)
|
483 |
-
|
484 |
-
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/cogvideox/cogvideox_image_to_video_lora.py
DELETED
@@ -1,1016 +0,0 @@
|
|
1 |
-
# Copyright 2024 The HuggingFace Team.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
import gc
|
17 |
-
import logging
|
18 |
-
import math
|
19 |
-
import os
|
20 |
-
import random
|
21 |
-
import shutil
|
22 |
-
from datetime import timedelta
|
23 |
-
from pathlib import Path
|
24 |
-
from typing import Any, Dict
|
25 |
-
|
26 |
-
import diffusers
|
27 |
-
import torch
|
28 |
-
import transformers
|
29 |
-
import wandb
|
30 |
-
from accelerate import Accelerator, DistributedType
|
31 |
-
from accelerate.logging import get_logger
|
32 |
-
from accelerate.utils import (
|
33 |
-
DistributedDataParallelKwargs,
|
34 |
-
InitProcessGroupKwargs,
|
35 |
-
ProjectConfiguration,
|
36 |
-
set_seed,
|
37 |
-
)
|
38 |
-
from diffusers import (
|
39 |
-
AutoencoderKLCogVideoX,
|
40 |
-
CogVideoXDPMScheduler,
|
41 |
-
CogVideoXImageToVideoPipeline,
|
42 |
-
CogVideoXTransformer3DModel,
|
43 |
-
)
|
44 |
-
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
45 |
-
from diffusers.optimization import get_scheduler
|
46 |
-
from diffusers.training_utils import cast_training_params
|
47 |
-
from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
|
48 |
-
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
49 |
-
from huggingface_hub import create_repo, upload_folder
|
50 |
-
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
51 |
-
from torch.utils.data import DataLoader
|
52 |
-
from tqdm.auto import tqdm
|
53 |
-
from transformers import AutoTokenizer, T5EncoderModel
|
54 |
-
|
55 |
-
|
56 |
-
from args import get_args # isort:skip
|
57 |
-
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
|
58 |
-
from text_encoder import compute_prompt_embeddings # isort:skip
|
59 |
-
from utils import (
|
60 |
-
get_gradient_norm,
|
61 |
-
get_optimizer,
|
62 |
-
prepare_rotary_positional_embeddings,
|
63 |
-
print_memory,
|
64 |
-
reset_memory,
|
65 |
-
unwrap_model,
|
66 |
-
)
|
67 |
-
|
68 |
-
|
69 |
-
logger = get_logger(__name__)
|
70 |
-
|
71 |
-
|
72 |
-
def save_model_card(
|
73 |
-
repo_id: str,
|
74 |
-
videos=None,
|
75 |
-
base_model: str = None,
|
76 |
-
validation_prompt=None,
|
77 |
-
repo_folder=None,
|
78 |
-
fps=8,
|
79 |
-
):
|
80 |
-
widget_dict = []
|
81 |
-
if videos is not None:
|
82 |
-
for i, video in enumerate(videos):
|
83 |
-
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
|
84 |
-
widget_dict.append(
|
85 |
-
{
|
86 |
-
"text": validation_prompt if validation_prompt else " ",
|
87 |
-
"output": {"url": f"video_{i}.mp4"},
|
88 |
-
}
|
89 |
-
)
|
90 |
-
|
91 |
-
model_description = f"""
|
92 |
-
# CogVideoX LoRA Finetune
|
93 |
-
|
94 |
-
<Gallery />
|
95 |
-
|
96 |
-
## Model description
|
97 |
-
|
98 |
-
This is a lora finetune of the CogVideoX model `{base_model}`.
|
99 |
-
|
100 |
-
The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
101 |
-
|
102 |
-
## Download model
|
103 |
-
|
104 |
-
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
|
105 |
-
|
106 |
-
## Usage
|
107 |
-
|
108 |
-
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
|
109 |
-
|
110 |
-
```py
|
111 |
-
import torch
|
112 |
-
from diffusers import CogVideoXImageToVideoPipeline
|
113 |
-
from diffusers.utils import export_to_video, load_image
|
114 |
-
|
115 |
-
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16).to("cuda")
|
116 |
-
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
|
117 |
-
|
118 |
-
# The LoRA adapter weights are determined by what was used for training.
|
119 |
-
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
|
120 |
-
# It can be made lower or higher from what was used in training to decrease or amplify the effect
|
121 |
-
# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
|
122 |
-
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
|
123 |
-
|
124 |
-
image = load_image("/path/to/image.png")
|
125 |
-
video = pipe(image=image, prompt="{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
126 |
-
export_to_video(video, "output.mp4", fps=8)
|
127 |
-
```
|
128 |
-
|
129 |
-
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
|
130 |
-
|
131 |
-
## License
|
132 |
-
|
133 |
-
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).
|
134 |
-
"""
|
135 |
-
model_card = load_or_create_model_card(
|
136 |
-
repo_id_or_path=repo_id,
|
137 |
-
from_training=True,
|
138 |
-
license="other",
|
139 |
-
base_model=base_model,
|
140 |
-
prompt=validation_prompt,
|
141 |
-
model_description=model_description,
|
142 |
-
widget=widget_dict,
|
143 |
-
)
|
144 |
-
tags = [
|
145 |
-
"text-to-video",
|
146 |
-
"image-to-video",
|
147 |
-
"diffusers-training",
|
148 |
-
"diffusers",
|
149 |
-
"lora",
|
150 |
-
"cogvideox",
|
151 |
-
"cogvideox-diffusers",
|
152 |
-
"template:sd-lora",
|
153 |
-
]
|
154 |
-
|
155 |
-
model_card = populate_model_card(model_card, tags=tags)
|
156 |
-
model_card.save(os.path.join(repo_folder, "README.md"))
|
157 |
-
|
158 |
-
|
159 |
-
def log_validation(
|
160 |
-
accelerator: Accelerator,
|
161 |
-
pipe: CogVideoXImageToVideoPipeline,
|
162 |
-
args: Dict[str, Any],
|
163 |
-
pipeline_args: Dict[str, Any],
|
164 |
-
is_final_validation: bool = False,
|
165 |
-
):
|
166 |
-
logger.info(
|
167 |
-
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
168 |
-
)
|
169 |
-
|
170 |
-
pipe = pipe.to(accelerator.device)
|
171 |
-
|
172 |
-
# run inference
|
173 |
-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
174 |
-
|
175 |
-
videos = []
|
176 |
-
for _ in range(args.num_validation_videos):
|
177 |
-
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
178 |
-
videos.append(video)
|
179 |
-
|
180 |
-
for tracker in accelerator.trackers:
|
181 |
-
phase_name = "test" if is_final_validation else "validation"
|
182 |
-
if tracker.name == "wandb":
|
183 |
-
video_filenames = []
|
184 |
-
for i, video in enumerate(videos):
|
185 |
-
prompt = (
|
186 |
-
pipeline_args["prompt"][:25]
|
187 |
-
.replace(" ", "_")
|
188 |
-
.replace(" ", "_")
|
189 |
-
.replace("'", "_")
|
190 |
-
.replace('"', "_")
|
191 |
-
.replace("/", "_")
|
192 |
-
)
|
193 |
-
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
|
194 |
-
export_to_video(video, filename, fps=8)
|
195 |
-
video_filenames.append(filename)
|
196 |
-
|
197 |
-
tracker.log(
|
198 |
-
{
|
199 |
-
phase_name: [
|
200 |
-
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
|
201 |
-
for i, filename in enumerate(video_filenames)
|
202 |
-
]
|
203 |
-
}
|
204 |
-
)
|
205 |
-
|
206 |
-
return videos
|
207 |
-
|
208 |
-
|
209 |
-
def run_validation(
|
210 |
-
args: Dict[str, Any],
|
211 |
-
accelerator: Accelerator,
|
212 |
-
transformer,
|
213 |
-
scheduler,
|
214 |
-
model_config: Dict[str, Any],
|
215 |
-
weight_dtype: torch.dtype,
|
216 |
-
) -> None:
|
217 |
-
accelerator.print("===== Memory before validation =====")
|
218 |
-
print_memory(accelerator.device)
|
219 |
-
torch.cuda.synchronize(accelerator.device)
|
220 |
-
|
221 |
-
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
222 |
-
args.pretrained_model_name_or_path,
|
223 |
-
transformer=unwrap_model(accelerator, transformer),
|
224 |
-
scheduler=scheduler,
|
225 |
-
revision=args.revision,
|
226 |
-
variant=args.variant,
|
227 |
-
torch_dtype=weight_dtype,
|
228 |
-
)
|
229 |
-
|
230 |
-
if args.enable_slicing:
|
231 |
-
pipe.vae.enable_slicing()
|
232 |
-
if args.enable_tiling:
|
233 |
-
pipe.vae.enable_tiling()
|
234 |
-
if args.enable_model_cpu_offload:
|
235 |
-
pipe.enable_model_cpu_offload()
|
236 |
-
|
237 |
-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
238 |
-
validation_images = args.validation_images.split(args.validation_prompt_separator)
|
239 |
-
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
|
240 |
-
pipeline_args = {
|
241 |
-
"image": load_image(validation_image),
|
242 |
-
"prompt": validation_prompt,
|
243 |
-
"guidance_scale": args.guidance_scale,
|
244 |
-
"use_dynamic_cfg": args.use_dynamic_cfg,
|
245 |
-
"height": args.height,
|
246 |
-
"width": args.width,
|
247 |
-
"max_sequence_length": model_config.max_text_seq_length,
|
248 |
-
}
|
249 |
-
|
250 |
-
log_validation(
|
251 |
-
pipe=pipe,
|
252 |
-
args=args,
|
253 |
-
accelerator=accelerator,
|
254 |
-
pipeline_args=pipeline_args,
|
255 |
-
)
|
256 |
-
|
257 |
-
accelerator.print("===== Memory after validation =====")
|
258 |
-
print_memory(accelerator.device)
|
259 |
-
reset_memory(accelerator.device)
|
260 |
-
|
261 |
-
del pipe
|
262 |
-
gc.collect()
|
263 |
-
torch.cuda.empty_cache()
|
264 |
-
torch.cuda.synchronize(accelerator.device)
|
265 |
-
|
266 |
-
|
267 |
-
class CollateFunction:
|
268 |
-
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
|
269 |
-
self.weight_dtype = weight_dtype
|
270 |
-
self.load_tensors = load_tensors
|
271 |
-
|
272 |
-
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
273 |
-
prompts = [x["prompt"] for x in data[0]]
|
274 |
-
|
275 |
-
if self.load_tensors:
|
276 |
-
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
|
277 |
-
|
278 |
-
images = [x["image"] for x in data[0]]
|
279 |
-
images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True)
|
280 |
-
|
281 |
-
videos = [x["video"] for x in data[0]]
|
282 |
-
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
|
283 |
-
|
284 |
-
return {
|
285 |
-
"images": images,
|
286 |
-
"videos": videos,
|
287 |
-
"prompts": prompts,
|
288 |
-
}
|
289 |
-
|
290 |
-
|
291 |
-
def main(args):
|
292 |
-
if args.report_to == "wandb" and args.hub_token is not None:
|
293 |
-
raise ValueError(
|
294 |
-
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
295 |
-
" Please use `huggingface-cli login` to authenticate with the Hub."
|
296 |
-
)
|
297 |
-
|
298 |
-
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
299 |
-
# due to pytorch#99272, MPS does not yet support bfloat16.
|
300 |
-
raise ValueError(
|
301 |
-
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
302 |
-
)
|
303 |
-
|
304 |
-
logging_dir = Path(args.output_dir, args.logging_dir)
|
305 |
-
|
306 |
-
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
307 |
-
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
308 |
-
init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
|
309 |
-
accelerator = Accelerator(
|
310 |
-
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
311 |
-
mixed_precision=args.mixed_precision,
|
312 |
-
log_with=args.report_to,
|
313 |
-
project_config=accelerator_project_config,
|
314 |
-
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
|
315 |
-
)
|
316 |
-
|
317 |
-
# Disable AMP for MPS.
|
318 |
-
if torch.backends.mps.is_available():
|
319 |
-
accelerator.native_amp = False
|
320 |
-
|
321 |
-
# Make one log on every process with the configuration for debugging.
|
322 |
-
logging.basicConfig(
|
323 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
324 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
325 |
-
level=logging.INFO,
|
326 |
-
)
|
327 |
-
logger.info(accelerator.state, main_process_only=False)
|
328 |
-
if accelerator.is_local_main_process:
|
329 |
-
transformers.utils.logging.set_verbosity_warning()
|
330 |
-
diffusers.utils.logging.set_verbosity_info()
|
331 |
-
else:
|
332 |
-
transformers.utils.logging.set_verbosity_error()
|
333 |
-
diffusers.utils.logging.set_verbosity_error()
|
334 |
-
|
335 |
-
# If passed along, set the training seed now.
|
336 |
-
if args.seed is not None:
|
337 |
-
set_seed(args.seed)
|
338 |
-
|
339 |
-
# Handle the repository creation
|
340 |
-
if accelerator.is_main_process:
|
341 |
-
if args.output_dir is not None:
|
342 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
343 |
-
|
344 |
-
if args.push_to_hub:
|
345 |
-
repo_id = create_repo(
|
346 |
-
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
347 |
-
exist_ok=True,
|
348 |
-
).repo_id
|
349 |
-
|
350 |
-
# Prepare models and scheduler
|
351 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
352 |
-
args.pretrained_model_name_or_path,
|
353 |
-
subfolder="tokenizer",
|
354 |
-
revision=args.revision,
|
355 |
-
)
|
356 |
-
|
357 |
-
text_encoder = T5EncoderModel.from_pretrained(
|
358 |
-
args.pretrained_model_name_or_path,
|
359 |
-
subfolder="text_encoder",
|
360 |
-
revision=args.revision,
|
361 |
-
)
|
362 |
-
|
363 |
-
# CogVideoX-2b weights are stored in float16
|
364 |
-
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
|
365 |
-
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
|
366 |
-
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
367 |
-
args.pretrained_model_name_or_path,
|
368 |
-
subfolder="transformer",
|
369 |
-
torch_dtype=load_dtype,
|
370 |
-
revision=args.revision,
|
371 |
-
variant=args.variant,
|
372 |
-
)
|
373 |
-
|
374 |
-
# These changes will also be required when trying to run inference with the trained lora
|
375 |
-
if args.ignore_learned_positional_embeddings:
|
376 |
-
del transformer.patch_embed.pos_embedding
|
377 |
-
transformer.patch_embed.use_learned_positional_embeddings = False
|
378 |
-
transformer.config.use_learned_positional_embeddings = False
|
379 |
-
|
380 |
-
vae = AutoencoderKLCogVideoX.from_pretrained(
|
381 |
-
args.pretrained_model_name_or_path,
|
382 |
-
subfolder="vae",
|
383 |
-
revision=args.revision,
|
384 |
-
variant=args.variant,
|
385 |
-
)
|
386 |
-
|
387 |
-
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
388 |
-
|
389 |
-
if args.enable_slicing:
|
390 |
-
vae.enable_slicing()
|
391 |
-
if args.enable_tiling:
|
392 |
-
vae.enable_tiling()
|
393 |
-
|
394 |
-
# We only train the additional adapter LoRA layers
|
395 |
-
text_encoder.requires_grad_(False)
|
396 |
-
transformer.requires_grad_(False)
|
397 |
-
vae.requires_grad_(False)
|
398 |
-
|
399 |
-
VAE_SCALING_FACTOR = vae.config.scaling_factor
|
400 |
-
VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
|
401 |
-
RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
|
402 |
-
RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
|
403 |
-
|
404 |
-
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
405 |
-
# as these weights are only used for inference, keeping weights in full precision is not required.
|
406 |
-
weight_dtype = torch.float32
|
407 |
-
if accelerator.state.deepspeed_plugin:
|
408 |
-
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
409 |
-
if (
|
410 |
-
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
411 |
-
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
412 |
-
):
|
413 |
-
weight_dtype = torch.float16
|
414 |
-
if (
|
415 |
-
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
416 |
-
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
417 |
-
):
|
418 |
-
weight_dtype = torch.bfloat16
|
419 |
-
else:
|
420 |
-
if accelerator.mixed_precision == "fp16":
|
421 |
-
weight_dtype = torch.float16
|
422 |
-
elif accelerator.mixed_precision == "bf16":
|
423 |
-
weight_dtype = torch.bfloat16
|
424 |
-
|
425 |
-
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
426 |
-
# due to pytorch#99272, MPS does not yet support bfloat16.
|
427 |
-
raise ValueError(
|
428 |
-
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
429 |
-
)
|
430 |
-
|
431 |
-
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
432 |
-
transformer.to(accelerator.device, dtype=weight_dtype)
|
433 |
-
vae.to(accelerator.device, dtype=weight_dtype)
|
434 |
-
|
435 |
-
if args.gradient_checkpointing:
|
436 |
-
transformer.enable_gradient_checkpointing()
|
437 |
-
|
438 |
-
# now we will add new LoRA weights to the attention layers
|
439 |
-
transformer_lora_config = LoraConfig(
|
440 |
-
r=args.rank,
|
441 |
-
lora_alpha=args.lora_alpha,
|
442 |
-
init_lora_weights=True,
|
443 |
-
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
444 |
-
)
|
445 |
-
transformer.add_adapter(transformer_lora_config)
|
446 |
-
|
447 |
-
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
448 |
-
def save_model_hook(models, weights, output_dir):
|
449 |
-
if accelerator.is_main_process:
|
450 |
-
transformer_lora_layers_to_save = None
|
451 |
-
|
452 |
-
for model in models:
|
453 |
-
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
454 |
-
model = unwrap_model(accelerator, model)
|
455 |
-
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
456 |
-
else:
|
457 |
-
raise ValueError(f"Unexpected save model: {model.__class__}")
|
458 |
-
|
459 |
-
# make sure to pop weight so that corresponding model is not saved again
|
460 |
-
if weights:
|
461 |
-
weights.pop()
|
462 |
-
|
463 |
-
CogVideoXImageToVideoPipeline.save_lora_weights(
|
464 |
-
output_dir,
|
465 |
-
transformer_lora_layers=transformer_lora_layers_to_save,
|
466 |
-
)
|
467 |
-
|
468 |
-
def load_model_hook(models, input_dir):
|
469 |
-
transformer_ = None
|
470 |
-
|
471 |
-
# This is a bit of a hack but I don't know any other solution.
|
472 |
-
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
473 |
-
while len(models) > 0:
|
474 |
-
model = models.pop()
|
475 |
-
|
476 |
-
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
477 |
-
transformer_ = unwrap_model(accelerator, model)
|
478 |
-
else:
|
479 |
-
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
|
480 |
-
else:
|
481 |
-
transformer_ = CogVideoXTransformer3DModel.from_pretrained(
|
482 |
-
args.pretrained_model_name_or_path, subfolder="transformer"
|
483 |
-
)
|
484 |
-
transformer_.add_adapter(transformer_lora_config)
|
485 |
-
|
486 |
-
lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
|
487 |
-
|
488 |
-
transformer_state_dict = {
|
489 |
-
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
490 |
-
}
|
491 |
-
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
492 |
-
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
493 |
-
if incompatible_keys is not None:
|
494 |
-
# check only for unexpected keys
|
495 |
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
496 |
-
if unexpected_keys:
|
497 |
-
logger.warning(
|
498 |
-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
499 |
-
f" {unexpected_keys}. "
|
500 |
-
)
|
501 |
-
|
502 |
-
# Make sure the trainable params are in float32. This is again needed since the base models
|
503 |
-
# are in `weight_dtype`. More details:
|
504 |
-
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
505 |
-
if args.mixed_precision == "fp16":
|
506 |
-
# only upcast trainable parameters (LoRA) into fp32
|
507 |
-
cast_training_params([transformer_])
|
508 |
-
|
509 |
-
accelerator.register_save_state_pre_hook(save_model_hook)
|
510 |
-
accelerator.register_load_state_pre_hook(load_model_hook)
|
511 |
-
|
512 |
-
# Enable TF32 for faster training on Ampere GPUs,
|
513 |
-
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
514 |
-
if args.allow_tf32 and torch.cuda.is_available():
|
515 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
516 |
-
|
517 |
-
if args.scale_lr:
|
518 |
-
args.learning_rate = (
|
519 |
-
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
520 |
-
)
|
521 |
-
|
522 |
-
# Make sure the trainable params are in float32.
|
523 |
-
if args.mixed_precision == "fp16":
|
524 |
-
# only upcast trainable parameters (LoRA) into fp32
|
525 |
-
cast_training_params([transformer], dtype=torch.float32)
|
526 |
-
|
527 |
-
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
528 |
-
|
529 |
-
# Optimization parameters
|
530 |
-
transformer_parameters_with_lr = {
|
531 |
-
"params": transformer_lora_parameters,
|
532 |
-
"lr": args.learning_rate,
|
533 |
-
}
|
534 |
-
params_to_optimize = [transformer_parameters_with_lr]
|
535 |
-
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
536 |
-
|
537 |
-
use_deepspeed_optimizer = (
|
538 |
-
accelerator.state.deepspeed_plugin is not None
|
539 |
-
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
540 |
-
)
|
541 |
-
use_deepspeed_scheduler = (
|
542 |
-
accelerator.state.deepspeed_plugin is not None
|
543 |
-
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
544 |
-
)
|
545 |
-
|
546 |
-
optimizer = get_optimizer(
|
547 |
-
params_to_optimize=params_to_optimize,
|
548 |
-
optimizer_name=args.optimizer,
|
549 |
-
learning_rate=args.learning_rate,
|
550 |
-
beta1=args.beta1,
|
551 |
-
beta2=args.beta2,
|
552 |
-
beta3=args.beta3,
|
553 |
-
epsilon=args.epsilon,
|
554 |
-
weight_decay=args.weight_decay,
|
555 |
-
prodigy_decouple=args.prodigy_decouple,
|
556 |
-
prodigy_use_bias_correction=args.prodigy_use_bias_correction,
|
557 |
-
prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
|
558 |
-
use_8bit=args.use_8bit,
|
559 |
-
use_4bit=args.use_4bit,
|
560 |
-
use_torchao=args.use_torchao,
|
561 |
-
use_deepspeed=use_deepspeed_optimizer,
|
562 |
-
use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
|
563 |
-
offload_gradients=args.offload_gradients,
|
564 |
-
)
|
565 |
-
|
566 |
-
# Dataset and DataLoader
|
567 |
-
dataset_init_kwargs = {
|
568 |
-
"data_root": args.data_root,
|
569 |
-
"dataset_file": args.dataset_file,
|
570 |
-
"caption_column": args.caption_column,
|
571 |
-
"video_column": args.video_column,
|
572 |
-
"max_num_frames": args.max_num_frames,
|
573 |
-
"id_token": args.id_token,
|
574 |
-
"height_buckets": args.height_buckets,
|
575 |
-
"width_buckets": args.width_buckets,
|
576 |
-
"frame_buckets": args.frame_buckets,
|
577 |
-
"load_tensors": args.load_tensors,
|
578 |
-
"random_flip": args.random_flip,
|
579 |
-
"image_to_video": True,
|
580 |
-
}
|
581 |
-
if args.video_reshape_mode is None:
|
582 |
-
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
|
583 |
-
else:
|
584 |
-
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
|
585 |
-
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
|
586 |
-
)
|
587 |
-
|
588 |
-
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
|
589 |
-
|
590 |
-
train_dataloader = DataLoader(
|
591 |
-
train_dataset,
|
592 |
-
batch_size=1,
|
593 |
-
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
|
594 |
-
collate_fn=collate_fn,
|
595 |
-
num_workers=args.dataloader_num_workers,
|
596 |
-
pin_memory=args.pin_memory,
|
597 |
-
)
|
598 |
-
|
599 |
-
# Scheduler and math around the number of training steps.
|
600 |
-
overrode_max_train_steps = False
|
601 |
-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
602 |
-
if args.max_train_steps is None:
|
603 |
-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
604 |
-
overrode_max_train_steps = True
|
605 |
-
|
606 |
-
if args.use_cpu_offload_optimizer:
|
607 |
-
lr_scheduler = None
|
608 |
-
accelerator.print(
|
609 |
-
"CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
|
610 |
-
"you are training with those settings, they will be ignored."
|
611 |
-
)
|
612 |
-
else:
|
613 |
-
if use_deepspeed_scheduler:
|
614 |
-
from accelerate.utils import DummyScheduler
|
615 |
-
|
616 |
-
lr_scheduler = DummyScheduler(
|
617 |
-
name=args.lr_scheduler,
|
618 |
-
optimizer=optimizer,
|
619 |
-
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
620 |
-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
621 |
-
)
|
622 |
-
else:
|
623 |
-
lr_scheduler = get_scheduler(
|
624 |
-
args.lr_scheduler,
|
625 |
-
optimizer=optimizer,
|
626 |
-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
627 |
-
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
628 |
-
num_cycles=args.lr_num_cycles,
|
629 |
-
power=args.lr_power,
|
630 |
-
)
|
631 |
-
|
632 |
-
# Prepare everything with our `accelerator`.
|
633 |
-
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
634 |
-
transformer, optimizer, train_dataloader, lr_scheduler
|
635 |
-
)
|
636 |
-
|
637 |
-
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
638 |
-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
639 |
-
if overrode_max_train_steps:
|
640 |
-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
641 |
-
# Afterwards we recalculate our number of training epochs
|
642 |
-
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
643 |
-
|
644 |
-
# We need to initialize the trackers we use, and also store our configuration.
|
645 |
-
# The trackers initializes automatically on the main process.
|
646 |
-
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
647 |
-
tracker_name = args.tracker_name or "cogvideox-lora"
|
648 |
-
accelerator.init_trackers(tracker_name, config=vars(args))
|
649 |
-
|
650 |
-
accelerator.print("===== Memory before training =====")
|
651 |
-
reset_memory(accelerator.device)
|
652 |
-
print_memory(accelerator.device)
|
653 |
-
|
654 |
-
# Train!
|
655 |
-
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
656 |
-
|
657 |
-
accelerator.print("***** Running training *****")
|
658 |
-
accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
|
659 |
-
accelerator.print(f" Num examples = {len(train_dataset)}")
|
660 |
-
accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
|
661 |
-
accelerator.print(f" Num epochs = {args.num_train_epochs}")
|
662 |
-
accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
663 |
-
accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
664 |
-
accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
665 |
-
accelerator.print(f" Total optimization steps = {args.max_train_steps}")
|
666 |
-
global_step = 0
|
667 |
-
first_epoch = 0
|
668 |
-
|
669 |
-
# Potentially load in the weights and states from a previous save
|
670 |
-
if not args.resume_from_checkpoint:
|
671 |
-
initial_global_step = 0
|
672 |
-
else:
|
673 |
-
if args.resume_from_checkpoint != "latest":
|
674 |
-
path = os.path.basename(args.resume_from_checkpoint)
|
675 |
-
else:
|
676 |
-
# Get the most recent checkpoint
|
677 |
-
dirs = os.listdir(args.output_dir)
|
678 |
-
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
679 |
-
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
680 |
-
path = dirs[-1] if len(dirs) > 0 else None
|
681 |
-
|
682 |
-
if path is None:
|
683 |
-
accelerator.print(
|
684 |
-
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
685 |
-
)
|
686 |
-
args.resume_from_checkpoint = None
|
687 |
-
initial_global_step = 0
|
688 |
-
else:
|
689 |
-
accelerator.print(f"Resuming from checkpoint {path}")
|
690 |
-
accelerator.load_state(os.path.join(args.output_dir, path))
|
691 |
-
global_step = int(path.split("-")[1])
|
692 |
-
|
693 |
-
initial_global_step = global_step
|
694 |
-
first_epoch = global_step // num_update_steps_per_epoch
|
695 |
-
|
696 |
-
progress_bar = tqdm(
|
697 |
-
range(0, args.max_train_steps),
|
698 |
-
initial=initial_global_step,
|
699 |
-
desc="Steps",
|
700 |
-
# Only show the progress bar once on each machine.
|
701 |
-
disable=not accelerator.is_local_main_process,
|
702 |
-
)
|
703 |
-
|
704 |
-
# For DeepSpeed training
|
705 |
-
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
706 |
-
|
707 |
-
if args.load_tensors:
|
708 |
-
del vae, text_encoder
|
709 |
-
gc.collect()
|
710 |
-
torch.cuda.empty_cache()
|
711 |
-
torch.cuda.synchronize(accelerator.device)
|
712 |
-
|
713 |
-
alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
|
714 |
-
|
715 |
-
for epoch in range(first_epoch, args.num_train_epochs):
|
716 |
-
transformer.train()
|
717 |
-
|
718 |
-
for step, batch in enumerate(train_dataloader):
|
719 |
-
models_to_accumulate = [transformer]
|
720 |
-
logs = {}
|
721 |
-
|
722 |
-
with accelerator.accumulate(models_to_accumulate):
|
723 |
-
images = batch["images"].to(accelerator.device, non_blocking=True)
|
724 |
-
videos = batch["videos"].to(accelerator.device, non_blocking=True)
|
725 |
-
prompts = batch["prompts"]
|
726 |
-
|
727 |
-
# Encode videos
|
728 |
-
if not args.load_tensors:
|
729 |
-
images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
730 |
-
image_noise_sigma = torch.normal(
|
731 |
-
mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
|
732 |
-
)
|
733 |
-
image_noise_sigma = torch.exp(image_noise_sigma)
|
734 |
-
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
735 |
-
image_latent_dist = vae.encode(noisy_images).latent_dist
|
736 |
-
|
737 |
-
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
738 |
-
latent_dist = vae.encode(videos).latent_dist
|
739 |
-
else:
|
740 |
-
image_latent_dist = DiagonalGaussianDistribution(images)
|
741 |
-
latent_dist = DiagonalGaussianDistribution(videos)
|
742 |
-
|
743 |
-
image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR
|
744 |
-
image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
745 |
-
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
746 |
-
|
747 |
-
video_latents = latent_dist.sample() * VAE_SCALING_FACTOR
|
748 |
-
video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
749 |
-
video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
750 |
-
|
751 |
-
padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])
|
752 |
-
latent_padding = image_latents.new_zeros(padding_shape)
|
753 |
-
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
754 |
-
|
755 |
-
if random.random() < args.noised_image_dropout:
|
756 |
-
image_latents = torch.zeros_like(image_latents)
|
757 |
-
|
758 |
-
# Encode prompts
|
759 |
-
if not args.load_tensors:
|
760 |
-
prompt_embeds = compute_prompt_embeddings(
|
761 |
-
tokenizer,
|
762 |
-
text_encoder,
|
763 |
-
prompts,
|
764 |
-
model_config.max_text_seq_length,
|
765 |
-
accelerator.device,
|
766 |
-
weight_dtype,
|
767 |
-
requires_grad=False,
|
768 |
-
)
|
769 |
-
else:
|
770 |
-
prompt_embeds = prompts.to(dtype=weight_dtype)
|
771 |
-
|
772 |
-
# Sample noise that will be added to the latents
|
773 |
-
noise = torch.randn_like(video_latents)
|
774 |
-
batch_size, num_frames, num_channels, height, width = video_latents.shape
|
775 |
-
|
776 |
-
# Sample a random timestep for each image
|
777 |
-
timesteps = torch.randint(
|
778 |
-
0,
|
779 |
-
scheduler.config.num_train_timesteps,
|
780 |
-
(batch_size,),
|
781 |
-
dtype=torch.int64,
|
782 |
-
device=accelerator.device,
|
783 |
-
)
|
784 |
-
|
785 |
-
# Prepare rotary embeds
|
786 |
-
image_rotary_emb = (
|
787 |
-
prepare_rotary_positional_embeddings(
|
788 |
-
height=height * VAE_SCALE_FACTOR_SPATIAL,
|
789 |
-
width=width * VAE_SCALE_FACTOR_SPATIAL,
|
790 |
-
num_frames=num_frames,
|
791 |
-
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
|
792 |
-
patch_size=model_config.patch_size,
|
793 |
-
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
|
794 |
-
attention_head_dim=model_config.attention_head_dim,
|
795 |
-
device=accelerator.device,
|
796 |
-
base_height=RoPE_BASE_HEIGHT,
|
797 |
-
base_width=RoPE_BASE_WIDTH,
|
798 |
-
)
|
799 |
-
if model_config.use_rotary_positional_embeddings
|
800 |
-
else None
|
801 |
-
)
|
802 |
-
|
803 |
-
# Add noise to the model input according to the noise magnitude at each timestep
|
804 |
-
# (this is the forward diffusion process)
|
805 |
-
noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
|
806 |
-
noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
|
807 |
-
|
808 |
-
ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None,
|
809 |
-
ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0)
|
810 |
-
# Predict the noise residual
|
811 |
-
model_output = transformer(
|
812 |
-
hidden_states=noisy_model_input,
|
813 |
-
encoder_hidden_states=prompt_embeds,
|
814 |
-
timestep=timesteps,
|
815 |
-
ofs=ofs_emb,
|
816 |
-
image_rotary_emb=image_rotary_emb,
|
817 |
-
return_dict=False,
|
818 |
-
)[0]
|
819 |
-
|
820 |
-
model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
|
821 |
-
|
822 |
-
weights = 1 / (1 - alphas_cumprod[timesteps])
|
823 |
-
while len(weights.shape) < len(model_pred.shape):
|
824 |
-
weights = weights.unsqueeze(-1)
|
825 |
-
|
826 |
-
target = video_latents
|
827 |
-
|
828 |
-
loss = torch.mean(
|
829 |
-
(weights * (model_pred - target) ** 2).reshape(batch_size, -1),
|
830 |
-
dim=1,
|
831 |
-
)
|
832 |
-
loss = loss.mean()
|
833 |
-
accelerator.backward(loss)
|
834 |
-
|
835 |
-
if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
|
836 |
-
gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
|
837 |
-
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
|
838 |
-
gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
|
839 |
-
logs.update(
|
840 |
-
{
|
841 |
-
"gradient_norm_before_clip": gradient_norm_before_clip,
|
842 |
-
"gradient_norm_after_clip": gradient_norm_after_clip,
|
843 |
-
}
|
844 |
-
)
|
845 |
-
|
846 |
-
if accelerator.state.deepspeed_plugin is None:
|
847 |
-
optimizer.step()
|
848 |
-
optimizer.zero_grad()
|
849 |
-
|
850 |
-
if not args.use_cpu_offload_optimizer:
|
851 |
-
lr_scheduler.step()
|
852 |
-
|
853 |
-
# Checks if the accelerator has performed an optimization step behind the scenes
|
854 |
-
if accelerator.sync_gradients:
|
855 |
-
progress_bar.update(1)
|
856 |
-
global_step += 1
|
857 |
-
|
858 |
-
# Checkpointing
|
859 |
-
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
860 |
-
if global_step % args.checkpointing_steps == 0:
|
861 |
-
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
862 |
-
if args.checkpoints_total_limit is not None:
|
863 |
-
checkpoints = os.listdir(args.output_dir)
|
864 |
-
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
865 |
-
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
866 |
-
|
867 |
-
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
868 |
-
if len(checkpoints) >= args.checkpoints_total_limit:
|
869 |
-
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
870 |
-
removing_checkpoints = checkpoints[0:num_to_remove]
|
871 |
-
|
872 |
-
logger.info(
|
873 |
-
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
874 |
-
)
|
875 |
-
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
|
876 |
-
|
877 |
-
for removing_checkpoint in removing_checkpoints:
|
878 |
-
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
879 |
-
shutil.rmtree(removing_checkpoint)
|
880 |
-
|
881 |
-
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
882 |
-
accelerator.save_state(save_path)
|
883 |
-
logger.info(f"Saved state to {save_path}")
|
884 |
-
|
885 |
-
# Validation
|
886 |
-
should_run_validation = args.validation_prompt is not None and (
|
887 |
-
args.validation_steps is not None and global_step % args.validation_steps == 0
|
888 |
-
)
|
889 |
-
if should_run_validation:
|
890 |
-
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
|
891 |
-
|
892 |
-
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
|
893 |
-
logs.update(
|
894 |
-
{
|
895 |
-
"loss": loss.detach().item(),
|
896 |
-
"lr": last_lr,
|
897 |
-
}
|
898 |
-
)
|
899 |
-
progress_bar.set_postfix(**logs)
|
900 |
-
accelerator.log(logs, step=global_step)
|
901 |
-
|
902 |
-
if global_step >= args.max_train_steps:
|
903 |
-
break
|
904 |
-
|
905 |
-
if accelerator.is_main_process:
|
906 |
-
should_run_validation = args.validation_prompt is not None and (
|
907 |
-
args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0
|
908 |
-
)
|
909 |
-
if should_run_validation:
|
910 |
-
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
|
911 |
-
|
912 |
-
accelerator.wait_for_everyone()
|
913 |
-
|
914 |
-
if accelerator.is_main_process:
|
915 |
-
transformer = unwrap_model(accelerator, transformer)
|
916 |
-
dtype = (
|
917 |
-
torch.float16
|
918 |
-
if args.mixed_precision == "fp16"
|
919 |
-
else torch.bfloat16
|
920 |
-
if args.mixed_precision == "bf16"
|
921 |
-
else torch.float32
|
922 |
-
)
|
923 |
-
transformer = transformer.to(dtype)
|
924 |
-
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
925 |
-
|
926 |
-
CogVideoXImageToVideoPipeline.save_lora_weights(
|
927 |
-
save_directory=args.output_dir,
|
928 |
-
transformer_lora_layers=transformer_lora_layers,
|
929 |
-
)
|
930 |
-
|
931 |
-
# Cleanup trained models to save memory
|
932 |
-
if args.load_tensors:
|
933 |
-
del transformer
|
934 |
-
else:
|
935 |
-
del transformer, text_encoder, vae
|
936 |
-
|
937 |
-
gc.collect()
|
938 |
-
torch.cuda.empty_cache()
|
939 |
-
torch.cuda.synchronize(accelerator.device)
|
940 |
-
|
941 |
-
accelerator.print("===== Memory before testing =====")
|
942 |
-
print_memory(accelerator.device)
|
943 |
-
reset_memory(accelerator.device)
|
944 |
-
|
945 |
-
# Final test inference
|
946 |
-
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
947 |
-
args.pretrained_model_name_or_path,
|
948 |
-
revision=args.revision,
|
949 |
-
variant=args.variant,
|
950 |
-
torch_dtype=weight_dtype,
|
951 |
-
)
|
952 |
-
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
953 |
-
|
954 |
-
if args.enable_slicing:
|
955 |
-
pipe.vae.enable_slicing()
|
956 |
-
if args.enable_tiling:
|
957 |
-
pipe.vae.enable_tiling()
|
958 |
-
if args.enable_model_cpu_offload:
|
959 |
-
pipe.enable_model_cpu_offload()
|
960 |
-
|
961 |
-
# Load LoRA weights
|
962 |
-
lora_scaling = args.lora_alpha / args.rank
|
963 |
-
pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
|
964 |
-
pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
|
965 |
-
|
966 |
-
# Run inference
|
967 |
-
validation_outputs = []
|
968 |
-
if args.validation_prompt and args.num_validation_videos > 0:
|
969 |
-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
970 |
-
validation_images = args.validation_images.split(args.validation_prompt_separator)
|
971 |
-
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
|
972 |
-
pipeline_args = {
|
973 |
-
"image": load_image(validation_image),
|
974 |
-
"prompt": validation_prompt,
|
975 |
-
"guidance_scale": args.guidance_scale,
|
976 |
-
"use_dynamic_cfg": args.use_dynamic_cfg,
|
977 |
-
"height": args.height,
|
978 |
-
"width": args.width,
|
979 |
-
}
|
980 |
-
|
981 |
-
video = log_validation(
|
982 |
-
accelerator=accelerator,
|
983 |
-
pipe=pipe,
|
984 |
-
args=args,
|
985 |
-
pipeline_args=pipeline_args,
|
986 |
-
is_final_validation=True,
|
987 |
-
)
|
988 |
-
validation_outputs.extend(video)
|
989 |
-
|
990 |
-
accelerator.print("===== Memory after testing =====")
|
991 |
-
print_memory(accelerator.device)
|
992 |
-
reset_memory(accelerator.device)
|
993 |
-
torch.cuda.synchronize(accelerator.device)
|
994 |
-
|
995 |
-
if args.push_to_hub:
|
996 |
-
save_model_card(
|
997 |
-
repo_id,
|
998 |
-
videos=validation_outputs,
|
999 |
-
base_model=args.pretrained_model_name_or_path,
|
1000 |
-
validation_prompt=args.validation_prompt,
|
1001 |
-
repo_folder=args.output_dir,
|
1002 |
-
fps=args.fps,
|
1003 |
-
)
|
1004 |
-
upload_folder(
|
1005 |
-
repo_id=repo_id,
|
1006 |
-
folder_path=args.output_dir,
|
1007 |
-
commit_message="End of training",
|
1008 |
-
ignore_patterns=["step_*", "epoch_*"],
|
1009 |
-
)
|
1010 |
-
|
1011 |
-
accelerator.end_training()
|
1012 |
-
|
1013 |
-
|
1014 |
-
if __name__ == "__main__":
|
1015 |
-
args = get_args()
|
1016 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/cogvideox/cogvideox_image_to_video_sft.py
DELETED
@@ -1,947 +0,0 @@
|
|
1 |
-
# Copyright 2024 The HuggingFace Team.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
import gc
|
17 |
-
import logging
|
18 |
-
import math
|
19 |
-
import os
|
20 |
-
import random
|
21 |
-
import shutil
|
22 |
-
from datetime import timedelta
|
23 |
-
from pathlib import Path
|
24 |
-
from typing import Any, Dict
|
25 |
-
|
26 |
-
import diffusers
|
27 |
-
import torch
|
28 |
-
import transformers
|
29 |
-
import wandb
|
30 |
-
from accelerate import Accelerator, DistributedType, init_empty_weights
|
31 |
-
from accelerate.logging import get_logger
|
32 |
-
from accelerate.utils import (
|
33 |
-
DistributedDataParallelKwargs,
|
34 |
-
InitProcessGroupKwargs,
|
35 |
-
ProjectConfiguration,
|
36 |
-
set_seed,
|
37 |
-
)
|
38 |
-
from diffusers import (
|
39 |
-
AutoencoderKLCogVideoX,
|
40 |
-
CogVideoXDPMScheduler,
|
41 |
-
CogVideoXImageToVideoPipeline,
|
42 |
-
CogVideoXTransformer3DModel,
|
43 |
-
)
|
44 |
-
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
45 |
-
from diffusers.optimization import get_scheduler
|
46 |
-
from diffusers.training_utils import cast_training_params
|
47 |
-
from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
|
48 |
-
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
49 |
-
from huggingface_hub import create_repo, upload_folder
|
50 |
-
from torch.utils.data import DataLoader
|
51 |
-
from tqdm.auto import tqdm
|
52 |
-
from transformers import AutoTokenizer, T5EncoderModel
|
53 |
-
|
54 |
-
|
55 |
-
from args import get_args # isort:skip
|
56 |
-
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
|
57 |
-
from text_encoder import compute_prompt_embeddings # isort:skip
|
58 |
-
from utils import (
|
59 |
-
get_gradient_norm,
|
60 |
-
get_optimizer,
|
61 |
-
prepare_rotary_positional_embeddings,
|
62 |
-
print_memory,
|
63 |
-
reset_memory,
|
64 |
-
unwrap_model,
|
65 |
-
)
|
66 |
-
|
67 |
-
|
68 |
-
logger = get_logger(__name__)
|
69 |
-
|
70 |
-
|
71 |
-
def save_model_card(
|
72 |
-
repo_id: str,
|
73 |
-
videos=None,
|
74 |
-
base_model: str = None,
|
75 |
-
validation_prompt=None,
|
76 |
-
repo_folder=None,
|
77 |
-
fps=8,
|
78 |
-
):
|
79 |
-
widget_dict = []
|
80 |
-
if videos is not None:
|
81 |
-
for i, video in enumerate(videos):
|
82 |
-
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
|
83 |
-
widget_dict.append(
|
84 |
-
{
|
85 |
-
"text": validation_prompt if validation_prompt else " ",
|
86 |
-
"output": {"url": f"video_{i}.mp4"},
|
87 |
-
}
|
88 |
-
)
|
89 |
-
|
90 |
-
model_description = f"""
|
91 |
-
# CogVideoX Full Finetune
|
92 |
-
|
93 |
-
<Gallery />
|
94 |
-
|
95 |
-
## Model description
|
96 |
-
|
97 |
-
This is a full finetune of the CogVideoX model `{base_model}`.
|
98 |
-
|
99 |
-
## License
|
100 |
-
|
101 |
-
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).
|
102 |
-
"""
|
103 |
-
model_card = load_or_create_model_card(
|
104 |
-
repo_id_or_path=repo_id,
|
105 |
-
from_training=True,
|
106 |
-
license="other",
|
107 |
-
base_model=base_model,
|
108 |
-
prompt=validation_prompt,
|
109 |
-
model_description=model_description,
|
110 |
-
widget=widget_dict,
|
111 |
-
)
|
112 |
-
tags = [
|
113 |
-
"text-to-video",
|
114 |
-
"image-to-video",
|
115 |
-
"diffusers-training",
|
116 |
-
"diffusers",
|
117 |
-
"cogvideox",
|
118 |
-
"cogvideox-diffusers",
|
119 |
-
]
|
120 |
-
|
121 |
-
model_card = populate_model_card(model_card, tags=tags)
|
122 |
-
model_card.save(os.path.join(repo_folder, "README.md"))
|
123 |
-
|
124 |
-
|
125 |
-
def log_validation(
|
126 |
-
accelerator: Accelerator,
|
127 |
-
pipe: CogVideoXImageToVideoPipeline,
|
128 |
-
args: Dict[str, Any],
|
129 |
-
pipeline_args: Dict[str, Any],
|
130 |
-
is_final_validation: bool = False,
|
131 |
-
):
|
132 |
-
logger.info(
|
133 |
-
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
134 |
-
)
|
135 |
-
|
136 |
-
pipe = pipe.to(accelerator.device)
|
137 |
-
|
138 |
-
# run inference
|
139 |
-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
140 |
-
|
141 |
-
videos = []
|
142 |
-
for _ in range(args.num_validation_videos):
|
143 |
-
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
144 |
-
videos.append(video)
|
145 |
-
|
146 |
-
for tracker in accelerator.trackers:
|
147 |
-
phase_name = "test" if is_final_validation else "validation"
|
148 |
-
if tracker.name == "wandb":
|
149 |
-
video_filenames = []
|
150 |
-
for i, video in enumerate(videos):
|
151 |
-
prompt = (
|
152 |
-
pipeline_args["prompt"][:25]
|
153 |
-
.replace(" ", "_")
|
154 |
-
.replace(" ", "_")
|
155 |
-
.replace("'", "_")
|
156 |
-
.replace('"', "_")
|
157 |
-
.replace("/", "_")
|
158 |
-
)
|
159 |
-
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
|
160 |
-
export_to_video(video, filename, fps=8)
|
161 |
-
video_filenames.append(filename)
|
162 |
-
|
163 |
-
tracker.log(
|
164 |
-
{
|
165 |
-
phase_name: [
|
166 |
-
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
|
167 |
-
for i, filename in enumerate(video_filenames)
|
168 |
-
]
|
169 |
-
}
|
170 |
-
)
|
171 |
-
|
172 |
-
return videos
|
173 |
-
|
174 |
-
|
175 |
-
def run_validation(
|
176 |
-
args: Dict[str, Any],
|
177 |
-
accelerator: Accelerator,
|
178 |
-
transformer,
|
179 |
-
scheduler,
|
180 |
-
model_config: Dict[str, Any],
|
181 |
-
weight_dtype: torch.dtype,
|
182 |
-
) -> None:
|
183 |
-
accelerator.print("===== Memory before validation =====")
|
184 |
-
print_memory(accelerator.device)
|
185 |
-
torch.cuda.synchronize(accelerator.device)
|
186 |
-
|
187 |
-
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
188 |
-
args.pretrained_model_name_or_path,
|
189 |
-
transformer=unwrap_model(accelerator, transformer),
|
190 |
-
scheduler=scheduler,
|
191 |
-
revision=args.revision,
|
192 |
-
variant=args.variant,
|
193 |
-
torch_dtype=weight_dtype,
|
194 |
-
)
|
195 |
-
|
196 |
-
if args.enable_slicing:
|
197 |
-
pipe.vae.enable_slicing()
|
198 |
-
if args.enable_tiling:
|
199 |
-
pipe.vae.enable_tiling()
|
200 |
-
if args.enable_model_cpu_offload:
|
201 |
-
pipe.enable_model_cpu_offload()
|
202 |
-
|
203 |
-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
204 |
-
validation_images = args.validation_images.split(args.validation_prompt_separator)
|
205 |
-
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
|
206 |
-
pipeline_args = {
|
207 |
-
"image": load_image(validation_image),
|
208 |
-
"prompt": validation_prompt,
|
209 |
-
"guidance_scale": args.guidance_scale,
|
210 |
-
"use_dynamic_cfg": args.use_dynamic_cfg,
|
211 |
-
"height": args.height,
|
212 |
-
"width": args.width,
|
213 |
-
"max_sequence_length": model_config.max_text_seq_length,
|
214 |
-
}
|
215 |
-
|
216 |
-
log_validation(
|
217 |
-
pipe=pipe,
|
218 |
-
args=args,
|
219 |
-
accelerator=accelerator,
|
220 |
-
pipeline_args=pipeline_args,
|
221 |
-
)
|
222 |
-
|
223 |
-
accelerator.print("===== Memory after validation =====")
|
224 |
-
print_memory(accelerator.device)
|
225 |
-
reset_memory(accelerator.device)
|
226 |
-
|
227 |
-
del pipe
|
228 |
-
gc.collect()
|
229 |
-
torch.cuda.empty_cache()
|
230 |
-
torch.cuda.synchronize(accelerator.device)
|
231 |
-
|
232 |
-
|
233 |
-
class CollateFunction:
|
234 |
-
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
|
235 |
-
self.weight_dtype = weight_dtype
|
236 |
-
self.load_tensors = load_tensors
|
237 |
-
|
238 |
-
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
239 |
-
prompts = [x["prompt"] for x in data[0]]
|
240 |
-
|
241 |
-
if self.load_tensors:
|
242 |
-
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
|
243 |
-
|
244 |
-
images = [x["image"] for x in data[0]]
|
245 |
-
images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True)
|
246 |
-
|
247 |
-
videos = [x["video"] for x in data[0]]
|
248 |
-
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
|
249 |
-
|
250 |
-
return {
|
251 |
-
"images": images,
|
252 |
-
"videos": videos,
|
253 |
-
"prompts": prompts,
|
254 |
-
}
|
255 |
-
|
256 |
-
|
257 |
-
def main(args):
|
258 |
-
if args.report_to == "wandb" and args.hub_token is not None:
|
259 |
-
raise ValueError(
|
260 |
-
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
261 |
-
" Please use `huggingface-cli login` to authenticate with the Hub."
|
262 |
-
)
|
263 |
-
|
264 |
-
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
265 |
-
# due to pytorch#99272, MPS does not yet support bfloat16.
|
266 |
-
raise ValueError(
|
267 |
-
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
268 |
-
)
|
269 |
-
|
270 |
-
logging_dir = Path(args.output_dir, args.logging_dir)
|
271 |
-
|
272 |
-
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
273 |
-
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
274 |
-
init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
|
275 |
-
accelerator = Accelerator(
|
276 |
-
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
277 |
-
mixed_precision=args.mixed_precision,
|
278 |
-
log_with=args.report_to,
|
279 |
-
project_config=accelerator_project_config,
|
280 |
-
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
|
281 |
-
)
|
282 |
-
|
283 |
-
# Disable AMP for MPS.
|
284 |
-
if torch.backends.mps.is_available():
|
285 |
-
accelerator.native_amp = False
|
286 |
-
|
287 |
-
# Make one log on every process with the configuration for debugging.
|
288 |
-
logging.basicConfig(
|
289 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
290 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
291 |
-
level=logging.INFO,
|
292 |
-
)
|
293 |
-
logger.info(accelerator.state, main_process_only=False)
|
294 |
-
if accelerator.is_local_main_process:
|
295 |
-
transformers.utils.logging.set_verbosity_warning()
|
296 |
-
diffusers.utils.logging.set_verbosity_info()
|
297 |
-
else:
|
298 |
-
transformers.utils.logging.set_verbosity_error()
|
299 |
-
diffusers.utils.logging.set_verbosity_error()
|
300 |
-
|
301 |
-
# If passed along, set the training seed now.
|
302 |
-
if args.seed is not None:
|
303 |
-
set_seed(args.seed)
|
304 |
-
|
305 |
-
# Handle the repository creation
|
306 |
-
if accelerator.is_main_process:
|
307 |
-
if args.output_dir is not None:
|
308 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
309 |
-
|
310 |
-
if args.push_to_hub:
|
311 |
-
repo_id = create_repo(
|
312 |
-
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
313 |
-
exist_ok=True,
|
314 |
-
).repo_id
|
315 |
-
|
316 |
-
# Prepare models and scheduler
|
317 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
318 |
-
args.pretrained_model_name_or_path,
|
319 |
-
subfolder="tokenizer",
|
320 |
-
revision=args.revision,
|
321 |
-
)
|
322 |
-
|
323 |
-
text_encoder = T5EncoderModel.from_pretrained(
|
324 |
-
args.pretrained_model_name_or_path,
|
325 |
-
subfolder="text_encoder",
|
326 |
-
revision=args.revision,
|
327 |
-
)
|
328 |
-
|
329 |
-
# CogVideoX-2b weights are stored in float16
|
330 |
-
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
|
331 |
-
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
|
332 |
-
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
333 |
-
args.pretrained_model_name_or_path,
|
334 |
-
subfolder="transformer",
|
335 |
-
torch_dtype=load_dtype,
|
336 |
-
revision=args.revision,
|
337 |
-
variant=args.variant,
|
338 |
-
)
|
339 |
-
|
340 |
-
if args.ignore_learned_positional_embeddings:
|
341 |
-
del transformer.patch_embed.pos_embedding
|
342 |
-
transformer.patch_embed.use_learned_positional_embeddings = False
|
343 |
-
transformer.config.use_learned_positional_embeddings = False
|
344 |
-
|
345 |
-
vae = AutoencoderKLCogVideoX.from_pretrained(
|
346 |
-
args.pretrained_model_name_or_path,
|
347 |
-
subfolder="vae",
|
348 |
-
revision=args.revision,
|
349 |
-
variant=args.variant,
|
350 |
-
)
|
351 |
-
|
352 |
-
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
353 |
-
|
354 |
-
if args.enable_slicing:
|
355 |
-
vae.enable_slicing()
|
356 |
-
if args.enable_tiling:
|
357 |
-
vae.enable_tiling()
|
358 |
-
|
359 |
-
text_encoder.requires_grad_(False)
|
360 |
-
vae.requires_grad_(False)
|
361 |
-
transformer.requires_grad_(True)
|
362 |
-
|
363 |
-
VAE_SCALING_FACTOR = vae.config.scaling_factor
|
364 |
-
VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
|
365 |
-
RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
|
366 |
-
RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
|
367 |
-
|
368 |
-
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
369 |
-
# as these weights are only used for inference, keeping weights in full precision is not required.
|
370 |
-
weight_dtype = torch.float32
|
371 |
-
if accelerator.state.deepspeed_plugin:
|
372 |
-
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
373 |
-
if (
|
374 |
-
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
375 |
-
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
376 |
-
):
|
377 |
-
weight_dtype = torch.float16
|
378 |
-
if (
|
379 |
-
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
380 |
-
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
381 |
-
):
|
382 |
-
weight_dtype = torch.bfloat16
|
383 |
-
else:
|
384 |
-
if accelerator.mixed_precision == "fp16":
|
385 |
-
weight_dtype = torch.float16
|
386 |
-
elif accelerator.mixed_precision == "bf16":
|
387 |
-
weight_dtype = torch.bfloat16
|
388 |
-
|
389 |
-
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
390 |
-
# due to pytorch#99272, MPS does not yet support bfloat16.
|
391 |
-
raise ValueError(
|
392 |
-
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
393 |
-
)
|
394 |
-
|
395 |
-
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
396 |
-
transformer.to(accelerator.device, dtype=weight_dtype)
|
397 |
-
vae.to(accelerator.device, dtype=weight_dtype)
|
398 |
-
|
399 |
-
if args.gradient_checkpointing:
|
400 |
-
transformer.enable_gradient_checkpointing()
|
401 |
-
|
402 |
-
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
403 |
-
def save_model_hook(models, weights, output_dir):
|
404 |
-
if accelerator.is_main_process:
|
405 |
-
for model in models:
|
406 |
-
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
407 |
-
model = unwrap_model(accelerator, model)
|
408 |
-
model.save_pretrained(
|
409 |
-
os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
410 |
-
)
|
411 |
-
else:
|
412 |
-
raise ValueError(f"Unexpected save model: {model.__class__}")
|
413 |
-
|
414 |
-
# make sure to pop weight so that corresponding model is not saved again
|
415 |
-
if weights:
|
416 |
-
weights.pop()
|
417 |
-
|
418 |
-
def load_model_hook(models, input_dir):
|
419 |
-
transformer_ = None
|
420 |
-
init_under_meta = False
|
421 |
-
|
422 |
-
# This is a bit of a hack but I don't know any other solution.
|
423 |
-
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
424 |
-
while len(models) > 0:
|
425 |
-
model = models.pop()
|
426 |
-
|
427 |
-
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
428 |
-
transformer_ = unwrap_model(accelerator, model)
|
429 |
-
else:
|
430 |
-
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
|
431 |
-
else:
|
432 |
-
with init_empty_weights():
|
433 |
-
transformer_ = CogVideoXTransformer3DModel.from_config(
|
434 |
-
args.pretrained_model_name_or_path, subfolder="transformer"
|
435 |
-
)
|
436 |
-
init_under_meta = True
|
437 |
-
|
438 |
-
load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer"))
|
439 |
-
transformer_.register_to_config(**load_model.config)
|
440 |
-
transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta)
|
441 |
-
del load_model
|
442 |
-
|
443 |
-
# Make sure the trainable params are in float32. This is again needed since the base models
|
444 |
-
# are in `weight_dtype`. More details:
|
445 |
-
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
446 |
-
if args.mixed_precision == "fp16":
|
447 |
-
cast_training_params([transformer_])
|
448 |
-
|
449 |
-
accelerator.register_save_state_pre_hook(save_model_hook)
|
450 |
-
accelerator.register_load_state_pre_hook(load_model_hook)
|
451 |
-
|
452 |
-
# Enable TF32 for faster training on Ampere GPUs,
|
453 |
-
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
454 |
-
if args.allow_tf32 and torch.cuda.is_available():
|
455 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
456 |
-
|
457 |
-
if args.scale_lr:
|
458 |
-
args.learning_rate = (
|
459 |
-
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
460 |
-
)
|
461 |
-
|
462 |
-
# Make sure the trainable params are in float32.
|
463 |
-
if args.mixed_precision == "fp16":
|
464 |
-
cast_training_params([transformer], dtype=torch.float32)
|
465 |
-
|
466 |
-
transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
467 |
-
|
468 |
-
# Optimization parameters
|
469 |
-
transformer_parameters_with_lr = {
|
470 |
-
"params": transformer_parameters,
|
471 |
-
"lr": args.learning_rate,
|
472 |
-
}
|
473 |
-
params_to_optimize = [transformer_parameters_with_lr]
|
474 |
-
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
475 |
-
|
476 |
-
use_deepspeed_optimizer = (
|
477 |
-
accelerator.state.deepspeed_plugin is not None
|
478 |
-
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
479 |
-
)
|
480 |
-
use_deepspeed_scheduler = (
|
481 |
-
accelerator.state.deepspeed_plugin is not None
|
482 |
-
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
483 |
-
)
|
484 |
-
|
485 |
-
optimizer = get_optimizer(
|
486 |
-
params_to_optimize=params_to_optimize,
|
487 |
-
optimizer_name=args.optimizer,
|
488 |
-
learning_rate=args.learning_rate,
|
489 |
-
beta1=args.beta1,
|
490 |
-
beta2=args.beta2,
|
491 |
-
beta3=args.beta3,
|
492 |
-
epsilon=args.epsilon,
|
493 |
-
weight_decay=args.weight_decay,
|
494 |
-
prodigy_decouple=args.prodigy_decouple,
|
495 |
-
prodigy_use_bias_correction=args.prodigy_use_bias_correction,
|
496 |
-
prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
|
497 |
-
use_8bit=args.use_8bit,
|
498 |
-
use_4bit=args.use_4bit,
|
499 |
-
use_torchao=args.use_torchao,
|
500 |
-
use_deepspeed=use_deepspeed_optimizer,
|
501 |
-
use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
|
502 |
-
offload_gradients=args.offload_gradients,
|
503 |
-
)
|
504 |
-
|
505 |
-
# Dataset and DataLoader
|
506 |
-
dataset_init_kwargs = {
|
507 |
-
"data_root": args.data_root,
|
508 |
-
"dataset_file": args.dataset_file,
|
509 |
-
"caption_column": args.caption_column,
|
510 |
-
"video_column": args.video_column,
|
511 |
-
"max_num_frames": args.max_num_frames,
|
512 |
-
"id_token": args.id_token,
|
513 |
-
"height_buckets": args.height_buckets,
|
514 |
-
"width_buckets": args.width_buckets,
|
515 |
-
"frame_buckets": args.frame_buckets,
|
516 |
-
"load_tensors": args.load_tensors,
|
517 |
-
"random_flip": args.random_flip,
|
518 |
-
"image_to_video": True,
|
519 |
-
}
|
520 |
-
if args.video_reshape_mode is None:
|
521 |
-
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
|
522 |
-
else:
|
523 |
-
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
|
524 |
-
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
|
525 |
-
)
|
526 |
-
|
527 |
-
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
|
528 |
-
|
529 |
-
train_dataloader = DataLoader(
|
530 |
-
train_dataset,
|
531 |
-
batch_size=1,
|
532 |
-
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
|
533 |
-
collate_fn=collate_fn,
|
534 |
-
num_workers=args.dataloader_num_workers,
|
535 |
-
pin_memory=args.pin_memory,
|
536 |
-
)
|
537 |
-
|
538 |
-
# Scheduler and math around the number of training steps.
|
539 |
-
overrode_max_train_steps = False
|
540 |
-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
541 |
-
if args.max_train_steps is None:
|
542 |
-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
543 |
-
overrode_max_train_steps = True
|
544 |
-
|
545 |
-
if args.use_cpu_offload_optimizer:
|
546 |
-
lr_scheduler = None
|
547 |
-
accelerator.print(
|
548 |
-
"CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
|
549 |
-
"you are training with those settings, they will be ignored."
|
550 |
-
)
|
551 |
-
else:
|
552 |
-
if use_deepspeed_scheduler:
|
553 |
-
from accelerate.utils import DummyScheduler
|
554 |
-
|
555 |
-
lr_scheduler = DummyScheduler(
|
556 |
-
name=args.lr_scheduler,
|
557 |
-
optimizer=optimizer,
|
558 |
-
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
559 |
-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
560 |
-
)
|
561 |
-
else:
|
562 |
-
lr_scheduler = get_scheduler(
|
563 |
-
args.lr_scheduler,
|
564 |
-
optimizer=optimizer,
|
565 |
-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
566 |
-
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
567 |
-
num_cycles=args.lr_num_cycles,
|
568 |
-
power=args.lr_power,
|
569 |
-
)
|
570 |
-
|
571 |
-
# Prepare everything with our `accelerator`.
|
572 |
-
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
573 |
-
transformer, optimizer, train_dataloader, lr_scheduler
|
574 |
-
)
|
575 |
-
|
576 |
-
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
577 |
-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
578 |
-
if overrode_max_train_steps:
|
579 |
-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
580 |
-
# Afterwards we recalculate our number of training epochs
|
581 |
-
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
582 |
-
|
583 |
-
# We need to initialize the trackers we use, and also store our configuration.
|
584 |
-
# The trackers initializes automatically on the main process.
|
585 |
-
if accelerator.is_main_process:
|
586 |
-
tracker_name = args.tracker_name or "cogvideox-sft"
|
587 |
-
accelerator.init_trackers(tracker_name, config=vars(args))
|
588 |
-
|
589 |
-
accelerator.print("===== Memory before training =====")
|
590 |
-
reset_memory(accelerator.device)
|
591 |
-
print_memory(accelerator.device)
|
592 |
-
|
593 |
-
# Train!
|
594 |
-
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
595 |
-
|
596 |
-
accelerator.print("***** Running training *****")
|
597 |
-
accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
|
598 |
-
accelerator.print(f" Num examples = {len(train_dataset)}")
|
599 |
-
accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
|
600 |
-
accelerator.print(f" Num epochs = {args.num_train_epochs}")
|
601 |
-
accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
602 |
-
accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
603 |
-
accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
604 |
-
accelerator.print(f" Total optimization steps = {args.max_train_steps}")
|
605 |
-
global_step = 0
|
606 |
-
first_epoch = 0
|
607 |
-
|
608 |
-
# Potentially load in the weights and states from a previous save
|
609 |
-
if not args.resume_from_checkpoint:
|
610 |
-
initial_global_step = 0
|
611 |
-
else:
|
612 |
-
if args.resume_from_checkpoint != "latest":
|
613 |
-
path = os.path.basename(args.resume_from_checkpoint)
|
614 |
-
else:
|
615 |
-
# Get the most recent checkpoint
|
616 |
-
dirs = os.listdir(args.output_dir)
|
617 |
-
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
618 |
-
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
619 |
-
path = dirs[-1] if len(dirs) > 0 else None
|
620 |
-
|
621 |
-
if path is None:
|
622 |
-
accelerator.print(
|
623 |
-
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
624 |
-
)
|
625 |
-
args.resume_from_checkpoint = None
|
626 |
-
initial_global_step = 0
|
627 |
-
else:
|
628 |
-
accelerator.print(f"Resuming from checkpoint {path}")
|
629 |
-
accelerator.load_state(os.path.join(args.output_dir, path))
|
630 |
-
global_step = int(path.split("-")[1])
|
631 |
-
|
632 |
-
initial_global_step = global_step
|
633 |
-
first_epoch = global_step // num_update_steps_per_epoch
|
634 |
-
|
635 |
-
progress_bar = tqdm(
|
636 |
-
range(0, args.max_train_steps),
|
637 |
-
initial=initial_global_step,
|
638 |
-
desc="Steps",
|
639 |
-
# Only show the progress bar once on each machine.
|
640 |
-
disable=not accelerator.is_local_main_process,
|
641 |
-
)
|
642 |
-
|
643 |
-
# For DeepSpeed training
|
644 |
-
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
645 |
-
|
646 |
-
if args.load_tensors:
|
647 |
-
del vae, text_encoder
|
648 |
-
gc.collect()
|
649 |
-
torch.cuda.empty_cache()
|
650 |
-
torch.cuda.synchronize(accelerator.device)
|
651 |
-
|
652 |
-
alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
|
653 |
-
|
654 |
-
for epoch in range(first_epoch, args.num_train_epochs):
|
655 |
-
transformer.train()
|
656 |
-
for step, batch in enumerate(train_dataloader):
|
657 |
-
models_to_accumulate = [transformer]
|
658 |
-
logs = {}
|
659 |
-
|
660 |
-
with accelerator.accumulate(models_to_accumulate):
|
661 |
-
images = batch["images"].to(accelerator.device, non_blocking=True)
|
662 |
-
videos = batch["videos"].to(accelerator.device, non_blocking=True)
|
663 |
-
prompts = batch["prompts"]
|
664 |
-
|
665 |
-
# Encode videos
|
666 |
-
if not args.load_tensors:
|
667 |
-
images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
668 |
-
image_noise_sigma = torch.normal(
|
669 |
-
mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
|
670 |
-
)
|
671 |
-
image_noise_sigma = torch.exp(image_noise_sigma)
|
672 |
-
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
673 |
-
image_latent_dist = vae.encode(noisy_images).latent_dist
|
674 |
-
|
675 |
-
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
676 |
-
latent_dist = vae.encode(videos).latent_dist
|
677 |
-
else:
|
678 |
-
image_latent_dist = DiagonalGaussianDistribution(images)
|
679 |
-
latent_dist = DiagonalGaussianDistribution(videos)
|
680 |
-
|
681 |
-
image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR
|
682 |
-
image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
683 |
-
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
684 |
-
|
685 |
-
video_latents = latent_dist.sample() * VAE_SCALING_FACTOR
|
686 |
-
video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
687 |
-
video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
688 |
-
|
689 |
-
padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])
|
690 |
-
latent_padding = image_latents.new_zeros(padding_shape)
|
691 |
-
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
692 |
-
|
693 |
-
if random.random() < args.noised_image_dropout:
|
694 |
-
image_latents = torch.zeros_like(image_latents)
|
695 |
-
|
696 |
-
# Encode prompts
|
697 |
-
if not args.load_tensors:
|
698 |
-
prompt_embeds = compute_prompt_embeddings(
|
699 |
-
tokenizer,
|
700 |
-
text_encoder,
|
701 |
-
prompts,
|
702 |
-
model_config.max_text_seq_length,
|
703 |
-
accelerator.device,
|
704 |
-
weight_dtype,
|
705 |
-
requires_grad=False,
|
706 |
-
)
|
707 |
-
else:
|
708 |
-
prompt_embeds = prompts.to(dtype=weight_dtype)
|
709 |
-
|
710 |
-
# Sample noise that will be added to the latents
|
711 |
-
noise = torch.randn_like(video_latents)
|
712 |
-
batch_size, num_frames, num_channels, height, width = video_latents.shape
|
713 |
-
|
714 |
-
# Sample a random timestep for each image
|
715 |
-
timesteps = torch.randint(
|
716 |
-
0,
|
717 |
-
scheduler.config.num_train_timesteps,
|
718 |
-
(batch_size,),
|
719 |
-
dtype=torch.int64,
|
720 |
-
device=accelerator.device,
|
721 |
-
)
|
722 |
-
|
723 |
-
# Prepare rotary embeds
|
724 |
-
image_rotary_emb = (
|
725 |
-
prepare_rotary_positional_embeddings(
|
726 |
-
height=height * VAE_SCALE_FACTOR_SPATIAL,
|
727 |
-
width=width * VAE_SCALE_FACTOR_SPATIAL,
|
728 |
-
num_frames=num_frames,
|
729 |
-
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
|
730 |
-
patch_size=model_config.patch_size,
|
731 |
-
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
|
732 |
-
attention_head_dim=model_config.attention_head_dim,
|
733 |
-
device=accelerator.device,
|
734 |
-
base_height=RoPE_BASE_HEIGHT,
|
735 |
-
base_width=RoPE_BASE_WIDTH,
|
736 |
-
)
|
737 |
-
if model_config.use_rotary_positional_embeddings
|
738 |
-
else None
|
739 |
-
)
|
740 |
-
|
741 |
-
# Add noise to the model input according to the noise magnitude at each timestep
|
742 |
-
# (this is the forward diffusion process)
|
743 |
-
noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
|
744 |
-
noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
|
745 |
-
model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
|
746 |
-
ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None,
|
747 |
-
ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0)
|
748 |
-
# Predict the noise residual
|
749 |
-
model_output = transformer(
|
750 |
-
hidden_states=noisy_model_input,
|
751 |
-
encoder_hidden_states=prompt_embeds,
|
752 |
-
timestep=timesteps,
|
753 |
-
ofs=ofs_emb,
|
754 |
-
image_rotary_emb=image_rotary_emb,
|
755 |
-
return_dict=False,
|
756 |
-
)[0]
|
757 |
-
|
758 |
-
model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
|
759 |
-
|
760 |
-
weights = 1 / (1 - alphas_cumprod[timesteps])
|
761 |
-
while len(weights.shape) < len(model_pred.shape):
|
762 |
-
weights = weights.unsqueeze(-1)
|
763 |
-
|
764 |
-
target = video_latents
|
765 |
-
|
766 |
-
loss = torch.mean(
|
767 |
-
(weights * (model_pred - target) ** 2).reshape(batch_size, -1),
|
768 |
-
dim=1,
|
769 |
-
)
|
770 |
-
loss = loss.mean()
|
771 |
-
accelerator.backward(loss)
|
772 |
-
|
773 |
-
if accelerator.sync_gradients:
|
774 |
-
gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
|
775 |
-
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
|
776 |
-
gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
|
777 |
-
logs.update(
|
778 |
-
{
|
779 |
-
"gradient_norm_before_clip": gradient_norm_before_clip,
|
780 |
-
"gradient_norm_after_clip": gradient_norm_after_clip,
|
781 |
-
}
|
782 |
-
)
|
783 |
-
if accelerator.state.deepspeed_plugin is None:
|
784 |
-
optimizer.step()
|
785 |
-
optimizer.zero_grad()
|
786 |
-
|
787 |
-
if not args.use_cpu_offload_optimizer:
|
788 |
-
lr_scheduler.step()
|
789 |
-
|
790 |
-
# Checks if the accelerator has performed an optimization step behind the scenes
|
791 |
-
if accelerator.sync_gradients:
|
792 |
-
progress_bar.update(1)
|
793 |
-
global_step += 1
|
794 |
-
|
795 |
-
# Checkpointing
|
796 |
-
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
797 |
-
if global_step % args.checkpointing_steps == 0:
|
798 |
-
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
799 |
-
if args.checkpoints_total_limit is not None:
|
800 |
-
checkpoints = os.listdir(args.output_dir)
|
801 |
-
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
802 |
-
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
803 |
-
|
804 |
-
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
805 |
-
if len(checkpoints) >= args.checkpoints_total_limit:
|
806 |
-
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
807 |
-
removing_checkpoints = checkpoints[0:num_to_remove]
|
808 |
-
|
809 |
-
logger.info(
|
810 |
-
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
811 |
-
)
|
812 |
-
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
|
813 |
-
|
814 |
-
for removing_checkpoint in removing_checkpoints:
|
815 |
-
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
816 |
-
shutil.rmtree(removing_checkpoint)
|
817 |
-
|
818 |
-
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
819 |
-
accelerator.save_state(save_path)
|
820 |
-
logger.info(f"Saved state to {save_path}")
|
821 |
-
|
822 |
-
# Validation
|
823 |
-
should_run_validation = args.validation_prompt is not None and (
|
824 |
-
args.validation_steps is not None and global_step % args.validation_steps == 0
|
825 |
-
)
|
826 |
-
if should_run_validation:
|
827 |
-
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
|
828 |
-
|
829 |
-
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
|
830 |
-
logs.update(
|
831 |
-
{
|
832 |
-
"loss": loss.detach().item(),
|
833 |
-
"lr": last_lr,
|
834 |
-
}
|
835 |
-
)
|
836 |
-
progress_bar.set_postfix(**logs)
|
837 |
-
accelerator.log(logs, step=global_step)
|
838 |
-
|
839 |
-
if global_step >= args.max_train_steps:
|
840 |
-
break
|
841 |
-
|
842 |
-
if accelerator.is_main_process:
|
843 |
-
should_run_validation = args.validation_prompt is not None and (
|
844 |
-
args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0
|
845 |
-
)
|
846 |
-
if should_run_validation:
|
847 |
-
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
|
848 |
-
accelerator.wait_for_everyone()
|
849 |
-
|
850 |
-
if accelerator.is_main_process:
|
851 |
-
transformer = unwrap_model(accelerator, transformer)
|
852 |
-
dtype = (
|
853 |
-
torch.float16
|
854 |
-
if args.mixed_precision == "fp16"
|
855 |
-
else torch.bfloat16
|
856 |
-
if args.mixed_precision == "bf16"
|
857 |
-
else torch.float32
|
858 |
-
)
|
859 |
-
transformer = transformer.to(dtype)
|
860 |
-
|
861 |
-
transformer.save_pretrained(
|
862 |
-
os.path.join(args.output_dir, "transformer"),
|
863 |
-
safe_serialization=True,
|
864 |
-
max_shard_size="5GB",
|
865 |
-
)
|
866 |
-
|
867 |
-
# Cleanup trained models to save memory
|
868 |
-
if args.load_tensors:
|
869 |
-
del transformer
|
870 |
-
else:
|
871 |
-
del transformer, text_encoder, vae
|
872 |
-
|
873 |
-
gc.collect()
|
874 |
-
torch.cuda.empty_cache()
|
875 |
-
torch.cuda.synchronize(accelerator.device)
|
876 |
-
|
877 |
-
accelerator.print("===== Memory before testing =====")
|
878 |
-
print_memory(accelerator.device)
|
879 |
-
reset_memory(accelerator.device)
|
880 |
-
|
881 |
-
# Final test inference
|
882 |
-
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
883 |
-
args.pretrained_model_name_or_path,
|
884 |
-
revision=args.revision,
|
885 |
-
variant=args.variant,
|
886 |
-
torch_dtype=weight_dtype,
|
887 |
-
)
|
888 |
-
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
889 |
-
|
890 |
-
if args.enable_slicing:
|
891 |
-
pipe.vae.enable_slicing()
|
892 |
-
if args.enable_tiling:
|
893 |
-
pipe.vae.enable_tiling()
|
894 |
-
if args.enable_model_cpu_offload:
|
895 |
-
pipe.enable_model_cpu_offload()
|
896 |
-
|
897 |
-
# Run inference
|
898 |
-
validation_outputs = []
|
899 |
-
if args.validation_prompt and args.num_validation_videos > 0:
|
900 |
-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
901 |
-
validation_images = args.validation_images.split(args.validation_prompt_separator)
|
902 |
-
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
|
903 |
-
pipeline_args = {
|
904 |
-
"image": load_image(validation_image),
|
905 |
-
"prompt": validation_prompt,
|
906 |
-
"guidance_scale": args.guidance_scale,
|
907 |
-
"use_dynamic_cfg": args.use_dynamic_cfg,
|
908 |
-
"height": args.height,
|
909 |
-
"width": args.width,
|
910 |
-
}
|
911 |
-
|
912 |
-
video = log_validation(
|
913 |
-
accelerator=accelerator,
|
914 |
-
pipe=pipe,
|
915 |
-
args=args,
|
916 |
-
pipeline_args=pipeline_args,
|
917 |
-
is_final_validation=True,
|
918 |
-
)
|
919 |
-
validation_outputs.extend(video)
|
920 |
-
|
921 |
-
accelerator.print("===== Memory after testing =====")
|
922 |
-
print_memory(accelerator.device)
|
923 |
-
reset_memory(accelerator.device)
|
924 |
-
torch.cuda.synchronize(accelerator.device)
|
925 |
-
|
926 |
-
if args.push_to_hub:
|
927 |
-
save_model_card(
|
928 |
-
repo_id,
|
929 |
-
videos=validation_outputs,
|
930 |
-
base_model=args.pretrained_model_name_or_path,
|
931 |
-
validation_prompt=args.validation_prompt,
|
932 |
-
repo_folder=args.output_dir,
|
933 |
-
fps=args.fps,
|
934 |
-
)
|
935 |
-
upload_folder(
|
936 |
-
repo_id=repo_id,
|
937 |
-
folder_path=args.output_dir,
|
938 |
-
commit_message="End of training",
|
939 |
-
ignore_patterns=["step_*", "epoch_*"],
|
940 |
-
)
|
941 |
-
|
942 |
-
accelerator.end_training()
|
943 |
-
|
944 |
-
|
945 |
-
if __name__ == "__main__":
|
946 |
-
args = get_args()
|
947 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/cogvideox/cogvideox_text_to_video_lora.py
DELETED
@@ -1,955 +0,0 @@
|
|
1 |
-
# Copyright 2024 The HuggingFace Team.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
import gc
|
17 |
-
import logging
|
18 |
-
import math
|
19 |
-
import os
|
20 |
-
import shutil
|
21 |
-
from datetime import timedelta
|
22 |
-
from pathlib import Path
|
23 |
-
from typing import Any, Dict
|
24 |
-
|
25 |
-
import diffusers
|
26 |
-
import torch
|
27 |
-
import transformers
|
28 |
-
import wandb
|
29 |
-
from accelerate import Accelerator, DistributedType
|
30 |
-
from accelerate.logging import get_logger
|
31 |
-
from accelerate.utils import (
|
32 |
-
DistributedDataParallelKwargs,
|
33 |
-
InitProcessGroupKwargs,
|
34 |
-
ProjectConfiguration,
|
35 |
-
set_seed,
|
36 |
-
)
|
37 |
-
from diffusers import (
|
38 |
-
AutoencoderKLCogVideoX,
|
39 |
-
CogVideoXDPMScheduler,
|
40 |
-
CogVideoXPipeline,
|
41 |
-
CogVideoXTransformer3DModel,
|
42 |
-
)
|
43 |
-
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
44 |
-
from diffusers.optimization import get_scheduler
|
45 |
-
from diffusers.training_utils import cast_training_params
|
46 |
-
from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video
|
47 |
-
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
48 |
-
from huggingface_hub import create_repo, upload_folder
|
49 |
-
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
50 |
-
from torch.utils.data import DataLoader
|
51 |
-
from tqdm.auto import tqdm
|
52 |
-
from transformers import AutoTokenizer, T5EncoderModel
|
53 |
-
|
54 |
-
|
55 |
-
from args import get_args # isort:skip
|
56 |
-
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
|
57 |
-
from text_encoder import compute_prompt_embeddings # isort:skip
|
58 |
-
from utils import (
|
59 |
-
get_gradient_norm,
|
60 |
-
get_optimizer,
|
61 |
-
prepare_rotary_positional_embeddings,
|
62 |
-
print_memory,
|
63 |
-
reset_memory,
|
64 |
-
unwrap_model,
|
65 |
-
) # isort:skip
|
66 |
-
|
67 |
-
|
68 |
-
logger = get_logger(__name__)
|
69 |
-
|
70 |
-
|
71 |
-
def save_model_card(
|
72 |
-
repo_id: str,
|
73 |
-
videos=None,
|
74 |
-
base_model: str = None,
|
75 |
-
validation_prompt=None,
|
76 |
-
repo_folder=None,
|
77 |
-
fps=8,
|
78 |
-
):
|
79 |
-
widget_dict = []
|
80 |
-
if videos is not None:
|
81 |
-
for i, video in enumerate(videos):
|
82 |
-
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
|
83 |
-
widget_dict.append(
|
84 |
-
{
|
85 |
-
"text": validation_prompt if validation_prompt else " ",
|
86 |
-
"output": {"url": f"video_{i}.mp4"},
|
87 |
-
}
|
88 |
-
)
|
89 |
-
|
90 |
-
model_description = f"""
|
91 |
-
# CogVideoX LoRA Finetune
|
92 |
-
|
93 |
-
<Gallery />
|
94 |
-
|
95 |
-
## Model description
|
96 |
-
|
97 |
-
This is a lora finetune of the CogVideoX model `{base_model}`.
|
98 |
-
|
99 |
-
The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
100 |
-
|
101 |
-
## Download model
|
102 |
-
|
103 |
-
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
|
104 |
-
|
105 |
-
## Usage
|
106 |
-
|
107 |
-
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
|
108 |
-
|
109 |
-
```py
|
110 |
-
import torch
|
111 |
-
from diffusers import CogVideoXPipeline
|
112 |
-
from diffusers.utils import export_to_video
|
113 |
-
|
114 |
-
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
|
115 |
-
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
|
116 |
-
|
117 |
-
# The LoRA adapter weights are determined by what was used for training.
|
118 |
-
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
|
119 |
-
# It can be made lower or higher from what was used in training to decrease or amplify the effect
|
120 |
-
# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
|
121 |
-
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
|
122 |
-
|
123 |
-
video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
124 |
-
export_to_video(video, "output.mp4", fps=8)
|
125 |
-
```
|
126 |
-
|
127 |
-
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
|
128 |
-
|
129 |
-
## License
|
130 |
-
|
131 |
-
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
|
132 |
-
"""
|
133 |
-
model_card = load_or_create_model_card(
|
134 |
-
repo_id_or_path=repo_id,
|
135 |
-
from_training=True,
|
136 |
-
license="other",
|
137 |
-
base_model=base_model,
|
138 |
-
prompt=validation_prompt,
|
139 |
-
model_description=model_description,
|
140 |
-
widget=widget_dict,
|
141 |
-
)
|
142 |
-
tags = [
|
143 |
-
"text-to-video",
|
144 |
-
"diffusers-training",
|
145 |
-
"diffusers",
|
146 |
-
"lora",
|
147 |
-
"cogvideox",
|
148 |
-
"cogvideox-diffusers",
|
149 |
-
"template:sd-lora",
|
150 |
-
]
|
151 |
-
|
152 |
-
model_card = populate_model_card(model_card, tags=tags)
|
153 |
-
model_card.save(os.path.join(repo_folder, "README.md"))
|
154 |
-
|
155 |
-
|
156 |
-
def log_validation(
|
157 |
-
accelerator: Accelerator,
|
158 |
-
pipe: CogVideoXPipeline,
|
159 |
-
args: Dict[str, Any],
|
160 |
-
pipeline_args: Dict[str, Any],
|
161 |
-
epoch,
|
162 |
-
is_final_validation: bool = False,
|
163 |
-
):
|
164 |
-
logger.info(
|
165 |
-
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
166 |
-
)
|
167 |
-
|
168 |
-
pipe = pipe.to(accelerator.device)
|
169 |
-
|
170 |
-
# run inference
|
171 |
-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
172 |
-
|
173 |
-
videos = []
|
174 |
-
for _ in range(args.num_validation_videos):
|
175 |
-
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
176 |
-
videos.append(video)
|
177 |
-
|
178 |
-
for tracker in accelerator.trackers:
|
179 |
-
phase_name = "test" if is_final_validation else "validation"
|
180 |
-
if tracker.name == "wandb":
|
181 |
-
video_filenames = []
|
182 |
-
for i, video in enumerate(videos):
|
183 |
-
prompt = (
|
184 |
-
pipeline_args["prompt"][:25]
|
185 |
-
.replace(" ", "_")
|
186 |
-
.replace(" ", "_")
|
187 |
-
.replace("'", "_")
|
188 |
-
.replace('"', "_")
|
189 |
-
.replace("/", "_")
|
190 |
-
)
|
191 |
-
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
|
192 |
-
export_to_video(video, filename, fps=8)
|
193 |
-
video_filenames.append(filename)
|
194 |
-
|
195 |
-
tracker.log(
|
196 |
-
{
|
197 |
-
phase_name: [
|
198 |
-
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
|
199 |
-
for i, filename in enumerate(video_filenames)
|
200 |
-
]
|
201 |
-
}
|
202 |
-
)
|
203 |
-
|
204 |
-
return videos
|
205 |
-
|
206 |
-
|
207 |
-
class CollateFunction:
|
208 |
-
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
|
209 |
-
self.weight_dtype = weight_dtype
|
210 |
-
self.load_tensors = load_tensors
|
211 |
-
|
212 |
-
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
213 |
-
prompts = [x["prompt"] for x in data[0]]
|
214 |
-
|
215 |
-
if self.load_tensors:
|
216 |
-
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
|
217 |
-
|
218 |
-
videos = [x["video"] for x in data[0]]
|
219 |
-
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
|
220 |
-
|
221 |
-
return {
|
222 |
-
"videos": videos,
|
223 |
-
"prompts": prompts,
|
224 |
-
}
|
225 |
-
|
226 |
-
|
227 |
-
def main(args):
|
228 |
-
if args.report_to == "wandb" and args.hub_token is not None:
|
229 |
-
raise ValueError(
|
230 |
-
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
231 |
-
" Please use `huggingface-cli login` to authenticate with the Hub."
|
232 |
-
)
|
233 |
-
|
234 |
-
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
235 |
-
# due to pytorch#99272, MPS does not yet support bfloat16.
|
236 |
-
raise ValueError(
|
237 |
-
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
238 |
-
)
|
239 |
-
|
240 |
-
logging_dir = Path(args.output_dir, args.logging_dir)
|
241 |
-
|
242 |
-
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
243 |
-
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
244 |
-
init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
|
245 |
-
accelerator = Accelerator(
|
246 |
-
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
247 |
-
mixed_precision=args.mixed_precision,
|
248 |
-
log_with=args.report_to,
|
249 |
-
project_config=accelerator_project_config,
|
250 |
-
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
|
251 |
-
)
|
252 |
-
|
253 |
-
# Disable AMP for MPS.
|
254 |
-
if torch.backends.mps.is_available():
|
255 |
-
accelerator.native_amp = False
|
256 |
-
|
257 |
-
# Make one log on every process with the configuration for debugging.
|
258 |
-
logging.basicConfig(
|
259 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
260 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
261 |
-
level=logging.INFO,
|
262 |
-
)
|
263 |
-
logger.info(accelerator.state, main_process_only=False)
|
264 |
-
if accelerator.is_local_main_process:
|
265 |
-
transformers.utils.logging.set_verbosity_warning()
|
266 |
-
diffusers.utils.logging.set_verbosity_info()
|
267 |
-
else:
|
268 |
-
transformers.utils.logging.set_verbosity_error()
|
269 |
-
diffusers.utils.logging.set_verbosity_error()
|
270 |
-
|
271 |
-
# If passed along, set the training seed now.
|
272 |
-
if args.seed is not None:
|
273 |
-
set_seed(args.seed)
|
274 |
-
|
275 |
-
# Handle the repository creation
|
276 |
-
if accelerator.is_main_process:
|
277 |
-
if args.output_dir is not None:
|
278 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
279 |
-
|
280 |
-
if args.push_to_hub:
|
281 |
-
repo_id = create_repo(
|
282 |
-
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
283 |
-
exist_ok=True,
|
284 |
-
).repo_id
|
285 |
-
|
286 |
-
# Prepare models and scheduler
|
287 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
288 |
-
args.pretrained_model_name_or_path,
|
289 |
-
subfolder="tokenizer",
|
290 |
-
revision=args.revision,
|
291 |
-
)
|
292 |
-
|
293 |
-
text_encoder = T5EncoderModel.from_pretrained(
|
294 |
-
args.pretrained_model_name_or_path,
|
295 |
-
subfolder="text_encoder",
|
296 |
-
revision=args.revision,
|
297 |
-
)
|
298 |
-
|
299 |
-
# CogVideoX-2b weights are stored in float16
|
300 |
-
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
|
301 |
-
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
|
302 |
-
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
303 |
-
args.pretrained_model_name_or_path,
|
304 |
-
subfolder="transformer",
|
305 |
-
torch_dtype=load_dtype,
|
306 |
-
revision=args.revision,
|
307 |
-
variant=args.variant,
|
308 |
-
)
|
309 |
-
|
310 |
-
vae = AutoencoderKLCogVideoX.from_pretrained(
|
311 |
-
args.pretrained_model_name_or_path,
|
312 |
-
subfolder="vae",
|
313 |
-
revision=args.revision,
|
314 |
-
variant=args.variant,
|
315 |
-
)
|
316 |
-
|
317 |
-
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
318 |
-
|
319 |
-
if args.enable_slicing:
|
320 |
-
vae.enable_slicing()
|
321 |
-
if args.enable_tiling:
|
322 |
-
vae.enable_tiling()
|
323 |
-
|
324 |
-
# We only train the additional adapter LoRA layers
|
325 |
-
text_encoder.requires_grad_(False)
|
326 |
-
transformer.requires_grad_(False)
|
327 |
-
vae.requires_grad_(False)
|
328 |
-
|
329 |
-
VAE_SCALING_FACTOR = vae.config.scaling_factor
|
330 |
-
VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
|
331 |
-
RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
|
332 |
-
RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
|
333 |
-
|
334 |
-
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
335 |
-
# as these weights are only used for inference, keeping weights in full precision is not required.
|
336 |
-
weight_dtype = torch.float32
|
337 |
-
if accelerator.state.deepspeed_plugin:
|
338 |
-
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
339 |
-
if (
|
340 |
-
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
341 |
-
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
342 |
-
):
|
343 |
-
weight_dtype = torch.float16
|
344 |
-
if (
|
345 |
-
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
346 |
-
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
347 |
-
):
|
348 |
-
weight_dtype = torch.bfloat16
|
349 |
-
else:
|
350 |
-
if accelerator.mixed_precision == "fp16":
|
351 |
-
weight_dtype = torch.float16
|
352 |
-
elif accelerator.mixed_precision == "bf16":
|
353 |
-
weight_dtype = torch.bfloat16
|
354 |
-
|
355 |
-
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
356 |
-
# due to pytorch#99272, MPS does not yet support bfloat16.
|
357 |
-
raise ValueError(
|
358 |
-
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
359 |
-
)
|
360 |
-
|
361 |
-
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
362 |
-
transformer.to(accelerator.device, dtype=weight_dtype)
|
363 |
-
vae.to(accelerator.device, dtype=weight_dtype)
|
364 |
-
|
365 |
-
if args.gradient_checkpointing:
|
366 |
-
transformer.enable_gradient_checkpointing()
|
367 |
-
|
368 |
-
# now we will add new LoRA weights to the attention layers
|
369 |
-
transformer_lora_config = LoraConfig(
|
370 |
-
r=args.rank,
|
371 |
-
lora_alpha=args.lora_alpha,
|
372 |
-
init_lora_weights=True,
|
373 |
-
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
374 |
-
)
|
375 |
-
transformer.add_adapter(transformer_lora_config)
|
376 |
-
|
377 |
-
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
378 |
-
def save_model_hook(models, weights, output_dir):
|
379 |
-
if accelerator.is_main_process:
|
380 |
-
transformer_lora_layers_to_save = None
|
381 |
-
|
382 |
-
for model in models:
|
383 |
-
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
384 |
-
model = unwrap_model(accelerator, model)
|
385 |
-
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
386 |
-
else:
|
387 |
-
raise ValueError(f"unexpected save model: {model.__class__}")
|
388 |
-
|
389 |
-
# make sure to pop weight so that corresponding model is not saved again
|
390 |
-
if weights:
|
391 |
-
weights.pop()
|
392 |
-
|
393 |
-
CogVideoXPipeline.save_lora_weights(
|
394 |
-
output_dir,
|
395 |
-
transformer_lora_layers=transformer_lora_layers_to_save,
|
396 |
-
)
|
397 |
-
|
398 |
-
def load_model_hook(models, input_dir):
|
399 |
-
transformer_ = None
|
400 |
-
|
401 |
-
# This is a bit of a hack but I don't know any other solution.
|
402 |
-
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
403 |
-
while len(models) > 0:
|
404 |
-
model = models.pop()
|
405 |
-
|
406 |
-
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
407 |
-
transformer_ = unwrap_model(accelerator, model)
|
408 |
-
else:
|
409 |
-
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
|
410 |
-
else:
|
411 |
-
transformer_ = CogVideoXTransformer3DModel.from_pretrained(
|
412 |
-
args.pretrained_model_name_or_path, subfolder="transformer"
|
413 |
-
)
|
414 |
-
transformer_.add_adapter(transformer_lora_config)
|
415 |
-
|
416 |
-
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
|
417 |
-
|
418 |
-
transformer_state_dict = {
|
419 |
-
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
420 |
-
}
|
421 |
-
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
422 |
-
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
423 |
-
if incompatible_keys is not None:
|
424 |
-
# check only for unexpected keys
|
425 |
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
426 |
-
if unexpected_keys:
|
427 |
-
logger.warning(
|
428 |
-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
429 |
-
f" {unexpected_keys}. "
|
430 |
-
)
|
431 |
-
|
432 |
-
# Make sure the trainable params are in float32. This is again needed since the base models
|
433 |
-
# are in `weight_dtype`. More details:
|
434 |
-
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
435 |
-
if args.mixed_precision == "fp16":
|
436 |
-
# only upcast trainable parameters (LoRA) into fp32
|
437 |
-
cast_training_params([transformer_])
|
438 |
-
|
439 |
-
accelerator.register_save_state_pre_hook(save_model_hook)
|
440 |
-
accelerator.register_load_state_pre_hook(load_model_hook)
|
441 |
-
|
442 |
-
# Enable TF32 for faster training on Ampere GPUs,
|
443 |
-
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
444 |
-
if args.allow_tf32 and torch.cuda.is_available():
|
445 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
446 |
-
|
447 |
-
if args.scale_lr:
|
448 |
-
args.learning_rate = (
|
449 |
-
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
450 |
-
)
|
451 |
-
|
452 |
-
# Make sure the trainable params are in float32.
|
453 |
-
if args.mixed_precision == "fp16":
|
454 |
-
# only upcast trainable parameters (LoRA) into fp32
|
455 |
-
cast_training_params([transformer], dtype=torch.float32)
|
456 |
-
|
457 |
-
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
458 |
-
|
459 |
-
# Optimization parameters
|
460 |
-
transformer_parameters_with_lr = {
|
461 |
-
"params": transformer_lora_parameters,
|
462 |
-
"lr": args.learning_rate,
|
463 |
-
}
|
464 |
-
params_to_optimize = [transformer_parameters_with_lr]
|
465 |
-
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
466 |
-
|
467 |
-
use_deepspeed_optimizer = (
|
468 |
-
accelerator.state.deepspeed_plugin is not None
|
469 |
-
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
470 |
-
)
|
471 |
-
use_deepspeed_scheduler = (
|
472 |
-
accelerator.state.deepspeed_plugin is not None
|
473 |
-
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
474 |
-
)
|
475 |
-
|
476 |
-
optimizer = get_optimizer(
|
477 |
-
params_to_optimize=params_to_optimize,
|
478 |
-
optimizer_name=args.optimizer,
|
479 |
-
learning_rate=args.learning_rate,
|
480 |
-
beta1=args.beta1,
|
481 |
-
beta2=args.beta2,
|
482 |
-
beta3=args.beta3,
|
483 |
-
epsilon=args.epsilon,
|
484 |
-
weight_decay=args.weight_decay,
|
485 |
-
prodigy_decouple=args.prodigy_decouple,
|
486 |
-
prodigy_use_bias_correction=args.prodigy_use_bias_correction,
|
487 |
-
prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
|
488 |
-
use_8bit=args.use_8bit,
|
489 |
-
use_4bit=args.use_4bit,
|
490 |
-
use_torchao=args.use_torchao,
|
491 |
-
use_deepspeed=use_deepspeed_optimizer,
|
492 |
-
use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
|
493 |
-
offload_gradients=args.offload_gradients,
|
494 |
-
)
|
495 |
-
|
496 |
-
# Dataset and DataLoader
|
497 |
-
dataset_init_kwargs = {
|
498 |
-
"data_root": args.data_root,
|
499 |
-
"dataset_file": args.dataset_file,
|
500 |
-
"caption_column": args.caption_column,
|
501 |
-
"video_column": args.video_column,
|
502 |
-
"max_num_frames": args.max_num_frames,
|
503 |
-
"id_token": args.id_token,
|
504 |
-
"height_buckets": args.height_buckets,
|
505 |
-
"width_buckets": args.width_buckets,
|
506 |
-
"frame_buckets": args.frame_buckets,
|
507 |
-
"load_tensors": args.load_tensors,
|
508 |
-
"random_flip": args.random_flip,
|
509 |
-
}
|
510 |
-
if args.video_reshape_mode is None:
|
511 |
-
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
|
512 |
-
else:
|
513 |
-
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
|
514 |
-
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
|
515 |
-
)
|
516 |
-
|
517 |
-
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
|
518 |
-
|
519 |
-
train_dataloader = DataLoader(
|
520 |
-
train_dataset,
|
521 |
-
batch_size=1,
|
522 |
-
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
|
523 |
-
collate_fn=collate_fn,
|
524 |
-
num_workers=args.dataloader_num_workers,
|
525 |
-
pin_memory=args.pin_memory,
|
526 |
-
)
|
527 |
-
|
528 |
-
# Scheduler and math around the number of training steps.
|
529 |
-
overrode_max_train_steps = False
|
530 |
-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
531 |
-
if args.max_train_steps is None:
|
532 |
-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
533 |
-
overrode_max_train_steps = True
|
534 |
-
|
535 |
-
if args.use_cpu_offload_optimizer:
|
536 |
-
lr_scheduler = None
|
537 |
-
accelerator.print(
|
538 |
-
"CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
|
539 |
-
"you are training with those settings, they will be ignored."
|
540 |
-
)
|
541 |
-
else:
|
542 |
-
if use_deepspeed_scheduler:
|
543 |
-
from accelerate.utils import DummyScheduler
|
544 |
-
|
545 |
-
lr_scheduler = DummyScheduler(
|
546 |
-
name=args.lr_scheduler,
|
547 |
-
optimizer=optimizer,
|
548 |
-
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
549 |
-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
550 |
-
)
|
551 |
-
else:
|
552 |
-
lr_scheduler = get_scheduler(
|
553 |
-
args.lr_scheduler,
|
554 |
-
optimizer=optimizer,
|
555 |
-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
556 |
-
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
557 |
-
num_cycles=args.lr_num_cycles,
|
558 |
-
power=args.lr_power,
|
559 |
-
)
|
560 |
-
|
561 |
-
# Prepare everything with our `accelerator`.
|
562 |
-
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
563 |
-
transformer, optimizer, train_dataloader, lr_scheduler
|
564 |
-
)
|
565 |
-
|
566 |
-
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
567 |
-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
568 |
-
if overrode_max_train_steps:
|
569 |
-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
570 |
-
# Afterwards we recalculate our number of training epochs
|
571 |
-
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
572 |
-
|
573 |
-
# We need to initialize the trackers we use, and also store our configuration.
|
574 |
-
# The trackers initializes automatically on the main process.
|
575 |
-
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
576 |
-
tracker_name = args.tracker_name or "cogvideox-lora"
|
577 |
-
accelerator.init_trackers(tracker_name, config=vars(args))
|
578 |
-
|
579 |
-
accelerator.print("===== Memory before training =====")
|
580 |
-
reset_memory(accelerator.device)
|
581 |
-
print_memory(accelerator.device)
|
582 |
-
|
583 |
-
# Train!
|
584 |
-
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
585 |
-
|
586 |
-
accelerator.print("***** Running training *****")
|
587 |
-
accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
|
588 |
-
accelerator.print(f" Num examples = {len(train_dataset)}")
|
589 |
-
accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
|
590 |
-
accelerator.print(f" Num epochs = {args.num_train_epochs}")
|
591 |
-
accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
592 |
-
accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
593 |
-
accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
594 |
-
accelerator.print(f" Total optimization steps = {args.max_train_steps}")
|
595 |
-
global_step = 0
|
596 |
-
first_epoch = 0
|
597 |
-
|
598 |
-
# Potentially load in the weights and states from a previous save
|
599 |
-
if not args.resume_from_checkpoint:
|
600 |
-
initial_global_step = 0
|
601 |
-
else:
|
602 |
-
if args.resume_from_checkpoint != "latest":
|
603 |
-
path = os.path.basename(args.resume_from_checkpoint)
|
604 |
-
else:
|
605 |
-
# Get the most recent checkpoint
|
606 |
-
dirs = os.listdir(args.output_dir)
|
607 |
-
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
608 |
-
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
609 |
-
path = dirs[-1] if len(dirs) > 0 else None
|
610 |
-
|
611 |
-
if path is None:
|
612 |
-
accelerator.print(
|
613 |
-
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
614 |
-
)
|
615 |
-
args.resume_from_checkpoint = None
|
616 |
-
initial_global_step = 0
|
617 |
-
else:
|
618 |
-
accelerator.print(f"Resuming from checkpoint {path}")
|
619 |
-
accelerator.load_state(os.path.join(args.output_dir, path))
|
620 |
-
global_step = int(path.split("-")[1])
|
621 |
-
|
622 |
-
initial_global_step = global_step
|
623 |
-
first_epoch = global_step // num_update_steps_per_epoch
|
624 |
-
|
625 |
-
progress_bar = tqdm(
|
626 |
-
range(0, args.max_train_steps),
|
627 |
-
initial=initial_global_step,
|
628 |
-
desc="Steps",
|
629 |
-
# Only show the progress bar once on each machine.
|
630 |
-
disable=not accelerator.is_local_main_process,
|
631 |
-
)
|
632 |
-
|
633 |
-
# For DeepSpeed training
|
634 |
-
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
635 |
-
|
636 |
-
if args.load_tensors:
|
637 |
-
del vae, text_encoder
|
638 |
-
gc.collect()
|
639 |
-
torch.cuda.empty_cache()
|
640 |
-
torch.cuda.synchronize(accelerator.device)
|
641 |
-
|
642 |
-
alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
|
643 |
-
|
644 |
-
for epoch in range(first_epoch, args.num_train_epochs):
|
645 |
-
transformer.train()
|
646 |
-
|
647 |
-
for step, batch in enumerate(train_dataloader):
|
648 |
-
models_to_accumulate = [transformer]
|
649 |
-
logs = {}
|
650 |
-
|
651 |
-
with accelerator.accumulate(models_to_accumulate):
|
652 |
-
videos = batch["videos"].to(accelerator.device, non_blocking=True)
|
653 |
-
prompts = batch["prompts"]
|
654 |
-
|
655 |
-
# Encode videos
|
656 |
-
if not args.load_tensors:
|
657 |
-
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
658 |
-
latent_dist = vae.encode(videos).latent_dist
|
659 |
-
else:
|
660 |
-
latent_dist = DiagonalGaussianDistribution(videos)
|
661 |
-
|
662 |
-
videos = latent_dist.sample() * VAE_SCALING_FACTOR
|
663 |
-
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
664 |
-
videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
665 |
-
model_input = videos
|
666 |
-
|
667 |
-
# Encode prompts
|
668 |
-
if not args.load_tensors:
|
669 |
-
prompt_embeds = compute_prompt_embeddings(
|
670 |
-
tokenizer,
|
671 |
-
text_encoder,
|
672 |
-
prompts,
|
673 |
-
model_config.max_text_seq_length,
|
674 |
-
accelerator.device,
|
675 |
-
weight_dtype,
|
676 |
-
requires_grad=False,
|
677 |
-
)
|
678 |
-
else:
|
679 |
-
prompt_embeds = prompts.to(dtype=weight_dtype)
|
680 |
-
|
681 |
-
# Sample noise that will be added to the latents
|
682 |
-
noise = torch.randn_like(model_input)
|
683 |
-
batch_size, num_frames, num_channels, height, width = model_input.shape
|
684 |
-
|
685 |
-
# Sample a random timestep for each image
|
686 |
-
timesteps = torch.randint(
|
687 |
-
0,
|
688 |
-
scheduler.config.num_train_timesteps,
|
689 |
-
(batch_size,),
|
690 |
-
dtype=torch.int64,
|
691 |
-
device=model_input.device,
|
692 |
-
)
|
693 |
-
|
694 |
-
# Prepare rotary embeds
|
695 |
-
image_rotary_emb = (
|
696 |
-
prepare_rotary_positional_embeddings(
|
697 |
-
height=height * VAE_SCALE_FACTOR_SPATIAL,
|
698 |
-
width=width * VAE_SCALE_FACTOR_SPATIAL,
|
699 |
-
num_frames=num_frames,
|
700 |
-
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
|
701 |
-
patch_size=model_config.patch_size,
|
702 |
-
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
|
703 |
-
attention_head_dim=model_config.attention_head_dim,
|
704 |
-
device=accelerator.device,
|
705 |
-
base_height=RoPE_BASE_HEIGHT,
|
706 |
-
base_width=RoPE_BASE_WIDTH,
|
707 |
-
)
|
708 |
-
if model_config.use_rotary_positional_embeddings
|
709 |
-
else None
|
710 |
-
)
|
711 |
-
|
712 |
-
# Add noise to the model input according to the noise magnitude at each timestep
|
713 |
-
# (this is the forward diffusion process)
|
714 |
-
noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
|
715 |
-
|
716 |
-
# Predict the noise residual
|
717 |
-
model_output = transformer(
|
718 |
-
hidden_states=noisy_model_input,
|
719 |
-
encoder_hidden_states=prompt_embeds,
|
720 |
-
timestep=timesteps,
|
721 |
-
image_rotary_emb=image_rotary_emb,
|
722 |
-
return_dict=False,
|
723 |
-
)[0]
|
724 |
-
|
725 |
-
model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
|
726 |
-
|
727 |
-
weights = 1 / (1 - alphas_cumprod[timesteps])
|
728 |
-
while len(weights.shape) < len(model_pred.shape):
|
729 |
-
weights = weights.unsqueeze(-1)
|
730 |
-
|
731 |
-
target = model_input
|
732 |
-
|
733 |
-
loss = torch.mean(
|
734 |
-
(weights * (model_pred - target) ** 2).reshape(batch_size, -1),
|
735 |
-
dim=1,
|
736 |
-
)
|
737 |
-
loss = loss.mean()
|
738 |
-
accelerator.backward(loss)
|
739 |
-
|
740 |
-
if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
|
741 |
-
gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
|
742 |
-
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
|
743 |
-
gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
|
744 |
-
logs.update(
|
745 |
-
{
|
746 |
-
"gradient_norm_before_clip": gradient_norm_before_clip,
|
747 |
-
"gradient_norm_after_clip": gradient_norm_after_clip,
|
748 |
-
}
|
749 |
-
)
|
750 |
-
|
751 |
-
if accelerator.state.deepspeed_plugin is None:
|
752 |
-
optimizer.step()
|
753 |
-
optimizer.zero_grad()
|
754 |
-
|
755 |
-
if not args.use_cpu_offload_optimizer:
|
756 |
-
lr_scheduler.step()
|
757 |
-
|
758 |
-
# Checks if the accelerator has performed an optimization step behind the scenes
|
759 |
-
if accelerator.sync_gradients:
|
760 |
-
progress_bar.update(1)
|
761 |
-
global_step += 1
|
762 |
-
|
763 |
-
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
764 |
-
if global_step % args.checkpointing_steps == 0:
|
765 |
-
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
766 |
-
if args.checkpoints_total_limit is not None:
|
767 |
-
checkpoints = os.listdir(args.output_dir)
|
768 |
-
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
769 |
-
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
770 |
-
|
771 |
-
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
772 |
-
if len(checkpoints) >= args.checkpoints_total_limit:
|
773 |
-
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
774 |
-
removing_checkpoints = checkpoints[0:num_to_remove]
|
775 |
-
|
776 |
-
logger.info(
|
777 |
-
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
778 |
-
)
|
779 |
-
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
|
780 |
-
|
781 |
-
for removing_checkpoint in removing_checkpoints:
|
782 |
-
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
783 |
-
shutil.rmtree(removing_checkpoint)
|
784 |
-
|
785 |
-
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
786 |
-
accelerator.save_state(save_path)
|
787 |
-
logger.info(f"Saved state to {save_path}")
|
788 |
-
|
789 |
-
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
|
790 |
-
logs.update(
|
791 |
-
{
|
792 |
-
"loss": loss.detach().item(),
|
793 |
-
"lr": last_lr,
|
794 |
-
}
|
795 |
-
)
|
796 |
-
progress_bar.set_postfix(**logs)
|
797 |
-
accelerator.log(logs, step=global_step)
|
798 |
-
|
799 |
-
if global_step >= args.max_train_steps:
|
800 |
-
break
|
801 |
-
|
802 |
-
if accelerator.is_main_process:
|
803 |
-
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
|
804 |
-
accelerator.print("===== Memory before validation =====")
|
805 |
-
print_memory(accelerator.device)
|
806 |
-
torch.cuda.synchronize(accelerator.device)
|
807 |
-
|
808 |
-
pipe = CogVideoXPipeline.from_pretrained(
|
809 |
-
args.pretrained_model_name_or_path,
|
810 |
-
transformer=unwrap_model(accelerator, transformer),
|
811 |
-
scheduler=scheduler,
|
812 |
-
revision=args.revision,
|
813 |
-
variant=args.variant,
|
814 |
-
torch_dtype=weight_dtype,
|
815 |
-
)
|
816 |
-
|
817 |
-
if args.enable_slicing:
|
818 |
-
pipe.vae.enable_slicing()
|
819 |
-
if args.enable_tiling:
|
820 |
-
pipe.vae.enable_tiling()
|
821 |
-
if args.enable_model_cpu_offload:
|
822 |
-
pipe.enable_model_cpu_offload()
|
823 |
-
|
824 |
-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
825 |
-
for validation_prompt in validation_prompts:
|
826 |
-
pipeline_args = {
|
827 |
-
"prompt": validation_prompt,
|
828 |
-
"guidance_scale": args.guidance_scale,
|
829 |
-
"use_dynamic_cfg": args.use_dynamic_cfg,
|
830 |
-
"height": args.height,
|
831 |
-
"width": args.width,
|
832 |
-
"max_sequence_length": model_config.max_text_seq_length,
|
833 |
-
}
|
834 |
-
|
835 |
-
log_validation(
|
836 |
-
pipe=pipe,
|
837 |
-
args=args,
|
838 |
-
accelerator=accelerator,
|
839 |
-
pipeline_args=pipeline_args,
|
840 |
-
epoch=epoch,
|
841 |
-
)
|
842 |
-
|
843 |
-
accelerator.print("===== Memory after validation =====")
|
844 |
-
print_memory(accelerator.device)
|
845 |
-
reset_memory(accelerator.device)
|
846 |
-
|
847 |
-
del pipe
|
848 |
-
gc.collect()
|
849 |
-
torch.cuda.empty_cache()
|
850 |
-
torch.cuda.synchronize(accelerator.device)
|
851 |
-
|
852 |
-
accelerator.wait_for_everyone()
|
853 |
-
|
854 |
-
if accelerator.is_main_process:
|
855 |
-
transformer = unwrap_model(accelerator, transformer)
|
856 |
-
dtype = (
|
857 |
-
torch.float16
|
858 |
-
if args.mixed_precision == "fp16"
|
859 |
-
else torch.bfloat16
|
860 |
-
if args.mixed_precision == "bf16"
|
861 |
-
else torch.float32
|
862 |
-
)
|
863 |
-
transformer = transformer.to(dtype)
|
864 |
-
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
865 |
-
|
866 |
-
CogVideoXPipeline.save_lora_weights(
|
867 |
-
save_directory=args.output_dir,
|
868 |
-
transformer_lora_layers=transformer_lora_layers,
|
869 |
-
)
|
870 |
-
|
871 |
-
# Cleanup trained models to save memory
|
872 |
-
if args.load_tensors:
|
873 |
-
del transformer
|
874 |
-
else:
|
875 |
-
del transformer, text_encoder, vae
|
876 |
-
|
877 |
-
gc.collect()
|
878 |
-
torch.cuda.empty_cache()
|
879 |
-
torch.cuda.synchronize(accelerator.device)
|
880 |
-
|
881 |
-
accelerator.print("===== Memory before testing =====")
|
882 |
-
print_memory(accelerator.device)
|
883 |
-
reset_memory(accelerator.device)
|
884 |
-
|
885 |
-
# Final test inference
|
886 |
-
pipe = CogVideoXPipeline.from_pretrained(
|
887 |
-
args.pretrained_model_name_or_path,
|
888 |
-
revision=args.revision,
|
889 |
-
variant=args.variant,
|
890 |
-
torch_dtype=weight_dtype,
|
891 |
-
)
|
892 |
-
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
893 |
-
|
894 |
-
if args.enable_slicing:
|
895 |
-
pipe.vae.enable_slicing()
|
896 |
-
if args.enable_tiling:
|
897 |
-
pipe.vae.enable_tiling()
|
898 |
-
if args.enable_model_cpu_offload:
|
899 |
-
pipe.enable_model_cpu_offload()
|
900 |
-
|
901 |
-
# Load LoRA weights
|
902 |
-
lora_scaling = args.lora_alpha / args.rank
|
903 |
-
pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
|
904 |
-
pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
|
905 |
-
|
906 |
-
# Run inference
|
907 |
-
validation_outputs = []
|
908 |
-
if args.validation_prompt and args.num_validation_videos > 0:
|
909 |
-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
910 |
-
for validation_prompt in validation_prompts:
|
911 |
-
pipeline_args = {
|
912 |
-
"prompt": validation_prompt,
|
913 |
-
"guidance_scale": args.guidance_scale,
|
914 |
-
"use_dynamic_cfg": args.use_dynamic_cfg,
|
915 |
-
"height": args.height,
|
916 |
-
"width": args.width,
|
917 |
-
}
|
918 |
-
|
919 |
-
video = log_validation(
|
920 |
-
accelerator=accelerator,
|
921 |
-
pipe=pipe,
|
922 |
-
args=args,
|
923 |
-
pipeline_args=pipeline_args,
|
924 |
-
epoch=epoch,
|
925 |
-
is_final_validation=True,
|
926 |
-
)
|
927 |
-
validation_outputs.extend(video)
|
928 |
-
|
929 |
-
accelerator.print("===== Memory after testing =====")
|
930 |
-
print_memory(accelerator.device)
|
931 |
-
reset_memory(accelerator.device)
|
932 |
-
torch.cuda.synchronize(accelerator.device)
|
933 |
-
|
934 |
-
if args.push_to_hub:
|
935 |
-
save_model_card(
|
936 |
-
repo_id,
|
937 |
-
videos=validation_outputs,
|
938 |
-
base_model=args.pretrained_model_name_or_path,
|
939 |
-
validation_prompt=args.validation_prompt,
|
940 |
-
repo_folder=args.output_dir,
|
941 |
-
fps=args.fps,
|
942 |
-
)
|
943 |
-
upload_folder(
|
944 |
-
repo_id=repo_id,
|
945 |
-
folder_path=args.output_dir,
|
946 |
-
commit_message="End of training",
|
947 |
-
ignore_patterns=["step_*", "epoch_*"],
|
948 |
-
)
|
949 |
-
|
950 |
-
accelerator.end_training()
|
951 |
-
|
952 |
-
|
953 |
-
if __name__ == "__main__":
|
954 |
-
args = get_args()
|
955 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/cogvideox/cogvideox_text_to_video_sft.py
DELETED
@@ -1,917 +0,0 @@
|
|
1 |
-
# Copyright 2024 The HuggingFace Team.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
import gc
|
17 |
-
import logging
|
18 |
-
import math
|
19 |
-
import os
|
20 |
-
import shutil
|
21 |
-
from datetime import timedelta
|
22 |
-
from pathlib import Path
|
23 |
-
from typing import Any, Dict
|
24 |
-
|
25 |
-
import diffusers
|
26 |
-
import torch
|
27 |
-
import transformers
|
28 |
-
import wandb
|
29 |
-
from accelerate import Accelerator, DistributedType, init_empty_weights
|
30 |
-
from accelerate.logging import get_logger
|
31 |
-
from accelerate.utils import (
|
32 |
-
DistributedDataParallelKwargs,
|
33 |
-
InitProcessGroupKwargs,
|
34 |
-
ProjectConfiguration,
|
35 |
-
set_seed,
|
36 |
-
)
|
37 |
-
from diffusers import (
|
38 |
-
AutoencoderKLCogVideoX,
|
39 |
-
CogVideoXDPMScheduler,
|
40 |
-
CogVideoXPipeline,
|
41 |
-
CogVideoXTransformer3DModel,
|
42 |
-
)
|
43 |
-
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
44 |
-
from diffusers.optimization import get_scheduler
|
45 |
-
from diffusers.training_utils import cast_training_params
|
46 |
-
from diffusers.utils import export_to_video
|
47 |
-
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
48 |
-
from huggingface_hub import create_repo, upload_folder
|
49 |
-
from torch.utils.data import DataLoader
|
50 |
-
from tqdm.auto import tqdm
|
51 |
-
from transformers import AutoTokenizer, T5EncoderModel
|
52 |
-
|
53 |
-
|
54 |
-
from args import get_args # isort:skip
|
55 |
-
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
|
56 |
-
from text_encoder import compute_prompt_embeddings # isort:skip
|
57 |
-
from utils import (
|
58 |
-
get_gradient_norm,
|
59 |
-
get_optimizer,
|
60 |
-
prepare_rotary_positional_embeddings,
|
61 |
-
print_memory,
|
62 |
-
reset_memory,
|
63 |
-
unwrap_model,
|
64 |
-
) # isort:skip
|
65 |
-
|
66 |
-
|
67 |
-
logger = get_logger(__name__)
|
68 |
-
|
69 |
-
|
70 |
-
def save_model_card(
|
71 |
-
repo_id: str,
|
72 |
-
videos=None,
|
73 |
-
base_model: str = None,
|
74 |
-
validation_prompt=None,
|
75 |
-
repo_folder=None,
|
76 |
-
fps=8,
|
77 |
-
):
|
78 |
-
widget_dict = []
|
79 |
-
if videos is not None:
|
80 |
-
for i, video in enumerate(videos):
|
81 |
-
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
|
82 |
-
widget_dict.append(
|
83 |
-
{
|
84 |
-
"text": validation_prompt if validation_prompt else " ",
|
85 |
-
"output": {"url": f"video_{i}.mp4"},
|
86 |
-
}
|
87 |
-
)
|
88 |
-
|
89 |
-
model_description = f"""
|
90 |
-
# CogVideoX Full Finetune
|
91 |
-
|
92 |
-
<Gallery />
|
93 |
-
|
94 |
-
## Model description
|
95 |
-
|
96 |
-
This is a full finetune of the CogVideoX model `{base_model}`.
|
97 |
-
|
98 |
-
The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
99 |
-
|
100 |
-
## Download model
|
101 |
-
|
102 |
-
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
|
103 |
-
|
104 |
-
## Usage
|
105 |
-
|
106 |
-
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
|
107 |
-
|
108 |
-
```py
|
109 |
-
import torch
|
110 |
-
from diffusers import CogVideoXPipeline
|
111 |
-
from diffusers.utils import export_to_video
|
112 |
-
|
113 |
-
pipe = CogVideoXPipeline.from_pretrained("{repo_id}", torch_dtype=torch.bfloat16).to("cuda")
|
114 |
-
|
115 |
-
video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
116 |
-
export_to_video(video, "output.mp4", fps=8)
|
117 |
-
```
|
118 |
-
|
119 |
-
For more details, checkout the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox) for CogVideoX.
|
120 |
-
|
121 |
-
## License
|
122 |
-
|
123 |
-
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
|
124 |
-
"""
|
125 |
-
model_card = load_or_create_model_card(
|
126 |
-
repo_id_or_path=repo_id,
|
127 |
-
from_training=True,
|
128 |
-
license="other",
|
129 |
-
base_model=base_model,
|
130 |
-
prompt=validation_prompt,
|
131 |
-
model_description=model_description,
|
132 |
-
widget=widget_dict,
|
133 |
-
)
|
134 |
-
tags = [
|
135 |
-
"text-to-video",
|
136 |
-
"diffusers-training",
|
137 |
-
"diffusers",
|
138 |
-
"cogvideox",
|
139 |
-
"cogvideox-diffusers",
|
140 |
-
]
|
141 |
-
|
142 |
-
model_card = populate_model_card(model_card, tags=tags)
|
143 |
-
model_card.save(os.path.join(repo_folder, "README.md"))
|
144 |
-
|
145 |
-
|
146 |
-
def log_validation(
|
147 |
-
accelerator: Accelerator,
|
148 |
-
pipe: CogVideoXPipeline,
|
149 |
-
args: Dict[str, Any],
|
150 |
-
pipeline_args: Dict[str, Any],
|
151 |
-
epoch,
|
152 |
-
is_final_validation: bool = False,
|
153 |
-
):
|
154 |
-
logger.info(
|
155 |
-
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
156 |
-
)
|
157 |
-
|
158 |
-
pipe = pipe.to(accelerator.device)
|
159 |
-
|
160 |
-
# run inference
|
161 |
-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
162 |
-
|
163 |
-
videos = []
|
164 |
-
for _ in range(args.num_validation_videos):
|
165 |
-
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
166 |
-
videos.append(video)
|
167 |
-
|
168 |
-
for tracker in accelerator.trackers:
|
169 |
-
phase_name = "test" if is_final_validation else "validation"
|
170 |
-
if tracker.name == "wandb":
|
171 |
-
video_filenames = []
|
172 |
-
for i, video in enumerate(videos):
|
173 |
-
prompt = (
|
174 |
-
pipeline_args["prompt"][:25]
|
175 |
-
.replace(" ", "_")
|
176 |
-
.replace(" ", "_")
|
177 |
-
.replace("'", "_")
|
178 |
-
.replace('"', "_")
|
179 |
-
.replace("/", "_")
|
180 |
-
)
|
181 |
-
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
|
182 |
-
export_to_video(video, filename, fps=8)
|
183 |
-
video_filenames.append(filename)
|
184 |
-
|
185 |
-
tracker.log(
|
186 |
-
{
|
187 |
-
phase_name: [
|
188 |
-
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
|
189 |
-
for i, filename in enumerate(video_filenames)
|
190 |
-
]
|
191 |
-
}
|
192 |
-
)
|
193 |
-
|
194 |
-
return videos
|
195 |
-
|
196 |
-
|
197 |
-
class CollateFunction:
|
198 |
-
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
|
199 |
-
self.weight_dtype = weight_dtype
|
200 |
-
self.load_tensors = load_tensors
|
201 |
-
|
202 |
-
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
203 |
-
prompts = [x["prompt"] for x in data[0]]
|
204 |
-
|
205 |
-
if self.load_tensors:
|
206 |
-
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
|
207 |
-
|
208 |
-
videos = [x["video"] for x in data[0]]
|
209 |
-
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
|
210 |
-
|
211 |
-
return {
|
212 |
-
"videos": videos,
|
213 |
-
"prompts": prompts,
|
214 |
-
}
|
215 |
-
|
216 |
-
|
217 |
-
def main(args):
|
218 |
-
if args.report_to == "wandb" and args.hub_token is not None:
|
219 |
-
raise ValueError(
|
220 |
-
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
221 |
-
" Please use `huggingface-cli login` to authenticate with the Hub."
|
222 |
-
)
|
223 |
-
|
224 |
-
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
225 |
-
# due to pytorch#99272, MPS does not yet support bfloat16.
|
226 |
-
raise ValueError(
|
227 |
-
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
228 |
-
)
|
229 |
-
|
230 |
-
logging_dir = Path(args.output_dir, args.logging_dir)
|
231 |
-
|
232 |
-
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
233 |
-
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
234 |
-
init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
|
235 |
-
accelerator = Accelerator(
|
236 |
-
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
237 |
-
mixed_precision=args.mixed_precision,
|
238 |
-
log_with=args.report_to,
|
239 |
-
project_config=accelerator_project_config,
|
240 |
-
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
|
241 |
-
)
|
242 |
-
|
243 |
-
# Disable AMP for MPS.
|
244 |
-
if torch.backends.mps.is_available():
|
245 |
-
accelerator.native_amp = False
|
246 |
-
|
247 |
-
# Make one log on every process with the configuration for debugging.
|
248 |
-
logging.basicConfig(
|
249 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
250 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
251 |
-
level=logging.INFO,
|
252 |
-
)
|
253 |
-
logger.info(accelerator.state, main_process_only=False)
|
254 |
-
if accelerator.is_local_main_process:
|
255 |
-
transformers.utils.logging.set_verbosity_warning()
|
256 |
-
diffusers.utils.logging.set_verbosity_info()
|
257 |
-
else:
|
258 |
-
transformers.utils.logging.set_verbosity_error()
|
259 |
-
diffusers.utils.logging.set_verbosity_error()
|
260 |
-
|
261 |
-
# If passed along, set the training seed now.
|
262 |
-
if args.seed is not None:
|
263 |
-
set_seed(args.seed)
|
264 |
-
|
265 |
-
# Handle the repository creation
|
266 |
-
if accelerator.is_main_process:
|
267 |
-
if args.output_dir is not None:
|
268 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
269 |
-
|
270 |
-
if args.push_to_hub:
|
271 |
-
repo_id = create_repo(
|
272 |
-
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
273 |
-
exist_ok=True,
|
274 |
-
).repo_id
|
275 |
-
|
276 |
-
# Prepare models and scheduler
|
277 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
278 |
-
args.pretrained_model_name_or_path,
|
279 |
-
subfolder="tokenizer",
|
280 |
-
revision=args.revision,
|
281 |
-
)
|
282 |
-
|
283 |
-
text_encoder = T5EncoderModel.from_pretrained(
|
284 |
-
args.pretrained_model_name_or_path,
|
285 |
-
subfolder="text_encoder",
|
286 |
-
revision=args.revision,
|
287 |
-
)
|
288 |
-
|
289 |
-
# CogVideoX-2b weights are stored in float16
|
290 |
-
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
|
291 |
-
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
|
292 |
-
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
293 |
-
args.pretrained_model_name_or_path,
|
294 |
-
subfolder="transformer",
|
295 |
-
torch_dtype=load_dtype,
|
296 |
-
revision=args.revision,
|
297 |
-
variant=args.variant,
|
298 |
-
)
|
299 |
-
|
300 |
-
vae = AutoencoderKLCogVideoX.from_pretrained(
|
301 |
-
args.pretrained_model_name_or_path,
|
302 |
-
subfolder="vae",
|
303 |
-
revision=args.revision,
|
304 |
-
variant=args.variant,
|
305 |
-
)
|
306 |
-
|
307 |
-
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
308 |
-
|
309 |
-
if args.enable_slicing:
|
310 |
-
vae.enable_slicing()
|
311 |
-
if args.enable_tiling:
|
312 |
-
vae.enable_tiling()
|
313 |
-
|
314 |
-
text_encoder.requires_grad_(False)
|
315 |
-
vae.requires_grad_(False)
|
316 |
-
transformer.requires_grad_(True)
|
317 |
-
|
318 |
-
VAE_SCALING_FACTOR = vae.config.scaling_factor
|
319 |
-
VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
|
320 |
-
RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
|
321 |
-
RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
|
322 |
-
|
323 |
-
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
324 |
-
# as these weights are only used for inference, keeping weights in full precision is not required.
|
325 |
-
weight_dtype = torch.float32
|
326 |
-
if accelerator.state.deepspeed_plugin:
|
327 |
-
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
328 |
-
if (
|
329 |
-
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
330 |
-
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
331 |
-
):
|
332 |
-
weight_dtype = torch.float16
|
333 |
-
if (
|
334 |
-
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
335 |
-
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
336 |
-
):
|
337 |
-
weight_dtype = torch.bfloat16
|
338 |
-
else:
|
339 |
-
if accelerator.mixed_precision == "fp16":
|
340 |
-
weight_dtype = torch.float16
|
341 |
-
elif accelerator.mixed_precision == "bf16":
|
342 |
-
weight_dtype = torch.bfloat16
|
343 |
-
|
344 |
-
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
345 |
-
# due to pytorch#99272, MPS does not yet support bfloat16.
|
346 |
-
raise ValueError(
|
347 |
-
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
348 |
-
)
|
349 |
-
|
350 |
-
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
351 |
-
transformer.to(accelerator.device, dtype=weight_dtype)
|
352 |
-
vae.to(accelerator.device, dtype=weight_dtype)
|
353 |
-
|
354 |
-
if args.gradient_checkpointing:
|
355 |
-
transformer.enable_gradient_checkpointing()
|
356 |
-
|
357 |
-
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
358 |
-
def save_model_hook(models, weights, output_dir):
|
359 |
-
if accelerator.is_main_process:
|
360 |
-
for model in models:
|
361 |
-
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
362 |
-
model: CogVideoXTransformer3DModel
|
363 |
-
model = unwrap_model(accelerator, model)
|
364 |
-
model.save_pretrained(
|
365 |
-
os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
366 |
-
)
|
367 |
-
else:
|
368 |
-
raise ValueError(f"Unexpected save model: {model.__class__}")
|
369 |
-
|
370 |
-
# make sure to pop weight so that corresponding model is not saved again
|
371 |
-
if weights:
|
372 |
-
weights.pop()
|
373 |
-
|
374 |
-
def load_model_hook(models, input_dir):
|
375 |
-
transformer_ = None
|
376 |
-
init_under_meta = False
|
377 |
-
|
378 |
-
# This is a bit of a hack but I don't know any other solution.
|
379 |
-
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
380 |
-
while len(models) > 0:
|
381 |
-
model = models.pop()
|
382 |
-
|
383 |
-
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
|
384 |
-
transformer_ = unwrap_model(accelerator, model)
|
385 |
-
else:
|
386 |
-
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
|
387 |
-
else:
|
388 |
-
with init_empty_weights():
|
389 |
-
transformer_ = CogVideoXTransformer3DModel.from_config(
|
390 |
-
args.pretrained_model_name_or_path, subfolder="transformer"
|
391 |
-
)
|
392 |
-
init_under_meta = True
|
393 |
-
|
394 |
-
load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer"))
|
395 |
-
transformer_.register_to_config(**load_model.config)
|
396 |
-
transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta)
|
397 |
-
del load_model
|
398 |
-
|
399 |
-
# Make sure the trainable params are in float32. This is again needed since the base models
|
400 |
-
# are in `weight_dtype`. More details:
|
401 |
-
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
402 |
-
if args.mixed_precision == "fp16":
|
403 |
-
cast_training_params([transformer_])
|
404 |
-
|
405 |
-
accelerator.register_save_state_pre_hook(save_model_hook)
|
406 |
-
accelerator.register_load_state_pre_hook(load_model_hook)
|
407 |
-
|
408 |
-
# Enable TF32 for faster training on Ampere GPUs,
|
409 |
-
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
410 |
-
if args.allow_tf32 and torch.cuda.is_available():
|
411 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
412 |
-
|
413 |
-
if args.scale_lr:
|
414 |
-
args.learning_rate = (
|
415 |
-
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
416 |
-
)
|
417 |
-
|
418 |
-
# Make sure the trainable params are in float32.
|
419 |
-
if args.mixed_precision == "fp16":
|
420 |
-
# only upcast trainable parameters (LoRA) into fp32
|
421 |
-
cast_training_params([transformer], dtype=torch.float32)
|
422 |
-
|
423 |
-
transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
424 |
-
|
425 |
-
# Optimization parameters
|
426 |
-
transformer_parameters_with_lr = {
|
427 |
-
"params": transformer_parameters,
|
428 |
-
"lr": args.learning_rate,
|
429 |
-
}
|
430 |
-
params_to_optimize = [transformer_parameters_with_lr]
|
431 |
-
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
432 |
-
|
433 |
-
use_deepspeed_optimizer = (
|
434 |
-
accelerator.state.deepspeed_plugin is not None
|
435 |
-
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
436 |
-
)
|
437 |
-
use_deepspeed_scheduler = (
|
438 |
-
accelerator.state.deepspeed_plugin is not None
|
439 |
-
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
440 |
-
)
|
441 |
-
|
442 |
-
optimizer = get_optimizer(
|
443 |
-
params_to_optimize=params_to_optimize,
|
444 |
-
optimizer_name=args.optimizer,
|
445 |
-
learning_rate=args.learning_rate,
|
446 |
-
beta1=args.beta1,
|
447 |
-
beta2=args.beta2,
|
448 |
-
beta3=args.beta3,
|
449 |
-
epsilon=args.epsilon,
|
450 |
-
weight_decay=args.weight_decay,
|
451 |
-
prodigy_decouple=args.prodigy_decouple,
|
452 |
-
prodigy_use_bias_correction=args.prodigy_use_bias_correction,
|
453 |
-
prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
|
454 |
-
use_8bit=args.use_8bit,
|
455 |
-
use_4bit=args.use_4bit,
|
456 |
-
use_torchao=args.use_torchao,
|
457 |
-
use_deepspeed=use_deepspeed_optimizer,
|
458 |
-
use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
|
459 |
-
offload_gradients=args.offload_gradients,
|
460 |
-
)
|
461 |
-
|
462 |
-
# Dataset and DataLoader
|
463 |
-
dataset_init_kwargs = {
|
464 |
-
"data_root": args.data_root,
|
465 |
-
"dataset_file": args.dataset_file,
|
466 |
-
"caption_column": args.caption_column,
|
467 |
-
"video_column": args.video_column,
|
468 |
-
"max_num_frames": args.max_num_frames,
|
469 |
-
"id_token": args.id_token,
|
470 |
-
"height_buckets": args.height_buckets,
|
471 |
-
"width_buckets": args.width_buckets,
|
472 |
-
"frame_buckets": args.frame_buckets,
|
473 |
-
"load_tensors": args.load_tensors,
|
474 |
-
"random_flip": args.random_flip,
|
475 |
-
}
|
476 |
-
if args.video_reshape_mode is None:
|
477 |
-
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
|
478 |
-
else:
|
479 |
-
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
|
480 |
-
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
|
481 |
-
)
|
482 |
-
|
483 |
-
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
|
484 |
-
|
485 |
-
train_dataloader = DataLoader(
|
486 |
-
train_dataset,
|
487 |
-
batch_size=1,
|
488 |
-
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
|
489 |
-
collate_fn=collate_fn,
|
490 |
-
num_workers=args.dataloader_num_workers,
|
491 |
-
pin_memory=args.pin_memory,
|
492 |
-
)
|
493 |
-
|
494 |
-
# Scheduler and math around the number of training steps.
|
495 |
-
overrode_max_train_steps = False
|
496 |
-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
497 |
-
if args.max_train_steps is None:
|
498 |
-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
499 |
-
overrode_max_train_steps = True
|
500 |
-
|
501 |
-
if args.use_cpu_offload_optimizer:
|
502 |
-
lr_scheduler = None
|
503 |
-
accelerator.print(
|
504 |
-
"CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
|
505 |
-
"you are training with those settings, they will be ignored."
|
506 |
-
)
|
507 |
-
else:
|
508 |
-
if use_deepspeed_scheduler:
|
509 |
-
from accelerate.utils import DummyScheduler
|
510 |
-
|
511 |
-
lr_scheduler = DummyScheduler(
|
512 |
-
name=args.lr_scheduler,
|
513 |
-
optimizer=optimizer,
|
514 |
-
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
515 |
-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
516 |
-
)
|
517 |
-
else:
|
518 |
-
lr_scheduler = get_scheduler(
|
519 |
-
args.lr_scheduler,
|
520 |
-
optimizer=optimizer,
|
521 |
-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
522 |
-
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
523 |
-
num_cycles=args.lr_num_cycles,
|
524 |
-
power=args.lr_power,
|
525 |
-
)
|
526 |
-
|
527 |
-
# Prepare everything with our `accelerator`.
|
528 |
-
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
529 |
-
transformer, optimizer, train_dataloader, lr_scheduler
|
530 |
-
)
|
531 |
-
|
532 |
-
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
533 |
-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
534 |
-
if overrode_max_train_steps:
|
535 |
-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
536 |
-
# Afterwards we recalculate our number of training epochs
|
537 |
-
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
538 |
-
|
539 |
-
# We need to initialize the trackers we use, and also store our configuration.
|
540 |
-
# The trackers initializes automatically on the main process.
|
541 |
-
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
542 |
-
tracker_name = args.tracker_name or "cogvideox-sft"
|
543 |
-
accelerator.init_trackers(tracker_name, config=vars(args))
|
544 |
-
|
545 |
-
accelerator.print("===== Memory before training =====")
|
546 |
-
reset_memory(accelerator.device)
|
547 |
-
print_memory(accelerator.device)
|
548 |
-
|
549 |
-
# Train!
|
550 |
-
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
551 |
-
|
552 |
-
accelerator.print("***** Running training *****")
|
553 |
-
accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
|
554 |
-
accelerator.print(f" Num examples = {len(train_dataset)}")
|
555 |
-
accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
|
556 |
-
accelerator.print(f" Num epochs = {args.num_train_epochs}")
|
557 |
-
accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
558 |
-
accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
559 |
-
accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
560 |
-
accelerator.print(f" Total optimization steps = {args.max_train_steps}")
|
561 |
-
global_step = 0
|
562 |
-
first_epoch = 0
|
563 |
-
|
564 |
-
# Potentially load in the weights and states from a previous save
|
565 |
-
if not args.resume_from_checkpoint:
|
566 |
-
initial_global_step = 0
|
567 |
-
else:
|
568 |
-
if args.resume_from_checkpoint != "latest":
|
569 |
-
path = os.path.basename(args.resume_from_checkpoint)
|
570 |
-
else:
|
571 |
-
# Get the most recent checkpoint
|
572 |
-
dirs = os.listdir(args.output_dir)
|
573 |
-
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
574 |
-
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
575 |
-
path = dirs[-1] if len(dirs) > 0 else None
|
576 |
-
|
577 |
-
if path is None:
|
578 |
-
accelerator.print(
|
579 |
-
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
580 |
-
)
|
581 |
-
args.resume_from_checkpoint = None
|
582 |
-
initial_global_step = 0
|
583 |
-
else:
|
584 |
-
accelerator.print(f"Resuming from checkpoint {path}")
|
585 |
-
accelerator.load_state(os.path.join(args.output_dir, path))
|
586 |
-
global_step = int(path.split("-")[1])
|
587 |
-
|
588 |
-
initial_global_step = global_step
|
589 |
-
first_epoch = global_step // num_update_steps_per_epoch
|
590 |
-
|
591 |
-
progress_bar = tqdm(
|
592 |
-
range(0, args.max_train_steps),
|
593 |
-
initial=initial_global_step,
|
594 |
-
desc="Steps",
|
595 |
-
# Only show the progress bar once on each machine.
|
596 |
-
disable=not accelerator.is_local_main_process,
|
597 |
-
)
|
598 |
-
|
599 |
-
# For DeepSpeed training
|
600 |
-
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
601 |
-
|
602 |
-
if args.load_tensors:
|
603 |
-
del vae, text_encoder
|
604 |
-
gc.collect()
|
605 |
-
torch.cuda.empty_cache()
|
606 |
-
torch.cuda.synchronize(accelerator.device)
|
607 |
-
|
608 |
-
alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
|
609 |
-
|
610 |
-
for epoch in range(first_epoch, args.num_train_epochs):
|
611 |
-
transformer.train()
|
612 |
-
|
613 |
-
for step, batch in enumerate(train_dataloader):
|
614 |
-
models_to_accumulate = [transformer]
|
615 |
-
logs = {}
|
616 |
-
|
617 |
-
with accelerator.accumulate(models_to_accumulate):
|
618 |
-
videos = batch["videos"].to(accelerator.device, non_blocking=True)
|
619 |
-
prompts = batch["prompts"]
|
620 |
-
|
621 |
-
# Encode videos
|
622 |
-
if not args.load_tensors:
|
623 |
-
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
624 |
-
latent_dist = vae.encode(videos).latent_dist
|
625 |
-
else:
|
626 |
-
latent_dist = DiagonalGaussianDistribution(videos)
|
627 |
-
|
628 |
-
videos = latent_dist.sample() * VAE_SCALING_FACTOR
|
629 |
-
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
630 |
-
videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
631 |
-
model_input = videos
|
632 |
-
|
633 |
-
# Encode prompts
|
634 |
-
if not args.load_tensors:
|
635 |
-
prompt_embeds = compute_prompt_embeddings(
|
636 |
-
tokenizer,
|
637 |
-
text_encoder,
|
638 |
-
prompts,
|
639 |
-
model_config.max_text_seq_length,
|
640 |
-
accelerator.device,
|
641 |
-
weight_dtype,
|
642 |
-
requires_grad=False,
|
643 |
-
)
|
644 |
-
else:
|
645 |
-
prompt_embeds = prompts.to(dtype=weight_dtype)
|
646 |
-
|
647 |
-
# Sample noise that will be added to the latents
|
648 |
-
noise = torch.randn_like(model_input)
|
649 |
-
batch_size, num_frames, num_channels, height, width = model_input.shape
|
650 |
-
|
651 |
-
# Sample a random timestep for each image
|
652 |
-
timesteps = torch.randint(
|
653 |
-
0,
|
654 |
-
scheduler.config.num_train_timesteps,
|
655 |
-
(batch_size,),
|
656 |
-
dtype=torch.int64,
|
657 |
-
device=model_input.device,
|
658 |
-
)
|
659 |
-
|
660 |
-
# Prepare rotary embeds
|
661 |
-
image_rotary_emb = (
|
662 |
-
prepare_rotary_positional_embeddings(
|
663 |
-
height=height * VAE_SCALE_FACTOR_SPATIAL,
|
664 |
-
width=width * VAE_SCALE_FACTOR_SPATIAL,
|
665 |
-
num_frames=num_frames,
|
666 |
-
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
|
667 |
-
patch_size=model_config.patch_size,
|
668 |
-
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
|
669 |
-
attention_head_dim=model_config.attention_head_dim,
|
670 |
-
device=accelerator.device,
|
671 |
-
base_height=RoPE_BASE_HEIGHT,
|
672 |
-
base_width=RoPE_BASE_WIDTH,
|
673 |
-
)
|
674 |
-
if model_config.use_rotary_positional_embeddings
|
675 |
-
else None
|
676 |
-
)
|
677 |
-
|
678 |
-
# Add noise to the model input according to the noise magnitude at each timestep
|
679 |
-
# (this is the forward diffusion process)
|
680 |
-
noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
|
681 |
-
|
682 |
-
# Predict the noise residual
|
683 |
-
model_output = transformer(
|
684 |
-
hidden_states=noisy_model_input,
|
685 |
-
encoder_hidden_states=prompt_embeds,
|
686 |
-
timestep=timesteps,
|
687 |
-
image_rotary_emb=image_rotary_emb,
|
688 |
-
return_dict=False,
|
689 |
-
)[0]
|
690 |
-
|
691 |
-
model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
|
692 |
-
|
693 |
-
weights = 1 / (1 - alphas_cumprod[timesteps])
|
694 |
-
while len(weights.shape) < len(model_pred.shape):
|
695 |
-
weights = weights.unsqueeze(-1)
|
696 |
-
|
697 |
-
target = model_input
|
698 |
-
|
699 |
-
loss = torch.mean(
|
700 |
-
(weights * (model_pred - target) ** 2).reshape(batch_size, -1),
|
701 |
-
dim=1,
|
702 |
-
)
|
703 |
-
loss = loss.mean()
|
704 |
-
accelerator.backward(loss)
|
705 |
-
|
706 |
-
if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
|
707 |
-
gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
|
708 |
-
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
|
709 |
-
gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
|
710 |
-
logs.update(
|
711 |
-
{
|
712 |
-
"gradient_norm_before_clip": gradient_norm_before_clip,
|
713 |
-
"gradient_norm_after_clip": gradient_norm_after_clip,
|
714 |
-
}
|
715 |
-
)
|
716 |
-
|
717 |
-
if accelerator.state.deepspeed_plugin is None:
|
718 |
-
optimizer.step()
|
719 |
-
optimizer.zero_grad()
|
720 |
-
|
721 |
-
if not args.use_cpu_offload_optimizer:
|
722 |
-
lr_scheduler.step()
|
723 |
-
|
724 |
-
# Checks if the accelerator has performed an optimization step behind the scenes
|
725 |
-
if accelerator.sync_gradients:
|
726 |
-
progress_bar.update(1)
|
727 |
-
global_step += 1
|
728 |
-
|
729 |
-
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
730 |
-
if global_step % args.checkpointing_steps == 0:
|
731 |
-
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
732 |
-
if args.checkpoints_total_limit is not None:
|
733 |
-
checkpoints = os.listdir(args.output_dir)
|
734 |
-
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
735 |
-
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
736 |
-
|
737 |
-
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
738 |
-
if len(checkpoints) >= args.checkpoints_total_limit:
|
739 |
-
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
740 |
-
removing_checkpoints = checkpoints[0:num_to_remove]
|
741 |
-
|
742 |
-
logger.info(
|
743 |
-
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
744 |
-
)
|
745 |
-
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
|
746 |
-
|
747 |
-
for removing_checkpoint in removing_checkpoints:
|
748 |
-
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
749 |
-
shutil.rmtree(removing_checkpoint)
|
750 |
-
|
751 |
-
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
752 |
-
accelerator.save_state(save_path)
|
753 |
-
logger.info(f"Saved state to {save_path}")
|
754 |
-
|
755 |
-
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
|
756 |
-
logs.update(
|
757 |
-
{
|
758 |
-
"loss": loss.detach().item(),
|
759 |
-
"lr": last_lr,
|
760 |
-
}
|
761 |
-
)
|
762 |
-
progress_bar.set_postfix(**logs)
|
763 |
-
accelerator.log(logs, step=global_step)
|
764 |
-
|
765 |
-
if global_step >= args.max_train_steps:
|
766 |
-
break
|
767 |
-
|
768 |
-
if accelerator.is_main_process:
|
769 |
-
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
|
770 |
-
accelerator.print("===== Memory before validation =====")
|
771 |
-
print_memory(accelerator.device)
|
772 |
-
torch.cuda.synchronize(accelerator.device)
|
773 |
-
|
774 |
-
pipe = CogVideoXPipeline.from_pretrained(
|
775 |
-
args.pretrained_model_name_or_path,
|
776 |
-
transformer=unwrap_model(accelerator, transformer),
|
777 |
-
scheduler=scheduler,
|
778 |
-
revision=args.revision,
|
779 |
-
variant=args.variant,
|
780 |
-
torch_dtype=weight_dtype,
|
781 |
-
)
|
782 |
-
|
783 |
-
if args.enable_slicing:
|
784 |
-
pipe.vae.enable_slicing()
|
785 |
-
if args.enable_tiling:
|
786 |
-
pipe.vae.enable_tiling()
|
787 |
-
if args.enable_model_cpu_offload:
|
788 |
-
pipe.enable_model_cpu_offload()
|
789 |
-
|
790 |
-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
791 |
-
for validation_prompt in validation_prompts:
|
792 |
-
pipeline_args = {
|
793 |
-
"prompt": validation_prompt,
|
794 |
-
"guidance_scale": args.guidance_scale,
|
795 |
-
"use_dynamic_cfg": args.use_dynamic_cfg,
|
796 |
-
"height": args.height,
|
797 |
-
"width": args.width,
|
798 |
-
"max_sequence_length": model_config.max_text_seq_length,
|
799 |
-
}
|
800 |
-
|
801 |
-
log_validation(
|
802 |
-
accelerator=accelerator,
|
803 |
-
pipe=pipe,
|
804 |
-
args=args,
|
805 |
-
pipeline_args=pipeline_args,
|
806 |
-
epoch=epoch,
|
807 |
-
is_final_validation=False,
|
808 |
-
)
|
809 |
-
|
810 |
-
accelerator.print("===== Memory after validation =====")
|
811 |
-
print_memory(accelerator.device)
|
812 |
-
reset_memory(accelerator.device)
|
813 |
-
|
814 |
-
del pipe
|
815 |
-
gc.collect()
|
816 |
-
torch.cuda.empty_cache()
|
817 |
-
torch.cuda.synchronize(accelerator.device)
|
818 |
-
|
819 |
-
accelerator.wait_for_everyone()
|
820 |
-
|
821 |
-
if accelerator.is_main_process:
|
822 |
-
transformer = unwrap_model(accelerator, transformer)
|
823 |
-
dtype = (
|
824 |
-
torch.float16
|
825 |
-
if args.mixed_precision == "fp16"
|
826 |
-
else torch.bfloat16
|
827 |
-
if args.mixed_precision == "bf16"
|
828 |
-
else torch.float32
|
829 |
-
)
|
830 |
-
transformer = transformer.to(dtype)
|
831 |
-
|
832 |
-
transformer.save_pretrained(
|
833 |
-
os.path.join(args.output_dir, "transformer"),
|
834 |
-
safe_serialization=True,
|
835 |
-
max_shard_size="5GB",
|
836 |
-
)
|
837 |
-
|
838 |
-
# Cleanup trained models to save memory
|
839 |
-
if args.load_tensors:
|
840 |
-
del transformer
|
841 |
-
else:
|
842 |
-
del transformer, text_encoder, vae
|
843 |
-
|
844 |
-
gc.collect()
|
845 |
-
torch.cuda.empty_cache()
|
846 |
-
torch.cuda.synchronize(accelerator.device)
|
847 |
-
|
848 |
-
accelerator.print("===== Memory before testing =====")
|
849 |
-
print_memory(accelerator.device)
|
850 |
-
reset_memory(accelerator.device)
|
851 |
-
|
852 |
-
# Final test inference
|
853 |
-
pipe = CogVideoXPipeline.from_pretrained(
|
854 |
-
args.pretrained_model_name_or_path,
|
855 |
-
revision=args.revision,
|
856 |
-
variant=args.variant,
|
857 |
-
torch_dtype=weight_dtype,
|
858 |
-
)
|
859 |
-
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
860 |
-
|
861 |
-
if args.enable_slicing:
|
862 |
-
pipe.vae.enable_slicing()
|
863 |
-
if args.enable_tiling:
|
864 |
-
pipe.vae.enable_tiling()
|
865 |
-
if args.enable_model_cpu_offload:
|
866 |
-
pipe.enable_model_cpu_offload()
|
867 |
-
|
868 |
-
# Run inference
|
869 |
-
validation_outputs = []
|
870 |
-
if args.validation_prompt and args.num_validation_videos > 0:
|
871 |
-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
872 |
-
for validation_prompt in validation_prompts:
|
873 |
-
pipeline_args = {
|
874 |
-
"prompt": validation_prompt,
|
875 |
-
"guidance_scale": args.guidance_scale,
|
876 |
-
"use_dynamic_cfg": args.use_dynamic_cfg,
|
877 |
-
"height": args.height,
|
878 |
-
"width": args.width,
|
879 |
-
}
|
880 |
-
|
881 |
-
video = log_validation(
|
882 |
-
accelerator=accelerator,
|
883 |
-
pipe=pipe,
|
884 |
-
args=args,
|
885 |
-
pipeline_args=pipeline_args,
|
886 |
-
epoch=epoch,
|
887 |
-
is_final_validation=True,
|
888 |
-
)
|
889 |
-
validation_outputs.extend(video)
|
890 |
-
|
891 |
-
accelerator.print("===== Memory after testing =====")
|
892 |
-
print_memory(accelerator.device)
|
893 |
-
reset_memory(accelerator.device)
|
894 |
-
torch.cuda.synchronize(accelerator.device)
|
895 |
-
|
896 |
-
if args.push_to_hub:
|
897 |
-
save_model_card(
|
898 |
-
repo_id,
|
899 |
-
videos=validation_outputs,
|
900 |
-
base_model=args.pretrained_model_name_or_path,
|
901 |
-
validation_prompt=args.validation_prompt,
|
902 |
-
repo_folder=args.output_dir,
|
903 |
-
fps=args.fps,
|
904 |
-
)
|
905 |
-
upload_folder(
|
906 |
-
repo_id=repo_id,
|
907 |
-
folder_path=args.output_dir,
|
908 |
-
commit_message="End of training",
|
909 |
-
ignore_patterns=["step_*", "epoch_*"],
|
910 |
-
)
|
911 |
-
|
912 |
-
accelerator.end_training()
|
913 |
-
|
914 |
-
|
915 |
-
if __name__ == "__main__":
|
916 |
-
args = get_args()
|
917 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/cogvideox/dataset.py
DELETED
@@ -1,428 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Any, Dict, List, Optional, Tuple
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import pandas as pd
|
7 |
-
import torch
|
8 |
-
import torchvision.transforms as TT
|
9 |
-
from accelerate.logging import get_logger
|
10 |
-
from torch.utils.data import Dataset, Sampler
|
11 |
-
from torchvision import transforms
|
12 |
-
from torchvision.transforms import InterpolationMode
|
13 |
-
from torchvision.transforms.functional import resize
|
14 |
-
|
15 |
-
|
16 |
-
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
17 |
-
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
18 |
-
import decord # isort:skip
|
19 |
-
|
20 |
-
decord.bridge.set_bridge("torch")
|
21 |
-
|
22 |
-
logger = get_logger(__name__)
|
23 |
-
|
24 |
-
HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
|
25 |
-
WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
|
26 |
-
FRAME_BUCKETS = [16, 24, 32, 48, 64, 80]
|
27 |
-
|
28 |
-
|
29 |
-
class VideoDataset(Dataset):
|
30 |
-
def __init__(
|
31 |
-
self,
|
32 |
-
data_root: str,
|
33 |
-
dataset_file: Optional[str] = None,
|
34 |
-
caption_column: str = "text",
|
35 |
-
video_column: str = "video",
|
36 |
-
max_num_frames: int = 49,
|
37 |
-
id_token: Optional[str] = None,
|
38 |
-
height_buckets: List[int] = None,
|
39 |
-
width_buckets: List[int] = None,
|
40 |
-
frame_buckets: List[int] = None,
|
41 |
-
load_tensors: bool = False,
|
42 |
-
random_flip: Optional[float] = None,
|
43 |
-
image_to_video: bool = False,
|
44 |
-
) -> None:
|
45 |
-
super().__init__()
|
46 |
-
|
47 |
-
self.data_root = Path(data_root)
|
48 |
-
self.dataset_file = dataset_file
|
49 |
-
self.caption_column = caption_column
|
50 |
-
self.video_column = video_column
|
51 |
-
self.max_num_frames = max_num_frames
|
52 |
-
self.id_token = f"{id_token.strip()} " if id_token else ""
|
53 |
-
self.height_buckets = height_buckets or HEIGHT_BUCKETS
|
54 |
-
self.width_buckets = width_buckets or WIDTH_BUCKETS
|
55 |
-
self.frame_buckets = frame_buckets or FRAME_BUCKETS
|
56 |
-
self.load_tensors = load_tensors
|
57 |
-
self.random_flip = random_flip
|
58 |
-
self.image_to_video = image_to_video
|
59 |
-
|
60 |
-
self.resolution_buckets = [
|
61 |
-
(f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets
|
62 |
-
]
|
63 |
-
|
64 |
-
# Two methods of loading data are supported.
|
65 |
-
# - Using a CSV: caption_column and video_column must be some column in the CSV. One could
|
66 |
-
# make use of other columns too, such as a motion score or aesthetic score, by modifying the
|
67 |
-
# logic in CSV processing.
|
68 |
-
# - Using two files containing line-separate captions and relative paths to videos.
|
69 |
-
# For a more detailed explanation about preparing dataset format, checkout the README.
|
70 |
-
if dataset_file is None:
|
71 |
-
(
|
72 |
-
self.prompts,
|
73 |
-
self.video_paths,
|
74 |
-
) = self._load_dataset_from_local_path()
|
75 |
-
else:
|
76 |
-
(
|
77 |
-
self.prompts,
|
78 |
-
self.video_paths,
|
79 |
-
) = self._load_dataset_from_csv()
|
80 |
-
|
81 |
-
if len(self.video_paths) != len(self.prompts):
|
82 |
-
raise ValueError(
|
83 |
-
f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
|
84 |
-
)
|
85 |
-
|
86 |
-
self.video_transforms = transforms.Compose(
|
87 |
-
[
|
88 |
-
transforms.RandomHorizontalFlip(random_flip)
|
89 |
-
if random_flip
|
90 |
-
else transforms.Lambda(self.identity_transform),
|
91 |
-
transforms.Lambda(self.scale_transform),
|
92 |
-
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
93 |
-
]
|
94 |
-
)
|
95 |
-
|
96 |
-
@staticmethod
|
97 |
-
def identity_transform(x):
|
98 |
-
return x
|
99 |
-
|
100 |
-
@staticmethod
|
101 |
-
def scale_transform(x):
|
102 |
-
return x / 255.0
|
103 |
-
|
104 |
-
def __len__(self) -> int:
|
105 |
-
return len(self.video_paths)
|
106 |
-
|
107 |
-
def __getitem__(self, index: int) -> Dict[str, Any]:
|
108 |
-
if isinstance(index, list):
|
109 |
-
# Here, index is actually a list of data objects that we need to return.
|
110 |
-
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
|
111 |
-
# to have information about num_frames, height and width. Since this is not stored
|
112 |
-
# as metadata, we need to read the video to get this information. You could read this
|
113 |
-
# information without loading the full video in memory, but we do it anyway. In order
|
114 |
-
# to not load the video twice (once to get the metadata, and once to return the loaded video
|
115 |
-
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
|
116 |
-
# to yield, we yield the cache data instead of indices. So, this special check ensures
|
117 |
-
# that data is not loaded a second time. PRs are welcome for improvements.
|
118 |
-
return index
|
119 |
-
|
120 |
-
if self.load_tensors:
|
121 |
-
image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index])
|
122 |
-
|
123 |
-
# This is hardcoded for now.
|
124 |
-
# The VAE's temporal compression ratio is 4.
|
125 |
-
# The VAE's spatial compression ratio is 8.
|
126 |
-
latent_num_frames = video_latents.size(1)
|
127 |
-
if latent_num_frames % 2 == 0:
|
128 |
-
num_frames = latent_num_frames * 4
|
129 |
-
else:
|
130 |
-
num_frames = (latent_num_frames - 1) * 4 + 1
|
131 |
-
|
132 |
-
height = video_latents.size(2) * 8
|
133 |
-
width = video_latents.size(3) * 8
|
134 |
-
|
135 |
-
return {
|
136 |
-
"prompt": prompt_embeds,
|
137 |
-
"image": image_latents,
|
138 |
-
"video": video_latents,
|
139 |
-
"video_metadata": {
|
140 |
-
"num_frames": num_frames,
|
141 |
-
"height": height,
|
142 |
-
"width": width,
|
143 |
-
},
|
144 |
-
}
|
145 |
-
else:
|
146 |
-
image, video, _ = self._preprocess_video(self.video_paths[index])
|
147 |
-
|
148 |
-
return {
|
149 |
-
"prompt": self.id_token + self.prompts[index],
|
150 |
-
"image": image,
|
151 |
-
"video": video,
|
152 |
-
"video_metadata": {
|
153 |
-
"num_frames": video.shape[0],
|
154 |
-
"height": video.shape[2],
|
155 |
-
"width": video.shape[3],
|
156 |
-
},
|
157 |
-
}
|
158 |
-
|
159 |
-
def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]:
|
160 |
-
if not self.data_root.exists():
|
161 |
-
raise ValueError("Root folder for videos does not exist")
|
162 |
-
|
163 |
-
prompt_path = self.data_root.joinpath(self.caption_column)
|
164 |
-
video_path = self.data_root.joinpath(self.video_column)
|
165 |
-
|
166 |
-
if not prompt_path.exists() or not prompt_path.is_file():
|
167 |
-
raise ValueError(
|
168 |
-
"Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
|
169 |
-
)
|
170 |
-
if not video_path.exists() or not video_path.is_file():
|
171 |
-
raise ValueError(
|
172 |
-
"Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
|
173 |
-
)
|
174 |
-
|
175 |
-
with open(prompt_path, "r", encoding="utf-8") as file:
|
176 |
-
prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
177 |
-
with open(video_path, "r", encoding="utf-8") as file:
|
178 |
-
video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
|
179 |
-
|
180 |
-
if not self.load_tensors and any(not path.is_file() for path in video_paths):
|
181 |
-
raise ValueError(
|
182 |
-
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
|
183 |
-
)
|
184 |
-
|
185 |
-
return prompts, video_paths
|
186 |
-
|
187 |
-
def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]:
|
188 |
-
df = pd.read_csv(self.dataset_file)
|
189 |
-
prompts = df[self.caption_column].tolist()
|
190 |
-
video_paths = df[self.video_column].tolist()
|
191 |
-
video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
|
192 |
-
|
193 |
-
if any(not path.is_file() for path in video_paths):
|
194 |
-
raise ValueError(
|
195 |
-
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
|
196 |
-
)
|
197 |
-
|
198 |
-
return prompts, video_paths
|
199 |
-
|
200 |
-
def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
201 |
-
r"""
|
202 |
-
Loads a single video, or latent and prompt embedding, based on initialization parameters.
|
203 |
-
|
204 |
-
If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here,
|
205 |
-
F, C, H and W are the frames, channels, height and width of the input video.
|
206 |
-
|
207 |
-
If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D].
|
208 |
-
F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length
|
209 |
-
and embedding dimension of prompt embeddings.
|
210 |
-
"""
|
211 |
-
if self.load_tensors:
|
212 |
-
return self._load_preprocessed_latents_and_embeds(path)
|
213 |
-
else:
|
214 |
-
video_reader = decord.VideoReader(uri=path.as_posix())
|
215 |
-
video_num_frames = len(video_reader)
|
216 |
-
|
217 |
-
indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
|
218 |
-
frames = video_reader.get_batch(indices)
|
219 |
-
frames = frames[: self.max_num_frames].float()
|
220 |
-
frames = frames.permute(0, 3, 1, 2).contiguous()
|
221 |
-
frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
|
222 |
-
|
223 |
-
image = frames[:1].clone() if self.image_to_video else None
|
224 |
-
|
225 |
-
return image, frames, None
|
226 |
-
|
227 |
-
def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
228 |
-
filename_without_ext = path.name.split(".")[0]
|
229 |
-
pt_filename = f"{filename_without_ext}.pt"
|
230 |
-
|
231 |
-
# The current path is something like: /a/b/c/d/videos/00001.mp4
|
232 |
-
# We need to reach: /a/b/c/d/video_latents/00001.pt
|
233 |
-
image_latents_path = path.parent.parent.joinpath("image_latents")
|
234 |
-
video_latents_path = path.parent.parent.joinpath("video_latents")
|
235 |
-
embeds_path = path.parent.parent.joinpath("prompt_embeds")
|
236 |
-
|
237 |
-
if (
|
238 |
-
not video_latents_path.exists()
|
239 |
-
or not embeds_path.exists()
|
240 |
-
or (self.image_to_video and not image_latents_path.exists())
|
241 |
-
):
|
242 |
-
raise ValueError(
|
243 |
-
f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present."
|
244 |
-
)
|
245 |
-
|
246 |
-
if self.image_to_video:
|
247 |
-
image_latent_filepath = image_latents_path.joinpath(pt_filename)
|
248 |
-
video_latent_filepath = video_latents_path.joinpath(pt_filename)
|
249 |
-
embeds_filepath = embeds_path.joinpath(pt_filename)
|
250 |
-
|
251 |
-
if not video_latent_filepath.is_file() or not embeds_filepath.is_file():
|
252 |
-
if self.image_to_video:
|
253 |
-
image_latent_filepath = image_latent_filepath.as_posix()
|
254 |
-
video_latent_filepath = video_latent_filepath.as_posix()
|
255 |
-
embeds_filepath = embeds_filepath.as_posix()
|
256 |
-
raise ValueError(
|
257 |
-
f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`."
|
258 |
-
)
|
259 |
-
|
260 |
-
images = (
|
261 |
-
torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None
|
262 |
-
)
|
263 |
-
latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True)
|
264 |
-
embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True)
|
265 |
-
|
266 |
-
return images, latents, embeds
|
267 |
-
|
268 |
-
|
269 |
-
class VideoDatasetWithResizing(VideoDataset):
|
270 |
-
def __init__(self, *args, **kwargs) -> None:
|
271 |
-
super().__init__(*args, **kwargs)
|
272 |
-
|
273 |
-
def _preprocess_video(self, path: Path) -> torch.Tensor:
|
274 |
-
if self.load_tensors:
|
275 |
-
return self._load_preprocessed_latents_and_embeds(path)
|
276 |
-
else:
|
277 |
-
video_reader = decord.VideoReader(uri=path.as_posix())
|
278 |
-
video_num_frames = len(video_reader)
|
279 |
-
nearest_frame_bucket = min(
|
280 |
-
self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
|
281 |
-
)
|
282 |
-
|
283 |
-
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
|
284 |
-
|
285 |
-
frames = video_reader.get_batch(frame_indices)
|
286 |
-
frames = frames[:nearest_frame_bucket].float()
|
287 |
-
frames = frames.permute(0, 3, 1, 2).contiguous()
|
288 |
-
|
289 |
-
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
|
290 |
-
frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
|
291 |
-
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
|
292 |
-
|
293 |
-
image = frames[:1].clone() if self.image_to_video else None
|
294 |
-
|
295 |
-
return image, frames, None
|
296 |
-
|
297 |
-
def _find_nearest_resolution(self, height, width):
|
298 |
-
nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
|
299 |
-
return nearest_res[1], nearest_res[2]
|
300 |
-
|
301 |
-
|
302 |
-
class VideoDatasetWithResizeAndRectangleCrop(VideoDataset):
|
303 |
-
def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
|
304 |
-
super().__init__(*args, **kwargs)
|
305 |
-
self.video_reshape_mode = video_reshape_mode
|
306 |
-
|
307 |
-
def _resize_for_rectangle_crop(self, arr, image_size):
|
308 |
-
reshape_mode = self.video_reshape_mode
|
309 |
-
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
310 |
-
arr = resize(
|
311 |
-
arr,
|
312 |
-
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
313 |
-
interpolation=InterpolationMode.BICUBIC,
|
314 |
-
)
|
315 |
-
else:
|
316 |
-
arr = resize(
|
317 |
-
arr,
|
318 |
-
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
319 |
-
interpolation=InterpolationMode.BICUBIC,
|
320 |
-
)
|
321 |
-
|
322 |
-
h, w = arr.shape[2], arr.shape[3]
|
323 |
-
arr = arr.squeeze(0)
|
324 |
-
|
325 |
-
delta_h = h - image_size[0]
|
326 |
-
delta_w = w - image_size[1]
|
327 |
-
|
328 |
-
if reshape_mode == "random" or reshape_mode == "none":
|
329 |
-
top = np.random.randint(0, delta_h + 1)
|
330 |
-
left = np.random.randint(0, delta_w + 1)
|
331 |
-
elif reshape_mode == "center":
|
332 |
-
top, left = delta_h // 2, delta_w // 2
|
333 |
-
else:
|
334 |
-
raise NotImplementedError
|
335 |
-
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
336 |
-
return arr
|
337 |
-
|
338 |
-
def _preprocess_video(self, path: Path) -> torch.Tensor:
|
339 |
-
if self.load_tensors:
|
340 |
-
return self._load_preprocessed_latents_and_embeds(path)
|
341 |
-
else:
|
342 |
-
video_reader = decord.VideoReader(uri=path.as_posix())
|
343 |
-
video_num_frames = len(video_reader)
|
344 |
-
nearest_frame_bucket = min(
|
345 |
-
self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
|
346 |
-
)
|
347 |
-
|
348 |
-
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
|
349 |
-
|
350 |
-
frames = video_reader.get_batch(frame_indices)
|
351 |
-
frames = frames[:nearest_frame_bucket].float()
|
352 |
-
frames = frames.permute(0, 3, 1, 2).contiguous()
|
353 |
-
|
354 |
-
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
|
355 |
-
frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
|
356 |
-
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
|
357 |
-
|
358 |
-
image = frames[:1].clone() if self.image_to_video else None
|
359 |
-
|
360 |
-
return image, frames, None
|
361 |
-
|
362 |
-
def _find_nearest_resolution(self, height, width):
|
363 |
-
nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
|
364 |
-
return nearest_res[1], nearest_res[2]
|
365 |
-
|
366 |
-
|
367 |
-
class BucketSampler(Sampler):
|
368 |
-
r"""
|
369 |
-
PyTorch Sampler that groups 3D data by height, width and frames.
|
370 |
-
|
371 |
-
Args:
|
372 |
-
data_source (`VideoDataset`):
|
373 |
-
A PyTorch dataset object that is an instance of `VideoDataset`.
|
374 |
-
batch_size (`int`, defaults to `8`):
|
375 |
-
The batch size to use for training.
|
376 |
-
shuffle (`bool`, defaults to `True`):
|
377 |
-
Whether or not to shuffle the data in each batch before dispatching to dataloader.
|
378 |
-
drop_last (`bool`, defaults to `False`):
|
379 |
-
Whether or not to drop incomplete buckets of data after completely iterating over all data
|
380 |
-
in the dataset. If set to True, only batches that have `batch_size` number of entries will
|
381 |
-
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
|
382 |
-
and batches that do not have `batch_size` number of entries will also be yielded.
|
383 |
-
"""
|
384 |
-
|
385 |
-
def __init__(
|
386 |
-
self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
|
387 |
-
) -> None:
|
388 |
-
self.data_source = data_source
|
389 |
-
self.batch_size = batch_size
|
390 |
-
self.shuffle = shuffle
|
391 |
-
self.drop_last = drop_last
|
392 |
-
|
393 |
-
self.buckets = {resolution: [] for resolution in data_source.resolutions}
|
394 |
-
|
395 |
-
self._raised_warning_for_drop_last = False
|
396 |
-
|
397 |
-
def __len__(self):
|
398 |
-
if self.drop_last and not self._raised_warning_for_drop_last:
|
399 |
-
self._raised_warning_for_drop_last = True
|
400 |
-
logger.warning(
|
401 |
-
"Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
|
402 |
-
)
|
403 |
-
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
404 |
-
|
405 |
-
def __iter__(self):
|
406 |
-
for index, data in enumerate(self.data_source):
|
407 |
-
video_metadata = data["video_metadata"]
|
408 |
-
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
|
409 |
-
|
410 |
-
self.buckets[(f, h, w)].append(data)
|
411 |
-
if len(self.buckets[(f, h, w)]) == self.batch_size:
|
412 |
-
if self.shuffle:
|
413 |
-
random.shuffle(self.buckets[(f, h, w)])
|
414 |
-
yield self.buckets[(f, h, w)]
|
415 |
-
del self.buckets[(f, h, w)]
|
416 |
-
self.buckets[(f, h, w)] = []
|
417 |
-
|
418 |
-
if self.drop_last:
|
419 |
-
return
|
420 |
-
|
421 |
-
for fhw, bucket in list(self.buckets.items()):
|
422 |
-
if len(bucket) == 0:
|
423 |
-
continue
|
424 |
-
if self.shuffle:
|
425 |
-
random.shuffle(bucket)
|
426 |
-
yield bucket
|
427 |
-
del self.buckets[fhw]
|
428 |
-
self.buckets[fhw] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/cogvideox/prepare_dataset.py
DELETED
@@ -1,669 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
|
3 |
-
import argparse
|
4 |
-
import functools
|
5 |
-
import json
|
6 |
-
import os
|
7 |
-
import pathlib
|
8 |
-
import queue
|
9 |
-
import traceback
|
10 |
-
import uuid
|
11 |
-
from concurrent.futures import ThreadPoolExecutor
|
12 |
-
from typing import Any, Dict, List, Optional, Union
|
13 |
-
|
14 |
-
import torch
|
15 |
-
import torch.distributed as dist
|
16 |
-
from diffusers import AutoencoderKLCogVideoX
|
17 |
-
from diffusers.training_utils import set_seed
|
18 |
-
from diffusers.utils import export_to_video, get_logger
|
19 |
-
from torch.utils.data import DataLoader
|
20 |
-
from torchvision import transforms
|
21 |
-
from tqdm import tqdm
|
22 |
-
from transformers import T5EncoderModel, T5Tokenizer
|
23 |
-
|
24 |
-
|
25 |
-
import decord # isort:skip
|
26 |
-
|
27 |
-
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
|
28 |
-
|
29 |
-
|
30 |
-
decord.bridge.set_bridge("torch")
|
31 |
-
|
32 |
-
logger = get_logger(__name__)
|
33 |
-
|
34 |
-
DTYPE_MAPPING = {
|
35 |
-
"fp32": torch.float32,
|
36 |
-
"fp16": torch.float16,
|
37 |
-
"bf16": torch.bfloat16,
|
38 |
-
}
|
39 |
-
|
40 |
-
|
41 |
-
def check_height(x: Any) -> int:
|
42 |
-
x = int(x)
|
43 |
-
if x % 16 != 0:
|
44 |
-
raise argparse.ArgumentTypeError(
|
45 |
-
f"`--height_buckets` must be divisible by 16, but got {x} which does not fit criteria."
|
46 |
-
)
|
47 |
-
return x
|
48 |
-
|
49 |
-
|
50 |
-
def check_width(x: Any) -> int:
|
51 |
-
x = int(x)
|
52 |
-
if x % 16 != 0:
|
53 |
-
raise argparse.ArgumentTypeError(
|
54 |
-
f"`--width_buckets` must be divisible by 16, but got {x} which does not fit criteria."
|
55 |
-
)
|
56 |
-
return x
|
57 |
-
|
58 |
-
|
59 |
-
def check_frames(x: Any) -> int:
|
60 |
-
x = int(x)
|
61 |
-
if x % 4 != 0 and x % 4 != 1:
|
62 |
-
raise argparse.ArgumentTypeError(
|
63 |
-
f"`--frames_buckets` must be of form `4 * k` or `4 * k + 1`, but got {x} which does not fit criteria."
|
64 |
-
)
|
65 |
-
return x
|
66 |
-
|
67 |
-
|
68 |
-
def get_args() -> Dict[str, Any]:
|
69 |
-
parser = argparse.ArgumentParser()
|
70 |
-
parser.add_argument(
|
71 |
-
"--model_id",
|
72 |
-
type=str,
|
73 |
-
default="THUDM/CogVideoX-2b",
|
74 |
-
help="Hugging Face model ID to use for tokenizer, text encoder and VAE.",
|
75 |
-
)
|
76 |
-
parser.add_argument("--data_root", type=str, required=True, help="Path to where training data is located.")
|
77 |
-
parser.add_argument(
|
78 |
-
"--dataset_file", type=str, default=None, help="Path to CSV file containing metadata about training data."
|
79 |
-
)
|
80 |
-
parser.add_argument(
|
81 |
-
"--caption_column",
|
82 |
-
type=str,
|
83 |
-
default="caption",
|
84 |
-
help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the captions. If using the folder structure format for data loading, this should be the name of the file containing line-separated captions (the file should be located in `--data_root`).",
|
85 |
-
)
|
86 |
-
parser.add_argument(
|
87 |
-
"--video_column",
|
88 |
-
type=str,
|
89 |
-
default="video",
|
90 |
-
help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the video paths. If using the folder structure format for data loading, this should be the name of the file containing line-separated video paths (the file should be located in `--data_root`).",
|
91 |
-
)
|
92 |
-
parser.add_argument(
|
93 |
-
"--id_token",
|
94 |
-
type=str,
|
95 |
-
default=None,
|
96 |
-
help="Identifier token appended to the start of each prompt if provided.",
|
97 |
-
)
|
98 |
-
parser.add_argument(
|
99 |
-
"--height_buckets",
|
100 |
-
nargs="+",
|
101 |
-
type=check_height,
|
102 |
-
default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
|
103 |
-
)
|
104 |
-
parser.add_argument(
|
105 |
-
"--width_buckets",
|
106 |
-
nargs="+",
|
107 |
-
type=check_width,
|
108 |
-
default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
|
109 |
-
)
|
110 |
-
parser.add_argument(
|
111 |
-
"--frame_buckets",
|
112 |
-
nargs="+",
|
113 |
-
type=check_frames,
|
114 |
-
default=[49],
|
115 |
-
)
|
116 |
-
parser.add_argument(
|
117 |
-
"--random_flip",
|
118 |
-
type=float,
|
119 |
-
default=None,
|
120 |
-
help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
|
121 |
-
)
|
122 |
-
parser.add_argument(
|
123 |
-
"--dataloader_num_workers",
|
124 |
-
type=int,
|
125 |
-
default=0,
|
126 |
-
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
127 |
-
)
|
128 |
-
parser.add_argument(
|
129 |
-
"--pin_memory",
|
130 |
-
action="store_true",
|
131 |
-
help="Whether or not to use the pinned memory setting in pytorch dataloader.",
|
132 |
-
)
|
133 |
-
parser.add_argument(
|
134 |
-
"--video_reshape_mode",
|
135 |
-
type=str,
|
136 |
-
default=None,
|
137 |
-
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
|
138 |
-
)
|
139 |
-
parser.add_argument(
|
140 |
-
"--save_image_latents",
|
141 |
-
action="store_true",
|
142 |
-
help="Whether or not to encode and store image latents, which are required for image-to-video finetuning. The image latents are the first frame of input videos encoded with the VAE.",
|
143 |
-
)
|
144 |
-
parser.add_argument(
|
145 |
-
"--output_dir",
|
146 |
-
type=str,
|
147 |
-
required=True,
|
148 |
-
help="Path to output directory where preprocessed videos/latents/embeddings will be saved.",
|
149 |
-
)
|
150 |
-
parser.add_argument("--max_num_frames", type=int, default=49, help="Maximum number of frames in output video.")
|
151 |
-
parser.add_argument(
|
152 |
-
"--max_sequence_length", type=int, default=226, help="Max sequence length of prompt embeddings."
|
153 |
-
)
|
154 |
-
parser.add_argument("--target_fps", type=int, default=8, help="Frame rate of output videos.")
|
155 |
-
parser.add_argument(
|
156 |
-
"--save_latents_and_embeddings",
|
157 |
-
action="store_true",
|
158 |
-
help="Whether to encode videos/captions to latents/embeddings and save them in pytorch serializable format.",
|
159 |
-
)
|
160 |
-
parser.add_argument(
|
161 |
-
"--use_slicing",
|
162 |
-
action="store_true",
|
163 |
-
help="Whether to enable sliced encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.",
|
164 |
-
)
|
165 |
-
parser.add_argument(
|
166 |
-
"--use_tiling",
|
167 |
-
action="store_true",
|
168 |
-
help="Whether to enable tiled encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.",
|
169 |
-
)
|
170 |
-
parser.add_argument("--batch_size", type=int, default=1, help="Number of videos to process at once in the VAE.")
|
171 |
-
parser.add_argument(
|
172 |
-
"--num_decode_threads",
|
173 |
-
type=int,
|
174 |
-
default=0,
|
175 |
-
help="Number of decoding threads for `decord` to use. The default `0` means to automatically determine required number of threads.",
|
176 |
-
)
|
177 |
-
parser.add_argument(
|
178 |
-
"--dtype",
|
179 |
-
type=str,
|
180 |
-
choices=["fp32", "fp16", "bf16"],
|
181 |
-
default="fp32",
|
182 |
-
help="Data type to use when generating latents and prompt embeddings.",
|
183 |
-
)
|
184 |
-
parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.")
|
185 |
-
parser.add_argument(
|
186 |
-
"--num_artifact_workers", type=int, default=4, help="Number of worker threads for serializing artifacts."
|
187 |
-
)
|
188 |
-
return parser.parse_args()
|
189 |
-
|
190 |
-
|
191 |
-
def _get_t5_prompt_embeds(
|
192 |
-
tokenizer: T5Tokenizer,
|
193 |
-
text_encoder: T5EncoderModel,
|
194 |
-
prompt: Union[str, List[str]],
|
195 |
-
num_videos_per_prompt: int = 1,
|
196 |
-
max_sequence_length: int = 226,
|
197 |
-
device: Optional[torch.device] = None,
|
198 |
-
dtype: Optional[torch.dtype] = None,
|
199 |
-
text_input_ids=None,
|
200 |
-
):
|
201 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
202 |
-
batch_size = len(prompt)
|
203 |
-
|
204 |
-
if tokenizer is not None:
|
205 |
-
text_inputs = tokenizer(
|
206 |
-
prompt,
|
207 |
-
padding="max_length",
|
208 |
-
max_length=max_sequence_length,
|
209 |
-
truncation=True,
|
210 |
-
add_special_tokens=True,
|
211 |
-
return_tensors="pt",
|
212 |
-
)
|
213 |
-
text_input_ids = text_inputs.input_ids
|
214 |
-
else:
|
215 |
-
if text_input_ids is None:
|
216 |
-
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
|
217 |
-
|
218 |
-
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
219 |
-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
220 |
-
|
221 |
-
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
222 |
-
_, seq_len, _ = prompt_embeds.shape
|
223 |
-
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
224 |
-
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
225 |
-
|
226 |
-
return prompt_embeds
|
227 |
-
|
228 |
-
|
229 |
-
def encode_prompt(
|
230 |
-
tokenizer: T5Tokenizer,
|
231 |
-
text_encoder: T5EncoderModel,
|
232 |
-
prompt: Union[str, List[str]],
|
233 |
-
num_videos_per_prompt: int = 1,
|
234 |
-
max_sequence_length: int = 226,
|
235 |
-
device: Optional[torch.device] = None,
|
236 |
-
dtype: Optional[torch.dtype] = None,
|
237 |
-
text_input_ids=None,
|
238 |
-
):
|
239 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
240 |
-
prompt_embeds = _get_t5_prompt_embeds(
|
241 |
-
tokenizer,
|
242 |
-
text_encoder,
|
243 |
-
prompt=prompt,
|
244 |
-
num_videos_per_prompt=num_videos_per_prompt,
|
245 |
-
max_sequence_length=max_sequence_length,
|
246 |
-
device=device,
|
247 |
-
dtype=dtype,
|
248 |
-
text_input_ids=text_input_ids,
|
249 |
-
)
|
250 |
-
return prompt_embeds
|
251 |
-
|
252 |
-
|
253 |
-
def compute_prompt_embeddings(
|
254 |
-
tokenizer: T5Tokenizer,
|
255 |
-
text_encoder: T5EncoderModel,
|
256 |
-
prompts: List[str],
|
257 |
-
max_sequence_length: int,
|
258 |
-
device: torch.device,
|
259 |
-
dtype: torch.dtype,
|
260 |
-
requires_grad: bool = False,
|
261 |
-
):
|
262 |
-
if requires_grad:
|
263 |
-
prompt_embeds = encode_prompt(
|
264 |
-
tokenizer,
|
265 |
-
text_encoder,
|
266 |
-
prompts,
|
267 |
-
num_videos_per_prompt=1,
|
268 |
-
max_sequence_length=max_sequence_length,
|
269 |
-
device=device,
|
270 |
-
dtype=dtype,
|
271 |
-
)
|
272 |
-
else:
|
273 |
-
with torch.no_grad():
|
274 |
-
prompt_embeds = encode_prompt(
|
275 |
-
tokenizer,
|
276 |
-
text_encoder,
|
277 |
-
prompts,
|
278 |
-
num_videos_per_prompt=1,
|
279 |
-
max_sequence_length=max_sequence_length,
|
280 |
-
device=device,
|
281 |
-
dtype=dtype,
|
282 |
-
)
|
283 |
-
return prompt_embeds
|
284 |
-
|
285 |
-
|
286 |
-
to_pil_image = transforms.ToPILImage(mode="RGB")
|
287 |
-
|
288 |
-
|
289 |
-
def save_image(image: torch.Tensor, path: pathlib.Path) -> None:
|
290 |
-
image = image.to(dtype=torch.float32).clamp(-1, 1)
|
291 |
-
image = to_pil_image(image.float())
|
292 |
-
image.save(path)
|
293 |
-
|
294 |
-
|
295 |
-
def save_video(video: torch.Tensor, path: pathlib.Path, fps: int = 8) -> None:
|
296 |
-
video = video.to(dtype=torch.float32).clamp(-1, 1)
|
297 |
-
video = [to_pil_image(frame) for frame in video]
|
298 |
-
export_to_video(video, path, fps=fps)
|
299 |
-
|
300 |
-
|
301 |
-
def save_prompt(prompt: str, path: pathlib.Path) -> None:
|
302 |
-
with open(path, "w", encoding="utf-8") as file:
|
303 |
-
file.write(prompt)
|
304 |
-
|
305 |
-
|
306 |
-
def save_metadata(metadata: Dict[str, Any], path: pathlib.Path) -> None:
|
307 |
-
with open(path, "w", encoding="utf-8") as file:
|
308 |
-
file.write(json.dumps(metadata))
|
309 |
-
|
310 |
-
|
311 |
-
@torch.no_grad()
|
312 |
-
def serialize_artifacts(
|
313 |
-
batch_size: int,
|
314 |
-
fps: int,
|
315 |
-
images_dir: Optional[pathlib.Path] = None,
|
316 |
-
image_latents_dir: Optional[pathlib.Path] = None,
|
317 |
-
videos_dir: Optional[pathlib.Path] = None,
|
318 |
-
video_latents_dir: Optional[pathlib.Path] = None,
|
319 |
-
prompts_dir: Optional[pathlib.Path] = None,
|
320 |
-
prompt_embeds_dir: Optional[pathlib.Path] = None,
|
321 |
-
images: Optional[torch.Tensor] = None,
|
322 |
-
image_latents: Optional[torch.Tensor] = None,
|
323 |
-
videos: Optional[torch.Tensor] = None,
|
324 |
-
video_latents: Optional[torch.Tensor] = None,
|
325 |
-
prompts: Optional[List[str]] = None,
|
326 |
-
prompt_embeds: Optional[torch.Tensor] = None,
|
327 |
-
) -> None:
|
328 |
-
num_frames, height, width = videos.size(1), videos.size(3), videos.size(4)
|
329 |
-
metadata = [{"num_frames": num_frames, "height": height, "width": width}]
|
330 |
-
|
331 |
-
data_folder_mapper_list = [
|
332 |
-
(images, images_dir, lambda img, path: save_image(img[0], path), "png"),
|
333 |
-
(image_latents, image_latents_dir, torch.save, "pt"),
|
334 |
-
(videos, videos_dir, functools.partial(save_video, fps=fps), "mp4"),
|
335 |
-
(video_latents, video_latents_dir, torch.save, "pt"),
|
336 |
-
(prompts, prompts_dir, save_prompt, "txt"),
|
337 |
-
(prompt_embeds, prompt_embeds_dir, torch.save, "pt"),
|
338 |
-
(metadata, videos_dir, save_metadata, "txt"),
|
339 |
-
]
|
340 |
-
filenames = [uuid.uuid4() for _ in range(batch_size)]
|
341 |
-
|
342 |
-
for data, folder, save_fn, extension in data_folder_mapper_list:
|
343 |
-
if data is None:
|
344 |
-
continue
|
345 |
-
for slice, filename in zip(data, filenames):
|
346 |
-
if isinstance(slice, torch.Tensor):
|
347 |
-
slice = slice.clone().to("cpu")
|
348 |
-
path = folder.joinpath(f"{filename}.{extension}")
|
349 |
-
save_fn(slice, path)
|
350 |
-
|
351 |
-
|
352 |
-
def save_intermediates(output_queue: queue.Queue) -> None:
|
353 |
-
while True:
|
354 |
-
try:
|
355 |
-
item = output_queue.get(timeout=30)
|
356 |
-
if item is None:
|
357 |
-
break
|
358 |
-
serialize_artifacts(**item)
|
359 |
-
|
360 |
-
except queue.Empty:
|
361 |
-
continue
|
362 |
-
|
363 |
-
|
364 |
-
@torch.no_grad()
|
365 |
-
def main():
|
366 |
-
args = get_args()
|
367 |
-
set_seed(args.seed)
|
368 |
-
|
369 |
-
output_dir = pathlib.Path(args.output_dir)
|
370 |
-
tmp_dir = output_dir.joinpath("tmp")
|
371 |
-
|
372 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
373 |
-
tmp_dir.mkdir(parents=True, exist_ok=True)
|
374 |
-
|
375 |
-
# Create task queue for non-blocking serializing of artifacts
|
376 |
-
output_queue = queue.Queue()
|
377 |
-
save_thread = ThreadPoolExecutor(max_workers=args.num_artifact_workers)
|
378 |
-
save_future = save_thread.submit(save_intermediates, output_queue)
|
379 |
-
|
380 |
-
# Initialize distributed processing
|
381 |
-
if "LOCAL_RANK" in os.environ:
|
382 |
-
local_rank = int(os.environ["LOCAL_RANK"])
|
383 |
-
torch.cuda.set_device(local_rank)
|
384 |
-
dist.init_process_group(backend="nccl")
|
385 |
-
world_size = dist.get_world_size()
|
386 |
-
rank = dist.get_rank()
|
387 |
-
else:
|
388 |
-
# Single GPU
|
389 |
-
local_rank = 0
|
390 |
-
world_size = 1
|
391 |
-
rank = 0
|
392 |
-
torch.cuda.set_device(rank)
|
393 |
-
|
394 |
-
# Create folders where intermediate tensors from each rank will be saved
|
395 |
-
images_dir = tmp_dir.joinpath(f"images/{rank}")
|
396 |
-
image_latents_dir = tmp_dir.joinpath(f"image_latents/{rank}")
|
397 |
-
videos_dir = tmp_dir.joinpath(f"videos/{rank}")
|
398 |
-
video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}")
|
399 |
-
prompts_dir = tmp_dir.joinpath(f"prompts/{rank}")
|
400 |
-
prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}")
|
401 |
-
|
402 |
-
images_dir.mkdir(parents=True, exist_ok=True)
|
403 |
-
image_latents_dir.mkdir(parents=True, exist_ok=True)
|
404 |
-
videos_dir.mkdir(parents=True, exist_ok=True)
|
405 |
-
video_latents_dir.mkdir(parents=True, exist_ok=True)
|
406 |
-
prompts_dir.mkdir(parents=True, exist_ok=True)
|
407 |
-
prompt_embeds_dir.mkdir(parents=True, exist_ok=True)
|
408 |
-
|
409 |
-
weight_dtype = DTYPE_MAPPING[args.dtype]
|
410 |
-
target_fps = args.target_fps
|
411 |
-
|
412 |
-
# 1. Dataset
|
413 |
-
dataset_init_kwargs = {
|
414 |
-
"data_root": args.data_root,
|
415 |
-
"dataset_file": args.dataset_file,
|
416 |
-
"caption_column": args.caption_column,
|
417 |
-
"video_column": args.video_column,
|
418 |
-
"max_num_frames": args.max_num_frames,
|
419 |
-
"id_token": args.id_token,
|
420 |
-
"height_buckets": args.height_buckets,
|
421 |
-
"width_buckets": args.width_buckets,
|
422 |
-
"frame_buckets": args.frame_buckets,
|
423 |
-
"load_tensors": False,
|
424 |
-
"random_flip": args.random_flip,
|
425 |
-
"image_to_video": args.save_image_latents,
|
426 |
-
}
|
427 |
-
if args.video_reshape_mode is None:
|
428 |
-
dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
|
429 |
-
else:
|
430 |
-
dataset = VideoDatasetWithResizeAndRectangleCrop(
|
431 |
-
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
|
432 |
-
)
|
433 |
-
|
434 |
-
original_dataset_size = len(dataset)
|
435 |
-
|
436 |
-
# Split data among GPUs
|
437 |
-
if world_size > 1:
|
438 |
-
samples_per_gpu = original_dataset_size // world_size
|
439 |
-
start_index = rank * samples_per_gpu
|
440 |
-
end_index = start_index + samples_per_gpu
|
441 |
-
if rank == world_size - 1:
|
442 |
-
end_index = original_dataset_size # Make sure the last GPU gets the remaining data
|
443 |
-
|
444 |
-
# Slice the data
|
445 |
-
dataset.prompts = dataset.prompts[start_index:end_index]
|
446 |
-
dataset.video_paths = dataset.video_paths[start_index:end_index]
|
447 |
-
else:
|
448 |
-
pass
|
449 |
-
|
450 |
-
rank_dataset_size = len(dataset)
|
451 |
-
|
452 |
-
# 2. Dataloader
|
453 |
-
def collate_fn(data):
|
454 |
-
prompts = [x["prompt"] for x in data[0]]
|
455 |
-
|
456 |
-
images = None
|
457 |
-
if args.save_image_latents:
|
458 |
-
images = [x["image"] for x in data[0]]
|
459 |
-
images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True)
|
460 |
-
|
461 |
-
videos = [x["video"] for x in data[0]]
|
462 |
-
videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True)
|
463 |
-
|
464 |
-
return {
|
465 |
-
"images": images,
|
466 |
-
"videos": videos,
|
467 |
-
"prompts": prompts,
|
468 |
-
}
|
469 |
-
|
470 |
-
dataloader = DataLoader(
|
471 |
-
dataset,
|
472 |
-
batch_size=1,
|
473 |
-
sampler=BucketSampler(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False),
|
474 |
-
collate_fn=collate_fn,
|
475 |
-
num_workers=args.dataloader_num_workers,
|
476 |
-
pin_memory=args.pin_memory,
|
477 |
-
)
|
478 |
-
|
479 |
-
# 3. Prepare models
|
480 |
-
device = f"cuda:{rank}"
|
481 |
-
|
482 |
-
if args.save_latents_and_embeddings:
|
483 |
-
tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer")
|
484 |
-
text_encoder = T5EncoderModel.from_pretrained(
|
485 |
-
args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype
|
486 |
-
)
|
487 |
-
text_encoder = text_encoder.to(device)
|
488 |
-
|
489 |
-
vae = AutoencoderKLCogVideoX.from_pretrained(args.model_id, subfolder="vae", torch_dtype=weight_dtype)
|
490 |
-
vae = vae.to(device)
|
491 |
-
|
492 |
-
if args.use_slicing:
|
493 |
-
vae.enable_slicing()
|
494 |
-
if args.use_tiling:
|
495 |
-
vae.enable_tiling()
|
496 |
-
|
497 |
-
# 4. Compute latents and embeddings and save
|
498 |
-
if rank == 0:
|
499 |
-
iterator = tqdm(
|
500 |
-
dataloader, desc="Encoding", total=(rank_dataset_size + args.batch_size - 1) // args.batch_size
|
501 |
-
)
|
502 |
-
else:
|
503 |
-
iterator = dataloader
|
504 |
-
|
505 |
-
for step, batch in enumerate(iterator):
|
506 |
-
try:
|
507 |
-
images = None
|
508 |
-
image_latents = None
|
509 |
-
video_latents = None
|
510 |
-
prompt_embeds = None
|
511 |
-
|
512 |
-
if args.save_image_latents:
|
513 |
-
images = batch["images"].to(device, non_blocking=True)
|
514 |
-
images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
515 |
-
|
516 |
-
videos = batch["videos"].to(device, non_blocking=True)
|
517 |
-
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
518 |
-
|
519 |
-
prompts = batch["prompts"]
|
520 |
-
|
521 |
-
# Encode videos & images
|
522 |
-
if args.save_latents_and_embeddings:
|
523 |
-
if args.use_slicing:
|
524 |
-
if args.save_image_latents:
|
525 |
-
encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)]
|
526 |
-
image_latents = torch.cat(encoded_slices)
|
527 |
-
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
528 |
-
|
529 |
-
encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)]
|
530 |
-
video_latents = torch.cat(encoded_slices)
|
531 |
-
|
532 |
-
else:
|
533 |
-
if args.save_image_latents:
|
534 |
-
image_latents = vae._encode(images)
|
535 |
-
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
536 |
-
|
537 |
-
video_latents = vae._encode(videos)
|
538 |
-
|
539 |
-
video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
|
540 |
-
|
541 |
-
# Encode prompts
|
542 |
-
prompt_embeds = compute_prompt_embeddings(
|
543 |
-
tokenizer,
|
544 |
-
text_encoder,
|
545 |
-
prompts,
|
546 |
-
args.max_sequence_length,
|
547 |
-
device,
|
548 |
-
weight_dtype,
|
549 |
-
requires_grad=False,
|
550 |
-
)
|
551 |
-
|
552 |
-
if images is not None:
|
553 |
-
images = (images.permute(0, 2, 1, 3, 4) + 1) / 2
|
554 |
-
|
555 |
-
videos = (videos.permute(0, 2, 1, 3, 4) + 1) / 2
|
556 |
-
|
557 |
-
output_queue.put(
|
558 |
-
{
|
559 |
-
"batch_size": len(prompts),
|
560 |
-
"fps": target_fps,
|
561 |
-
"images_dir": images_dir,
|
562 |
-
"image_latents_dir": image_latents_dir,
|
563 |
-
"videos_dir": videos_dir,
|
564 |
-
"video_latents_dir": video_latents_dir,
|
565 |
-
"prompts_dir": prompts_dir,
|
566 |
-
"prompt_embeds_dir": prompt_embeds_dir,
|
567 |
-
"images": images,
|
568 |
-
"image_latents": image_latents,
|
569 |
-
"videos": videos,
|
570 |
-
"video_latents": video_latents,
|
571 |
-
"prompts": prompts,
|
572 |
-
"prompt_embeds": prompt_embeds,
|
573 |
-
}
|
574 |
-
)
|
575 |
-
|
576 |
-
except Exception:
|
577 |
-
print("-------------------------")
|
578 |
-
print(f"An exception occurred while processing data: {rank=}, {world_size=}, {step=}")
|
579 |
-
traceback.print_exc()
|
580 |
-
print("-------------------------")
|
581 |
-
|
582 |
-
# 5. Complete distributed processing
|
583 |
-
if world_size > 1:
|
584 |
-
dist.barrier()
|
585 |
-
dist.destroy_process_group()
|
586 |
-
|
587 |
-
output_queue.put(None)
|
588 |
-
save_thread.shutdown(wait=True)
|
589 |
-
save_future.result()
|
590 |
-
|
591 |
-
# 6. Combine results from each rank
|
592 |
-
if rank == 0:
|
593 |
-
print(
|
594 |
-
f"Completed preprocessing latents and embeddings. Temporary files from all ranks saved to `{tmp_dir.as_posix()}`"
|
595 |
-
)
|
596 |
-
|
597 |
-
# Move files from each rank to common directory
|
598 |
-
for subfolder, extension in [
|
599 |
-
("images", "png"),
|
600 |
-
("image_latents", "pt"),
|
601 |
-
("videos", "mp4"),
|
602 |
-
("video_latents", "pt"),
|
603 |
-
("prompts", "txt"),
|
604 |
-
("prompt_embeds", "pt"),
|
605 |
-
("videos", "txt"),
|
606 |
-
]:
|
607 |
-
tmp_subfolder = tmp_dir.joinpath(subfolder)
|
608 |
-
combined_subfolder = output_dir.joinpath(subfolder)
|
609 |
-
combined_subfolder.mkdir(parents=True, exist_ok=True)
|
610 |
-
pattern = f"*.{extension}"
|
611 |
-
|
612 |
-
for file in tmp_subfolder.rglob(pattern):
|
613 |
-
file.replace(combined_subfolder / file.name)
|
614 |
-
|
615 |
-
# Remove temporary directories
|
616 |
-
def rmdir_recursive(dir: pathlib.Path) -> None:
|
617 |
-
for child in dir.iterdir():
|
618 |
-
if child.is_file():
|
619 |
-
child.unlink()
|
620 |
-
else:
|
621 |
-
rmdir_recursive(child)
|
622 |
-
dir.rmdir()
|
623 |
-
|
624 |
-
rmdir_recursive(tmp_dir)
|
625 |
-
|
626 |
-
# Combine prompts and videos into individual text files and single jsonl
|
627 |
-
prompts_folder = output_dir.joinpath("prompts")
|
628 |
-
prompts = []
|
629 |
-
stems = []
|
630 |
-
|
631 |
-
for filename in prompts_folder.rglob("*.txt"):
|
632 |
-
with open(filename, "r") as file:
|
633 |
-
prompts.append(file.read().strip())
|
634 |
-
stems.append(filename.stem)
|
635 |
-
|
636 |
-
prompts_txt = output_dir.joinpath("prompts.txt")
|
637 |
-
videos_txt = output_dir.joinpath("videos.txt")
|
638 |
-
data_jsonl = output_dir.joinpath("data.jsonl")
|
639 |
-
|
640 |
-
with open(prompts_txt, "w") as file:
|
641 |
-
for prompt in prompts:
|
642 |
-
file.write(f"{prompt}\n")
|
643 |
-
|
644 |
-
with open(videos_txt, "w") as file:
|
645 |
-
for stem in stems:
|
646 |
-
file.write(f"videos/{stem}.mp4\n")
|
647 |
-
|
648 |
-
with open(data_jsonl, "w") as file:
|
649 |
-
for prompt, stem in zip(prompts, stems):
|
650 |
-
video_metadata_txt = output_dir.joinpath(f"videos/{stem}.txt")
|
651 |
-
with open(video_metadata_txt, "r", encoding="utf-8") as metadata_file:
|
652 |
-
metadata = json.loads(metadata_file.read())
|
653 |
-
|
654 |
-
data = {
|
655 |
-
"prompt": prompt,
|
656 |
-
"prompt_embed": f"prompt_embeds/{stem}.pt",
|
657 |
-
"image": f"images/{stem}.png",
|
658 |
-
"image_latent": f"image_latents/{stem}.pt",
|
659 |
-
"video": f"videos/{stem}.mp4",
|
660 |
-
"video_latent": f"video_latents/{stem}.pt",
|
661 |
-
"metadata": metadata,
|
662 |
-
}
|
663 |
-
file.write(json.dumps(data) + "\n")
|
664 |
-
|
665 |
-
print(f"Completed preprocessing. All files saved to `{output_dir.as_posix()}`")
|
666 |
-
|
667 |
-
|
668 |
-
if __name__ == "__main__":
|
669 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/cogvideox/text_encoder/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .text_encoder import compute_prompt_embeddings
|
|
|
|
training/cogvideox/text_encoder/text_encoder.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
from typing import List, Optional, Union
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from transformers import T5EncoderModel, T5Tokenizer
|
5 |
-
|
6 |
-
|
7 |
-
def _get_t5_prompt_embeds(
|
8 |
-
tokenizer: T5Tokenizer,
|
9 |
-
text_encoder: T5EncoderModel,
|
10 |
-
prompt: Union[str, List[str]],
|
11 |
-
num_videos_per_prompt: int = 1,
|
12 |
-
max_sequence_length: int = 226,
|
13 |
-
device: Optional[torch.device] = None,
|
14 |
-
dtype: Optional[torch.dtype] = None,
|
15 |
-
text_input_ids=None,
|
16 |
-
):
|
17 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
18 |
-
batch_size = len(prompt)
|
19 |
-
|
20 |
-
if tokenizer is not None:
|
21 |
-
text_inputs = tokenizer(
|
22 |
-
prompt,
|
23 |
-
padding="max_length",
|
24 |
-
max_length=max_sequence_length,
|
25 |
-
truncation=True,
|
26 |
-
add_special_tokens=True,
|
27 |
-
return_tensors="pt",
|
28 |
-
)
|
29 |
-
text_input_ids = text_inputs.input_ids
|
30 |
-
else:
|
31 |
-
if text_input_ids is None:
|
32 |
-
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
|
33 |
-
|
34 |
-
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
35 |
-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
36 |
-
|
37 |
-
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
38 |
-
_, seq_len, _ = prompt_embeds.shape
|
39 |
-
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
40 |
-
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
41 |
-
|
42 |
-
return prompt_embeds
|
43 |
-
|
44 |
-
|
45 |
-
def encode_prompt(
|
46 |
-
tokenizer: T5Tokenizer,
|
47 |
-
text_encoder: T5EncoderModel,
|
48 |
-
prompt: Union[str, List[str]],
|
49 |
-
num_videos_per_prompt: int = 1,
|
50 |
-
max_sequence_length: int = 226,
|
51 |
-
device: Optional[torch.device] = None,
|
52 |
-
dtype: Optional[torch.dtype] = None,
|
53 |
-
text_input_ids=None,
|
54 |
-
):
|
55 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
56 |
-
prompt_embeds = _get_t5_prompt_embeds(
|
57 |
-
tokenizer,
|
58 |
-
text_encoder,
|
59 |
-
prompt=prompt,
|
60 |
-
num_videos_per_prompt=num_videos_per_prompt,
|
61 |
-
max_sequence_length=max_sequence_length,
|
62 |
-
device=device,
|
63 |
-
dtype=dtype,
|
64 |
-
text_input_ids=text_input_ids,
|
65 |
-
)
|
66 |
-
return prompt_embeds
|
67 |
-
|
68 |
-
|
69 |
-
def compute_prompt_embeddings(
|
70 |
-
tokenizer: T5Tokenizer,
|
71 |
-
text_encoder: T5EncoderModel,
|
72 |
-
prompt: str,
|
73 |
-
max_sequence_length: int,
|
74 |
-
device: torch.device,
|
75 |
-
dtype: torch.dtype,
|
76 |
-
requires_grad: bool = False,
|
77 |
-
):
|
78 |
-
if requires_grad:
|
79 |
-
prompt_embeds = encode_prompt(
|
80 |
-
tokenizer,
|
81 |
-
text_encoder,
|
82 |
-
prompt,
|
83 |
-
num_videos_per_prompt=1,
|
84 |
-
max_sequence_length=max_sequence_length,
|
85 |
-
device=device,
|
86 |
-
dtype=dtype,
|
87 |
-
)
|
88 |
-
else:
|
89 |
-
with torch.no_grad():
|
90 |
-
prompt_embeds = encode_prompt(
|
91 |
-
tokenizer,
|
92 |
-
text_encoder,
|
93 |
-
prompt,
|
94 |
-
num_videos_per_prompt=1,
|
95 |
-
max_sequence_length=max_sequence_length,
|
96 |
-
device=device,
|
97 |
-
dtype=dtype,
|
98 |
-
)
|
99 |
-
return prompt_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/cogvideox/utils.py
DELETED
@@ -1,260 +0,0 @@
|
|
1 |
-
import gc
|
2 |
-
import inspect
|
3 |
-
from typing import Optional, Tuple, Union
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from accelerate import Accelerator
|
7 |
-
from accelerate.logging import get_logger
|
8 |
-
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
9 |
-
from diffusers.utils.torch_utils import is_compiled_module
|
10 |
-
|
11 |
-
|
12 |
-
logger = get_logger(__name__)
|
13 |
-
|
14 |
-
|
15 |
-
def get_optimizer(
|
16 |
-
params_to_optimize,
|
17 |
-
optimizer_name: str = "adam",
|
18 |
-
learning_rate: float = 1e-3,
|
19 |
-
beta1: float = 0.9,
|
20 |
-
beta2: float = 0.95,
|
21 |
-
beta3: float = 0.98,
|
22 |
-
epsilon: float = 1e-8,
|
23 |
-
weight_decay: float = 1e-4,
|
24 |
-
prodigy_decouple: bool = False,
|
25 |
-
prodigy_use_bias_correction: bool = False,
|
26 |
-
prodigy_safeguard_warmup: bool = False,
|
27 |
-
use_8bit: bool = False,
|
28 |
-
use_4bit: bool = False,
|
29 |
-
use_torchao: bool = False,
|
30 |
-
use_deepspeed: bool = False,
|
31 |
-
use_cpu_offload_optimizer: bool = False,
|
32 |
-
offload_gradients: bool = False,
|
33 |
-
) -> torch.optim.Optimizer:
|
34 |
-
optimizer_name = optimizer_name.lower()
|
35 |
-
|
36 |
-
# Use DeepSpeed optimzer
|
37 |
-
if use_deepspeed:
|
38 |
-
from accelerate.utils import DummyOptim
|
39 |
-
|
40 |
-
return DummyOptim(
|
41 |
-
params_to_optimize,
|
42 |
-
lr=learning_rate,
|
43 |
-
betas=(beta1, beta2),
|
44 |
-
eps=epsilon,
|
45 |
-
weight_decay=weight_decay,
|
46 |
-
)
|
47 |
-
|
48 |
-
if use_8bit and use_4bit:
|
49 |
-
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
|
50 |
-
|
51 |
-
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
|
52 |
-
try:
|
53 |
-
import torchao
|
54 |
-
|
55 |
-
torchao.__version__
|
56 |
-
except ImportError:
|
57 |
-
raise ImportError(
|
58 |
-
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
|
59 |
-
)
|
60 |
-
|
61 |
-
if not use_torchao and use_4bit:
|
62 |
-
raise ValueError("4-bit Optimizers are only supported with torchao.")
|
63 |
-
|
64 |
-
# Optimizer creation
|
65 |
-
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
|
66 |
-
if optimizer_name not in supported_optimizers:
|
67 |
-
logger.warning(
|
68 |
-
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
|
69 |
-
)
|
70 |
-
optimizer_name = "adamw"
|
71 |
-
|
72 |
-
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
|
73 |
-
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
|
74 |
-
|
75 |
-
if use_8bit:
|
76 |
-
try:
|
77 |
-
import bitsandbytes as bnb
|
78 |
-
except ImportError:
|
79 |
-
raise ImportError(
|
80 |
-
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
81 |
-
)
|
82 |
-
|
83 |
-
if optimizer_name == "adamw":
|
84 |
-
if use_torchao:
|
85 |
-
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
86 |
-
|
87 |
-
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
88 |
-
else:
|
89 |
-
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
|
90 |
-
|
91 |
-
init_kwargs = {
|
92 |
-
"betas": (beta1, beta2),
|
93 |
-
"eps": epsilon,
|
94 |
-
"weight_decay": weight_decay,
|
95 |
-
}
|
96 |
-
|
97 |
-
elif optimizer_name == "adam":
|
98 |
-
if use_torchao:
|
99 |
-
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
|
100 |
-
|
101 |
-
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
|
102 |
-
else:
|
103 |
-
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
|
104 |
-
|
105 |
-
init_kwargs = {
|
106 |
-
"betas": (beta1, beta2),
|
107 |
-
"eps": epsilon,
|
108 |
-
"weight_decay": weight_decay,
|
109 |
-
}
|
110 |
-
|
111 |
-
elif optimizer_name == "prodigy":
|
112 |
-
try:
|
113 |
-
import prodigyopt
|
114 |
-
except ImportError:
|
115 |
-
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
116 |
-
|
117 |
-
optimizer_class = prodigyopt.Prodigy
|
118 |
-
|
119 |
-
if learning_rate <= 0.1:
|
120 |
-
logger.warning(
|
121 |
-
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
122 |
-
)
|
123 |
-
|
124 |
-
init_kwargs = {
|
125 |
-
"lr": learning_rate,
|
126 |
-
"betas": (beta1, beta2),
|
127 |
-
"beta3": beta3,
|
128 |
-
"eps": epsilon,
|
129 |
-
"weight_decay": weight_decay,
|
130 |
-
"decouple": prodigy_decouple,
|
131 |
-
"use_bias_correction": prodigy_use_bias_correction,
|
132 |
-
"safeguard_warmup": prodigy_safeguard_warmup,
|
133 |
-
}
|
134 |
-
|
135 |
-
elif optimizer_name == "came":
|
136 |
-
try:
|
137 |
-
import came_pytorch
|
138 |
-
except ImportError:
|
139 |
-
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
|
140 |
-
|
141 |
-
optimizer_class = came_pytorch.CAME
|
142 |
-
|
143 |
-
init_kwargs = {
|
144 |
-
"lr": learning_rate,
|
145 |
-
"eps": (1e-30, 1e-16),
|
146 |
-
"betas": (beta1, beta2, beta3),
|
147 |
-
"weight_decay": weight_decay,
|
148 |
-
}
|
149 |
-
|
150 |
-
if use_cpu_offload_optimizer:
|
151 |
-
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
|
152 |
-
|
153 |
-
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
|
154 |
-
init_kwargs.update({"fused": True})
|
155 |
-
|
156 |
-
optimizer = CPUOffloadOptimizer(
|
157 |
-
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
|
158 |
-
)
|
159 |
-
else:
|
160 |
-
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
|
161 |
-
|
162 |
-
return optimizer
|
163 |
-
|
164 |
-
|
165 |
-
def get_gradient_norm(parameters):
|
166 |
-
norm = 0
|
167 |
-
for param in parameters:
|
168 |
-
if param.grad is None:
|
169 |
-
continue
|
170 |
-
local_norm = param.grad.detach().data.norm(2)
|
171 |
-
norm += local_norm.item() ** 2
|
172 |
-
norm = norm**0.5
|
173 |
-
return norm
|
174 |
-
|
175 |
-
|
176 |
-
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
177 |
-
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
178 |
-
tw = tgt_width
|
179 |
-
th = tgt_height
|
180 |
-
h, w = src
|
181 |
-
r = h / w
|
182 |
-
if r > (th / tw):
|
183 |
-
resize_height = th
|
184 |
-
resize_width = int(round(th / h * w))
|
185 |
-
else:
|
186 |
-
resize_width = tw
|
187 |
-
resize_height = int(round(tw / w * h))
|
188 |
-
|
189 |
-
crop_top = int(round((th - resize_height) / 2.0))
|
190 |
-
crop_left = int(round((tw - resize_width) / 2.0))
|
191 |
-
|
192 |
-
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
193 |
-
|
194 |
-
|
195 |
-
def prepare_rotary_positional_embeddings(
|
196 |
-
height: int,
|
197 |
-
width: int,
|
198 |
-
num_frames: int,
|
199 |
-
vae_scale_factor_spatial: int = 8,
|
200 |
-
patch_size: int = 2,
|
201 |
-
patch_size_t: int = None,
|
202 |
-
attention_head_dim: int = 64,
|
203 |
-
device: Optional[torch.device] = None,
|
204 |
-
base_height: int = 480,
|
205 |
-
base_width: int = 720,
|
206 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
207 |
-
grid_height = height // (vae_scale_factor_spatial * patch_size)
|
208 |
-
grid_width = width // (vae_scale_factor_spatial * patch_size)
|
209 |
-
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
|
210 |
-
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
|
211 |
-
|
212 |
-
if patch_size_t is None:
|
213 |
-
# CogVideoX 1.0
|
214 |
-
grid_crops_coords = get_resize_crop_region_for_grid(
|
215 |
-
(grid_height, grid_width), base_size_width, base_size_height
|
216 |
-
)
|
217 |
-
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
218 |
-
embed_dim=attention_head_dim,
|
219 |
-
crops_coords=grid_crops_coords,
|
220 |
-
grid_size=(grid_height, grid_width),
|
221 |
-
temporal_size=num_frames,
|
222 |
-
)
|
223 |
-
else:
|
224 |
-
# CogVideoX 1.5
|
225 |
-
base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
|
226 |
-
|
227 |
-
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
228 |
-
embed_dim=attention_head_dim,
|
229 |
-
crops_coords=None,
|
230 |
-
grid_size=(grid_height, grid_width),
|
231 |
-
temporal_size=base_num_frames,
|
232 |
-
grid_type="slice",
|
233 |
-
max_size=(base_size_height, base_size_width),
|
234 |
-
)
|
235 |
-
|
236 |
-
freqs_cos = freqs_cos.to(device=device)
|
237 |
-
freqs_sin = freqs_sin.to(device=device)
|
238 |
-
return freqs_cos, freqs_sin
|
239 |
-
|
240 |
-
|
241 |
-
def reset_memory(device: Union[str, torch.device]) -> None:
|
242 |
-
gc.collect()
|
243 |
-
torch.cuda.empty_cache()
|
244 |
-
torch.cuda.reset_peak_memory_stats(device)
|
245 |
-
torch.cuda.reset_accumulated_memory_stats(device)
|
246 |
-
|
247 |
-
|
248 |
-
def print_memory(device: Union[str, torch.device]) -> None:
|
249 |
-
memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
|
250 |
-
max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
|
251 |
-
max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
|
252 |
-
print(f"{memory_allocated=:.3f} GB")
|
253 |
-
print(f"{max_memory_allocated=:.3f} GB")
|
254 |
-
print(f"{max_memory_reserved=:.3f} GB")
|
255 |
-
|
256 |
-
|
257 |
-
def unwrap_model(accelerator: Accelerator, model):
|
258 |
-
model = accelerator.unwrap_model(model)
|
259 |
-
model = model._orig_mod if is_compiled_module(model) else model
|
260 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/mochi-1/README.md
DELETED
@@ -1,111 +0,0 @@
|
|
1 |
-
# Simple Mochi-1 finetuner
|
2 |
-
|
3 |
-
<table align=center>
|
4 |
-
<tr>
|
5 |
-
<th align=center> Dataset Sample </th>
|
6 |
-
<th align=center> Test Sample </th>
|
7 |
-
</tr>
|
8 |
-
<tr>
|
9 |
-
<td align=center><video src="https://github.com/user-attachments/assets/6f906a32-b169-493f-a713-07679e87cd91"> Your browser does not support the video tag. </video></td>
|
10 |
-
<td align=center><video src="https://github.com/user-attachments/assets/d356e70f-ccf4-47f7-be1d-8d21108d8a84"> Your browser does not support the video tag. </video></td>
|
11 |
-
</tr>
|
12 |
-
</table>
|
13 |
-
|
14 |
-
Now you can make Mochi-1 your own with `diffusers`, too 🤗 🧨
|
15 |
-
|
16 |
-
We provide a minimal and faithful reimplementation of the [Mochi-1 original fine-tuner](https://github.com/genmoai/mochi/tree/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner). As usual, we leverage `peft` for things LoRA in our implementation.
|
17 |
-
|
18 |
-
**Updates**
|
19 |
-
|
20 |
-
December 1 2024: Support for checkpoint saving and loading.
|
21 |
-
|
22 |
-
## Getting started
|
23 |
-
|
24 |
-
Install the dependencies: `pip install -r requirements.txt`. Also make sure your `diffusers` installation is from the current `main`.
|
25 |
-
|
26 |
-
Download a demo dataset:
|
27 |
-
|
28 |
-
```bash
|
29 |
-
huggingface-cli download \
|
30 |
-
--repo-type dataset sayakpaul/video-dataset-disney-organized \
|
31 |
-
--local-dir video-dataset-disney-organized
|
32 |
-
```
|
33 |
-
|
34 |
-
The dataset follows the directory structure expected by the subsequent scripts. In particular, it follows what's prescribed [here](https://github.com/genmoai/mochi/tree/main/demos/fine_tuner#1-collect-your-videos-and-captions):
|
35 |
-
|
36 |
-
```bash
|
37 |
-
video_1.mp4
|
38 |
-
video_1.txt -- One-paragraph description of video_1
|
39 |
-
video_2.mp4
|
40 |
-
video_2.txt -- One-paragraph description of video_2
|
41 |
-
...
|
42 |
-
```
|
43 |
-
|
44 |
-
Then run (be sure to check the paths accordingly):
|
45 |
-
|
46 |
-
```bash
|
47 |
-
bash prepare_dataset.sh
|
48 |
-
```
|
49 |
-
|
50 |
-
We can adjust `num_frames` and `resolution`. By default, in `prepare_dataset.sh`, we use `--force_upsample`. This means if the original video resolution is smaller than the requested resolution, we will upsample the video.
|
51 |
-
|
52 |
-
> [!IMPORTANT]
|
53 |
-
> It's important to have a resolution of at least 480x848 to satisy Mochi-1's requirements.
|
54 |
-
|
55 |
-
Now, we're ready to fine-tune. To launch, run:
|
56 |
-
|
57 |
-
```bash
|
58 |
-
bash train.sh
|
59 |
-
```
|
60 |
-
|
61 |
-
You can disable intermediate validation by:
|
62 |
-
|
63 |
-
```diff
|
64 |
-
- --validation_prompt "..." \
|
65 |
-
- --validation_prompt_separator ::: \
|
66 |
-
- --num_validation_videos 1 \
|
67 |
-
- --validation_epochs 1 \
|
68 |
-
```
|
69 |
-
|
70 |
-
We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM.
|
71 |
-
|
72 |
-
To use the LoRA checkpoint:
|
73 |
-
|
74 |
-
```py
|
75 |
-
from diffusers import MochiPipeline
|
76 |
-
from diffusers.utils import export_to_video
|
77 |
-
import torch
|
78 |
-
|
79 |
-
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
|
80 |
-
pipe.load_lora_weights("path-to-lora")
|
81 |
-
pipe.enable_model_cpu_offload()
|
82 |
-
|
83 |
-
pipeline_args = {
|
84 |
-
"prompt": "A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions",
|
85 |
-
"guidance_scale": 6.0,
|
86 |
-
"num_inference_steps": 64,
|
87 |
-
"height": 480,
|
88 |
-
"width": 848,
|
89 |
-
"max_sequence_length": 256,
|
90 |
-
"output_type": "np",
|
91 |
-
}
|
92 |
-
|
93 |
-
with torch.autocast("cuda", torch.bfloat16)
|
94 |
-
video = pipe(**pipeline_args).frames[0]
|
95 |
-
export_to_video(video)
|
96 |
-
```
|
97 |
-
|
98 |
-
## Known limitations
|
99 |
-
|
100 |
-
(Contributions are welcome 🤗)
|
101 |
-
|
102 |
-
Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below:
|
103 |
-
|
104 |
-
* No support for distributed training.
|
105 |
-
* `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support.
|
106 |
-
* No support for 8bit optimizers (but should be relatively easy to add).
|
107 |
-
|
108 |
-
**Misc**:
|
109 |
-
|
110 |
-
* We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033).
|
111 |
-
* `embed.py` script is non-batched.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/mochi-1/args.py
DELETED
@@ -1,268 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Default values taken from
|
3 |
-
https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml
|
4 |
-
when applicable.
|
5 |
-
"""
|
6 |
-
|
7 |
-
import argparse
|
8 |
-
|
9 |
-
|
10 |
-
def _get_model_args(parser: argparse.ArgumentParser) -> None:
|
11 |
-
parser.add_argument(
|
12 |
-
"--pretrained_model_name_or_path",
|
13 |
-
type=str,
|
14 |
-
default=None,
|
15 |
-
required=True,
|
16 |
-
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
17 |
-
)
|
18 |
-
parser.add_argument(
|
19 |
-
"--revision",
|
20 |
-
type=str,
|
21 |
-
default=None,
|
22 |
-
required=False,
|
23 |
-
help="Revision of pretrained model identifier from huggingface.co/models.",
|
24 |
-
)
|
25 |
-
parser.add_argument(
|
26 |
-
"--variant",
|
27 |
-
type=str,
|
28 |
-
default=None,
|
29 |
-
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
30 |
-
)
|
31 |
-
parser.add_argument(
|
32 |
-
"--cache_dir",
|
33 |
-
type=str,
|
34 |
-
default=None,
|
35 |
-
help="The directory where the downloaded models and datasets will be stored.",
|
36 |
-
)
|
37 |
-
parser.add_argument(
|
38 |
-
"--cast_dit",
|
39 |
-
action="store_true",
|
40 |
-
help="If we should cast DiT params to a lower precision.",
|
41 |
-
)
|
42 |
-
parser.add_argument(
|
43 |
-
"--compile_dit",
|
44 |
-
action="store_true",
|
45 |
-
help="If we should compile the DiT.",
|
46 |
-
)
|
47 |
-
|
48 |
-
|
49 |
-
def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
|
50 |
-
parser.add_argument(
|
51 |
-
"--data_root",
|
52 |
-
type=str,
|
53 |
-
default=None,
|
54 |
-
help=("A folder containing the training data."),
|
55 |
-
)
|
56 |
-
parser.add_argument(
|
57 |
-
"--caption_dropout",
|
58 |
-
type=float,
|
59 |
-
default=None,
|
60 |
-
help=("Probability to drop out captions randomly."),
|
61 |
-
)
|
62 |
-
|
63 |
-
parser.add_argument(
|
64 |
-
"--dataloader_num_workers",
|
65 |
-
type=int,
|
66 |
-
default=0,
|
67 |
-
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
68 |
-
)
|
69 |
-
parser.add_argument(
|
70 |
-
"--pin_memory",
|
71 |
-
action="store_true",
|
72 |
-
help="Whether or not to use the pinned memory setting in pytorch dataloader.",
|
73 |
-
)
|
74 |
-
|
75 |
-
|
76 |
-
def _get_validation_args(parser: argparse.ArgumentParser) -> None:
|
77 |
-
parser.add_argument(
|
78 |
-
"--validation_prompt",
|
79 |
-
type=str,
|
80 |
-
default=None,
|
81 |
-
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
82 |
-
)
|
83 |
-
parser.add_argument(
|
84 |
-
"--validation_images",
|
85 |
-
type=str,
|
86 |
-
default=None,
|
87 |
-
help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
|
88 |
-
)
|
89 |
-
parser.add_argument(
|
90 |
-
"--validation_prompt_separator",
|
91 |
-
type=str,
|
92 |
-
default=":::",
|
93 |
-
help="String that separates multiple validation prompts",
|
94 |
-
)
|
95 |
-
parser.add_argument(
|
96 |
-
"--num_validation_videos",
|
97 |
-
type=int,
|
98 |
-
default=1,
|
99 |
-
help="Number of videos that should be generated during validation per `validation_prompt`.",
|
100 |
-
)
|
101 |
-
parser.add_argument(
|
102 |
-
"--validation_epochs",
|
103 |
-
type=int,
|
104 |
-
default=50,
|
105 |
-
help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
|
106 |
-
)
|
107 |
-
parser.add_argument(
|
108 |
-
"--enable_slicing",
|
109 |
-
action="store_true",
|
110 |
-
default=False,
|
111 |
-
help="Whether or not to use VAE slicing for saving memory.",
|
112 |
-
)
|
113 |
-
parser.add_argument(
|
114 |
-
"--enable_tiling",
|
115 |
-
action="store_true",
|
116 |
-
default=False,
|
117 |
-
help="Whether or not to use VAE tiling for saving memory.",
|
118 |
-
)
|
119 |
-
parser.add_argument(
|
120 |
-
"--enable_model_cpu_offload",
|
121 |
-
action="store_true",
|
122 |
-
default=False,
|
123 |
-
help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
|
124 |
-
)
|
125 |
-
parser.add_argument(
|
126 |
-
"--fps",
|
127 |
-
type=int,
|
128 |
-
default=30,
|
129 |
-
help="FPS to use when serializing the output videos.",
|
130 |
-
)
|
131 |
-
parser.add_argument(
|
132 |
-
"--height",
|
133 |
-
type=int,
|
134 |
-
default=480,
|
135 |
-
)
|
136 |
-
parser.add_argument(
|
137 |
-
"--width",
|
138 |
-
type=int,
|
139 |
-
default=848,
|
140 |
-
)
|
141 |
-
|
142 |
-
|
143 |
-
def _get_training_args(parser: argparse.ArgumentParser) -> None:
|
144 |
-
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
145 |
-
parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.")
|
146 |
-
parser.add_argument(
|
147 |
-
"--lora_alpha",
|
148 |
-
type=int,
|
149 |
-
default=16,
|
150 |
-
help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
|
151 |
-
)
|
152 |
-
parser.add_argument(
|
153 |
-
"--target_modules",
|
154 |
-
nargs="+",
|
155 |
-
type=str,
|
156 |
-
default=["to_k", "to_q", "to_v", "to_out.0"],
|
157 |
-
help="Target modules to train LoRA for.",
|
158 |
-
)
|
159 |
-
parser.add_argument(
|
160 |
-
"--output_dir",
|
161 |
-
type=str,
|
162 |
-
default="mochi-lora",
|
163 |
-
help="The output directory where the model predictions and checkpoints will be written.",
|
164 |
-
)
|
165 |
-
parser.add_argument(
|
166 |
-
"--train_batch_size",
|
167 |
-
type=int,
|
168 |
-
default=4,
|
169 |
-
help="Batch size (per device) for the training dataloader.",
|
170 |
-
)
|
171 |
-
parser.add_argument("--num_train_epochs", type=int, default=1)
|
172 |
-
parser.add_argument(
|
173 |
-
"--max_train_steps",
|
174 |
-
type=int,
|
175 |
-
default=None,
|
176 |
-
help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
|
177 |
-
)
|
178 |
-
parser.add_argument(
|
179 |
-
"--gradient_checkpointing",
|
180 |
-
action="store_true",
|
181 |
-
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
182 |
-
)
|
183 |
-
parser.add_argument(
|
184 |
-
"--learning_rate",
|
185 |
-
type=float,
|
186 |
-
default=2e-4,
|
187 |
-
help="Initial learning rate (after the potential warmup period) to use.",
|
188 |
-
)
|
189 |
-
parser.add_argument(
|
190 |
-
"--scale_lr",
|
191 |
-
action="store_true",
|
192 |
-
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
193 |
-
)
|
194 |
-
parser.add_argument(
|
195 |
-
"--lr_warmup_steps",
|
196 |
-
type=int,
|
197 |
-
default=200,
|
198 |
-
help="Number of steps for the warmup in the lr scheduler.",
|
199 |
-
)
|
200 |
-
parser.add_argument(
|
201 |
-
"--checkpointing_steps",
|
202 |
-
type=int,
|
203 |
-
default=None,
|
204 |
-
)
|
205 |
-
parser.add_argument(
|
206 |
-
"--resume_from_checkpoint",
|
207 |
-
type=str,
|
208 |
-
default=None,
|
209 |
-
)
|
210 |
-
|
211 |
-
|
212 |
-
def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
|
213 |
-
parser.add_argument(
|
214 |
-
"--optimizer",
|
215 |
-
type=lambda s: s.lower(),
|
216 |
-
default="adam",
|
217 |
-
choices=["adam", "adamw"],
|
218 |
-
help=("The optimizer type to use."),
|
219 |
-
)
|
220 |
-
parser.add_argument(
|
221 |
-
"--weight_decay",
|
222 |
-
type=float,
|
223 |
-
default=0.01,
|
224 |
-
help="Weight decay to use for optimizer.",
|
225 |
-
)
|
226 |
-
|
227 |
-
|
228 |
-
def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
|
229 |
-
parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
|
230 |
-
parser.add_argument(
|
231 |
-
"--push_to_hub",
|
232 |
-
action="store_true",
|
233 |
-
help="Whether or not to push the model to the Hub.",
|
234 |
-
)
|
235 |
-
parser.add_argument(
|
236 |
-
"--hub_token",
|
237 |
-
type=str,
|
238 |
-
default=None,
|
239 |
-
help="The token to use to push to the Model Hub.",
|
240 |
-
)
|
241 |
-
parser.add_argument(
|
242 |
-
"--hub_model_id",
|
243 |
-
type=str,
|
244 |
-
default=None,
|
245 |
-
help="The name of the repository to keep in sync with the local `output_dir`.",
|
246 |
-
)
|
247 |
-
parser.add_argument(
|
248 |
-
"--allow_tf32",
|
249 |
-
action="store_true",
|
250 |
-
help=(
|
251 |
-
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
252 |
-
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
253 |
-
),
|
254 |
-
)
|
255 |
-
parser.add_argument("--report_to", type=str, default=None, help="If logging to wandb.")
|
256 |
-
|
257 |
-
|
258 |
-
def get_args():
|
259 |
-
parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.")
|
260 |
-
|
261 |
-
_get_model_args(parser)
|
262 |
-
_get_dataset_args(parser)
|
263 |
-
_get_training_args(parser)
|
264 |
-
_get_validation_args(parser)
|
265 |
-
_get_optimizer_args(parser)
|
266 |
-
_get_configuration_args(parser)
|
267 |
-
|
268 |
-
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/mochi-1/dataset_simple.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Taken from
|
3 |
-
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py
|
4 |
-
"""
|
5 |
-
|
6 |
-
from pathlib import Path
|
7 |
-
|
8 |
-
import click
|
9 |
-
import torch
|
10 |
-
from torch.utils.data import DataLoader, Dataset
|
11 |
-
|
12 |
-
|
13 |
-
def load_to_cpu(x):
|
14 |
-
return torch.load(x, map_location=torch.device("cpu"), weights_only=True)
|
15 |
-
|
16 |
-
|
17 |
-
class LatentEmbedDataset(Dataset):
|
18 |
-
def __init__(self, file_paths, repeat=1):
|
19 |
-
self.items = [
|
20 |
-
(Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt"))
|
21 |
-
for p in file_paths
|
22 |
-
if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file()
|
23 |
-
]
|
24 |
-
self.items = self.items * repeat
|
25 |
-
print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.")
|
26 |
-
|
27 |
-
def __len__(self):
|
28 |
-
return len(self.items)
|
29 |
-
|
30 |
-
def __getitem__(self, idx):
|
31 |
-
latent_path, embed_path = self.items[idx]
|
32 |
-
return load_to_cpu(latent_path), load_to_cpu(embed_path)
|
33 |
-
|
34 |
-
|
35 |
-
@click.command()
|
36 |
-
@click.argument("directory", type=click.Path(exists=True, file_okay=False))
|
37 |
-
def process_videos(directory):
|
38 |
-
dir_path = Path(directory)
|
39 |
-
mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")]
|
40 |
-
assert mp4_files, f"No mp4 files found"
|
41 |
-
|
42 |
-
dataset = LatentEmbedDataset(mp4_files)
|
43 |
-
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
|
44 |
-
|
45 |
-
for latents, embeds in dataloader:
|
46 |
-
print([(k, v.shape) for k, v in latents.items()])
|
47 |
-
|
48 |
-
|
49 |
-
if __name__ == "__main__":
|
50 |
-
process_videos()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/mochi-1/embed.py
DELETED
@@ -1,111 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Adapted from:
|
3 |
-
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py
|
4 |
-
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py
|
5 |
-
"""
|
6 |
-
|
7 |
-
import click
|
8 |
-
import torch
|
9 |
-
import torchvision
|
10 |
-
from pathlib import Path
|
11 |
-
from diffusers import AutoencoderKLMochi, MochiPipeline
|
12 |
-
from transformers import T5EncoderModel, T5Tokenizer
|
13 |
-
from tqdm.auto import tqdm
|
14 |
-
|
15 |
-
|
16 |
-
def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str):
|
17 |
-
T, H, W = [int(s) for s in shape.split("x")]
|
18 |
-
assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6"
|
19 |
-
video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs")
|
20 |
-
fps = metadata["video_fps"]
|
21 |
-
video = video.permute(3, 0, 1, 2)
|
22 |
-
og_shape = video.shape
|
23 |
-
assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}"
|
24 |
-
assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}"
|
25 |
-
assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}"
|
26 |
-
if video.shape[1] > T:
|
27 |
-
video = video[:, :T]
|
28 |
-
print(f"Trimmed video from {og_shape[1]} to first {T} frames")
|
29 |
-
video = video.unsqueeze(0)
|
30 |
-
video = video.float() / 127.5 - 1.0
|
31 |
-
video = video.to(model.device)
|
32 |
-
|
33 |
-
assert video.ndim == 5
|
34 |
-
|
35 |
-
with torch.inference_mode():
|
36 |
-
with torch.autocast("cuda", dtype=torch.bfloat16):
|
37 |
-
ldist = model._encode(video)
|
38 |
-
|
39 |
-
torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt"))
|
40 |
-
|
41 |
-
|
42 |
-
@click.command()
|
43 |
-
@click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
|
44 |
-
@click.option(
|
45 |
-
"--model_id",
|
46 |
-
type=str,
|
47 |
-
help="Repo id. Should be genmo/mochi-1-preview",
|
48 |
-
default="genmo/mochi-1-preview",
|
49 |
-
)
|
50 |
-
@click.option("--shape", default="163x480x848", help="Shape of the video to encode")
|
51 |
-
@click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.")
|
52 |
-
def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None:
|
53 |
-
"""Process all videos and captions in a directory using a single GPU."""
|
54 |
-
# comment out when running on unsupported hardware
|
55 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
56 |
-
torch.backends.cudnn.allow_tf32 = True
|
57 |
-
|
58 |
-
# Get all video paths
|
59 |
-
video_paths = list(output_dir.glob("**/*.mp4"))
|
60 |
-
if not video_paths:
|
61 |
-
print(f"No MP4 files found in {output_dir}")
|
62 |
-
return
|
63 |
-
|
64 |
-
text_paths = list(output_dir.glob("**/*.txt"))
|
65 |
-
if not text_paths:
|
66 |
-
print(f"No text files found in {output_dir}")
|
67 |
-
return
|
68 |
-
|
69 |
-
# load the models
|
70 |
-
vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
71 |
-
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder")
|
72 |
-
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
73 |
-
pipeline = MochiPipeline.from_pretrained(
|
74 |
-
model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None
|
75 |
-
).to("cuda")
|
76 |
-
|
77 |
-
for idx, video_path in tqdm(enumerate(sorted(video_paths))):
|
78 |
-
print(f"Processing {video_path}")
|
79 |
-
try:
|
80 |
-
if video_path.with_suffix(".latent.pt").exists() and not overwrite:
|
81 |
-
print(f"Skipping {video_path}")
|
82 |
-
continue
|
83 |
-
|
84 |
-
# encode videos.
|
85 |
-
encode_videos(vae, vid_path=video_path, shape=shape)
|
86 |
-
|
87 |
-
# embed captions.
|
88 |
-
prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt")
|
89 |
-
embed_path = prompt_path.with_suffix(".embed.pt")
|
90 |
-
|
91 |
-
if embed_path.exists() and not overwrite:
|
92 |
-
print(f"Skipping {prompt_path} - embeddings already exist")
|
93 |
-
continue
|
94 |
-
|
95 |
-
with open(prompt_path) as f:
|
96 |
-
text = f.read().strip()
|
97 |
-
with torch.inference_mode():
|
98 |
-
conditioning = pipeline.encode_prompt(prompt=[text])
|
99 |
-
|
100 |
-
conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]}
|
101 |
-
torch.save(conditioning, embed_path)
|
102 |
-
|
103 |
-
except Exception as e:
|
104 |
-
import traceback
|
105 |
-
|
106 |
-
traceback.print_exc()
|
107 |
-
print(f"Error processing {video_path}: {str(e)}")
|
108 |
-
|
109 |
-
|
110 |
-
if __name__ == "__main__":
|
111 |
-
batch_process()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/mochi-1/prepare_dataset.sh
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
|
3 |
-
GPU_ID=0
|
4 |
-
VIDEO_DIR=video-dataset-disney-organized
|
5 |
-
OUTPUT_DIR=videos_prepared
|
6 |
-
NUM_FRAMES=37
|
7 |
-
RESOLUTION=480x848
|
8 |
-
|
9 |
-
# Extract width and height from RESOLUTION
|
10 |
-
WIDTH=$(echo $RESOLUTION | cut -dx -f1)
|
11 |
-
HEIGHT=$(echo $RESOLUTION | cut -dx -f2)
|
12 |
-
|
13 |
-
python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample
|
14 |
-
|
15 |
-
CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/mochi-1/requirements.txt
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
peft
|
2 |
-
transformers
|
3 |
-
wandb
|
4 |
-
torch
|
5 |
-
torchvision
|
6 |
-
av==11.0.0
|
7 |
-
moviepy==1.0.3
|
8 |
-
click
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/mochi-1/text_to_video_lora.py
DELETED
@@ -1,592 +0,0 @@
|
|
1 |
-
# Copyright 2024 The HuggingFace Team.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
import gc
|
17 |
-
import random
|
18 |
-
from glob import glob
|
19 |
-
import math
|
20 |
-
import os
|
21 |
-
import torch.nn.functional as F
|
22 |
-
import numpy as np
|
23 |
-
from pathlib import Path
|
24 |
-
from typing import Any, Dict, Tuple, List
|
25 |
-
|
26 |
-
import torch
|
27 |
-
import wandb
|
28 |
-
from diffusers import FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
|
29 |
-
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
30 |
-
from diffusers.training_utils import cast_training_params
|
31 |
-
from diffusers.utils import export_to_video
|
32 |
-
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
33 |
-
from huggingface_hub import create_repo, upload_folder
|
34 |
-
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
35 |
-
from torch.utils.data import DataLoader
|
36 |
-
from tqdm.auto import tqdm
|
37 |
-
|
38 |
-
|
39 |
-
from args import get_args # isort:skip
|
40 |
-
from dataset_simple import LatentEmbedDataset
|
41 |
-
|
42 |
-
import sys
|
43 |
-
from utils import print_memory, reset_memory # isort:skip
|
44 |
-
|
45 |
-
|
46 |
-
# Taken from
|
47 |
-
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139
|
48 |
-
def get_cosine_annealing_lr_scheduler(
|
49 |
-
optimizer: torch.optim.Optimizer,
|
50 |
-
warmup_steps: int,
|
51 |
-
total_steps: int,
|
52 |
-
):
|
53 |
-
def lr_lambda(step):
|
54 |
-
if step < warmup_steps:
|
55 |
-
return float(step) / float(max(1, warmup_steps))
|
56 |
-
else:
|
57 |
-
return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
|
58 |
-
|
59 |
-
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
60 |
-
|
61 |
-
|
62 |
-
def save_model_card(
|
63 |
-
repo_id: str,
|
64 |
-
videos=None,
|
65 |
-
base_model: str = None,
|
66 |
-
validation_prompt=None,
|
67 |
-
repo_folder=None,
|
68 |
-
fps=30,
|
69 |
-
):
|
70 |
-
widget_dict = []
|
71 |
-
if videos is not None and len(videos) > 0:
|
72 |
-
for i, video in enumerate(videos):
|
73 |
-
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps)
|
74 |
-
widget_dict.append(
|
75 |
-
{
|
76 |
-
"text": validation_prompt if validation_prompt else " ",
|
77 |
-
"output": {"url": f"final_video_{i}.mp4"},
|
78 |
-
}
|
79 |
-
)
|
80 |
-
|
81 |
-
model_description = f"""
|
82 |
-
# Mochi-1 Preview LoRA Finetune
|
83 |
-
|
84 |
-
<Gallery />
|
85 |
-
|
86 |
-
## Model description
|
87 |
-
|
88 |
-
This is a lora finetune of the Mochi-1 preview model `{base_model}`.
|
89 |
-
|
90 |
-
The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
91 |
-
|
92 |
-
## Download model
|
93 |
-
|
94 |
-
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
|
95 |
-
|
96 |
-
## Usage
|
97 |
-
|
98 |
-
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
|
99 |
-
|
100 |
-
```py
|
101 |
-
from diffusers import MochiPipeline
|
102 |
-
from diffusers.utils import export_to_video
|
103 |
-
import torch
|
104 |
-
|
105 |
-
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
|
106 |
-
pipe.load_lora_weights("CHANGE_ME")
|
107 |
-
pipe.enable_model_cpu_offload()
|
108 |
-
|
109 |
-
with torch.autocast("cuda", torch.bfloat16):
|
110 |
-
video = pipe(
|
111 |
-
prompt="CHANGE_ME",
|
112 |
-
guidance_scale=6.0,
|
113 |
-
num_inference_steps=64,
|
114 |
-
height=480,
|
115 |
-
width=848,
|
116 |
-
max_sequence_length=256,
|
117 |
-
output_type="np"
|
118 |
-
).frames[0]
|
119 |
-
export_to_video(video)
|
120 |
-
```
|
121 |
-
|
122 |
-
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
|
123 |
-
|
124 |
-
"""
|
125 |
-
model_card = load_or_create_model_card(
|
126 |
-
repo_id_or_path=repo_id,
|
127 |
-
from_training=True,
|
128 |
-
license="apache-2.0",
|
129 |
-
base_model=base_model,
|
130 |
-
prompt=validation_prompt,
|
131 |
-
model_description=model_description,
|
132 |
-
widget=widget_dict,
|
133 |
-
)
|
134 |
-
tags = [
|
135 |
-
"text-to-video",
|
136 |
-
"diffusers-training",
|
137 |
-
"diffusers",
|
138 |
-
"lora",
|
139 |
-
"mochi-1-preview",
|
140 |
-
"mochi-1-preview-diffusers",
|
141 |
-
"template:sd-lora",
|
142 |
-
]
|
143 |
-
|
144 |
-
model_card = populate_model_card(model_card, tags=tags)
|
145 |
-
model_card.save(os.path.join(repo_folder, "README.md"))
|
146 |
-
|
147 |
-
|
148 |
-
def log_validation(
|
149 |
-
pipe: MochiPipeline,
|
150 |
-
args: Dict[str, Any],
|
151 |
-
pipeline_args: Dict[str, Any],
|
152 |
-
epoch,
|
153 |
-
wandb_run: str = None,
|
154 |
-
is_final_validation: bool = False,
|
155 |
-
):
|
156 |
-
print(
|
157 |
-
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
158 |
-
)
|
159 |
-
phase_name = "test" if is_final_validation else "validation"
|
160 |
-
|
161 |
-
if not args.enable_model_cpu_offload:
|
162 |
-
pipe = pipe.to("cuda")
|
163 |
-
|
164 |
-
# run inference
|
165 |
-
generator = torch.manual_seed(args.seed) if args.seed else None
|
166 |
-
|
167 |
-
videos = []
|
168 |
-
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
|
169 |
-
for _ in range(args.num_validation_videos):
|
170 |
-
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
171 |
-
videos.append(video)
|
172 |
-
|
173 |
-
video_filenames = []
|
174 |
-
for i, video in enumerate(videos):
|
175 |
-
prompt = (
|
176 |
-
pipeline_args["prompt"][:25]
|
177 |
-
.replace(" ", "_")
|
178 |
-
.replace(" ", "_")
|
179 |
-
.replace("'", "_")
|
180 |
-
.replace('"', "_")
|
181 |
-
.replace("/", "_")
|
182 |
-
)
|
183 |
-
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
|
184 |
-
export_to_video(video, filename, fps=30)
|
185 |
-
video_filenames.append(filename)
|
186 |
-
|
187 |
-
if wandb_run:
|
188 |
-
wandb.log(
|
189 |
-
{
|
190 |
-
phase_name: [
|
191 |
-
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30)
|
192 |
-
for i, filename in enumerate(video_filenames)
|
193 |
-
]
|
194 |
-
}
|
195 |
-
)
|
196 |
-
|
197 |
-
return videos
|
198 |
-
|
199 |
-
|
200 |
-
# Adapted from the original code:
|
201 |
-
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578
|
202 |
-
def cast_dit(model, dtype):
|
203 |
-
for name, module in model.named_modules():
|
204 |
-
if isinstance(module, torch.nn.Linear):
|
205 |
-
assert any(
|
206 |
-
n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"]
|
207 |
-
), f"Unexpected linear layer: {name}"
|
208 |
-
module.to(dtype=dtype)
|
209 |
-
elif isinstance(module, torch.nn.Conv2d):
|
210 |
-
module.to(dtype=dtype)
|
211 |
-
return model
|
212 |
-
|
213 |
-
|
214 |
-
def save_checkpoint(model, optimizer, lr_scheduler, global_step, checkpoint_path):
|
215 |
-
lora_state_dict = get_peft_model_state_dict(model)
|
216 |
-
torch.save(
|
217 |
-
{
|
218 |
-
"state_dict": lora_state_dict,
|
219 |
-
"optimizer": optimizer.state_dict(),
|
220 |
-
"lr_scheduler": lr_scheduler.state_dict(),
|
221 |
-
"global_step": global_step,
|
222 |
-
},
|
223 |
-
checkpoint_path,
|
224 |
-
)
|
225 |
-
|
226 |
-
|
227 |
-
class CollateFunction:
|
228 |
-
def __init__(self, caption_dropout: float = None) -> None:
|
229 |
-
self.caption_dropout = caption_dropout
|
230 |
-
|
231 |
-
def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
232 |
-
ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0)
|
233 |
-
z = DiagonalGaussianDistribution(ldists).sample()
|
234 |
-
assert torch.isfinite(z).all()
|
235 |
-
|
236 |
-
# Sample noise which we will add to the samples.
|
237 |
-
eps = torch.randn_like(z)
|
238 |
-
sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32)
|
239 |
-
|
240 |
-
prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0)
|
241 |
-
prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0)
|
242 |
-
if self.caption_dropout and random.random() < self.caption_dropout:
|
243 |
-
prompt_embeds.zero_()
|
244 |
-
prompt_attention_mask = prompt_attention_mask.long()
|
245 |
-
prompt_attention_mask.zero_()
|
246 |
-
prompt_attention_mask = prompt_attention_mask.bool()
|
247 |
-
|
248 |
-
return dict(
|
249 |
-
z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask
|
250 |
-
)
|
251 |
-
|
252 |
-
|
253 |
-
def main(args):
|
254 |
-
if not torch.cuda.is_available():
|
255 |
-
raise ValueError("Not supported without CUDA.")
|
256 |
-
|
257 |
-
if args.report_to == "wandb" and args.hub_token is not None:
|
258 |
-
raise ValueError(
|
259 |
-
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
260 |
-
" Please use `huggingface-cli login` to authenticate with the Hub."
|
261 |
-
)
|
262 |
-
|
263 |
-
# Handle the repository creation
|
264 |
-
if args.output_dir is not None:
|
265 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
266 |
-
|
267 |
-
if args.push_to_hub:
|
268 |
-
repo_id = create_repo(
|
269 |
-
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
270 |
-
exist_ok=True,
|
271 |
-
).repo_id
|
272 |
-
|
273 |
-
# Prepare models and scheduler
|
274 |
-
transformer = MochiTransformer3DModel.from_pretrained(
|
275 |
-
args.pretrained_model_name_or_path,
|
276 |
-
subfolder="transformer",
|
277 |
-
revision=args.revision,
|
278 |
-
variant=args.variant,
|
279 |
-
)
|
280 |
-
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
281 |
-
args.pretrained_model_name_or_path, subfolder="scheduler"
|
282 |
-
)
|
283 |
-
|
284 |
-
transformer.requires_grad_(False)
|
285 |
-
transformer.to("cuda")
|
286 |
-
if args.gradient_checkpointing:
|
287 |
-
transformer.enable_gradient_checkpointing()
|
288 |
-
if args.cast_dit:
|
289 |
-
transformer = cast_dit(transformer, torch.bfloat16)
|
290 |
-
if args.compile_dit:
|
291 |
-
transformer.compile()
|
292 |
-
|
293 |
-
# now we will add new LoRA weights to the attention layers
|
294 |
-
transformer_lora_config = LoraConfig(
|
295 |
-
r=args.rank,
|
296 |
-
lora_alpha=args.lora_alpha,
|
297 |
-
init_lora_weights="gaussian",
|
298 |
-
target_modules=args.target_modules,
|
299 |
-
)
|
300 |
-
transformer.add_adapter(transformer_lora_config)
|
301 |
-
|
302 |
-
# Enable TF32 for faster training on Ampere GPUs,
|
303 |
-
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
304 |
-
if args.allow_tf32 and torch.cuda.is_available():
|
305 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
306 |
-
|
307 |
-
if args.scale_lr:
|
308 |
-
args.learning_rate = args.learning_rate * args.train_batch_size
|
309 |
-
# only upcast trainable parameters (LoRA) into fp32
|
310 |
-
cast_training_params([transformer], dtype=torch.float32)
|
311 |
-
|
312 |
-
# Prepare optimizer
|
313 |
-
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
314 |
-
num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters)
|
315 |
-
optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
|
316 |
-
|
317 |
-
# Dataset and DataLoader
|
318 |
-
train_vids = list(sorted(glob(f"{args.data_root}/*.mp4")))
|
319 |
-
train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")]
|
320 |
-
print(f"Found {len(train_vids)} training videos in {args.data_root}")
|
321 |
-
assert len(train_vids) > 0, f"No training data found in {args.data_root}"
|
322 |
-
|
323 |
-
collate_fn = CollateFunction(caption_dropout=args.caption_dropout)
|
324 |
-
train_dataset = LatentEmbedDataset(train_vids, repeat=1)
|
325 |
-
train_dataloader = DataLoader(
|
326 |
-
train_dataset,
|
327 |
-
collate_fn=collate_fn,
|
328 |
-
batch_size=args.train_batch_size,
|
329 |
-
num_workers=args.dataloader_num_workers,
|
330 |
-
pin_memory=args.pin_memory,
|
331 |
-
)
|
332 |
-
|
333 |
-
# LR scheduler and math around the number of training steps.
|
334 |
-
overrode_max_train_steps = False
|
335 |
-
num_update_steps_per_epoch = len(train_dataloader)
|
336 |
-
if args.max_train_steps is None:
|
337 |
-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
338 |
-
overrode_max_train_steps = True
|
339 |
-
|
340 |
-
lr_scheduler = get_cosine_annealing_lr_scheduler(
|
341 |
-
optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps
|
342 |
-
)
|
343 |
-
|
344 |
-
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
345 |
-
num_update_steps_per_epoch = len(train_dataloader)
|
346 |
-
if overrode_max_train_steps:
|
347 |
-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
348 |
-
# Afterwards we recalculate our number of training epochs
|
349 |
-
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
350 |
-
|
351 |
-
# We need to initialize the trackers we use, and also store our configuration.
|
352 |
-
# The trackers initializes automatically on the main process.
|
353 |
-
wandb_run = None
|
354 |
-
if args.report_to == "wandb":
|
355 |
-
tracker_name = args.tracker_name or "mochi-1-lora"
|
356 |
-
wandb_run = wandb.init(project=tracker_name, config=vars(args))
|
357 |
-
|
358 |
-
# Resume from checkpoint if specified
|
359 |
-
if args.resume_from_checkpoint:
|
360 |
-
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu", weights_only=True)
|
361 |
-
if "global_step" in checkpoint:
|
362 |
-
global_step = checkpoint["global_step"]
|
363 |
-
if "optimizer" in checkpoint:
|
364 |
-
optimizer.load_state_dict(checkpoint["optimizer"])
|
365 |
-
if "lr_scheduler" in checkpoint:
|
366 |
-
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
367 |
-
|
368 |
-
set_peft_model_state_dict(transformer, checkpoint["state_dict"])
|
369 |
-
|
370 |
-
print(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
|
371 |
-
print(f"Resuming from global step: {global_step}")
|
372 |
-
else:
|
373 |
-
global_step = 0
|
374 |
-
|
375 |
-
print("===== Memory before training =====")
|
376 |
-
reset_memory("cuda")
|
377 |
-
print_memory("cuda")
|
378 |
-
|
379 |
-
# Train!
|
380 |
-
total_batch_size = args.train_batch_size
|
381 |
-
print("***** Running training *****")
|
382 |
-
print(f" Num trainable parameters = {num_trainable_parameters}")
|
383 |
-
print(f" Num examples = {len(train_dataset)}")
|
384 |
-
print(f" Num batches each epoch = {len(train_dataloader)}")
|
385 |
-
print(f" Num epochs = {args.num_train_epochs}")
|
386 |
-
print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
387 |
-
print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
388 |
-
print(f" Total optimization steps = {args.max_train_steps}")
|
389 |
-
|
390 |
-
first_epoch = 0
|
391 |
-
progress_bar = tqdm(
|
392 |
-
range(0, args.max_train_steps),
|
393 |
-
initial=global_step,
|
394 |
-
desc="Steps",
|
395 |
-
)
|
396 |
-
for epoch in range(first_epoch, args.num_train_epochs):
|
397 |
-
transformer.train()
|
398 |
-
|
399 |
-
for step, batch in enumerate(train_dataloader):
|
400 |
-
with torch.no_grad():
|
401 |
-
z = batch["z"].to("cuda")
|
402 |
-
eps = batch["eps"].to("cuda")
|
403 |
-
sigma = batch["sigma"].to("cuda")
|
404 |
-
prompt_embeds = batch["prompt_embeds"].to("cuda")
|
405 |
-
prompt_attention_mask = batch["prompt_attention_mask"].to("cuda")
|
406 |
-
|
407 |
-
sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1]
|
408 |
-
# Add noise according to flow matching.
|
409 |
-
# zt = (1 - texp) * x + texp * z1
|
410 |
-
z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps
|
411 |
-
ut = z - eps
|
412 |
-
|
413 |
-
# (1 - sigma) because of
|
414 |
-
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656
|
415 |
-
# Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation.
|
416 |
-
timesteps = (1 - sigma) * scheduler.config.num_train_timesteps
|
417 |
-
|
418 |
-
with torch.autocast("cuda", torch.bfloat16):
|
419 |
-
model_pred = transformer(
|
420 |
-
hidden_states=z_sigma,
|
421 |
-
encoder_hidden_states=prompt_embeds,
|
422 |
-
encoder_attention_mask=prompt_attention_mask,
|
423 |
-
timestep=timesteps,
|
424 |
-
return_dict=False,
|
425 |
-
)[0]
|
426 |
-
assert model_pred.shape == z.shape
|
427 |
-
loss = F.mse_loss(model_pred.float(), ut.float())
|
428 |
-
loss.backward()
|
429 |
-
|
430 |
-
optimizer.step()
|
431 |
-
optimizer.zero_grad()
|
432 |
-
lr_scheduler.step()
|
433 |
-
|
434 |
-
progress_bar.update(1)
|
435 |
-
global_step += 1
|
436 |
-
|
437 |
-
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
|
438 |
-
logs = {"loss": loss.detach().item(), "lr": last_lr}
|
439 |
-
progress_bar.set_postfix(**logs)
|
440 |
-
if wandb_run:
|
441 |
-
wandb_run.log(logs, step=global_step)
|
442 |
-
|
443 |
-
if args.checkpointing_steps is not None and global_step % args.checkpointing_steps == 0:
|
444 |
-
print(f"Saving checkpoint at step {global_step}")
|
445 |
-
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt")
|
446 |
-
save_checkpoint(
|
447 |
-
transformer,
|
448 |
-
optimizer,
|
449 |
-
lr_scheduler,
|
450 |
-
global_step,
|
451 |
-
checkpoint_path,
|
452 |
-
)
|
453 |
-
|
454 |
-
if global_step >= args.max_train_steps:
|
455 |
-
break
|
456 |
-
|
457 |
-
if global_step >= args.max_train_steps:
|
458 |
-
break
|
459 |
-
|
460 |
-
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
|
461 |
-
print("===== Memory before validation =====")
|
462 |
-
print_memory("cuda")
|
463 |
-
|
464 |
-
transformer.eval()
|
465 |
-
pipe = MochiPipeline.from_pretrained(
|
466 |
-
args.pretrained_model_name_or_path,
|
467 |
-
transformer=transformer,
|
468 |
-
scheduler=scheduler,
|
469 |
-
revision=args.revision,
|
470 |
-
variant=args.variant,
|
471 |
-
)
|
472 |
-
|
473 |
-
if args.enable_slicing:
|
474 |
-
pipe.vae.enable_slicing()
|
475 |
-
if args.enable_tiling:
|
476 |
-
pipe.vae.enable_tiling()
|
477 |
-
if args.enable_model_cpu_offload:
|
478 |
-
pipe.enable_model_cpu_offload()
|
479 |
-
|
480 |
-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
481 |
-
for validation_prompt in validation_prompts:
|
482 |
-
pipeline_args = {
|
483 |
-
"prompt": validation_prompt,
|
484 |
-
"guidance_scale": 6.0,
|
485 |
-
"num_inference_steps": 64,
|
486 |
-
"height": args.height,
|
487 |
-
"width": args.width,
|
488 |
-
"max_sequence_length": 256,
|
489 |
-
}
|
490 |
-
log_validation(
|
491 |
-
pipe=pipe,
|
492 |
-
args=args,
|
493 |
-
pipeline_args=pipeline_args,
|
494 |
-
epoch=epoch,
|
495 |
-
wandb_run=wandb_run,
|
496 |
-
)
|
497 |
-
|
498 |
-
print("===== Memory after validation =====")
|
499 |
-
print_memory("cuda")
|
500 |
-
reset_memory("cuda")
|
501 |
-
|
502 |
-
del pipe.text_encoder
|
503 |
-
del pipe.vae
|
504 |
-
del pipe
|
505 |
-
gc.collect()
|
506 |
-
torch.cuda.empty_cache()
|
507 |
-
|
508 |
-
transformer.train()
|
509 |
-
|
510 |
-
transformer.eval()
|
511 |
-
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
512 |
-
MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers)
|
513 |
-
|
514 |
-
# Cleanup trained models to save memory
|
515 |
-
del transformer
|
516 |
-
|
517 |
-
gc.collect()
|
518 |
-
torch.cuda.empty_cache()
|
519 |
-
|
520 |
-
# Final test inference
|
521 |
-
validation_outputs = []
|
522 |
-
if args.validation_prompt and args.num_validation_videos > 0:
|
523 |
-
print("===== Memory before testing =====")
|
524 |
-
print_memory("cuda")
|
525 |
-
reset_memory("cuda")
|
526 |
-
|
527 |
-
pipe = MochiPipeline.from_pretrained(
|
528 |
-
args.pretrained_model_name_or_path,
|
529 |
-
revision=args.revision,
|
530 |
-
variant=args.variant,
|
531 |
-
)
|
532 |
-
|
533 |
-
if args.enable_slicing:
|
534 |
-
pipe.vae.enable_slicing()
|
535 |
-
if args.enable_tiling:
|
536 |
-
pipe.vae.enable_tiling()
|
537 |
-
if args.enable_model_cpu_offload:
|
538 |
-
pipe.enable_model_cpu_offload()
|
539 |
-
|
540 |
-
# Load LoRA weights
|
541 |
-
lora_scaling = args.lora_alpha / args.rank
|
542 |
-
pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora")
|
543 |
-
pipe.set_adapters(["mochi-lora"], [lora_scaling])
|
544 |
-
|
545 |
-
# Run inference
|
546 |
-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
547 |
-
for validation_prompt in validation_prompts:
|
548 |
-
pipeline_args = {
|
549 |
-
"prompt": validation_prompt,
|
550 |
-
"guidance_scale": 6.0,
|
551 |
-
"num_inference_steps": 64,
|
552 |
-
"height": args.height,
|
553 |
-
"width": args.width,
|
554 |
-
"max_sequence_length": 256,
|
555 |
-
}
|
556 |
-
|
557 |
-
video = log_validation(
|
558 |
-
pipe=pipe,
|
559 |
-
args=args,
|
560 |
-
pipeline_args=pipeline_args,
|
561 |
-
epoch=epoch,
|
562 |
-
wandb_run=wandb_run,
|
563 |
-
is_final_validation=True,
|
564 |
-
)
|
565 |
-
validation_outputs.extend(video)
|
566 |
-
|
567 |
-
print("===== Memory after testing =====")
|
568 |
-
print_memory("cuda")
|
569 |
-
reset_memory("cuda")
|
570 |
-
torch.cuda.synchronize("cuda")
|
571 |
-
|
572 |
-
if args.push_to_hub:
|
573 |
-
save_model_card(
|
574 |
-
repo_id,
|
575 |
-
videos=validation_outputs,
|
576 |
-
base_model=args.pretrained_model_name_or_path,
|
577 |
-
validation_prompt=args.validation_prompt,
|
578 |
-
repo_folder=args.output_dir,
|
579 |
-
fps=args.fps,
|
580 |
-
)
|
581 |
-
upload_folder(
|
582 |
-
repo_id=repo_id,
|
583 |
-
folder_path=args.output_dir,
|
584 |
-
commit_message="End of training",
|
585 |
-
ignore_patterns=["*.bin"],
|
586 |
-
)
|
587 |
-
print(f"Params pushed to {repo_id}.")
|
588 |
-
|
589 |
-
|
590 |
-
if __name__ == "__main__":
|
591 |
-
args = get_args()
|
592 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/mochi-1/train.sh
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
export NCCL_P2P_DISABLE=1
|
3 |
-
export TORCH_NCCL_ENABLE_MONITORING=0
|
4 |
-
|
5 |
-
GPU_IDS="0"
|
6 |
-
|
7 |
-
DATA_ROOT="videos_prepared"
|
8 |
-
MODEL="genmo/mochi-1-preview"
|
9 |
-
OUTPUT_PATH="mochi-lora"
|
10 |
-
|
11 |
-
cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python text_to_video_lora.py \
|
12 |
-
--pretrained_model_name_or_path $MODEL \
|
13 |
-
--cast_dit \
|
14 |
-
--data_root $DATA_ROOT \
|
15 |
-
--seed 42 \
|
16 |
-
--output_dir $OUTPUT_PATH \
|
17 |
-
--train_batch_size 1 \
|
18 |
-
--dataloader_num_workers 4 \
|
19 |
-
--pin_memory \
|
20 |
-
--caption_dropout 0.1 \
|
21 |
-
--max_train_steps 2000 \
|
22 |
-
--gradient_checkpointing \
|
23 |
-
--enable_slicing \
|
24 |
-
--enable_tiling \
|
25 |
-
--enable_model_cpu_offload \
|
26 |
-
--optimizer adamw \
|
27 |
-
--validation_prompt \"A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \
|
28 |
-
--validation_prompt_separator ::: \
|
29 |
-
--num_validation_videos 1 \
|
30 |
-
--validation_epochs 1 \
|
31 |
-
--allow_tf32 \
|
32 |
-
--report_to wandb \
|
33 |
-
--push_to_hub"
|
34 |
-
|
35 |
-
echo "Running command: $cmd"
|
36 |
-
eval $cmd
|
37 |
-
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/mochi-1/trim_and_crop_videos.py
DELETED
@@ -1,126 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Adapted from:
|
3 |
-
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py
|
4 |
-
"""
|
5 |
-
|
6 |
-
from pathlib import Path
|
7 |
-
import shutil
|
8 |
-
|
9 |
-
import click
|
10 |
-
from moviepy.editor import VideoFileClip
|
11 |
-
from tqdm import tqdm
|
12 |
-
|
13 |
-
|
14 |
-
@click.command()
|
15 |
-
@click.argument("folder", type=click.Path(exists=True, dir_okay=True))
|
16 |
-
@click.argument("output_folder", type=click.Path(dir_okay=True))
|
17 |
-
@click.option("--num_frames", "-f", type=float, default=30, help="Number of frames")
|
18 |
-
@click.option("--resolution", "-r", type=str, default="480x848", help="Video resolution")
|
19 |
-
@click.option("--force_upsample", is_flag=True, help="Force upsample.")
|
20 |
-
def truncate_videos(folder, output_folder, num_frames, resolution, force_upsample):
|
21 |
-
"""Truncate all MP4 and MOV files in FOLDER to specified number of frames and resolution"""
|
22 |
-
input_path = Path(folder)
|
23 |
-
output_path = Path(output_folder)
|
24 |
-
output_path.mkdir(parents=True, exist_ok=True)
|
25 |
-
|
26 |
-
# Parse target resolution
|
27 |
-
target_height, target_width = map(int, resolution.split("x"))
|
28 |
-
|
29 |
-
# Calculate duration
|
30 |
-
duration = (num_frames / 30) + 0.09
|
31 |
-
|
32 |
-
# Find all MP4 and MOV files
|
33 |
-
video_files = (
|
34 |
-
list(input_path.rglob("*.mp4"))
|
35 |
-
+ list(input_path.rglob("*.MOV"))
|
36 |
-
+ list(input_path.rglob("*.mov"))
|
37 |
-
+ list(input_path.rglob("*.MP4"))
|
38 |
-
)
|
39 |
-
|
40 |
-
for file_path in tqdm(video_files):
|
41 |
-
try:
|
42 |
-
relative_path = file_path.relative_to(input_path)
|
43 |
-
output_file = output_path / relative_path.with_suffix(".mp4")
|
44 |
-
output_file.parent.mkdir(parents=True, exist_ok=True)
|
45 |
-
|
46 |
-
click.echo(f"Processing: {file_path}")
|
47 |
-
video = VideoFileClip(str(file_path))
|
48 |
-
|
49 |
-
# Skip if video is too short
|
50 |
-
if video.duration < duration:
|
51 |
-
click.echo(f"Skipping {file_path} as it is too short")
|
52 |
-
continue
|
53 |
-
|
54 |
-
# Skip if target resolution is larger than input
|
55 |
-
if target_width > video.w or target_height > video.h:
|
56 |
-
if force_upsample:
|
57 |
-
click.echo(
|
58 |
-
f"{file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}. So, upsampling the video."
|
59 |
-
)
|
60 |
-
video = video.resize(width=target_width, height=target_height)
|
61 |
-
else:
|
62 |
-
click.echo(
|
63 |
-
f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}"
|
64 |
-
)
|
65 |
-
continue
|
66 |
-
|
67 |
-
# First truncate duration
|
68 |
-
truncated = video.subclip(0, duration)
|
69 |
-
|
70 |
-
# Calculate crop dimensions to maintain aspect ratio
|
71 |
-
target_ratio = target_width / target_height
|
72 |
-
current_ratio = truncated.w / truncated.h
|
73 |
-
|
74 |
-
if current_ratio > target_ratio:
|
75 |
-
# Video is wider than target ratio - crop width
|
76 |
-
new_width = int(truncated.h * target_ratio)
|
77 |
-
x1 = (truncated.w - new_width) // 2
|
78 |
-
final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height))
|
79 |
-
else:
|
80 |
-
# Video is taller than target ratio - crop height
|
81 |
-
new_height = int(truncated.w / target_ratio)
|
82 |
-
y1 = (truncated.h - new_height) // 2
|
83 |
-
final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height))
|
84 |
-
|
85 |
-
# Set output parameters for consistent MP4 encoding
|
86 |
-
output_params = {
|
87 |
-
"codec": "libx264",
|
88 |
-
"audio": False, # Disable audio
|
89 |
-
"preset": "medium", # Balance between speed and quality
|
90 |
-
"bitrate": "5000k", # Adjust as needed
|
91 |
-
}
|
92 |
-
|
93 |
-
# Set FPS to 30
|
94 |
-
final = final.set_fps(30)
|
95 |
-
|
96 |
-
# Check for a corresponding .txt file
|
97 |
-
txt_file_path = file_path.with_suffix(".txt")
|
98 |
-
if txt_file_path.exists():
|
99 |
-
output_txt_file = output_path / relative_path.with_suffix(".txt")
|
100 |
-
output_txt_file.parent.mkdir(parents=True, exist_ok=True)
|
101 |
-
shutil.copy(txt_file_path, output_txt_file)
|
102 |
-
click.echo(f"Copied {txt_file_path} to {output_txt_file}")
|
103 |
-
else:
|
104 |
-
# Print warning in bold yellow with a warning emoji
|
105 |
-
click.echo(
|
106 |
-
f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m"
|
107 |
-
)
|
108 |
-
output_txt_file = output_path / relative_path.with_suffix(".txt")
|
109 |
-
output_txt_file.parent.mkdir(parents=True, exist_ok=True)
|
110 |
-
output_txt_file.touch()
|
111 |
-
|
112 |
-
# Write the output file
|
113 |
-
final.write_videofile(str(output_file), **output_params)
|
114 |
-
|
115 |
-
# Clean up
|
116 |
-
video.close()
|
117 |
-
truncated.close()
|
118 |
-
final.close()
|
119 |
-
|
120 |
-
except Exception as e:
|
121 |
-
click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True)
|
122 |
-
raise
|
123 |
-
|
124 |
-
|
125 |
-
if __name__ == "__main__":
|
126 |
-
truncate_videos()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/mochi-1/utils.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
import gc
|
2 |
-
import inspect
|
3 |
-
from typing import Optional, Tuple, Union
|
4 |
-
|
5 |
-
import torch
|
6 |
-
|
7 |
-
logger = get_logger(__name__)
|
8 |
-
|
9 |
-
def reset_memory(device: Union[str, torch.device]) -> None:
|
10 |
-
gc.collect()
|
11 |
-
torch.cuda.empty_cache()
|
12 |
-
torch.cuda.reset_peak_memory_stats(device)
|
13 |
-
torch.cuda.reset_accumulated_memory_stats(device)
|
14 |
-
|
15 |
-
|
16 |
-
def print_memory(device: Union[str, torch.device]) -> None:
|
17 |
-
memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
|
18 |
-
max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
|
19 |
-
max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
|
20 |
-
print(f"{memory_allocated=:.3f} GB")
|
21 |
-
print(f"{max_memory_allocated=:.3f} GB")
|
22 |
-
print(f"{max_memory_reserved=:.3f} GB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/prepare_dataset.sh
DELETED
@@ -1,48 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
|
3 |
-
MODEL_ID="THUDM/CogVideoX-2b"
|
4 |
-
|
5 |
-
NUM_GPUS=8
|
6 |
-
|
7 |
-
# For more details on the expected data format, please refer to the README.
|
8 |
-
DATA_ROOT="/path/to/my/datasets/video-dataset" # This needs to be the path to the base directory where your videos are located.
|
9 |
-
CAPTION_COLUMN="prompt.txt"
|
10 |
-
VIDEO_COLUMN="videos.txt"
|
11 |
-
OUTPUT_DIR="/path/to/my/datasets/preprocessed-dataset"
|
12 |
-
HEIGHT_BUCKETS="480 720"
|
13 |
-
WIDTH_BUCKETS="720 960"
|
14 |
-
FRAME_BUCKETS="49"
|
15 |
-
MAX_NUM_FRAMES="49"
|
16 |
-
MAX_SEQUENCE_LENGTH=226
|
17 |
-
TARGET_FPS=8
|
18 |
-
BATCH_SIZE=1
|
19 |
-
DTYPE=fp32
|
20 |
-
|
21 |
-
# To create a folder-style dataset structure without pre-encoding videos and captions
|
22 |
-
# For Image-to-Video finetuning, make sure to pass `--save_image_latents`
|
23 |
-
CMD_WITHOUT_PRE_ENCODING="\
|
24 |
-
torchrun --nproc_per_node=$NUM_GPUS \
|
25 |
-
training/prepare_dataset.py \
|
26 |
-
--model_id $MODEL_ID \
|
27 |
-
--data_root $DATA_ROOT \
|
28 |
-
--caption_column $CAPTION_COLUMN \
|
29 |
-
--video_column $VIDEO_COLUMN \
|
30 |
-
--output_dir $OUTPUT_DIR \
|
31 |
-
--height_buckets $HEIGHT_BUCKETS \
|
32 |
-
--width_buckets $WIDTH_BUCKETS \
|
33 |
-
--frame_buckets $FRAME_BUCKETS \
|
34 |
-
--max_num_frames $MAX_NUM_FRAMES \
|
35 |
-
--max_sequence_length $MAX_SEQUENCE_LENGTH \
|
36 |
-
--target_fps $TARGET_FPS \
|
37 |
-
--batch_size $BATCH_SIZE \
|
38 |
-
--dtype $DTYPE
|
39 |
-
"
|
40 |
-
|
41 |
-
CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_latents_and_embeddings"
|
42 |
-
|
43 |
-
# Select which you'd like to run
|
44 |
-
CMD=$CMD_WITH_PRE_ENCODING
|
45 |
-
|
46 |
-
echo "===== Running \`$CMD\` ====="
|
47 |
-
eval $CMD
|
48 |
-
echo -ne "===== Finished running script =====\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/train_image_to_video_lora.sh
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
|
2 |
-
export TORCHDYNAMO_VERBOSE=1
|
3 |
-
export WANDB_MODE="offline"
|
4 |
-
export NCCL_P2P_DISABLE=1
|
5 |
-
export TORCH_NCCL_ENABLE_MONITORING=0
|
6 |
-
|
7 |
-
GPU_IDS="0"
|
8 |
-
|
9 |
-
# Training Configurations
|
10 |
-
# Experiment with as many hyperparameters as you want!
|
11 |
-
LEARNING_RATES=("1e-4" "1e-3")
|
12 |
-
LR_SCHEDULES=("cosine_with_restarts")
|
13 |
-
OPTIMIZERS=("adamw" "adam")
|
14 |
-
MAX_TRAIN_STEPS=("3000")
|
15 |
-
|
16 |
-
# Single GPU uncompiled training
|
17 |
-
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
|
18 |
-
|
19 |
-
# Absolute path to where the data is located. Make sure to have read the README for how to prepare data.
|
20 |
-
# This example assumes you downloaded an already prepared dataset from HF CLI as follows:
|
21 |
-
# huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset
|
22 |
-
DATA_ROOT="/path/to/my/datasets/disney-dataset"
|
23 |
-
CAPTION_COLUMN="prompt.txt"
|
24 |
-
VIDEO_COLUMN="videos.txt"
|
25 |
-
|
26 |
-
# Launch experiments with different hyperparameters
|
27 |
-
for learning_rate in "${LEARNING_RATES[@]}"; do
|
28 |
-
for lr_schedule in "${LR_SCHEDULES[@]}"; do
|
29 |
-
for optimizer in "${OPTIMIZERS[@]}"; do
|
30 |
-
for steps in "${MAX_TRAIN_STEPS[@]}"; do
|
31 |
-
output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
|
32 |
-
|
33 |
-
cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_image_to_video_lora.py \
|
34 |
-
--pretrained_model_name_or_path THUDM/CogVideoX-5b-I2V \
|
35 |
-
--data_root $DATA_ROOT \
|
36 |
-
--caption_column $CAPTION_COLUMN \
|
37 |
-
--video_column $VIDEO_COLUMN \
|
38 |
-
--id_token BW_STYLE \
|
39 |
-
--height_buckets 480 \
|
40 |
-
--width_buckets 720 \
|
41 |
-
--frame_buckets 49 \
|
42 |
-
--dataloader_num_workers 8 \
|
43 |
-
--pin_memory \
|
44 |
-
--validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
|
45 |
-
--validation_images \"/path/to/image1.png:::/path/to/image2.png\"
|
46 |
-
--validation_prompt_separator ::: \
|
47 |
-
--num_validation_videos 1 \
|
48 |
-
--validation_epochs 10 \
|
49 |
-
--seed 42 \
|
50 |
-
--rank 128 \
|
51 |
-
--lora_alpha 128 \
|
52 |
-
--mixed_precision bf16 \
|
53 |
-
--output_dir $output_dir \
|
54 |
-
--max_num_frames 49 \
|
55 |
-
--train_batch_size 1 \
|
56 |
-
--max_train_steps $steps \
|
57 |
-
--checkpointing_steps 1000 \
|
58 |
-
--gradient_accumulation_steps 1 \
|
59 |
-
--gradient_checkpointing \
|
60 |
-
--learning_rate $learning_rate \
|
61 |
-
--lr_scheduler $lr_schedule \
|
62 |
-
--lr_warmup_steps 400 \
|
63 |
-
--lr_num_cycles 1 \
|
64 |
-
--enable_slicing \
|
65 |
-
--enable_tiling \
|
66 |
-
--noised_image_dropout 0.05 \
|
67 |
-
--optimizer $optimizer \
|
68 |
-
--beta1 0.9 \
|
69 |
-
--beta2 0.95 \
|
70 |
-
--weight_decay 0.001 \
|
71 |
-
--max_grad_norm 1.0 \
|
72 |
-
--allow_tf32 \
|
73 |
-
--report_to wandb \
|
74 |
-
--nccl_timeout 1800"
|
75 |
-
|
76 |
-
echo "Running command: $cmd"
|
77 |
-
eval $cmd
|
78 |
-
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
79 |
-
done
|
80 |
-
done
|
81 |
-
done
|
82 |
-
done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/train_image_to_video_sft.sh
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
|
2 |
-
export TORCHDYNAMO_VERBOSE=1
|
3 |
-
export WANDB_MODE="offline"
|
4 |
-
# export NCCL_P2P_DISABLE=1
|
5 |
-
export TORCH_NCCL_ENABLE_MONITORING=0
|
6 |
-
export TOKENIZERS_PARALLELISM=true
|
7 |
-
export OMP_NUM_THREADS=16
|
8 |
-
|
9 |
-
GPU_IDS="0,1"
|
10 |
-
|
11 |
-
# Training Configurations
|
12 |
-
# Experiment with as many hyperparameters as you want!
|
13 |
-
LEARNING_RATES=("1e-4")
|
14 |
-
LR_SCHEDULES=("cosine_with_restarts")
|
15 |
-
OPTIMIZERS=("adamw")
|
16 |
-
MAX_TRAIN_STEPS=("20000")
|
17 |
-
|
18 |
-
# Single GPU uncompiled training
|
19 |
-
ACCELERATE_CONFIG_FILE="accelerate_configs/deepspeed.yaml"
|
20 |
-
|
21 |
-
# Absolute path to where the data is located. Make sure to have read the README for how to prepare data.
|
22 |
-
# This example assumes you downloaded an already prepared dataset from HF CLI as follows:
|
23 |
-
# huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset
|
24 |
-
DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
|
25 |
-
CAPTION_COLUMN="prompt.txt"
|
26 |
-
VIDEO_COLUMN="videos.txt"
|
27 |
-
MODEL_PATH="THUDM/CogVideoX1.5-5B-I2V"
|
28 |
-
|
29 |
-
# Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process.
|
30 |
-
# Launch experiments with different hyperparameters
|
31 |
-
|
32 |
-
for learning_rate in "${LEARNING_RATES[@]}"; do
|
33 |
-
for lr_schedule in "${LR_SCHEDULES[@]}"; do
|
34 |
-
for optimizer in "${OPTIMIZERS[@]}"; do
|
35 |
-
for steps in "${MAX_TRAIN_STEPS[@]}"; do
|
36 |
-
output_dir="./cogvideox-sft__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
|
37 |
-
|
38 |
-
cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE \
|
39 |
-
--gpu_ids $GPU_IDS \
|
40 |
-
training/cogvideox/cogvideox_image_to_video_sft.py \
|
41 |
-
--pretrained_model_name_or_path $MODEL_PATH \
|
42 |
-
--data_root $DATA_ROOT \
|
43 |
-
--caption_column $CAPTION_COLUMN \
|
44 |
-
--video_column $VIDEO_COLUMN \
|
45 |
-
--id_token BW_STYLE \
|
46 |
-
--height_buckets 480 \
|
47 |
-
--width_buckets 720 \
|
48 |
-
--frame_buckets 77 \
|
49 |
-
--dataloader_num_workers 8 \
|
50 |
-
--pin_memory \
|
51 |
-
--validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
|
52 |
-
--validation_images \"/path/to/image1.png:::/path/to/image2.png\" \
|
53 |
-
--validation_prompt_separator ::: \
|
54 |
-
--num_validation_videos 1 \
|
55 |
-
--validation_epochs 1 \
|
56 |
-
--seed 42 \
|
57 |
-
--mixed_precision bf16 \
|
58 |
-
--output_dir $output_dir \
|
59 |
-
--max_num_frames 77 \
|
60 |
-
--train_batch_size 1 \
|
61 |
-
--max_train_steps $steps \
|
62 |
-
--checkpointing_steps 2000 \
|
63 |
-
--gradient_accumulation_steps 4 \
|
64 |
-
--gradient_checkpointing \
|
65 |
-
--learning_rate $learning_rate \
|
66 |
-
--lr_scheduler $lr_schedule \
|
67 |
-
--lr_warmup_steps 800 \
|
68 |
-
--lr_num_cycles 1 \
|
69 |
-
--enable_slicing \
|
70 |
-
--enable_tiling \
|
71 |
-
--noised_image_dropout 0.05 \
|
72 |
-
--optimizer $optimizer \
|
73 |
-
--beta1 0.9 \
|
74 |
-
--beta2 0.95 \
|
75 |
-
--weight_decay 0.001 \
|
76 |
-
--max_grad_norm 1.0 \
|
77 |
-
--allow_tf32 \
|
78 |
-
--report_to wandb \
|
79 |
-
--nccl_timeout 1800"
|
80 |
-
|
81 |
-
echo "Running command: $cmd"
|
82 |
-
eval $cmd
|
83 |
-
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
84 |
-
done
|
85 |
-
done
|
86 |
-
done
|
87 |
-
done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/train_text_to_video_lora.sh
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
|
2 |
-
export TORCHDYNAMO_VERBOSE=1
|
3 |
-
export WANDB_MODE="offline"
|
4 |
-
export NCCL_P2P_DISABLE=1
|
5 |
-
export TORCH_NCCL_ENABLE_MONITORING=0
|
6 |
-
|
7 |
-
GPU_IDS="0"
|
8 |
-
|
9 |
-
# Training Configurations
|
10 |
-
# Experiment with as many hyperparameters as you want!
|
11 |
-
LEARNING_RATES=("1e-4" "1e-3")
|
12 |
-
LR_SCHEDULES=("cosine_with_restarts")
|
13 |
-
OPTIMIZERS=("adamw" "adam")
|
14 |
-
MAX_TRAIN_STEPS=("3000")
|
15 |
-
|
16 |
-
# Single GPU uncompiled training
|
17 |
-
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
|
18 |
-
|
19 |
-
# Absolute path to where the data is located. Make sure to have read the README for how to prepare data.
|
20 |
-
# This example assumes you downloaded an already prepared dataset from HF CLI as follows:
|
21 |
-
# huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset
|
22 |
-
DATA_ROOT="/path/to/my/datasets/disney-dataset"
|
23 |
-
|
24 |
-
CAPTION_COLUMN="prompt.txt"
|
25 |
-
VIDEO_COLUMN="videos.txt"
|
26 |
-
MODEL_PATH="THUDM/CogVideoX-5b"
|
27 |
-
|
28 |
-
# Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process.
|
29 |
-
# Launch experiments with different hyperparameters
|
30 |
-
|
31 |
-
for learning_rate in "${LEARNING_RATES[@]}"; do
|
32 |
-
for lr_schedule in "${LR_SCHEDULES[@]}"; do
|
33 |
-
for optimizer in "${OPTIMIZERS[@]}"; do
|
34 |
-
for steps in "${MAX_TRAIN_STEPS[@]}"; do
|
35 |
-
output_dir="./cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
|
36 |
-
|
37 |
-
cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_lora.py \
|
38 |
-
--pretrained_model_name_or_path $MODEL_PATH \
|
39 |
-
--data_root $DATA_ROOT \
|
40 |
-
--caption_column $CAPTION_COLUMN \
|
41 |
-
--video_column $VIDEO_COLUMN \
|
42 |
-
--id_token BW_STYLE \
|
43 |
-
--height_buckets 480 \
|
44 |
-
--width_buckets 720 \
|
45 |
-
--frame_buckets 49 \
|
46 |
-
--dataloader_num_workers 8 \
|
47 |
-
--pin_memory \
|
48 |
-
--validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
|
49 |
-
--validation_prompt_separator ::: \
|
50 |
-
--num_validation_videos 1 \
|
51 |
-
--validation_epochs 10 \
|
52 |
-
--seed 42 \
|
53 |
-
--rank 128 \
|
54 |
-
--lora_alpha 128 \
|
55 |
-
--mixed_precision bf16 \
|
56 |
-
--output_dir $output_dir \
|
57 |
-
--max_num_frames 49 \
|
58 |
-
--train_batch_size 1 \
|
59 |
-
--max_train_steps $steps \
|
60 |
-
--checkpointing_steps 1000 \
|
61 |
-
--gradient_accumulation_steps 1 \
|
62 |
-
--gradient_checkpointing \
|
63 |
-
--learning_rate $learning_rate \
|
64 |
-
--lr_scheduler $lr_schedule \
|
65 |
-
--lr_warmup_steps 400 \
|
66 |
-
--lr_num_cycles 1 \
|
67 |
-
--enable_slicing \
|
68 |
-
--enable_tiling \
|
69 |
-
--enable_model_cpu_offload \
|
70 |
-
--load_tensors \
|
71 |
-
--optimizer $optimizer \
|
72 |
-
--beta1 0.9 \
|
73 |
-
--beta2 0.95 \
|
74 |
-
--weight_decay 0.001 \
|
75 |
-
--max_grad_norm 1.0 \
|
76 |
-
--allow_tf32 \
|
77 |
-
--report_to wandb \
|
78 |
-
--nccl_timeout 1800"
|
79 |
-
|
80 |
-
echo "Running command: $cmd"
|
81 |
-
eval $cmd
|
82 |
-
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
83 |
-
done
|
84 |
-
done
|
85 |
-
done
|
86 |
-
done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/train_text_to_video_sft.sh
DELETED
@@ -1,77 +0,0 @@
|
|
1 |
-
export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
|
2 |
-
export TORCHDYNAMO_VERBOSE=1
|
3 |
-
export WANDB_MODE="offline"
|
4 |
-
export NCCL_P2P_DISABLE=1
|
5 |
-
export TORCH_NCCL_ENABLE_MONITORING=0
|
6 |
-
|
7 |
-
GPU_IDS="0"
|
8 |
-
|
9 |
-
# Training Configurations
|
10 |
-
# Experiment with as many hyperparameters as you want!
|
11 |
-
LEARNING_RATES=("1e-4")
|
12 |
-
LR_SCHEDULES=("cosine_with_restarts")
|
13 |
-
OPTIMIZERS=("adamw")
|
14 |
-
MAX_TRAIN_STEPS=("20000")
|
15 |
-
|
16 |
-
# Single GPU uncompiled training
|
17 |
-
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
|
18 |
-
|
19 |
-
# Absolute path to where the data is located. Make sure to have read the README for how to prepare data.
|
20 |
-
# This example assumes you downloaded an already prepared dataset from HF CLI as follows:
|
21 |
-
# huggingface-cli download --repo-type dataset Wild-Heart/Tom-and-Jerry-VideoGeneration-Dataset --local-dir /path/to/my/datasets/tom-and-jerry-dataset
|
22 |
-
DATA_ROOT="/path/to/my/datasets/tom-and-jerry-dataset"
|
23 |
-
CAPTION_COLUMN="captions.txt"
|
24 |
-
VIDEO_COLUMN="videos.txt"
|
25 |
-
|
26 |
-
# Launch experiments with different hyperparameters
|
27 |
-
for learning_rate in "${LEARNING_RATES[@]}"; do
|
28 |
-
for lr_schedule in "${LR_SCHEDULES[@]}"; do
|
29 |
-
for optimizer in "${OPTIMIZERS[@]}"; do
|
30 |
-
for steps in "${MAX_TRAIN_STEPS[@]}"; do
|
31 |
-
output_dir="/path/to/my/models/cogvideox-sft__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
|
32 |
-
|
33 |
-
cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox/cogvideox_text_to_video_sft.py \
|
34 |
-
--pretrained_model_name_or_path THUDM/CogVideoX-5b \
|
35 |
-
--data_root $DATA_ROOT \
|
36 |
-
--caption_column $CAPTION_COLUMN \
|
37 |
-
--video_column $VIDEO_COLUMN \
|
38 |
-
--height_buckets 480 \
|
39 |
-
--width_buckets 720 \
|
40 |
-
--frame_buckets 49 \
|
41 |
-
--dataloader_num_workers 8 \
|
42 |
-
--pin_memory \
|
43 |
-
--validation_prompt \"Tom, the mischievous gray cat, is sprawled out on a vibrant red pillow, his body relaxed and his eyes half-closed, as if he's just woken up or is about to doze off. His white paws are stretched out in front of him, and his tail is casually draped over the edge of the pillow. The setting appears to be a cozy corner of a room, with a warm yellow wall in the background and a hint of a wooden floor. The scene captures a rare moment of tranquility for Tom, contrasting with his usual energetic and playful demeanor:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
|
44 |
-
--validation_prompt_separator ::: \
|
45 |
-
--num_validation_videos 1 \
|
46 |
-
--validation_epochs 1 \
|
47 |
-
--seed 42 \
|
48 |
-
--mixed_precision bf16 \
|
49 |
-
--output_dir $output_dir \
|
50 |
-
--max_num_frames 49 \
|
51 |
-
--train_batch_size 1 \
|
52 |
-
--max_train_steps $steps \
|
53 |
-
--checkpointing_steps 2000 \
|
54 |
-
--gradient_accumulation_steps 4 \
|
55 |
-
--gradient_checkpointing \
|
56 |
-
--learning_rate $learning_rate \
|
57 |
-
--lr_scheduler $lr_schedule \
|
58 |
-
--lr_warmup_steps 800 \
|
59 |
-
--lr_num_cycles 1 \
|
60 |
-
--enable_slicing \
|
61 |
-
--enable_tiling \
|
62 |
-
--optimizer $optimizer \
|
63 |
-
--beta1 0.9 \
|
64 |
-
--beta2 0.95 \
|
65 |
-
--weight_decay 0.001 \
|
66 |
-
--max_grad_norm 1.0 \
|
67 |
-
--allow_tf32 \
|
68 |
-
--report_to wandb \
|
69 |
-
--nccl_timeout 1800"
|
70 |
-
|
71 |
-
echo "Running command: $cmd"
|
72 |
-
eval $cmd
|
73 |
-
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
74 |
-
done
|
75 |
-
done
|
76 |
-
done
|
77 |
-
done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vms/config.py
CHANGED
@@ -56,9 +56,17 @@ if NORMALIZE_IMAGES_TO not in ['png', 'jpg']:
|
|
56 |
raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'")
|
57 |
JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
|
58 |
|
|
|
59 |
MODEL_TYPES = {
|
60 |
"HunyuanVideo (LoRA)": "hunyuan_video",
|
61 |
-
"LTX-Video (LoRA)": "ltx_video"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
}
|
63 |
|
64 |
|
@@ -136,9 +144,11 @@ MEDIUM_19_9_RATIO_BUCKETS = [
|
|
136 |
(NB_FRAMES_256, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
|
137 |
]
|
138 |
|
|
|
139 |
TRAINING_PRESETS = {
|
140 |
"HunyuanVideo (normal)": {
|
141 |
"model_type": "hunyuan_video",
|
|
|
142 |
"lora_rank": "128",
|
143 |
"lora_alpha": "128",
|
144 |
"num_epochs": 70,
|
@@ -146,9 +156,11 @@ TRAINING_PRESETS = {
|
|
146 |
"learning_rate": 2e-5,
|
147 |
"save_iterations": 500,
|
148 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
|
|
149 |
},
|
150 |
"LTX-Video (normal)": {
|
151 |
"model_type": "ltx_video",
|
|
|
152 |
"lora_rank": "128",
|
153 |
"lora_alpha": "128",
|
154 |
"num_epochs": 70,
|
@@ -156,9 +168,11 @@ TRAINING_PRESETS = {
|
|
156 |
"learning_rate": 3e-5,
|
157 |
"save_iterations": 500,
|
158 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
|
|
159 |
},
|
160 |
"LTX-Video (16:9, HQ)": {
|
161 |
"model_type": "ltx_video",
|
|
|
162 |
"lora_rank": "256",
|
163 |
"lora_alpha": "128",
|
164 |
"num_epochs": 50,
|
@@ -166,6 +180,41 @@ TRAINING_PRESETS = {
|
|
166 |
"learning_rate": 3e-5,
|
167 |
"save_iterations": 200,
|
168 |
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
}
|
170 |
}
|
171 |
|
@@ -260,7 +309,8 @@ class TrainingConfig:
|
|
260 |
lora_alpha=128,
|
261 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
262 |
caption_dropout_p=0.05,
|
263 |
-
flow_weighting_scheme="none" # Hunyuan specific
|
|
|
264 |
)
|
265 |
|
266 |
@classmethod
|
@@ -281,7 +331,51 @@ class TrainingConfig:
|
|
281 |
lora_alpha=128,
|
282 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
283 |
caption_dropout_p=0.05,
|
284 |
-
flow_weighting_scheme="logit_normal" # LTX specific
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
)
|
286 |
|
287 |
def to_args_list(self) -> List[str]:
|
@@ -302,9 +396,9 @@ class TrainingConfig:
|
|
302 |
args.extend(["--cache_dir", self.cache_dir])
|
303 |
|
304 |
# Dataset arguments
|
305 |
-
args.extend(["--
|
306 |
-
|
307 |
-
|
308 |
if self.id_token:
|
309 |
args.extend(["--id_token", self.id_token])
|
310 |
|
@@ -313,9 +407,6 @@ class TrainingConfig:
|
|
313 |
bucket_strs = [f"{f}x{h}x{w}" for f, h, w in self.video_resolution_buckets]
|
314 |
args.extend(["--video_resolution_buckets"] + bucket_strs)
|
315 |
|
316 |
-
if self.video_reshape_mode:
|
317 |
-
args.extend(["--video_reshape_mode", self.video_reshape_mode])
|
318 |
-
|
319 |
args.extend(["--caption_dropout_p", str(self.caption_dropout_p)])
|
320 |
args.extend(["--caption_dropout_technique", self.caption_dropout_technique])
|
321 |
if self.precompute_conditions:
|
@@ -333,14 +424,18 @@ class TrainingConfig:
|
|
333 |
args.extend(["--training_type", self.training_type])
|
334 |
args.extend(["--seed", str(self.seed)])
|
335 |
|
336 |
-
#
|
337 |
#args.extend(["--mixed_precision", self.mixed_precision])
|
338 |
|
339 |
args.extend(["--batch_size", str(self.batch_size)])
|
340 |
-
args.extend(["--
|
341 |
-
|
342 |
-
|
343 |
-
|
|
|
|
|
|
|
|
|
344 |
args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)])
|
345 |
if self.gradient_checkpointing:
|
346 |
args.append("--gradient_checkpointing")
|
|
|
56 |
raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'")
|
57 |
JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
|
58 |
|
59 |
+
# Expanded model types to include Wan-2.1-T2V
|
60 |
MODEL_TYPES = {
|
61 |
"HunyuanVideo (LoRA)": "hunyuan_video",
|
62 |
+
"LTX-Video (LoRA)": "ltx_video",
|
63 |
+
"Wan-2.1-T2V (LoRA)": "wan"
|
64 |
+
}
|
65 |
+
|
66 |
+
# Training types
|
67 |
+
TRAINING_TYPES = {
|
68 |
+
"LoRA Finetune": "lora",
|
69 |
+
"Full Finetune": "full-finetune"
|
70 |
}
|
71 |
|
72 |
|
|
|
144 |
(NB_FRAMES_256, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
|
145 |
]
|
146 |
|
147 |
+
# Updated training presets to include Wan-2.1-T2V and support both LoRA and full-finetune
|
148 |
TRAINING_PRESETS = {
|
149 |
"HunyuanVideo (normal)": {
|
150 |
"model_type": "hunyuan_video",
|
151 |
+
"training_type": "lora",
|
152 |
"lora_rank": "128",
|
153 |
"lora_alpha": "128",
|
154 |
"num_epochs": 70,
|
|
|
156 |
"learning_rate": 2e-5,
|
157 |
"save_iterations": 500,
|
158 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
159 |
+
"flow_weighting_scheme": "none"
|
160 |
},
|
161 |
"LTX-Video (normal)": {
|
162 |
"model_type": "ltx_video",
|
163 |
+
"training_type": "lora",
|
164 |
"lora_rank": "128",
|
165 |
"lora_alpha": "128",
|
166 |
"num_epochs": 70,
|
|
|
168 |
"learning_rate": 3e-5,
|
169 |
"save_iterations": 500,
|
170 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
171 |
+
"flow_weighting_scheme": "logit_normal"
|
172 |
},
|
173 |
"LTX-Video (16:9, HQ)": {
|
174 |
"model_type": "ltx_video",
|
175 |
+
"training_type": "lora",
|
176 |
"lora_rank": "256",
|
177 |
"lora_alpha": "128",
|
178 |
"num_epochs": 50,
|
|
|
180 |
"learning_rate": 3e-5,
|
181 |
"save_iterations": 200,
|
182 |
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
183 |
+
"flow_weighting_scheme": "logit_normal"
|
184 |
+
},
|
185 |
+
"LTX-Video (Full Finetune)": {
|
186 |
+
"model_type": "ltx_video",
|
187 |
+
"training_type": "full-finetune",
|
188 |
+
"num_epochs": 30,
|
189 |
+
"batch_size": 1,
|
190 |
+
"learning_rate": 1e-5,
|
191 |
+
"save_iterations": 300,
|
192 |
+
"training_buckets": SMALL_TRAINING_BUCKETS,
|
193 |
+
"flow_weighting_scheme": "logit_normal"
|
194 |
+
},
|
195 |
+
"Wan-2.1-T2V (normal)": {
|
196 |
+
"model_type": "wan",
|
197 |
+
"training_type": "lora",
|
198 |
+
"lora_rank": "32",
|
199 |
+
"lora_alpha": "32",
|
200 |
+
"num_epochs": 70,
|
201 |
+
"batch_size": 1,
|
202 |
+
"learning_rate": 5e-5,
|
203 |
+
"save_iterations": 500,
|
204 |
+
"training_buckets": SMALL_TRAINING_BUCKETS,
|
205 |
+
"flow_weighting_scheme": "logit_normal"
|
206 |
+
},
|
207 |
+
"Wan-2.1-T2V (HQ)": {
|
208 |
+
"model_type": "wan",
|
209 |
+
"training_type": "lora",
|
210 |
+
"lora_rank": "64",
|
211 |
+
"lora_alpha": "64",
|
212 |
+
"num_epochs": 50,
|
213 |
+
"batch_size": 1,
|
214 |
+
"learning_rate": 3e-5,
|
215 |
+
"save_iterations": 200,
|
216 |
+
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
217 |
+
"flow_weighting_scheme": "logit_normal"
|
218 |
}
|
219 |
}
|
220 |
|
|
|
309 |
lora_alpha=128,
|
310 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
311 |
caption_dropout_p=0.05,
|
312 |
+
flow_weighting_scheme="none", # Hunyuan specific
|
313 |
+
training_type="lora"
|
314 |
)
|
315 |
|
316 |
@classmethod
|
|
|
331 |
lora_alpha=128,
|
332 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
333 |
caption_dropout_p=0.05,
|
334 |
+
flow_weighting_scheme="logit_normal", # LTX specific
|
335 |
+
training_type="lora"
|
336 |
+
)
|
337 |
+
|
338 |
+
@classmethod
|
339 |
+
def ltx_video_full_finetune(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig':
|
340 |
+
"""Configuration for LTX-Video full finetune training"""
|
341 |
+
return cls(
|
342 |
+
model_name="ltx_video",
|
343 |
+
pretrained_model_name_or_path="Lightricks/LTX-Video",
|
344 |
+
data_root=data_path,
|
345 |
+
output_dir=output_path,
|
346 |
+
batch_size=1,
|
347 |
+
train_epochs=30,
|
348 |
+
lr=1e-5,
|
349 |
+
gradient_checkpointing=True,
|
350 |
+
id_token="BW_STYLE",
|
351 |
+
gradient_accumulation_steps=1,
|
352 |
+
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
353 |
+
caption_dropout_p=0.05,
|
354 |
+
flow_weighting_scheme="logit_normal", # LTX specific
|
355 |
+
training_type="full-finetune"
|
356 |
+
)
|
357 |
+
|
358 |
+
@classmethod
|
359 |
+
def wan_lora(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig':
|
360 |
+
"""Configuration for Wan T2V LoRA training"""
|
361 |
+
return cls(
|
362 |
+
model_name="wan",
|
363 |
+
pretrained_model_name_or_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
364 |
+
data_root=data_path,
|
365 |
+
output_dir=output_path,
|
366 |
+
batch_size=1,
|
367 |
+
train_epochs=70,
|
368 |
+
lr=5e-5,
|
369 |
+
gradient_checkpointing=True,
|
370 |
+
id_token=None, # Default is no ID token for Wan
|
371 |
+
gradient_accumulation_steps=1,
|
372 |
+
lora_rank=32,
|
373 |
+
lora_alpha=32,
|
374 |
+
target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
|
375 |
+
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
376 |
+
caption_dropout_p=0.05,
|
377 |
+
flow_weighting_scheme="logit_normal", # Wan specific
|
378 |
+
training_type="lora"
|
379 |
)
|
380 |
|
381 |
def to_args_list(self) -> List[str]:
|
|
|
396 |
args.extend(["--cache_dir", self.cache_dir])
|
397 |
|
398 |
# Dataset arguments
|
399 |
+
args.extend(["--dataset_config", self.data_root])
|
400 |
+
|
401 |
+
# Add ID token if specified
|
402 |
if self.id_token:
|
403 |
args.extend(["--id_token", self.id_token])
|
404 |
|
|
|
407 |
bucket_strs = [f"{f}x{h}x{w}" for f, h, w in self.video_resolution_buckets]
|
408 |
args.extend(["--video_resolution_buckets"] + bucket_strs)
|
409 |
|
|
|
|
|
|
|
410 |
args.extend(["--caption_dropout_p", str(self.caption_dropout_p)])
|
411 |
args.extend(["--caption_dropout_technique", self.caption_dropout_technique])
|
412 |
if self.precompute_conditions:
|
|
|
424 |
args.extend(["--training_type", self.training_type])
|
425 |
args.extend(["--seed", str(self.seed)])
|
426 |
|
427 |
+
# We don't use this, because mixed precision is handled by accelerate launch, not by the training script itself.
|
428 |
#args.extend(["--mixed_precision", self.mixed_precision])
|
429 |
|
430 |
args.extend(["--batch_size", str(self.batch_size)])
|
431 |
+
args.extend(["--train_steps", str(self.train_epochs * 1000)]) # Convert epochs to steps for compatibility
|
432 |
+
|
433 |
+
# LoRA specific arguments
|
434 |
+
if self.training_type == "lora":
|
435 |
+
args.extend(["--rank", str(self.lora_rank)])
|
436 |
+
args.extend(["--lora_alpha", str(self.lora_alpha)])
|
437 |
+
args.extend(["--target_modules"] + self.target_modules)
|
438 |
+
|
439 |
args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)])
|
440 |
if self.gradient_checkpointing:
|
441 |
args.append("--gradient_checkpointing")
|
vms/services/trainer.py
CHANGED
@@ -20,7 +20,11 @@ from typing import Any, Optional, Dict, List, Union, Tuple
|
|
20 |
|
21 |
from huggingface_hub import upload_folder, create_repo
|
22 |
|
23 |
-
from ..config import
|
|
|
|
|
|
|
|
|
24 |
from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
@@ -112,6 +116,7 @@ class TrainingService:
|
|
112 |
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
113 |
default_state = {
|
114 |
"model_type": list(MODEL_TYPES.keys())[0],
|
|
|
115 |
"lora_rank": "128",
|
116 |
"lora_alpha": "128",
|
117 |
"num_epochs": 50,
|
@@ -153,7 +158,6 @@ class TrainingService:
|
|
153 |
# Make sure we have all keys (in case structure changed)
|
154 |
merged_state = default_state.copy()
|
155 |
merged_state.update(saved_state)
|
156 |
-
#logger.info(f"Successfully loaded UI state from {ui_state_file}")
|
157 |
return merged_state
|
158 |
except json.JSONDecodeError as e:
|
159 |
logger.error(f"Error parsing UI state JSON: {str(e)}")
|
@@ -171,6 +175,7 @@ class TrainingService:
|
|
171 |
logger.info("Creating new UI state file with default values")
|
172 |
default_state = {
|
173 |
"model_type": list(MODEL_TYPES.keys())[0],
|
|
|
174 |
"lora_rank": "128",
|
175 |
"lora_alpha": "128",
|
176 |
"num_epochs": 50,
|
@@ -203,6 +208,7 @@ class TrainingService:
|
|
203 |
# Create a new file with default values
|
204 |
default_state = {
|
205 |
"model_type": list(MODEL_TYPES.keys())[0],
|
|
|
206 |
"lora_rank": "128",
|
207 |
"lora_alpha": "128",
|
208 |
"num_epochs": 50,
|
@@ -331,6 +337,9 @@ class TrainingService:
|
|
331 |
elif model_type == "ltx_video":
|
332 |
if config.batch_size > 4:
|
333 |
return "LTX model recommended batch size is 1-4"
|
|
|
|
|
|
|
334 |
|
335 |
logger.info(f"Config validation passed with {len(video_lines)} training files")
|
336 |
return None
|
@@ -350,6 +359,7 @@ class TrainingService:
|
|
350 |
save_iterations: int,
|
351 |
repo_id: str,
|
352 |
preset_name: str,
|
|
|
353 |
resume_from_checkpoint: Optional[str] = None,
|
354 |
) -> Tuple[str, str]:
|
355 |
"""Start training with finetrainers"""
|
@@ -360,11 +370,13 @@ class TrainingService:
|
|
360 |
raise ValueError("model_type cannot be empty")
|
361 |
if model_type not in MODEL_TYPES.values():
|
362 |
raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}")
|
|
|
|
|
363 |
|
364 |
# Check if we're resuming or starting new
|
365 |
is_resuming = resume_from_checkpoint is not None
|
366 |
log_prefix = "Resuming" if is_resuming else "Initializing"
|
367 |
-
logger.info(f"{log_prefix} training with model_type={model_type}")
|
368 |
|
369 |
try:
|
370 |
# Get absolute paths - FIXED to look in project root instead of within vms directory
|
@@ -409,32 +421,66 @@ class TrainingService:
|
|
409 |
logger.error(error_msg)
|
410 |
return error_msg, "No training data available"
|
411 |
|
412 |
-
|
413 |
# Get preset configuration
|
414 |
preset = TRAINING_PRESETS[preset_name]
|
415 |
training_buckets = preset["training_buckets"]
|
|
|
|
|
416 |
|
417 |
# Get config for selected model type with preset buckets
|
418 |
if model_type == "hunyuan_video":
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
|
431 |
# Update with UI parameters
|
432 |
config.train_epochs = int(num_epochs)
|
433 |
-
config.lora_rank = int(lora_rank)
|
434 |
-
config.lora_alpha = int(lora_alpha)
|
435 |
config.batch_size = int(batch_size)
|
436 |
config.lr = float(learning_rate)
|
437 |
config.checkpointing_steps = int(save_iterations)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
# Update with resume_from_checkpoint if provided
|
440 |
if resume_from_checkpoint:
|
@@ -469,7 +515,6 @@ class TrainingService:
|
|
469 |
# Convert config to command line arguments
|
470 |
config_args = config.to_args_list()
|
471 |
|
472 |
-
|
473 |
logger.debug("Generated args list: %s", config_args)
|
474 |
|
475 |
# Log the full command for debugging
|
@@ -505,6 +550,7 @@ class TrainingService:
|
|
505 |
# Save session info including repo_id for later hub upload
|
506 |
self.save_session({
|
507 |
"model_type": model_type,
|
|
|
508 |
"lora_rank": lora_rank,
|
509 |
"lora_alpha": lora_alpha,
|
510 |
"num_epochs": num_epochs,
|
@@ -526,13 +572,14 @@ class TrainingService:
|
|
526 |
total_epochs=num_epochs,
|
527 |
message='Training started',
|
528 |
repo_id=repo_id,
|
529 |
-
model_type=model_type
|
|
|
530 |
)
|
531 |
|
532 |
# Start monitoring process output
|
533 |
self._start_log_monitor(process)
|
534 |
|
535 |
-
success_msg = f"Started training {model_type} model"
|
536 |
self.append_log(success_msg)
|
537 |
logger.info(success_msg)
|
538 |
|
@@ -668,6 +715,7 @@ class TrainingService:
|
|
668 |
last_session = {
|
669 |
"params": {
|
670 |
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
|
|
671 |
"lora_rank": ui_state.get("lora_rank", "128"),
|
672 |
"lora_alpha": ui_state.get("lora_alpha", "128"),
|
673 |
"num_epochs": ui_state.get("num_epochs", 70),
|
@@ -724,10 +772,15 @@ class TrainingService:
|
|
724 |
logger.info(f"Mapped internal model type '{model_type_internal}' to display name '{model_type_display}'")
|
725 |
break
|
726 |
|
|
|
|
|
|
|
|
|
727 |
# Add UI updates to restore the training parameters in the UI
|
728 |
# This shows the user what values are being used for the resumed training
|
729 |
ui_updates.update({
|
730 |
"model_type": model_type_display, # Use the display name for the UI dropdown
|
|
|
731 |
"lora_rank": params.get('lora_rank', "128"),
|
732 |
"lora_alpha": params.get('lora_alpha', "128"),
|
733 |
"num_epochs": params.get('num_epochs', 70),
|
@@ -755,6 +808,7 @@ class TrainingService:
|
|
755 |
save_iterations=params.get('save_iterations', 500),
|
756 |
repo_id=params.get('repo_id', ''),
|
757 |
preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
|
|
|
758 |
resume_from_checkpoint=str(latest_checkpoint)
|
759 |
)
|
760 |
|
@@ -1029,5 +1083,4 @@ class TrainingService:
|
|
1029 |
return temp_zip_path
|
1030 |
except Exception as e:
|
1031 |
print(f"Failed to create zip: {str(e)}")
|
1032 |
-
raise gr.Error(f"Failed to create zip: {str(e)}")
|
1033 |
-
|
|
|
20 |
|
21 |
from huggingface_hub import upload_folder, create_repo
|
22 |
|
23 |
+
from ..config import (
|
24 |
+
TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
|
25 |
+
STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
|
26 |
+
MODEL_TYPES, TRAINING_TYPES
|
27 |
+
)
|
28 |
from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
|
|
116 |
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
117 |
default_state = {
|
118 |
"model_type": list(MODEL_TYPES.keys())[0],
|
119 |
+
"training_type": list(TRAINING_TYPES.keys())[0],
|
120 |
"lora_rank": "128",
|
121 |
"lora_alpha": "128",
|
122 |
"num_epochs": 50,
|
|
|
158 |
# Make sure we have all keys (in case structure changed)
|
159 |
merged_state = default_state.copy()
|
160 |
merged_state.update(saved_state)
|
|
|
161 |
return merged_state
|
162 |
except json.JSONDecodeError as e:
|
163 |
logger.error(f"Error parsing UI state JSON: {str(e)}")
|
|
|
175 |
logger.info("Creating new UI state file with default values")
|
176 |
default_state = {
|
177 |
"model_type": list(MODEL_TYPES.keys())[0],
|
178 |
+
"training_type": list(TRAINING_TYPES.keys())[0],
|
179 |
"lora_rank": "128",
|
180 |
"lora_alpha": "128",
|
181 |
"num_epochs": 50,
|
|
|
208 |
# Create a new file with default values
|
209 |
default_state = {
|
210 |
"model_type": list(MODEL_TYPES.keys())[0],
|
211 |
+
"training_type": list(TRAINING_TYPES.keys())[0],
|
212 |
"lora_rank": "128",
|
213 |
"lora_alpha": "128",
|
214 |
"num_epochs": 50,
|
|
|
337 |
elif model_type == "ltx_video":
|
338 |
if config.batch_size > 4:
|
339 |
return "LTX model recommended batch size is 1-4"
|
340 |
+
elif model_type == "wan":
|
341 |
+
if config.batch_size > 4:
|
342 |
+
return "Wan model recommended batch size is 1-4"
|
343 |
|
344 |
logger.info(f"Config validation passed with {len(video_lines)} training files")
|
345 |
return None
|
|
|
359 |
save_iterations: int,
|
360 |
repo_id: str,
|
361 |
preset_name: str,
|
362 |
+
training_type: str = "lora",
|
363 |
resume_from_checkpoint: Optional[str] = None,
|
364 |
) -> Tuple[str, str]:
|
365 |
"""Start training with finetrainers"""
|
|
|
370 |
raise ValueError("model_type cannot be empty")
|
371 |
if model_type not in MODEL_TYPES.values():
|
372 |
raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}")
|
373 |
+
if training_type not in TRAINING_TYPES.values():
|
374 |
+
raise ValueError(f"Invalid training_type: {training_type}. Must be one of {list(TRAINING_TYPES.values())}")
|
375 |
|
376 |
# Check if we're resuming or starting new
|
377 |
is_resuming = resume_from_checkpoint is not None
|
378 |
log_prefix = "Resuming" if is_resuming else "Initializing"
|
379 |
+
logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}")
|
380 |
|
381 |
try:
|
382 |
# Get absolute paths - FIXED to look in project root instead of within vms directory
|
|
|
421 |
logger.error(error_msg)
|
422 |
return error_msg, "No training data available"
|
423 |
|
|
|
424 |
# Get preset configuration
|
425 |
preset = TRAINING_PRESETS[preset_name]
|
426 |
training_buckets = preset["training_buckets"]
|
427 |
+
flow_weighting_scheme = preset.get("flow_weighting_scheme", "none")
|
428 |
+
preset_training_type = preset.get("training_type", "lora")
|
429 |
|
430 |
# Get config for selected model type with preset buckets
|
431 |
if model_type == "hunyuan_video":
|
432 |
+
if training_type == "lora":
|
433 |
+
config = TrainingConfig.hunyuan_video_lora(
|
434 |
+
data_path=str(TRAINING_PATH),
|
435 |
+
output_path=str(OUTPUT_PATH),
|
436 |
+
buckets=training_buckets
|
437 |
+
)
|
438 |
+
else:
|
439 |
+
# Hunyuan doesn't support full finetune in our UI yet
|
440 |
+
error_msg = "Full finetune is not supported for Hunyuan Video due to memory limitations"
|
441 |
+
logger.error(error_msg)
|
442 |
+
return error_msg, "Training configuration error"
|
443 |
+
elif model_type == "ltx_video":
|
444 |
+
if training_type == "lora":
|
445 |
+
config = TrainingConfig.ltx_video_lora(
|
446 |
+
data_path=str(TRAINING_PATH),
|
447 |
+
output_path=str(OUTPUT_PATH),
|
448 |
+
buckets=training_buckets
|
449 |
+
)
|
450 |
+
else:
|
451 |
+
config = TrainingConfig.ltx_video_full_finetune(
|
452 |
+
data_path=str(TRAINING_PATH),
|
453 |
+
output_path=str(OUTPUT_PATH),
|
454 |
+
buckets=training_buckets
|
455 |
+
)
|
456 |
+
elif model_type == "wan":
|
457 |
+
if training_type == "lora":
|
458 |
+
config = TrainingConfig.wan_lora(
|
459 |
+
data_path=str(TRAINING_PATH),
|
460 |
+
output_path=str(OUTPUT_PATH),
|
461 |
+
buckets=training_buckets
|
462 |
+
)
|
463 |
+
else:
|
464 |
+
error_msg = "Full finetune for Wan is not yet supported in this UI"
|
465 |
+
logger.error(error_msg)
|
466 |
+
return error_msg, "Training configuration error"
|
467 |
+
else:
|
468 |
+
error_msg = f"Unsupported model type: {model_type}"
|
469 |
+
logger.error(error_msg)
|
470 |
+
return error_msg, "Unsupported model"
|
471 |
|
472 |
# Update with UI parameters
|
473 |
config.train_epochs = int(num_epochs)
|
|
|
|
|
474 |
config.batch_size = int(batch_size)
|
475 |
config.lr = float(learning_rate)
|
476 |
config.checkpointing_steps = int(save_iterations)
|
477 |
+
config.training_type = training_type
|
478 |
+
config.flow_weighting_scheme = flow_weighting_scheme
|
479 |
+
|
480 |
+
# Update LoRA parameters if using LoRA training type
|
481 |
+
if training_type == "lora":
|
482 |
+
config.lora_rank = int(lora_rank)
|
483 |
+
config.lora_alpha = int(lora_alpha)
|
484 |
|
485 |
# Update with resume_from_checkpoint if provided
|
486 |
if resume_from_checkpoint:
|
|
|
515 |
# Convert config to command line arguments
|
516 |
config_args = config.to_args_list()
|
517 |
|
|
|
518 |
logger.debug("Generated args list: %s", config_args)
|
519 |
|
520 |
# Log the full command for debugging
|
|
|
550 |
# Save session info including repo_id for later hub upload
|
551 |
self.save_session({
|
552 |
"model_type": model_type,
|
553 |
+
"training_type": training_type,
|
554 |
"lora_rank": lora_rank,
|
555 |
"lora_alpha": lora_alpha,
|
556 |
"num_epochs": num_epochs,
|
|
|
572 |
total_epochs=num_epochs,
|
573 |
message='Training started',
|
574 |
repo_id=repo_id,
|
575 |
+
model_type=model_type,
|
576 |
+
training_type=training_type
|
577 |
)
|
578 |
|
579 |
# Start monitoring process output
|
580 |
self._start_log_monitor(process)
|
581 |
|
582 |
+
success_msg = f"Started {training_type} training for {model_type} model"
|
583 |
self.append_log(success_msg)
|
584 |
logger.info(success_msg)
|
585 |
|
|
|
715 |
last_session = {
|
716 |
"params": {
|
717 |
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
718 |
+
"training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])),
|
719 |
"lora_rank": ui_state.get("lora_rank", "128"),
|
720 |
"lora_alpha": ui_state.get("lora_alpha", "128"),
|
721 |
"num_epochs": ui_state.get("num_epochs", 70),
|
|
|
772 |
logger.info(f"Mapped internal model type '{model_type_internal}' to display name '{model_type_display}'")
|
773 |
break
|
774 |
|
775 |
+
# Get training type (default to LoRA if not present in saved session)
|
776 |
+
training_type_internal = params.get('training_type', 'lora')
|
777 |
+
training_type_display = next((disp for disp, val in TRAINING_TYPES.items() if val == training_type_internal), list(TRAINING_TYPES.keys())[0])
|
778 |
+
|
779 |
# Add UI updates to restore the training parameters in the UI
|
780 |
# This shows the user what values are being used for the resumed training
|
781 |
ui_updates.update({
|
782 |
"model_type": model_type_display, # Use the display name for the UI dropdown
|
783 |
+
"training_type": training_type_display, # Use the display name for training type
|
784 |
"lora_rank": params.get('lora_rank', "128"),
|
785 |
"lora_alpha": params.get('lora_alpha', "128"),
|
786 |
"num_epochs": params.get('num_epochs', 70),
|
|
|
808 |
save_iterations=params.get('save_iterations', 500),
|
809 |
repo_id=params.get('repo_id', ''),
|
810 |
preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
|
811 |
+
training_type=training_type_internal,
|
812 |
resume_from_checkpoint=str(latest_checkpoint)
|
813 |
)
|
814 |
|
|
|
1083 |
return temp_zip_path
|
1084 |
except Exception as e:
|
1085 |
print(f"Failed to create zip: {str(e)}")
|
1086 |
+
raise gr.Error(f"Failed to create zip: {str(e)}")
|
|
vms/tabs/train_tab.py
CHANGED
@@ -45,11 +45,21 @@ class TrainTab(BaseTab):
|
|
45 |
label="Model Type",
|
46 |
value=list(MODEL_TYPES.keys())[0]
|
47 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
self.components["model_info"] = gr.Markdown(
|
49 |
-
value=self.get_model_info(list(MODEL_TYPES.keys())[0])
|
50 |
)
|
51 |
|
52 |
-
|
|
|
|
|
53 |
self.components["lora_rank"] = gr.Dropdown(
|
54 |
label="LoRA Rank",
|
55 |
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
@@ -62,6 +72,7 @@ class TrainTab(BaseTab):
|
|
62 |
value="128",
|
63 |
type="value"
|
64 |
)
|
|
|
65 |
with gr.Row():
|
66 |
self.components["num_epochs"] = gr.Number(
|
67 |
label="Number of Epochs",
|
@@ -143,15 +154,18 @@ class TrainTab(BaseTab):
|
|
143 |
def connect_events(self) -> None:
|
144 |
"""Connect event handlers to UI components"""
|
145 |
# Model type change event
|
146 |
-
def update_model_info(model):
|
147 |
-
params = self.get_default_params(MODEL_TYPES[model])
|
148 |
-
info = self.get_model_info(MODEL_TYPES[model])
|
|
|
|
|
149 |
return {
|
150 |
self.components["model_info"]: info,
|
151 |
self.components["num_epochs"]: params["num_epochs"],
|
152 |
self.components["batch_size"]: params["batch_size"],
|
153 |
self.components["learning_rate"]: params["learning_rate"],
|
154 |
-
self.components["save_iterations"]: params["save_iterations"]
|
|
|
155 |
}
|
156 |
|
157 |
self.components["model_type"].change(
|
@@ -160,13 +174,32 @@ class TrainTab(BaseTab):
|
|
160 |
outputs=[]
|
161 |
).then(
|
162 |
fn=update_model_info,
|
163 |
-
inputs=[self.components["model_type"]],
|
164 |
outputs=[
|
165 |
self.components["model_info"],
|
166 |
self.components["num_epochs"],
|
167 |
self.components["batch_size"],
|
168 |
self.components["learning_rate"],
|
169 |
-
self.components["save_iterations"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
]
|
171 |
)
|
172 |
|
@@ -217,13 +250,15 @@ class TrainTab(BaseTab):
|
|
217 |
inputs=[self.components["training_preset"]],
|
218 |
outputs=[
|
219 |
self.components["model_type"],
|
|
|
220 |
self.components["lora_rank"],
|
221 |
self.components["lora_alpha"],
|
222 |
self.components["num_epochs"],
|
223 |
self.components["batch_size"],
|
224 |
self.components["learning_rate"],
|
225 |
self.components["save_iterations"],
|
226 |
-
self.components["preset_info"]
|
|
|
227 |
]
|
228 |
)
|
229 |
|
@@ -233,6 +268,7 @@ class TrainTab(BaseTab):
|
|
233 |
inputs=[
|
234 |
self.components["training_preset"],
|
235 |
self.components["model_type"],
|
|
|
236 |
self.components["lora_rank"],
|
237 |
self.components["lora_alpha"],
|
238 |
self.components["num_epochs"],
|
@@ -278,7 +314,7 @@ class TrainTab(BaseTab):
|
|
278 |
]
|
279 |
)
|
280 |
|
281 |
-
def handle_training_start(self, preset, model_type, *args):
|
282 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
283 |
# Safely reset log parser if it exists
|
284 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
@@ -305,12 +341,20 @@ class TrainTab(BaseTab):
|
|
305 |
logger.error(f"Invalid model type: {model_type}")
|
306 |
return f"Error: Invalid model type '{model_type}'", "Model type not recognized"
|
307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
# Start training (it will automatically use the checkpoint if provided)
|
309 |
try:
|
310 |
return self.app.trainer.start_training(
|
311 |
model_internal_type, # Use internal model type
|
312 |
*args,
|
313 |
preset_name=preset,
|
|
|
314 |
resume_from_checkpoint=resume_from
|
315 |
)
|
316 |
except Exception as e:
|
@@ -318,25 +362,56 @@ class TrainTab(BaseTab):
|
|
318 |
return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
|
319 |
|
320 |
|
321 |
-
def get_model_info(self, model_type: str) -> str:
|
322 |
-
"""Get information about the selected model type"""
|
|
|
|
|
323 |
if model_type == "hunyuan_video":
|
324 |
-
|
325 |
- Required VRAM: ~48GB minimum
|
326 |
- Recommended batch size: 1-2
|
327 |
- Typical training time: 2-4 hours
|
328 |
-
- Default resolution: 49x512x768
|
329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
elif model_type == "ltx_video":
|
332 |
-
|
333 |
-
- Required VRAM: ~18GB minimum
|
334 |
- Recommended batch size: 1-4
|
335 |
- Typical training time: 1-3 hours
|
336 |
-
- Default resolution: 49x512x768
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
|
341 |
def get_default_params(self, model_type: str) -> Dict[str, Any]:
|
342 |
"""Get default training parameters for model type"""
|
|
|
45 |
label="Model Type",
|
46 |
value=list(MODEL_TYPES.keys())[0]
|
47 |
)
|
48 |
+
with gr.Column():
|
49 |
+
self.components["training_type"] = gr.Dropdown(
|
50 |
+
choices=list(TRAINING_TYPES.keys()),
|
51 |
+
label="Training Type",
|
52 |
+
value=list(TRAINING_TYPES.keys())[0]
|
53 |
+
)
|
54 |
+
|
55 |
+
with gr.Row():
|
56 |
self.components["model_info"] = gr.Markdown(
|
57 |
+
value=self.get_model_info(list(MODEL_TYPES.keys())[0], list(TRAINING_TYPES.keys())[0])
|
58 |
)
|
59 |
|
60 |
+
# LoRA specific parameters (will show/hide based on training type)
|
61 |
+
with gr.Row(visible=True) as lora_params_row:
|
62 |
+
self.components["lora_params_row"] = lora_params_row
|
63 |
self.components["lora_rank"] = gr.Dropdown(
|
64 |
label="LoRA Rank",
|
65 |
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
|
|
72 |
value="128",
|
73 |
type="value"
|
74 |
)
|
75 |
+
|
76 |
with gr.Row():
|
77 |
self.components["num_epochs"] = gr.Number(
|
78 |
label="Number of Epochs",
|
|
|
154 |
def connect_events(self) -> None:
|
155 |
"""Connect event handlers to UI components"""
|
156 |
# Model type change event
|
157 |
+
def update_model_info(model, training_type):
|
158 |
+
params = self.get_default_params(MODEL_TYPES[model], TRAINING_TYPES[training_type])
|
159 |
+
info = self.get_model_info(MODEL_TYPES[model], TRAINING_TYPES[training_type])
|
160 |
+
show_lora_params = training_type == list(TRAINING_TYPES.keys())[0] # Show if LoRA Finetune
|
161 |
+
|
162 |
return {
|
163 |
self.components["model_info"]: info,
|
164 |
self.components["num_epochs"]: params["num_epochs"],
|
165 |
self.components["batch_size"]: params["batch_size"],
|
166 |
self.components["learning_rate"]: params["learning_rate"],
|
167 |
+
self.components["save_iterations"]: params["save_iterations"],
|
168 |
+
self.components["lora_params_row"]: gr.Row(visible=show_lora_params)
|
169 |
}
|
170 |
|
171 |
self.components["model_type"].change(
|
|
|
174 |
outputs=[]
|
175 |
).then(
|
176 |
fn=update_model_info,
|
177 |
+
inputs=[self.components["model_type"], self.components["training_type"]],
|
178 |
outputs=[
|
179 |
self.components["model_info"],
|
180 |
self.components["num_epochs"],
|
181 |
self.components["batch_size"],
|
182 |
self.components["learning_rate"],
|
183 |
+
self.components["save_iterations"],
|
184 |
+
self.components["lora_params_row"]
|
185 |
+
]
|
186 |
+
)
|
187 |
+
|
188 |
+
# Training type change event
|
189 |
+
self.components["training_type"].change(
|
190 |
+
fn=lambda v: self.app.update_ui_state(training_type=v),
|
191 |
+
inputs=[self.components["training_type"]],
|
192 |
+
outputs=[]
|
193 |
+
).then(
|
194 |
+
fn=update_model_info,
|
195 |
+
inputs=[self.components["model_type"], self.components["training_type"]],
|
196 |
+
outputs=[
|
197 |
+
self.components["model_info"],
|
198 |
+
self.components["num_epochs"],
|
199 |
+
self.components["batch_size"],
|
200 |
+
self.components["learning_rate"],
|
201 |
+
self.components["save_iterations"],
|
202 |
+
self.components["lora_params_row"]
|
203 |
]
|
204 |
)
|
205 |
|
|
|
250 |
inputs=[self.components["training_preset"]],
|
251 |
outputs=[
|
252 |
self.components["model_type"],
|
253 |
+
self.components["training_type"],
|
254 |
self.components["lora_rank"],
|
255 |
self.components["lora_alpha"],
|
256 |
self.components["num_epochs"],
|
257 |
self.components["batch_size"],
|
258 |
self.components["learning_rate"],
|
259 |
self.components["save_iterations"],
|
260 |
+
self.components["preset_info"],
|
261 |
+
self.components["lora_params_row"]
|
262 |
]
|
263 |
)
|
264 |
|
|
|
268 |
inputs=[
|
269 |
self.components["training_preset"],
|
270 |
self.components["model_type"],
|
271 |
+
self.components["training_type"],
|
272 |
self.components["lora_rank"],
|
273 |
self.components["lora_alpha"],
|
274 |
self.components["num_epochs"],
|
|
|
314 |
]
|
315 |
)
|
316 |
|
317 |
+
def handle_training_start(self, preset, model_type, training_type, *args):
|
318 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
319 |
# Safely reset log parser if it exists
|
320 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
|
|
341 |
logger.error(f"Invalid model type: {model_type}")
|
342 |
return f"Error: Invalid model type '{model_type}'", "Model type not recognized"
|
343 |
|
344 |
+
# Convert training_type display name to internal name
|
345 |
+
training_internal_type = TRAINING_TYPES.get(training_type)
|
346 |
+
|
347 |
+
if not training_internal_type:
|
348 |
+
logger.error(f"Invalid training type: {training_type}")
|
349 |
+
return f"Error: Invalid training type '{training_type}'", "Training type not recognized"
|
350 |
+
|
351 |
# Start training (it will automatically use the checkpoint if provided)
|
352 |
try:
|
353 |
return self.app.trainer.start_training(
|
354 |
model_internal_type, # Use internal model type
|
355 |
*args,
|
356 |
preset_name=preset,
|
357 |
+
training_type=training_internal_type, # Pass the internal training type
|
358 |
resume_from_checkpoint=resume_from
|
359 |
)
|
360 |
except Exception as e:
|
|
|
362 |
return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
|
363 |
|
364 |
|
365 |
+
def get_model_info(self, model_type: str, training_type: str) -> str:
|
366 |
+
"""Get information about the selected model type and training method"""
|
367 |
+
training_method = "LoRA finetune" if training_type == "lora" else "Full finetune"
|
368 |
+
|
369 |
if model_type == "hunyuan_video":
|
370 |
+
base_info = """### HunyuanVideo
|
371 |
- Required VRAM: ~48GB minimum
|
372 |
- Recommended batch size: 1-2
|
373 |
- Typical training time: 2-4 hours
|
374 |
+
- Default resolution: 49x512x768"""
|
375 |
+
|
376 |
+
if training_type == "lora":
|
377 |
+
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
378 |
+
else:
|
379 |
+
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
380 |
+
|
381 |
+
elif model_type == "wan":
|
382 |
+
base_info = """### Wan-2.1-T2V
|
383 |
+
- Recommended batch size: 1-2
|
384 |
+
- Typical training time: 1-3 hours
|
385 |
+
- Default resolution: 49x512x768"""
|
386 |
+
|
387 |
+
if training_type == "lora":
|
388 |
+
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
389 |
+
else:
|
390 |
+
return base_info + "\n- **Full finetune not supported in this UI**" + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
391 |
+
else:
|
392 |
+
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
393 |
+
|
394 |
+
elif model_type == "wan":
|
395 |
+
base_info = """### Wan-2.1-T2V
|
396 |
+
- Recommended batch size: 1-2
|
397 |
+
- Typical training time: 1-3 hours
|
398 |
+
- Default resolution: 49x512x768"""
|
399 |
+
|
400 |
+
if training_type == "lora":
|
401 |
+
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
402 |
+
else:
|
403 |
+
return base_info + "\n- **Full finetune not supported in this UI**" + "\n- Default LoRA rank: 128 (~600 MB)"
|
404 |
+
else:
|
405 |
+
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
406 |
|
407 |
elif model_type == "ltx_video":
|
408 |
+
base_info = """### LTX-Video
|
|
|
409 |
- Recommended batch size: 1-4
|
410 |
- Typical training time: 1-3 hours
|
411 |
+
- Default resolution: 49x512x768"""
|
412 |
+
|
413 |
+
if training_type == "lora":
|
414 |
+
return base_
|
415 |
|
416 |
def get_default_params(self, model_type: str) -> Dict[str, Any]:
|
417 |
"""Get default training parameters for model type"""
|
vms/ui/video_trainer_ui.py
CHANGED
@@ -75,7 +75,7 @@ class VideoTrainerUI:
|
|
75 |
for tab_id, tab_obj in self.tabs.items():
|
76 |
tab_obj.connect_events()
|
77 |
|
78 |
-
#
|
79 |
self._add_timers()
|
80 |
|
81 |
# Initialize app state on load
|
@@ -89,6 +89,7 @@ class VideoTrainerUI:
|
|
89 |
self.tabs["train_tab"].components["pause_resume_btn"],
|
90 |
self.tabs["train_tab"].components["training_preset"],
|
91 |
self.tabs["train_tab"].components["model_type"],
|
|
|
92 |
self.tabs["train_tab"].components["lora_rank"],
|
93 |
self.tabs["train_tab"].components["lora_alpha"],
|
94 |
self.tabs["train_tab"].components["num_epochs"],
|
@@ -180,9 +181,23 @@ class VideoTrainerUI:
|
|
180 |
|
181 |
ui_state["model_type"] = model_type_value
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
# Copy other parameters
|
184 |
for param in ["lora_rank", "lora_alpha", "num_epochs",
|
185 |
-
|
186 |
if param in recovery_ui:
|
187 |
ui_state[param] = recovery_ui[param]
|
188 |
|
@@ -205,6 +220,15 @@ class VideoTrainerUI:
|
|
205 |
model_type_val = display_name
|
206 |
break
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
209 |
lora_rank_val = ui_state.get("lora_rank", "128")
|
210 |
lora_alpha_val = ui_state.get("lora_alpha", "128")
|
@@ -222,6 +246,7 @@ class VideoTrainerUI:
|
|
222 |
delete_checkpoints_btn,
|
223 |
training_preset,
|
224 |
model_type_val,
|
|
|
225 |
lora_rank_val,
|
226 |
lora_alpha_val,
|
227 |
num_epochs_val,
|
@@ -238,6 +263,7 @@ class VideoTrainerUI:
|
|
238 |
return (
|
239 |
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
240 |
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
|
|
241 |
ui_state.get("lora_rank", "128"),
|
242 |
ui_state.get("lora_alpha", "128"),
|
243 |
ui_state.get("num_epochs", 70),
|
|
|
75 |
for tab_id, tab_obj in self.tabs.items():
|
76 |
tab_obj.connect_events()
|
77 |
|
78 |
+
# app-level timers for auto-refresh functionality
|
79 |
self._add_timers()
|
80 |
|
81 |
# Initialize app state on load
|
|
|
89 |
self.tabs["train_tab"].components["pause_resume_btn"],
|
90 |
self.tabs["train_tab"].components["training_preset"],
|
91 |
self.tabs["train_tab"].components["model_type"],
|
92 |
+
self.tabs["train_tab"].components["training_type"], # Add the new training_type component to outputs
|
93 |
self.tabs["train_tab"].components["lora_rank"],
|
94 |
self.tabs["train_tab"].components["lora_alpha"],
|
95 |
self.tabs["train_tab"].components["num_epochs"],
|
|
|
181 |
|
182 |
ui_state["model_type"] = model_type_value
|
183 |
|
184 |
+
# Handle training_type
|
185 |
+
if "training_type" in recovery_ui:
|
186 |
+
training_type_value = recovery_ui["training_type"]
|
187 |
+
|
188 |
+
# If it's an internal name, convert to display name
|
189 |
+
if training_type_value not in TRAINING_TYPES:
|
190 |
+
for display_name, internal_name in TRAINING_TYPES.items():
|
191 |
+
if internal_name == training_type_value:
|
192 |
+
training_type_value = display_name
|
193 |
+
logger.info(f"Converted internal training type '{recovery_ui['training_type']}' to display name '{training_type_value}'")
|
194 |
+
break
|
195 |
+
|
196 |
+
ui_state["training_type"] = training_type_value
|
197 |
+
|
198 |
# Copy other parameters
|
199 |
for param in ["lora_rank", "lora_alpha", "num_epochs",
|
200 |
+
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
201 |
if param in recovery_ui:
|
202 |
ui_state[param] = recovery_ui[param]
|
203 |
|
|
|
220 |
model_type_val = display_name
|
221 |
break
|
222 |
|
223 |
+
# Ensure training_type is a display name, not internal name
|
224 |
+
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
225 |
+
if training_type_val not in TRAINING_TYPES:
|
226 |
+
# Convert from internal to display name
|
227 |
+
for display_name, internal_name in TRAINING_TYPES.items():
|
228 |
+
if internal_name == training_type_val:
|
229 |
+
training_type_val = display_name
|
230 |
+
break
|
231 |
+
|
232 |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
233 |
lora_rank_val = ui_state.get("lora_rank", "128")
|
234 |
lora_alpha_val = ui_state.get("lora_alpha", "128")
|
|
|
246 |
delete_checkpoints_btn,
|
247 |
training_preset,
|
248 |
model_type_val,
|
249 |
+
training_type_val,
|
250 |
lora_rank_val,
|
251 |
lora_alpha_val,
|
252 |
num_epochs_val,
|
|
|
263 |
return (
|
264 |
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
265 |
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
266 |
+
ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
|
267 |
ui_state.get("lora_rank", "128"),
|
268 |
ui_state.get("lora_alpha", "128"),
|
269 |
ui_state.get("num_epochs", 70),
|