|
# Pyramid Flow's VAE Training Guide |
|
|
|
This is the training guide for a [MAGVIT-v2](https://arxiv.org/abs/2310.05737) like continuous 3D VAE, which should be quite flexible. Feel free to build your own video generative model on this part of VAE training code. Please refer to [another document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/DiT) for DiT finetuning. |
|
|
|
## Hardware Requirements |
|
|
|
+ VAE training: At least 8 A100 GPUs. |
|
|
|
|
|
## Prepare the Dataset |
|
|
|
The training of our causal video vae uses both image and video data. Both of them should be arranged into a json file, with `video` or `image` field. The final training annotation json file should look like the following format: |
|
|
|
``` |
|
# For Video |
|
{"video": video_path} |
|
|
|
# For Image |
|
{"image": image_path} |
|
``` |
|
|
|
## Run Training |
|
|
|
The causal video vae undergoes a two-stage training. |
|
+ Stage-1: image and video mixed training |
|
+ Stage-2: pure video training, using context parallel to load video with more video frames |
|
|
|
The VAE training script is `scripts/train_causal_video_vae.sh`, run it as follows: |
|
|
|
```bash |
|
sh scripts/train_causal_video_vae.sh |
|
``` |
|
|
|
We also provide a VAE demo `causal_video_vae_demo.ipynb` for image and video reconstruction. |
|
> The original vgg lpips download URL is not available, I have shared the one we used in this [URL](https://drive.google.com/file/d/1YeFlX5BKKw-HGkjNd1r7DSwas1iJJwqC/view). You can download it and replace the LPIPS_CKPT with the correct path. |
|
|
|
|
|
## Tips |
|
|
|
+ For stage-1, we use a mixed image and video training. Add the param `--use_image_video_mixed_training` to support the mixed training. We set the image ratio to 0.1 by default. |
|
+ Set the `resolution` to 256 is enough for VAE training. |
|
+ For stage-1, the `max_frames` is set to 17. It means we use 17 sampled video frames for training. |
|
+ For stage-2, we open the param `use_context_parallel` to distribute long video frames to multiple GPUs. Make sure to set `GPUS % CONTEXT_SIZE == 0` and `NUM_FRAMES=17 * CONTEXT_SIZE + 1` |