Spaces:
Paused
Paused
from model.unet import ScaleAt | |
from model.latentnet import * | |
from diffusion.resample import UniformSampler | |
from diffusion.diffusion import space_timesteps | |
from typing import Tuple | |
from torch.utils.data import DataLoader | |
from config_base import BaseConfig | |
from diffusion import * | |
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule | |
from model import * | |
from choices import * | |
from multiprocessing import get_context | |
import os | |
from dataset_util import * | |
from torch.utils.data.distributed import DistributedSampler | |
from dataset import LatentDataLoader | |
class PretrainConfig(BaseConfig): | |
name: str | |
path: str | |
class TrainConfig(BaseConfig): | |
# random seed | |
seed: int = 0 | |
train_mode: TrainMode = TrainMode.diffusion | |
train_cond0_prob: float = 0 | |
train_pred_xstart_detach: bool = True | |
train_interpolate_prob: float = 0 | |
train_interpolate_img: bool = False | |
manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all | |
manipulate_cls: str = None | |
manipulate_shots: int = None | |
manipulate_loss: ManipulateLossType = ManipulateLossType.bce | |
manipulate_znormalize: bool = False | |
manipulate_seed: int = 0 | |
accum_batches: int = 1 | |
autoenc_mid_attn: bool = True | |
batch_size: int = 16 | |
batch_size_eval: int = None | |
beatgans_gen_type: GenerativeType = GenerativeType.ddim | |
beatgans_loss_type: LossType = LossType.mse | |
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps | |
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large | |
beatgans_rescale_timesteps: bool = False | |
latent_infer_path: str = None | |
latent_znormalize: bool = False | |
latent_gen_type: GenerativeType = GenerativeType.ddim | |
latent_loss_type: LossType = LossType.mse | |
latent_model_mean_type: ModelMeanType = ModelMeanType.eps | |
latent_model_var_type: ModelVarType = ModelVarType.fixed_large | |
latent_rescale_timesteps: bool = False | |
latent_T_eval: int = 1_000 | |
latent_clip_sample: bool = False | |
latent_beta_scheduler: str = 'linear' | |
beta_scheduler: str = 'linear' | |
data_name: str = '' | |
data_val_name: str = None | |
diffusion_type: str = None | |
dropout: float = 0.1 | |
ema_decay: float = 0.9999 | |
eval_num_images: int = 5_000 | |
eval_every_samples: int = 200_000 | |
eval_ema_every_samples: int = 200_000 | |
fid_use_torch: bool = True | |
fp16: bool = False | |
grad_clip: float = 1 | |
img_size: int = 64 | |
lr: float = 0.0001 | |
optimizer: OptimizerType = OptimizerType.adam | |
weight_decay: float = 0 | |
model_conf: ModelConfig = None | |
model_name: ModelName = None | |
model_type: ModelType = None | |
net_attn: Tuple[int] = None | |
net_beatgans_attn_head: int = 1 | |
# not necessarily the same as the the number of style channels | |
net_beatgans_embed_channels: int = 512 | |
net_resblock_updown: bool = True | |
net_enc_use_time: bool = False | |
net_enc_pool: str = 'adaptivenonzero' | |
net_beatgans_gradient_checkpoint: bool = False | |
net_beatgans_resnet_two_cond: bool = False | |
net_beatgans_resnet_use_zero_module: bool = True | |
net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm | |
net_beatgans_resnet_cond_channels: int = None | |
net_ch_mult: Tuple[int] = None | |
net_ch: int = 64 | |
net_enc_attn: Tuple[int] = None | |
net_enc_k: int = None | |
# number of resblocks for the encoder (half-unet) | |
net_enc_num_res_blocks: int = 2 | |
net_enc_channel_mult: Tuple[int] = None | |
net_enc_grad_checkpoint: bool = False | |
net_autoenc_stochastic: bool = False | |
net_latent_activation: Activation = Activation.silu | |
net_latent_channel_mult: Tuple[int] = (1, 2, 4) | |
net_latent_condition_bias: float = 0 | |
net_latent_dropout: float = 0 | |
net_latent_layers: int = None | |
net_latent_net_last_act: Activation = Activation.none | |
net_latent_net_type: LatentNetType = LatentNetType.none | |
net_latent_num_hid_channels: int = 1024 | |
net_latent_num_time_layers: int = 2 | |
net_latent_skip_layers: Tuple[int] = None | |
net_latent_time_emb_channels: int = 64 | |
net_latent_use_norm: bool = False | |
net_latent_time_last_act: bool = False | |
net_num_res_blocks: int = 2 | |
# number of resblocks for the UNET | |
net_num_input_res_blocks: int = None | |
net_enc_num_cls: int = None | |
num_workers: int = 4 | |
parallel: bool = False | |
postfix: str = '' | |
sample_size: int = 64 | |
sample_every_samples: int = 20_000 | |
save_every_samples: int = 100_000 | |
style_ch: int = 512 | |
T_eval: int = 1_000 | |
T_sampler: str = 'uniform' | |
T: int = 1_000 | |
total_samples: int = 10_000_000 | |
warmup: int = 0 | |
pretrain: PretrainConfig = None | |
continue_from: PretrainConfig = None | |
eval_programs: Tuple[str] = None | |
# if present load the checkpoint from this path instead | |
eval_path: str = None | |
base_dir: str = 'checkpoints' | |
use_cache_dataset: bool = False | |
data_cache_dir: str = os.path.expanduser('~/cache') | |
work_cache_dir: str = os.path.expanduser('~/mycache') | |
# to be overridden | |
name: str = '' | |
def __post_init__(self): | |
self.batch_size_eval = self.batch_size_eval or self.batch_size | |
self.data_val_name = self.data_val_name or self.data_name | |
def scale_up_gpus(self, num_gpus, num_nodes=1): | |
self.eval_ema_every_samples *= num_gpus * num_nodes | |
self.eval_every_samples *= num_gpus * num_nodes | |
self.sample_every_samples *= num_gpus * num_nodes | |
self.batch_size *= num_gpus * num_nodes | |
self.batch_size_eval *= num_gpus * num_nodes | |
return self | |
def batch_size_effective(self): | |
return self.batch_size * self.accum_batches | |
def fid_cache(self): | |
# we try to use the local dirs to reduce the load over network drives | |
# hopefully, this would reduce the disconnection problems with sshfs | |
return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}' | |
def data_path(self): | |
# may use the cache dir | |
path = data_paths[self.data_name] | |
if self.use_cache_dataset and path is not None: | |
path = use_cached_dataset_path( | |
path, f'{self.data_cache_dir}/{self.data_name}') | |
return path | |
def logdir(self): | |
return f'{self.base_dir}/{self.name}' | |
def generate_dir(self): | |
# we try to use the local dirs to reduce the load over network drives | |
# hopefully, this would reduce the disconnection problems with sshfs | |
return f'{self.work_cache_dir}/gen_images/{self.name}' | |
def _make_diffusion_conf(self, T=None): | |
if self.diffusion_type == 'beatgans': | |
# can use T < self.T for evaluation | |
# follows the guided-diffusion repo conventions | |
# t's are evenly spaced | |
if self.beatgans_gen_type == GenerativeType.ddpm: | |
section_counts = [T] | |
elif self.beatgans_gen_type == GenerativeType.ddim: | |
section_counts = f'ddim{T}' | |
else: | |
raise NotImplementedError() | |
return SpacedDiffusionBeatGansConfig( | |
gen_type=self.beatgans_gen_type, | |
model_type=self.model_type, | |
betas=get_named_beta_schedule(self.beta_scheduler, self.T), | |
model_mean_type=self.beatgans_model_mean_type, | |
model_var_type=self.beatgans_model_var_type, | |
loss_type=self.beatgans_loss_type, | |
rescale_timesteps=self.beatgans_rescale_timesteps, | |
use_timesteps=space_timesteps(num_timesteps=self.T, | |
section_counts=section_counts), | |
fp16=self.fp16, | |
) | |
else: | |
raise NotImplementedError() | |
def _make_latent_diffusion_conf(self, T=None): | |
# can use T < self.T for evaluation | |
# follows the guided-diffusion repo conventions | |
# t's are evenly spaced | |
if self.latent_gen_type == GenerativeType.ddpm: | |
section_counts = [T] | |
elif self.latent_gen_type == GenerativeType.ddim: | |
section_counts = f'ddim{T}' | |
else: | |
raise NotImplementedError() | |
return SpacedDiffusionBeatGansConfig( | |
train_pred_xstart_detach=self.train_pred_xstart_detach, | |
gen_type=self.latent_gen_type, | |
# latent's model is always ddpm | |
model_type=ModelType.ddpm, | |
# latent shares the beta scheduler and full T | |
betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T), | |
model_mean_type=self.latent_model_mean_type, | |
model_var_type=self.latent_model_var_type, | |
loss_type=self.latent_loss_type, | |
rescale_timesteps=self.latent_rescale_timesteps, | |
use_timesteps=space_timesteps(num_timesteps=self.T, | |
section_counts=section_counts), | |
fp16=self.fp16, | |
) | |
def model_out_channels(self): | |
return 3 | |
def make_T_sampler(self): | |
if self.T_sampler == 'uniform': | |
return UniformSampler(self.T) | |
else: | |
raise NotImplementedError() | |
def make_diffusion_conf(self): | |
return self._make_diffusion_conf(self.T) | |
def make_eval_diffusion_conf(self): | |
return self._make_diffusion_conf(T=self.T_eval) | |
def make_latent_diffusion_conf(self): | |
return self._make_latent_diffusion_conf(T=self.T) | |
def make_latent_eval_diffusion_conf(self): | |
# latent can have different eval T | |
return self._make_latent_diffusion_conf(T=self.latent_T_eval) | |
def make_dataset(self, path=None, **kwargs): | |
return LatentDataLoader(self.window_size, | |
self.frame_jpgs, | |
self.lmd_feats_prefix, | |
self.audio_prefix, | |
self.raw_audio_prefix, | |
self.motion_latents_prefix, | |
self.pose_prefix, | |
self.db_name, | |
audio_hz=self.audio_hz) | |
def make_loader(self, | |
dataset, | |
shuffle: bool, | |
num_worker: bool = None, | |
drop_last: bool = True, | |
batch_size: int = None, | |
parallel: bool = False): | |
if parallel and distributed.is_initialized(): | |
# drop last to make sure that there is no added special indexes | |
sampler = DistributedSampler(dataset, | |
shuffle=shuffle, | |
drop_last=True) | |
else: | |
sampler = None | |
return DataLoader( | |
dataset, | |
batch_size=batch_size or self.batch_size, | |
sampler=sampler, | |
# with sampler, use the sample instead of this option | |
shuffle=False if sampler else shuffle, | |
num_workers=num_worker or self.num_workers, | |
pin_memory=True, | |
drop_last=drop_last, | |
multiprocessing_context=get_context('fork'), | |
) | |
def make_model_conf(self): | |
if self.model_name == ModelName.beatgans_ddpm: | |
self.model_type = ModelType.ddpm | |
self.model_conf = BeatGANsUNetConfig( | |
attention_resolutions=self.net_attn, | |
channel_mult=self.net_ch_mult, | |
conv_resample=True, | |
dims=2, | |
dropout=self.dropout, | |
embed_channels=self.net_beatgans_embed_channels, | |
image_size=self.img_size, | |
in_channels=3, | |
model_channels=self.net_ch, | |
num_classes=None, | |
num_head_channels=-1, | |
num_heads_upsample=-1, | |
num_heads=self.net_beatgans_attn_head, | |
num_res_blocks=self.net_num_res_blocks, | |
num_input_res_blocks=self.net_num_input_res_blocks, | |
out_channels=self.model_out_channels, | |
resblock_updown=self.net_resblock_updown, | |
use_checkpoint=self.net_beatgans_gradient_checkpoint, | |
use_new_attention_order=False, | |
resnet_two_cond=self.net_beatgans_resnet_two_cond, | |
resnet_use_zero_module=self. | |
net_beatgans_resnet_use_zero_module, | |
) | |
elif self.model_name in [ | |
ModelName.beatgans_autoenc, | |
]: | |
cls = BeatGANsAutoencConfig | |
# supports both autoenc and vaeddpm | |
if self.model_name == ModelName.beatgans_autoenc: | |
self.model_type = ModelType.autoencoder | |
else: | |
raise NotImplementedError() | |
if self.net_latent_net_type == LatentNetType.none: | |
latent_net_conf = None | |
elif self.net_latent_net_type == LatentNetType.skip: | |
latent_net_conf = MLPSkipNetConfig( | |
num_channels=self.style_ch, | |
skip_layers=self.net_latent_skip_layers, | |
num_hid_channels=self.net_latent_num_hid_channels, | |
num_layers=self.net_latent_layers, | |
num_time_emb_channels=self.net_latent_time_emb_channels, | |
activation=self.net_latent_activation, | |
use_norm=self.net_latent_use_norm, | |
condition_bias=self.net_latent_condition_bias, | |
dropout=self.net_latent_dropout, | |
last_act=self.net_latent_net_last_act, | |
num_time_layers=self.net_latent_num_time_layers, | |
time_last_act=self.net_latent_time_last_act, | |
) | |
else: | |
raise NotImplementedError() | |
self.model_conf = cls( | |
attention_resolutions=self.net_attn, | |
channel_mult=self.net_ch_mult, | |
conv_resample=True, | |
dims=2, | |
dropout=self.dropout, | |
embed_channels=self.net_beatgans_embed_channels, | |
enc_out_channels=self.style_ch, | |
enc_pool=self.net_enc_pool, | |
enc_num_res_block=self.net_enc_num_res_blocks, | |
enc_channel_mult=self.net_enc_channel_mult, | |
enc_grad_checkpoint=self.net_enc_grad_checkpoint, | |
enc_attn_resolutions=self.net_enc_attn, | |
image_size=self.img_size, | |
in_channels=3, | |
model_channels=self.net_ch, | |
num_classes=None, | |
num_head_channels=-1, | |
num_heads_upsample=-1, | |
num_heads=self.net_beatgans_attn_head, | |
num_res_blocks=self.net_num_res_blocks, | |
num_input_res_blocks=self.net_num_input_res_blocks, | |
out_channels=self.model_out_channels, | |
resblock_updown=self.net_resblock_updown, | |
use_checkpoint=self.net_beatgans_gradient_checkpoint, | |
use_new_attention_order=False, | |
resnet_two_cond=self.net_beatgans_resnet_two_cond, | |
resnet_use_zero_module=self. | |
net_beatgans_resnet_use_zero_module, | |
latent_net_conf=latent_net_conf, | |
resnet_cond_channels=self.net_beatgans_resnet_cond_channels, | |
) | |
else: | |
raise NotImplementedError(self.model_name) | |
return self.model_conf | |