UKBBLatent_Cardiac_20208_DiffAE3D_L128_S42 / DiffAE_support_config.py
soumickmj's picture
Upload DiffAE
c2ced9d verified
raw
history blame
17.5 kB
from .DiffAE_model_blocks import ScaleAt
from .DiffAE_model import *
from .DiffAE_diffusion_resample import UniformSampler
from .DiffAE_diffusion_diffusion import space_timesteps
from typing import Tuple
from torch.utils.data import DataLoader
from .DiffAE_support_config_base import BaseConfig
from .DiffAE_support_choices import GenerativeType, LossType, ModelMeanType, ModelVarType
from .DiffAE_diffusion_base import get_named_beta_schedule
from .DiffAE_support_choices import *
from .DiffAE_diffusion_diffusion import SpacedDiffusionBeatGansConfig
from multiprocessing import get_context
import os
from torch.utils.data.distributed import DistributedSampler
from dataclasses import dataclass
data_paths = {
'ffhqlmdb256':
os.path.expanduser('datasets/ffhq256.lmdb'),
# used for training a classifier
'celeba':
os.path.expanduser('datasets/celeba'),
# used for training DPM models
'celebalmdb':
os.path.expanduser('datasets/celeba.lmdb'),
'celebahq':
os.path.expanduser('datasets/celebahq256.lmdb'),
'horse256':
os.path.expanduser('datasets/horse256.lmdb'),
'bedroom256':
os.path.expanduser('datasets/bedroom256.lmdb'),
'celeba_anno':
os.path.expanduser('datasets/celeba_anno/list_attr_celeba.txt'),
'celebahq_anno':
os.path.expanduser(
'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
'celeba_relight':
os.path.expanduser('datasets/celeba_hq_light/celeba_light.txt'),
}
@dataclass
class PretrainConfig(BaseConfig):
name: str
path: str
@dataclass
class TrainConfig(BaseConfig):
#new params added (Soumick)
n_dims: int = 2
in_channels: int = 3
out_channels: int = 3
group_norm_limit: int = 32
# random seed
seed: int = 0
train_mode: TrainMode = TrainMode.diffusion
train_cond0_prob: float = 0
train_pred_xstart_detach: bool = True
train_interpolate_prob: float = 0
train_interpolate_img: bool = False
manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all
manipulate_cls: str = None
manipulate_shots: int = None
manipulate_loss: ManipulateLossType = ManipulateLossType.bce
manipulate_znormalize: bool = False
manipulate_seed: int = 0
accum_batches: int = 1
autoenc_mid_attn: bool = True
batch_size: int = 16
batch_size_eval: int = None
beatgans_gen_type: GenerativeType = GenerativeType.ddim
beatgans_loss_type: LossType = LossType.mse
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
beatgans_rescale_timesteps: bool = False
latent_infer_path: str = None
latent_znormalize: bool = False
latent_gen_type: GenerativeType = GenerativeType.ddim
latent_loss_type: LossType = LossType.mse
latent_model_mean_type: ModelMeanType = ModelMeanType.eps
latent_model_var_type: ModelVarType = ModelVarType.fixed_large
latent_rescale_timesteps: bool = False
latent_T_eval: int = 1_000
latent_clip_sample: bool = False
latent_beta_scheduler: str = 'linear'
beta_scheduler: str = 'linear'
data_name: str = ''
data_val_name: str = None
diffusion_type: str = None
dropout: float = 0.1
ema_decay: float = 0.9999
eval_num_images: int = 5_000
eval_every_samples: int = 200_000
eval_ema_every_samples: int = 200_000
fid_use_torch: bool = True
fp16: bool = False
grad_clip: float = 1
img_size: int = 64
lr: float = 0.0001
optimizer: OptimizerType = OptimizerType.adam
weight_decay: float = 0
model_conf: ModelConfig = None
model_name: ModelName = None
model_type: ModelType = None
net_attn: Tuple[int] = None
net_beatgans_attn_head: int = 1
# not necessarily the same as the the number of style channels
net_beatgans_embed_channels: int = 512
net_resblock_updown: bool = True
net_enc_use_time: bool = False
net_enc_pool: str = 'adaptivenonzero'
net_beatgans_gradient_checkpoint: bool = False
net_beatgans_resnet_two_cond: bool = False
net_beatgans_resnet_use_zero_module: bool = True
net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
net_beatgans_resnet_cond_channels: int = None
net_ch_mult: Tuple[int] = None
net_ch: int = 64
net_enc_attn: Tuple[int] = None
net_enc_k: int = None
# number of resblocks for the encoder (half-unet)
net_enc_num_res_blocks: int = 2
net_enc_channel_mult: Tuple[int] = None
net_enc_grad_checkpoint: bool = False
net_autoenc_stochastic: bool = False
net_latent_activation: Activation = Activation.silu
net_latent_channel_mult: Tuple[int] = (1, 2, 4)
net_latent_condition_bias: float = 0
net_latent_dropout: float = 0
net_latent_layers: int = None
net_latent_net_last_act: Activation = Activation.none
net_latent_net_type: LatentNetType = LatentNetType.none
net_latent_num_hid_channels: int = 1024
net_latent_num_time_layers: int = 2
net_latent_skip_layers: Tuple[int] = None
net_latent_time_emb_channels: int = 64
net_latent_use_norm: bool = False
net_latent_time_last_act: bool = False
net_num_res_blocks: int = 2
# number of resblocks for the UNET
net_num_input_res_blocks: int = None
net_enc_num_cls: int = None
num_workers: int = 4
parallel: bool = False
postfix: str = ''
sample_size: int = 64
sample_every_samples: int = 20_000
save_every_samples: int = 100_000
style_ch: int = 512
T_eval: int = 1_000
T_sampler: str = 'uniform'
T: int = 1_000
total_samples: int = 10_000_000
warmup: int = 0
pretrain: PretrainConfig = None
continue_from: PretrainConfig = None
eval_programs: Tuple[str] = None
# if present load the checkpoint from this path instead
eval_path: str = None
base_dir: str = 'checkpoints'
use_cache_dataset: bool = False
data_cache_dir: str = os.path.expanduser('~/cache')
work_cache_dir: str = os.path.expanduser('~/mycache')
# to be overridden
name: str = ''
def refresh_values(self):
self.img_size = max(self.input_shape)
self.n_dims = 3 if self.is3D else 2
self.group_norm_limit = min(32, self.net_ch)
def __post_init__(self):
self.batch_size_eval = self.batch_size_eval or self.batch_size
self.data_val_name = self.data_val_name or self.data_name
def scale_up_gpus(self, num_gpus, num_nodes=1):
self.eval_ema_every_samples *= num_gpus * num_nodes
self.eval_every_samples *= num_gpus * num_nodes
self.sample_every_samples *= num_gpus * num_nodes
self.batch_size *= num_gpus * num_nodes
self.batch_size_eval *= num_gpus * num_nodes
return self
@property
def batch_size_effective(self):
return self.batch_size * self.accum_batches
@property
def fid_cache(self):
# we try to use the local dirs to reduce the load over network drives
# hopefully, this would reduce the disconnection problems with sshfs
return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}'
@property
def data_path(self):
# may use the cache dir
path = data_paths[self.data_name]
if self.use_cache_dataset and path is not None:
path = use_cached_dataset_path(
path, f'{self.data_cache_dir}/{self.data_name}')
return path
@property
def logdir(self):
return f'{self.base_dir}/{self.name}'
@property
def generate_dir(self):
# we try to use the local dirs to reduce the load over network drives
# hopefully, this would reduce the disconnection problems with sshfs
return f'{self.work_cache_dir}/gen_images/{self.name}'
def _make_diffusion_conf(self, T=None):
if self.diffusion_type == 'beatgans':
# can use T < self.T for evaluation
# follows the guided-diffusion repo conventions
# t's are evenly spaced
if self.beatgans_gen_type == GenerativeType.ddpm:
section_counts = [T]
elif self.beatgans_gen_type == GenerativeType.ddim:
section_counts = f'ddim{T}'
else:
raise NotImplementedError()
return SpacedDiffusionBeatGansConfig(
gen_type=self.beatgans_gen_type,
model_type=self.model_type,
betas=get_named_beta_schedule(self.beta_scheduler, self.T),
model_mean_type=self.beatgans_model_mean_type,
model_var_type=self.beatgans_model_var_type,
loss_type=self.beatgans_loss_type,
rescale_timesteps=self.beatgans_rescale_timesteps,
use_timesteps=space_timesteps(num_timesteps=self.T,
section_counts=section_counts),
fp16=self.fp16,
)
else:
raise NotImplementedError()
def _make_latent_diffusion_conf(self, T=None):
# can use T < self.T for evaluation
# follows the guided-diffusion repo conventions
# t's are evenly spaced
if self.latent_gen_type == GenerativeType.ddpm:
section_counts = [T]
elif self.latent_gen_type == GenerativeType.ddim:
section_counts = f'ddim{T}'
else:
raise NotImplementedError()
return SpacedDiffusionBeatGansConfig(
train_pred_xstart_detach=self.train_pred_xstart_detach,
gen_type=self.latent_gen_type,
# latent's model is always ddpm
model_type=ModelType.ddpm,
# latent shares the beta scheduler and full T
betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
model_mean_type=self.latent_model_mean_type,
model_var_type=self.latent_model_var_type,
loss_type=self.latent_loss_type,
rescale_timesteps=self.latent_rescale_timesteps,
use_timesteps=space_timesteps(num_timesteps=self.T,
section_counts=section_counts),
fp16=self.fp16,
)
@property
def model_out_channels(self):
return self.out_channels
def make_T_sampler(self):
if self.T_sampler == 'uniform':
return UniformSampler(self.T)
else:
raise NotImplementedError()
def make_diffusion_conf(self):
return self._make_diffusion_conf(self.T)
def make_eval_diffusion_conf(self):
return self._make_diffusion_conf(T=self.T_eval)
def make_latent_diffusion_conf(self):
return self._make_latent_diffusion_conf(T=self.T)
def make_latent_eval_diffusion_conf(self):
# latent can have different eval T
return self._make_latent_diffusion_conf(T=self.latent_T_eval)
def make_dataset(self, path=None, **kwargs):
if self.data_name == 'ffhqlmdb256':
return FFHQlmdb(path=path or self.data_path,
image_size=self.img_size,
**kwargs)
elif self.data_name == 'horse256':
return Horse_lmdb(path=path or self.data_path,
image_size=self.img_size,
**kwargs)
elif self.data_name == 'bedroom256':
return Horse_lmdb(path=path or self.data_path,
image_size=self.img_size,
**kwargs)
elif self.data_name == 'celebalmdb':
# always use d2c crop
return CelebAlmdb(path=path or self.data_path,
image_size=self.img_size,
original_resolution=None,
crop_d2c=True,
**kwargs)
else:
raise NotImplementedError()
def make_loader(self,
dataset,
shuffle: bool,
num_worker: bool = None,
drop_last: bool = True,
batch_size: int = None,
parallel: bool = False):
if parallel and distributed.is_initialized():
# drop last to make sure that there is no added special indexes
sampler = DistributedSampler(dataset,
shuffle=shuffle,
drop_last=True)
else:
sampler = None
return DataLoader(
dataset,
batch_size=batch_size or self.batch_size,
sampler=sampler,
# with sampler, use the sample instead of this option
shuffle=False if sampler else shuffle,
num_workers=num_worker or self.num_workers,
pin_memory=True,
drop_last=drop_last,
multiprocessing_context=get_context('fork'),
)
def make_model_conf(self):
if self.model_name == ModelName.beatgans_ddpm:
self.model_type = ModelType.ddpm
self.model_conf = BeatGANsUNetConfig(
attention_resolutions=self.net_attn,
channel_mult=self.net_ch_mult,
conv_resample=True,
group_norm_limit=self.group_norm_limit,
dims=self.n_dims,
dropout=self.dropout,
embed_channels=self.net_beatgans_embed_channels,
image_size=self.img_size,
in_channels=self.in_channels,
model_channels=self.net_ch,
num_classes=None,
num_head_channels=-1,
num_heads_upsample=-1,
num_heads=self.net_beatgans_attn_head,
num_res_blocks=self.net_num_res_blocks,
num_input_res_blocks=self.net_num_input_res_blocks,
out_channels=self.model_out_channels,
resblock_updown=self.net_resblock_updown,
use_checkpoint=self.net_beatgans_gradient_checkpoint,
use_new_attention_order=False,
resnet_two_cond=self.net_beatgans_resnet_two_cond,
resnet_use_zero_module=self.
net_beatgans_resnet_use_zero_module,
)
elif self.model_name in [
ModelName.beatgans_autoenc,
]:
cls = BeatGANsAutoencConfig
# supports both autoenc and vaeddpm
if self.model_name == ModelName.beatgans_autoenc:
self.model_type = ModelType.autoencoder
else:
raise NotImplementedError()
if self.net_latent_net_type == LatentNetType.none:
latent_net_conf = None
elif self.net_latent_net_type == LatentNetType.skip:
latent_net_conf = MLPSkipNetConfig(
num_channels=self.style_ch,
skip_layers=self.net_latent_skip_layers,
num_hid_channels=self.net_latent_num_hid_channels,
num_layers=self.net_latent_layers,
num_time_emb_channels=self.net_latent_time_emb_channels,
activation=self.net_latent_activation,
use_norm=self.net_latent_use_norm,
condition_bias=self.net_latent_condition_bias,
dropout=self.net_latent_dropout,
last_act=self.net_latent_net_last_act,
num_time_layers=self.net_latent_num_time_layers,
time_last_act=self.net_latent_time_last_act,
)
else:
raise NotImplementedError()
self.model_conf = cls(
attention_resolutions=self.net_attn,
channel_mult=self.net_ch_mult,
conv_resample=True,
group_norm_limit=self.group_norm_limit,
dims=self.n_dims,
dropout=self.dropout,
embed_channels=self.net_beatgans_embed_channels,
enc_out_channels=self.style_ch,
enc_pool=self.net_enc_pool,
enc_num_res_block=self.net_enc_num_res_blocks,
enc_channel_mult=self.net_enc_channel_mult,
enc_grad_checkpoint=self.net_enc_grad_checkpoint,
enc_attn_resolutions=self.net_enc_attn,
image_size=self.img_size,
in_channels=self.in_channels,
model_channels=self.net_ch,
num_classes=None,
num_head_channels=-1,
num_heads_upsample=-1,
num_heads=self.net_beatgans_attn_head,
num_res_blocks=self.net_num_res_blocks,
num_input_res_blocks=self.net_num_input_res_blocks,
out_channels=self.model_out_channels,
resblock_updown=self.net_resblock_updown,
use_checkpoint=self.net_beatgans_gradient_checkpoint,
use_new_attention_order=False,
resnet_two_cond=self.net_beatgans_resnet_two_cond,
resnet_use_zero_module=self.
net_beatgans_resnet_use_zero_module,
latent_net_conf=latent_net_conf,
resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
)
else:
raise NotImplementedError(self.model_name)
return self.model_conf