from .DiffAE_support_config import * def ddpm(): """ base configuration for all DDIM-based models. """ conf = TrainConfig() conf.batch_size = 32 conf.beatgans_gen_type = GenerativeType.ddim conf.beta_scheduler = 'linear' conf.data_name = 'ffhq' conf.diffusion_type = 'beatgans' conf.eval_ema_every_samples = 200_000 conf.eval_every_samples = 200_000 conf.fp16 = True conf.lr = 1e-4 conf.model_name = ModelName.beatgans_ddpm conf.net_attn = (16, ) conf.net_beatgans_attn_head = 1 conf.net_beatgans_embed_channels = 512 conf.net_ch_mult = (1, 2, 4, 8) conf.net_ch = 64 conf.sample_size = 32 conf.T_eval = 20 conf.T = 1000 conf.make_model_conf() return conf def autoenc_base(): """ base configuration for all Diff-AE models. """ conf = TrainConfig() conf.batch_size = 32 conf.beatgans_gen_type = GenerativeType.ddim conf.beta_scheduler = 'linear' conf.data_name = 'ffhq' conf.diffusion_type = 'beatgans' conf.eval_ema_every_samples = 200_000 conf.eval_every_samples = 200_000 conf.fp16 = True conf.lr = 1e-4 conf.model_name = ModelName.beatgans_autoenc conf.net_attn = (16, ) conf.net_beatgans_attn_head = 1 conf.net_beatgans_embed_channels = 512 conf.net_beatgans_resnet_two_cond = True conf.net_ch_mult = (1, 2, 4, 8) conf.net_ch = 64 conf.net_enc_channel_mult = (1, 2, 4, 8, 8) conf.net_enc_pool = 'adaptivenonzero' conf.sample_size = 32 conf.T_eval = 20 conf.T = 1000 conf.make_model_conf() return conf def ffhq64_ddpm(): conf = ddpm() conf.data_name = 'ffhqlmdb256' conf.warmup = 0 conf.total_samples = 72_000_000 conf.scale_up_gpus(4) return conf def ffhq64_autoenc(): conf = autoenc_base() conf.data_name = 'ffhqlmdb256' conf.warmup = 0 conf.total_samples = 72_000_000 conf.net_ch_mult = (1, 2, 4, 8) conf.net_enc_channel_mult = (1, 2, 4, 8, 8) conf.eval_every_samples = 1_000_000 conf.eval_ema_every_samples = 1_000_000 conf.scale_up_gpus(4) conf.make_model_conf() return conf def celeba64d2c_ddpm(): conf = ffhq128_ddpm() conf.data_name = 'celebalmdb' conf.eval_every_samples = 10_000_000 conf.eval_ema_every_samples = 10_000_000 conf.total_samples = 72_000_000 conf.name = 'celeba64d2c_ddpm' return conf def celeba64d2c_autoenc(): conf = ffhq64_autoenc() conf.data_name = 'celebalmdb' conf.eval_every_samples = 10_000_000 conf.eval_ema_every_samples = 10_000_000 conf.total_samples = 72_000_000 conf.name = 'celeba64d2c_autoenc' return conf def ffhq128_ddpm(): conf = ddpm() conf.data_name = 'ffhqlmdb256' conf.warmup = 0 conf.total_samples = 48_000_000 conf.img_size = 128 conf.net_ch = 128 # channels: # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4 # sizes: # 128 => 128 => 64 => 32 => 16 => 8 conf.net_ch_mult = (1, 1, 2, 3, 4) conf.eval_every_samples = 1_000_000 conf.eval_ema_every_samples = 1_000_000 conf.scale_up_gpus(4) conf.eval_ema_every_samples = 10_000_000 conf.eval_every_samples = 10_000_000 conf.make_model_conf() return conf def ffhq128_autoenc_base(): conf = autoenc_base() conf.data_name = 'ffhqlmdb256' conf.scale_up_gpus(4) conf.img_size = 128 conf.net_ch = 128 # final resolution = 8x8 conf.net_ch_mult = (1, 1, 2, 3, 4) # final resolution = 4x4 conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4) conf.eval_ema_every_samples = 10_000_000 conf.eval_every_samples = 10_000_000 conf.make_model_conf() return conf def ffhq256_autoenc(): conf = ffhq128_autoenc_base() conf.img_size = 256 conf.net_ch = 128 conf.net_ch_mult = (1, 1, 2, 2, 4, 4) conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) conf.eval_every_samples = 10_000_000 conf.eval_ema_every_samples = 10_000_000 conf.total_samples = 200_000_000 conf.batch_size = 64 conf.make_model_conf() conf.name = 'ffhq256_autoenc' return conf def ffhq256_autoenc_eco(): conf = ffhq128_autoenc_base() conf.img_size = 256 conf.net_ch = 128 conf.net_ch_mult = (1, 1, 2, 2, 4, 4) conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) conf.eval_every_samples = 10_000_000 conf.eval_ema_every_samples = 10_000_000 conf.total_samples = 200_000_000 conf.batch_size = 64 conf.make_model_conf() conf.name = 'ffhq256_autoenc_eco' return conf def ffhq128_ddpm_72M(): conf = ffhq128_ddpm() conf.total_samples = 72_000_000 conf.name = 'ffhq128_ddpm_72M' return conf def ffhq128_autoenc_72M(): conf = ffhq128_autoenc_base() conf.total_samples = 72_000_000 conf.name = 'ffhq128_autoenc_72M' return conf def ffhq128_ddpm_130M(): conf = ffhq128_ddpm() conf.total_samples = 130_000_000 conf.eval_ema_every_samples = 10_000_000 conf.eval_every_samples = 10_000_000 conf.name = 'ffhq128_ddpm_130M' return conf def ffhq128_autoenc_130M(): conf = ffhq128_autoenc_base() conf.total_samples = 130_000_000 conf.eval_ema_every_samples = 10_000_000 conf.eval_every_samples = 10_000_000 conf.name = 'ffhq128_autoenc_130M' return conf #created from ffhq128_autoenc_130M def ukbb_autoenc(ds_name="ukbb", n_latents=128): conf = TrainConfig() conf.beatgans_gen_type = GenerativeType.ddim conf.beta_scheduler = 'linear' conf.diffusion_type = 'beatgans' conf.fp16 = True conf.model_name = ModelName.beatgans_autoenc conf.net_attn = (16, ) conf.net_beatgans_attn_head = 1 conf.net_beatgans_embed_channels = n_latents conf.style_ch = n_latents conf.net_beatgans_resnet_two_cond = True conf.net_enc_pool = 'adaptivenonzero' conf.sample_size = 32 conf.T_eval = 20 conf.T = 1000 conf.T_inv = 200 conf.T_step = 100 conf.data_name = ds_name conf.net_ch_mult = (1, 1, 2, 3, 4) conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4) conf.name = 'ukbb_ffhq128_autoenc' return conf def horse128_ddpm(): conf = ffhq128_ddpm() conf.data_name = 'horse256' conf.total_samples = 130_000_000 conf.eval_ema_every_samples = 10_000_000 conf.eval_every_samples = 10_000_000 conf.name = 'horse128_ddpm' return conf def horse128_autoenc(): conf = ffhq128_autoenc_base() conf.data_name = 'horse256' conf.total_samples = 130_000_000 conf.eval_ema_every_samples = 10_000_000 conf.eval_every_samples = 10_000_000 conf.name = 'horse128_autoenc' return conf def bedroom128_ddpm(): conf = ffhq128_ddpm() conf.data_name = 'bedroom256' conf.eval_ema_every_samples = 10_000_000 conf.eval_every_samples = 10_000_000 conf.total_samples = 120_000_000 conf.name = 'bedroom128_ddpm' return conf def bedroom128_autoenc(): conf = ffhq128_autoenc_base() conf.data_name = 'bedroom256' conf.eval_ema_every_samples = 10_000_000 conf.eval_every_samples = 10_000_000 conf.total_samples = 120_000_000 conf.name = 'bedroom128_autoenc' return conf def pretrain_celeba64d2c_72M(): conf = celeba64d2c_autoenc() conf.pretrain = PretrainConfig( name='72M', path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt', ) conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl' return conf def pretrain_ffhq128_autoenc72M(): conf = ffhq128_autoenc_base() conf.postfix = '' conf.pretrain = PretrainConfig( name='72M', path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt', ) conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl' return conf def pretrain_ffhq128_autoenc130M(): conf = ffhq128_autoenc_base() conf.pretrain = PretrainConfig( name='130M', path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt', ) conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl' return conf def pretrain_ffhq256_autoenc(): conf = ffhq256_autoenc() conf.pretrain = PretrainConfig( name='90M', path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt', ) conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' return conf def pretrain_horse128(): conf = horse128_autoenc() conf.pretrain = PretrainConfig( name='82M', path=f'checkpoints/{horse128_autoenc().name}/last.ckpt', ) conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl' return conf def pretrain_bedroom128(): conf = bedroom128_autoenc() conf.pretrain = PretrainConfig( name='120M', path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt', ) conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl' return conf