|
from .DiffAE_model_blocks import ScaleAt |
|
from .DiffAE_model import * |
|
from .DiffAE_diffusion_resample import UniformSampler |
|
from .DiffAE_diffusion_diffusion import space_timesteps |
|
from typing import Tuple |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from .DiffAE_support_config_base import BaseConfig |
|
from .DiffAE_support_choices import GenerativeType, LossType, ModelMeanType, ModelVarType |
|
from .DiffAE_diffusion_base import get_named_beta_schedule |
|
from .DiffAE_support_choices import * |
|
from .DiffAE_diffusion_diffusion import SpacedDiffusionBeatGansConfig |
|
from multiprocessing import get_context |
|
import os |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
from dataclasses import dataclass |
|
|
|
data_paths = { |
|
'ffhqlmdb256': |
|
os.path.expanduser('datasets/ffhq256.lmdb'), |
|
|
|
'celeba': |
|
os.path.expanduser('datasets/celeba'), |
|
|
|
'celebalmdb': |
|
os.path.expanduser('datasets/celeba.lmdb'), |
|
'celebahq': |
|
os.path.expanduser('datasets/celebahq256.lmdb'), |
|
'horse256': |
|
os.path.expanduser('datasets/horse256.lmdb'), |
|
'bedroom256': |
|
os.path.expanduser('datasets/bedroom256.lmdb'), |
|
'celeba_anno': |
|
os.path.expanduser('datasets/celeba_anno/list_attr_celeba.txt'), |
|
'celebahq_anno': |
|
os.path.expanduser( |
|
'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'), |
|
'celeba_relight': |
|
os.path.expanduser('datasets/celeba_hq_light/celeba_light.txt'), |
|
} |
|
|
|
|
|
@dataclass |
|
class PretrainConfig(BaseConfig): |
|
name: str |
|
path: str |
|
|
|
|
|
@dataclass |
|
class TrainConfig(BaseConfig): |
|
|
|
n_dims: int = 2 |
|
in_channels: int = 3 |
|
out_channels: int = 3 |
|
group_norm_limit: int = 32 |
|
|
|
|
|
seed: int = 0 |
|
train_mode: TrainMode = TrainMode.diffusion |
|
train_cond0_prob: float = 0 |
|
train_pred_xstart_detach: bool = True |
|
train_interpolate_prob: float = 0 |
|
train_interpolate_img: bool = False |
|
manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all |
|
manipulate_cls: str = None |
|
manipulate_shots: int = None |
|
manipulate_loss: ManipulateLossType = ManipulateLossType.bce |
|
manipulate_znormalize: bool = False |
|
manipulate_seed: int = 0 |
|
accum_batches: int = 1 |
|
autoenc_mid_attn: bool = True |
|
batch_size: int = 16 |
|
batch_size_eval: int = None |
|
beatgans_gen_type: GenerativeType = GenerativeType.ddim |
|
beatgans_loss_type: LossType = LossType.mse |
|
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps |
|
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large |
|
beatgans_rescale_timesteps: bool = False |
|
latent_infer_path: str = None |
|
latent_znormalize: bool = False |
|
latent_gen_type: GenerativeType = GenerativeType.ddim |
|
latent_loss_type: LossType = LossType.mse |
|
latent_model_mean_type: ModelMeanType = ModelMeanType.eps |
|
latent_model_var_type: ModelVarType = ModelVarType.fixed_large |
|
latent_rescale_timesteps: bool = False |
|
latent_T_eval: int = 1_000 |
|
latent_clip_sample: bool = False |
|
latent_beta_scheduler: str = 'linear' |
|
beta_scheduler: str = 'linear' |
|
data_name: str = '' |
|
data_val_name: str = None |
|
diffusion_type: str = None |
|
dropout: float = 0.1 |
|
ema_decay: float = 0.9999 |
|
eval_num_images: int = 5_000 |
|
eval_every_samples: int = 200_000 |
|
eval_ema_every_samples: int = 200_000 |
|
fid_use_torch: bool = True |
|
fp16: bool = False |
|
grad_clip: float = 1 |
|
img_size: int = 64 |
|
lr: float = 0.0001 |
|
optimizer: OptimizerType = OptimizerType.adam |
|
weight_decay: float = 0 |
|
model_conf: ModelConfig = None |
|
model_name: ModelName = None |
|
model_type: ModelType = None |
|
net_attn: Tuple[int] = None |
|
net_beatgans_attn_head: int = 1 |
|
|
|
net_beatgans_embed_channels: int = 512 |
|
net_resblock_updown: bool = True |
|
net_enc_use_time: bool = False |
|
net_enc_pool: str = 'adaptivenonzero' |
|
net_beatgans_gradient_checkpoint: bool = False |
|
net_beatgans_resnet_two_cond: bool = False |
|
net_beatgans_resnet_use_zero_module: bool = True |
|
net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm |
|
net_beatgans_resnet_cond_channels: int = None |
|
net_ch_mult: Tuple[int] = None |
|
net_ch: int = 64 |
|
net_enc_attn: Tuple[int] = None |
|
net_enc_k: int = None |
|
|
|
net_enc_num_res_blocks: int = 2 |
|
net_enc_channel_mult: Tuple[int] = None |
|
net_enc_grad_checkpoint: bool = False |
|
net_autoenc_stochastic: bool = False |
|
net_latent_activation: Activation = Activation.silu |
|
net_latent_channel_mult: Tuple[int] = (1, 2, 4) |
|
net_latent_condition_bias: float = 0 |
|
net_latent_dropout: float = 0 |
|
net_latent_layers: int = None |
|
net_latent_net_last_act: Activation = Activation.none |
|
net_latent_net_type: LatentNetType = LatentNetType.none |
|
net_latent_num_hid_channels: int = 1024 |
|
net_latent_num_time_layers: int = 2 |
|
net_latent_skip_layers: Tuple[int] = None |
|
net_latent_time_emb_channels: int = 64 |
|
net_latent_use_norm: bool = False |
|
net_latent_time_last_act: bool = False |
|
net_num_res_blocks: int = 2 |
|
|
|
net_num_input_res_blocks: int = None |
|
net_enc_num_cls: int = None |
|
num_workers: int = 4 |
|
parallel: bool = False |
|
postfix: str = '' |
|
sample_size: int = 64 |
|
sample_every_samples: int = 20_000 |
|
save_every_samples: int = 100_000 |
|
style_ch: int = 512 |
|
T_eval: int = 1_000 |
|
T_sampler: str = 'uniform' |
|
T: int = 1_000 |
|
total_samples: int = 10_000_000 |
|
warmup: int = 0 |
|
pretrain: PretrainConfig = None |
|
continue_from: PretrainConfig = None |
|
eval_programs: Tuple[str] = None |
|
|
|
eval_path: str = None |
|
base_dir: str = 'checkpoints' |
|
use_cache_dataset: bool = False |
|
data_cache_dir: str = os.path.expanduser('~/cache') |
|
work_cache_dir: str = os.path.expanduser('~/mycache') |
|
|
|
name: str = '' |
|
|
|
def refresh_values(self): |
|
self.img_size = max(self.input_shape) |
|
self.n_dims = 3 if self.is3D else 2 |
|
self.group_norm_limit = min(32, self.net_ch) |
|
|
|
def __post_init__(self): |
|
self.batch_size_eval = self.batch_size_eval or self.batch_size |
|
self.data_val_name = self.data_val_name or self.data_name |
|
|
|
def scale_up_gpus(self, num_gpus, num_nodes=1): |
|
self.eval_ema_every_samples *= num_gpus * num_nodes |
|
self.eval_every_samples *= num_gpus * num_nodes |
|
self.sample_every_samples *= num_gpus * num_nodes |
|
self.batch_size *= num_gpus * num_nodes |
|
self.batch_size_eval *= num_gpus * num_nodes |
|
return self |
|
|
|
@property |
|
def batch_size_effective(self): |
|
return self.batch_size * self.accum_batches |
|
|
|
@property |
|
def fid_cache(self): |
|
|
|
|
|
return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}' |
|
|
|
@property |
|
def data_path(self): |
|
|
|
path = data_paths[self.data_name] |
|
if self.use_cache_dataset and path is not None: |
|
path = use_cached_dataset_path( |
|
path, f'{self.data_cache_dir}/{self.data_name}') |
|
return path |
|
|
|
@property |
|
def logdir(self): |
|
return f'{self.base_dir}/{self.name}' |
|
|
|
@property |
|
def generate_dir(self): |
|
|
|
|
|
return f'{self.work_cache_dir}/gen_images/{self.name}' |
|
|
|
def _make_diffusion_conf(self, T=None): |
|
if self.diffusion_type == 'beatgans': |
|
|
|
|
|
|
|
if self.beatgans_gen_type == GenerativeType.ddpm: |
|
section_counts = [T] |
|
elif self.beatgans_gen_type == GenerativeType.ddim: |
|
section_counts = f'ddim{T}' |
|
else: |
|
raise NotImplementedError() |
|
|
|
return SpacedDiffusionBeatGansConfig( |
|
gen_type=self.beatgans_gen_type, |
|
model_type=self.model_type, |
|
betas=get_named_beta_schedule(self.beta_scheduler, self.T), |
|
model_mean_type=self.beatgans_model_mean_type, |
|
model_var_type=self.beatgans_model_var_type, |
|
loss_type=self.beatgans_loss_type, |
|
rescale_timesteps=self.beatgans_rescale_timesteps, |
|
use_timesteps=space_timesteps(num_timesteps=self.T, |
|
section_counts=section_counts), |
|
fp16=self.fp16, |
|
) |
|
else: |
|
raise NotImplementedError() |
|
|
|
def _make_latent_diffusion_conf(self, T=None): |
|
|
|
|
|
|
|
if self.latent_gen_type == GenerativeType.ddpm: |
|
section_counts = [T] |
|
elif self.latent_gen_type == GenerativeType.ddim: |
|
section_counts = f'ddim{T}' |
|
else: |
|
raise NotImplementedError() |
|
|
|
return SpacedDiffusionBeatGansConfig( |
|
train_pred_xstart_detach=self.train_pred_xstart_detach, |
|
gen_type=self.latent_gen_type, |
|
|
|
model_type=ModelType.ddpm, |
|
|
|
betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T), |
|
model_mean_type=self.latent_model_mean_type, |
|
model_var_type=self.latent_model_var_type, |
|
loss_type=self.latent_loss_type, |
|
rescale_timesteps=self.latent_rescale_timesteps, |
|
use_timesteps=space_timesteps(num_timesteps=self.T, |
|
section_counts=section_counts), |
|
fp16=self.fp16, |
|
) |
|
|
|
@property |
|
def model_out_channels(self): |
|
return self.out_channels |
|
|
|
def make_T_sampler(self): |
|
if self.T_sampler == 'uniform': |
|
return UniformSampler(self.T) |
|
else: |
|
raise NotImplementedError() |
|
|
|
def make_diffusion_conf(self): |
|
return self._make_diffusion_conf(self.T) |
|
|
|
def make_eval_diffusion_conf(self): |
|
return self._make_diffusion_conf(T=self.T_eval) |
|
|
|
def make_latent_diffusion_conf(self): |
|
return self._make_latent_diffusion_conf(T=self.T) |
|
|
|
def make_latent_eval_diffusion_conf(self): |
|
|
|
return self._make_latent_diffusion_conf(T=self.latent_T_eval) |
|
|
|
def make_dataset(self, path=None, **kwargs): |
|
if self.data_name == 'ffhqlmdb256': |
|
return FFHQlmdb(path=path or self.data_path, |
|
image_size=self.img_size, |
|
**kwargs) |
|
elif self.data_name == 'horse256': |
|
return Horse_lmdb(path=path or self.data_path, |
|
image_size=self.img_size, |
|
**kwargs) |
|
elif self.data_name == 'bedroom256': |
|
return Horse_lmdb(path=path or self.data_path, |
|
image_size=self.img_size, |
|
**kwargs) |
|
elif self.data_name == 'celebalmdb': |
|
|
|
return CelebAlmdb(path=path or self.data_path, |
|
image_size=self.img_size, |
|
original_resolution=None, |
|
crop_d2c=True, |
|
**kwargs) |
|
else: |
|
raise NotImplementedError() |
|
|
|
def make_loader(self, |
|
dataset, |
|
shuffle: bool, |
|
num_worker: bool = None, |
|
drop_last: bool = True, |
|
batch_size: int = None, |
|
parallel: bool = False): |
|
if parallel and distributed.is_initialized(): |
|
|
|
sampler = DistributedSampler(dataset, |
|
shuffle=shuffle, |
|
drop_last=True) |
|
else: |
|
sampler = None |
|
return DataLoader( |
|
dataset, |
|
batch_size=batch_size or self.batch_size, |
|
sampler=sampler, |
|
|
|
shuffle=False if sampler else shuffle, |
|
num_workers=num_worker or self.num_workers, |
|
pin_memory=True, |
|
drop_last=drop_last, |
|
multiprocessing_context=get_context('fork'), |
|
) |
|
|
|
def make_model_conf(self): |
|
if self.model_name == ModelName.beatgans_ddpm: |
|
self.model_type = ModelType.ddpm |
|
self.model_conf = BeatGANsUNetConfig( |
|
attention_resolutions=self.net_attn, |
|
channel_mult=self.net_ch_mult, |
|
conv_resample=True, |
|
group_norm_limit=self.group_norm_limit, |
|
dims=self.n_dims, |
|
dropout=self.dropout, |
|
embed_channels=self.net_beatgans_embed_channels, |
|
image_size=self.img_size, |
|
in_channels=self.in_channels, |
|
model_channels=self.net_ch, |
|
num_classes=None, |
|
num_head_channels=-1, |
|
num_heads_upsample=-1, |
|
num_heads=self.net_beatgans_attn_head, |
|
num_res_blocks=self.net_num_res_blocks, |
|
num_input_res_blocks=self.net_num_input_res_blocks, |
|
out_channels=self.model_out_channels, |
|
resblock_updown=self.net_resblock_updown, |
|
use_checkpoint=self.net_beatgans_gradient_checkpoint, |
|
use_new_attention_order=False, |
|
resnet_two_cond=self.net_beatgans_resnet_two_cond, |
|
resnet_use_zero_module=self. |
|
net_beatgans_resnet_use_zero_module, |
|
) |
|
elif self.model_name in [ |
|
ModelName.beatgans_autoenc, |
|
]: |
|
cls = BeatGANsAutoencConfig |
|
|
|
if self.model_name == ModelName.beatgans_autoenc: |
|
self.model_type = ModelType.autoencoder |
|
else: |
|
raise NotImplementedError() |
|
|
|
if self.net_latent_net_type == LatentNetType.none: |
|
latent_net_conf = None |
|
elif self.net_latent_net_type == LatentNetType.skip: |
|
latent_net_conf = MLPSkipNetConfig( |
|
num_channels=self.style_ch, |
|
skip_layers=self.net_latent_skip_layers, |
|
num_hid_channels=self.net_latent_num_hid_channels, |
|
num_layers=self.net_latent_layers, |
|
num_time_emb_channels=self.net_latent_time_emb_channels, |
|
activation=self.net_latent_activation, |
|
use_norm=self.net_latent_use_norm, |
|
condition_bias=self.net_latent_condition_bias, |
|
dropout=self.net_latent_dropout, |
|
last_act=self.net_latent_net_last_act, |
|
num_time_layers=self.net_latent_num_time_layers, |
|
time_last_act=self.net_latent_time_last_act, |
|
) |
|
else: |
|
raise NotImplementedError() |
|
|
|
self.model_conf = cls( |
|
attention_resolutions=self.net_attn, |
|
channel_mult=self.net_ch_mult, |
|
conv_resample=True, |
|
group_norm_limit=self.group_norm_limit, |
|
dims=self.n_dims, |
|
dropout=self.dropout, |
|
embed_channels=self.net_beatgans_embed_channels, |
|
enc_out_channels=self.style_ch, |
|
enc_pool=self.net_enc_pool, |
|
enc_num_res_block=self.net_enc_num_res_blocks, |
|
enc_channel_mult=self.net_enc_channel_mult, |
|
enc_grad_checkpoint=self.net_enc_grad_checkpoint, |
|
enc_attn_resolutions=self.net_enc_attn, |
|
image_size=self.img_size, |
|
in_channels=self.in_channels, |
|
model_channels=self.net_ch, |
|
num_classes=None, |
|
num_head_channels=-1, |
|
num_heads_upsample=-1, |
|
num_heads=self.net_beatgans_attn_head, |
|
num_res_blocks=self.net_num_res_blocks, |
|
num_input_res_blocks=self.net_num_input_res_blocks, |
|
out_channels=self.model_out_channels, |
|
resblock_updown=self.net_resblock_updown, |
|
use_checkpoint=self.net_beatgans_gradient_checkpoint, |
|
use_new_attention_order=False, |
|
resnet_two_cond=self.net_beatgans_resnet_two_cond, |
|
resnet_use_zero_module=self. |
|
net_beatgans_resnet_use_zero_module, |
|
latent_net_conf=latent_net_conf, |
|
resnet_cond_channels=self.net_beatgans_resnet_cond_channels, |
|
) |
|
else: |
|
raise NotImplementedError(self.model_name) |
|
|
|
return self.model_conf |
|
|