File size: 4,831 Bytes
485c2ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from .DiffAE_support_templates import *
def latent_diffusion_config(conf: TrainConfig):
conf.batch_size = 128
conf.train_mode = TrainMode.latent_diffusion
conf.latent_gen_type = GenerativeType.ddim
conf.latent_loss_type = LossType.mse
conf.latent_model_mean_type = ModelMeanType.eps
conf.latent_model_var_type = ModelVarType.fixed_large
conf.latent_rescale_timesteps = False
conf.latent_clip_sample = False
conf.latent_T_eval = 20
conf.latent_znormalize = True
conf.total_samples = 96_000_000
conf.sample_every_samples = 400_000
conf.eval_every_samples = 20_000_000
conf.eval_ema_every_samples = 20_000_000
conf.save_every_samples = 2_000_000
return conf
def latent_diffusion128_config(conf: TrainConfig):
conf = latent_diffusion_config(conf)
conf.batch_size_eval = 32
return conf
def latent_mlp_2048_norm_10layers(conf: TrainConfig):
conf.net_latent_net_type = LatentNetType.skip
conf.net_latent_layers = 10
conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers))
conf.net_latent_activation = Activation.silu
conf.net_latent_num_hid_channels = 2048
conf.net_latent_use_norm = True
conf.net_latent_condition_bias = 1
return conf
def latent_mlp_2048_norm_20layers(conf: TrainConfig):
conf = latent_mlp_2048_norm_10layers(conf)
conf.net_latent_layers = 20
conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers))
return conf
def latent_256_batch_size(conf: TrainConfig):
conf.batch_size = 256
conf.eval_ema_every_samples = 100_000_000
conf.eval_every_samples = 100_000_000
conf.sample_every_samples = 1_000_000
conf.save_every_samples = 2_000_000
conf.total_samples = 301_000_000
return conf
def latent_512_batch_size(conf: TrainConfig):
conf.batch_size = 512
conf.eval_ema_every_samples = 100_000_000
conf.eval_every_samples = 100_000_000
conf.sample_every_samples = 1_000_000
conf.save_every_samples = 5_000_000
conf.total_samples = 501_000_000
return conf
def latent_2048_batch_size(conf: TrainConfig):
conf.batch_size = 2048
conf.eval_ema_every_samples = 200_000_000
conf.eval_every_samples = 200_000_000
conf.sample_every_samples = 4_000_000
conf.save_every_samples = 20_000_000
conf.total_samples = 1_501_000_000
return conf
def adamw_weight_decay(conf: TrainConfig):
conf.optimizer = OptimizerType.adamw
conf.weight_decay = 0.01
return conf
def ffhq128_autoenc_latent():
conf = pretrain_ffhq128_autoenc130M()
conf = latent_diffusion128_config(conf)
conf = latent_mlp_2048_norm_10layers(conf)
conf = latent_256_batch_size(conf)
conf = adamw_weight_decay(conf)
conf.total_samples = 101_000_000
conf.latent_loss_type = LossType.l1
conf.latent_beta_scheduler = 'const0.008'
conf.name = 'ffhq128_autoenc_latent'
return conf
def ffhq256_autoenc_latent():
conf = pretrain_ffhq256_autoenc()
conf = latent_diffusion128_config(conf)
conf = latent_mlp_2048_norm_10layers(conf)
conf = latent_256_batch_size(conf)
conf = adamw_weight_decay(conf)
conf.total_samples = 101_000_000
conf.latent_loss_type = LossType.l1
conf.latent_beta_scheduler = 'const0.008'
conf.eval_ema_every_samples = 200_000_000
conf.eval_every_samples = 200_000_000
conf.sample_every_samples = 4_000_000
conf.name = 'ffhq256_autoenc_latent'
return conf
def horse128_autoenc_latent():
conf = pretrain_horse128()
conf = latent_diffusion128_config(conf)
conf = latent_2048_batch_size(conf)
conf = latent_mlp_2048_norm_20layers(conf)
conf.total_samples = 2_001_000_000
conf.latent_beta_scheduler = 'const0.008'
conf.latent_loss_type = LossType.l1
conf.name = 'horse128_autoenc_latent'
return conf
def bedroom128_autoenc_latent():
conf = pretrain_bedroom128()
conf = latent_diffusion128_config(conf)
conf = latent_2048_batch_size(conf)
conf = latent_mlp_2048_norm_20layers(conf)
conf.total_samples = 2_001_000_000
conf.latent_beta_scheduler = 'const0.008'
conf.latent_loss_type = LossType.l1
conf.name = 'bedroom128_autoenc_latent'
return conf
def celeba64d2c_autoenc_latent():
conf = pretrain_celeba64d2c_72M()
conf = latent_diffusion_config(conf)
conf = latent_512_batch_size(conf)
conf = latent_mlp_2048_norm_10layers(conf)
conf = adamw_weight_decay(conf)
# just for the name
conf.continue_from = PretrainConfig('200M',
f'log-latent/{conf.name}/last.ckpt')
conf.postfix = '_300M'
conf.total_samples = 301_000_000
conf.latent_beta_scheduler = 'const0.008'
conf.latent_loss_type = LossType.l1
conf.name = 'celeba64d2c_autoenc_latent'
return conf |