File size: 9,297 Bytes
b6a4e9a b63a03d b6a4e9a cdfd061 3efca31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
---
license: mit
pipeline_tag: unconditional-image-generation
library_name: diffusers
datasets:
- ILSVRC/imagenet-1k
---
# Marrying Autoregressive Transformer and Diffusion with Multi-Reference Autoregression <br><sub>Official PyTorch Implementation</sub>
[](https://arxiv.org/pdf/2506.09482)
[](https://huggingface.co/zhendch/Transdiff)
[](https://github.com/TransDiff/TransDiff)
<p align="center">
<img src="figs/visual.png" width="720">
</p>
This is a PyTorch/GPU implementation of the paper [Marrying Autoregressive Transformer and Diffusion with Multi-Reference Autoregression](https://arxiv.org/pdf/2506.09482):
```
@article{zhen2025marrying,
title={Marrying Autoregressive Transformer and Diffusion with Multi-Reference Autoregression},
author={Zhen, Dingcheng and Qiao, Qian and Yu, Tan and Wu, Kangxi and Zhang, Ziwei and Liu, Siyuan and Yin, Shunshun and Tao, Ming},
journal={arXiv preprint arXiv:2506.09482},
year={2025}
}
```
This repo contains:
* 🪐 A simple PyTorch implementation of [TransDiff Model](models/transdiff.py) and [TransDiff Model with MRAR](models/transdiff_mrar.py)
* ⚡️ Pre-trained class-conditional TransDiff models trained on ImageNet 256x256 and 512x512
* 💥 A self-contained [notebook](demo.ipynb) for running various pre-trained TransDiff models
* 🛸 An TransDiff [training and evaluation script](main.py) using PyTorch DDP
## Preparation
### Dataset
Download [ImageNet](http://image-net.org/download) dataset, and place it in your `IMAGENET_PATH`.
### VAE Model
We adopt the VAE model from [MAR](https://github.com/LTH14/mar) , you can also get it [here](https://huggingface.co/zhendch/Transdiff/resolve/main/vae/checkpoint-last.pth?download=true).
### Installation
Download the code:
```
git clone https://github.com/TransDiff/TransDiff
cd TransDiff
```
A suitable [conda](https://conda.io/) environment named `transdiff` can be created and activated with:
```
conda env create -f environment.yaml
conda activate transdiff
```
For convenience, our pre-trained TransDiff models can be downloaded directly here as well:
| TransDiff Model | FID-50K | Inception Score | #params |
|--------------------------------------------------------------------------------------------------------------------------------|---------|-----------------|---------|
| [TransDiff-B](https://huggingface.co/zhendch/Transdiff/resolve/main/transdiff_b/checkpoint-last.pth?download=true) | 2.47 | 244.2 | 290M |
| [TransDiff-L](https://huggingface.co/zhendch/Transdiff/resolve/main/transdiff_l/checkpoint-last.pth?download=true) | 2.25 | 244.3 | 683M |
| [TransDiff-H](https://huggingface.co/zhendch/Transdiff/resolve/main/transdiff_h/checkpoint-last.pth?download=true) | 1.69 | 282.0 | 1.3B |
| [TransDiff-B MRAR](https://huggingface.co/zhendch/Transdiff/resolve/main/transdiff_b_mrar/checkpoint-last.pth?download=true) | 1.49 | 282.2 | 290M |
| [TransDiff-L MRAR](https://huggingface.co/zhendch/Transdiff/resolve/main/transdiff_l_mrar/checkpoint-last.pth?download=true) | 1.61 | 293.4 | 683M |
| [TransDiff-H MRAR](https://huggingface.co/zhendch/Transdiff/resolve/main/transdiff_h_mrar/checkpoint-last.pth?download=true) | 1.42 | 301.2 | 1.3B |
| [TransDiff-L 512x512](https://huggingface.co/zhendch/Transdiff/resolve/main/transdiff_l_512/checkpoint-last.pth?download=true) | 2.51 | 286.6 | 683M |
### (Optional) Download Other Files
Download necessary [file](https://huggingface.co/zhendch/Transdiff/resolve/main/VIRTUAL_imagenet512.npz?download=true) and put it into folder `fid_stats/`, if you want to run evaluation on ImageNet 512x512.
Download [MRAR index file](https://huggingface.co/zhendch/Transdiff/resolve/main/Imagenet2012_mrar_files.txt?download=true) and put it into root of project, if you want to train TransDiff with MRAR.
### (Optional) Caching VAE Latents
Given that our data augmentation consists of simple center cropping and random flipping,
the VAE latents can be pre-computed and saved to `CACHED_PATH` to save computations during TransDiff training:
```
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main_cache.py \
--img_size 256 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 \
--batch_size 128 \
--data_path ${IMAGENET_PATH} --cached_path ${CACHED_PATH}
```
## Usage
### Demo
Run our interactive visualization [demo](demo.ipynb).
### Training
Script for the TransDiff-L 1StepAR setting (Pretrain TransDiff-L with a width of 1024 channels, 800 epochs):
```
torchrun --nproc_per_node=8 --nnodes=8 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
main.py \
--img_size 256 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 \
--diffusion_batch_mul 4 \
--epochs 800 --warmup_epochs 100 --blr 1.0e-4 --batch_size 32 \
--output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
--data_path ${IMAGENET_PATH}
```
- Training time is ~115h on 64 A100 GPUs with `--batch_size 32`.
- Add `--online_eval` to evaluate FID during training (every 50 epochs).
- (Optional) To train with cached VAE latents, add `--use_cached --cached_path ${CACHED_PATH}` to the arguments.
- (Optional) If the error 'Loss is nan, stopping training' frequently occurs during training when using mixed precision training with 'torch.cuda.amp.autocast()', you can add `--bf16` to the arguments.
- (Optional) If necessary, you can use gradient accumulation by setting `--gradient_accumulation_steps n`.
Script for the TransDiff-L MRAR setting (Finetune TransDiff-L MRAR with a width of 1024 channels, 40 epochs):
```
torchrun --nproc_per_node=8 --nnodes=8 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
main.py \
--img_size 256 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 --mrar --bf16 \
--diffusion_batch_mul 2 \
--epochs 40 --warmup_epochs 10 --lr 5.0e-5 --batch_size 16 --gradient_accumulation_steps 2 \
--output_dir ${OUTPUT_DIR} --resume ${Transdiff-L_1StepAR_DIR} \
--data_path ${IMAGENET_PATH}
```
Script for the TransDiff-L 512x512 setting (Finetune TransDiff-L 512x512 with a width of 1024 channels, 150 epochs):
```
torchrun --nproc_per_node=8 --nnodes=8 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
main.py \
--img_size 512 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 --ema_rate 0.999 --bf16 \
--diffusion_batch_mul 4 \
--epochs 150 --warmup_epochs 10 --lr 1.0e-4 --batch_size 16 --gradient_accumulation_steps 2 \
--only_train_diff \
--output_dir ${OUTPUT_DIR} --resume ${Transdiff-L_1StepAR_DIR} \
--data_path ${IMAGENET_PATH}
```
### Evaluation (ImageNet 256x256 and 512x512)
Evaluate TransDiff-L 1StepAR with classifier-free guidance:
```
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main.py \
--img_size 256 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 \
--output_dir ${OUTPUT_DIR} --resume ckpt/transdiff_l/ \
--evaluate --eval_bsz 256 --num_images 50000 \
--cfg 1.3 --scale_0 0.89 --scale_1 0.95
```
Evaluate TransDiff-L MRAR with classifier-free guidance:
```
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main.py \
--img_size 256 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 \
--output_dir ${OUTPUT_DIR} --resume ckpt/transdiff_l_mrar/ \
--evaluate --eval_bsz 256 --num_images 50000 \
--cfg 1.3 --scale_0 0.91 --scale_1 0.93
```
Evaluate TransDiff-L 512x512 with classifier-free guidance:
```
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main.py \
--img_size 512 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 \
--output_dir ${OUTPUT_DIR} --resume ckpt/transdiff_l_512/ \
--evaluate --eval_bsz 64 --num_images 50000 \
--cfg 1.3 --scale_0 0.87 --scale_1 0.87
```
More settings for Benchmark in paper:
| TransDiff Model | cfg | scale_0 | scale_1 |
|---------------------|------|---------|---------|
| TransDiff-B | 1.30 | 0.87 | 0.91 |
| TransDiff-L | 1.30 | 0.89 | 0.95 |
| TransDiff-H | 1.23 | 0.87 | 0.93 |
| TransDiff-B MRAR | 1.30 | 0.87 | 0.91 |
| TransDiff-L MRAR | 1.30 | 0.91 | 0.93 |
| TransDiff-H MRAR | 1.28 | 0.87 | 0.91 |
| TransDiff-L 512x512 | 1.30 | 0.87 | 0.87 |
## Acknowledgements
A large portion of codes in this repo is based on [MAR](https://github.com/LTH14/mar), [diffusers](https://github.com/huggingface/diffusers) and [timm](https://github.com/huggingface/pytorch-image-models).
## Contact
If you have any questions, feel free to contact me through email ([email protected]). Enjoy! |