UKBBLatent_Cardiac_20208_DiffAE3D_L128_S42 / DiffAE_support_choices.py
soumickmj's picture
Upload DiffAE
c2ced9d verified
raw
history blame
4.07 kB
from enum import Enum
from torch import nn
class TrainMode(Enum):
# manipulate mode = training the classifier
manipulate = 'manipulate'
# default trainin mode!
diffusion = 'diffusion'
# default latent training mode!
# fitting the a DDPM to a given latent
latent_diffusion = 'latentdiffusion'
def is_manipulate(self):
return self in [
TrainMode.manipulate,
]
def is_diffusion(self):
return self in [
TrainMode.diffusion,
TrainMode.latent_diffusion,
]
def is_autoenc(self):
# the network possibly does autoencoding
return self in [
TrainMode.diffusion,
]
def is_latent_diffusion(self):
return self in [
TrainMode.latent_diffusion,
]
def use_latent_net(self):
return self.is_latent_diffusion()
def require_dataset_infer(self):
"""
whether training in this mode requires the latent variables to be available?
"""
# this will precalculate all the latents before hand
# and the dataset will be all the predicted latents
return self in [
TrainMode.latent_diffusion,
TrainMode.manipulate,
]
class ManipulateMode(Enum):
"""
how to train the classifier to manipulate
"""
# train on whole celeba attr dataset
celebahq_all = 'celebahq_all'
# celeba with D2C's crop
d2c_fewshot = 'd2cfewshot'
d2c_fewshot_allneg = 'd2cfewshotallneg'
def is_celeba_attr(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
ManipulateMode.celebahq_all,
]
def is_single_class(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
]
def is_fewshot(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
]
def is_fewshot_allneg(self):
return self in [
ManipulateMode.d2c_fewshot_allneg,
]
class ModelType(Enum):
"""
Kinds of the backbone models
"""
# unconditional ddpm
ddpm = 'ddpm'
# autoencoding ddpm cannot do unconditional generation
autoencoder = 'autoencoder'
def has_autoenc(self):
return self in [
ModelType.autoencoder,
]
def can_sample(self):
return self in [ModelType.ddpm]
class ModelName(Enum):
"""
List of all supported model classes
"""
beatgans_ddpm = 'beatgans_ddpm'
beatgans_autoenc = 'beatgans_autoenc'
class ModelMeanType(Enum):
"""
Which type of output the model predicts.
"""
eps = 'eps' # the model predicts epsilon
class ModelVarType(Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
# posterior beta_t
fixed_small = 'fixed_small'
# beta_t
fixed_large = 'fixed_large'
class LossType(Enum):
mse = 'mse' # use raw MSE loss (and KL when learning variances)
l1 = 'l1'
class GenerativeType(Enum):
"""
How's a sample generated
"""
ddpm = 'ddpm'
ddim = 'ddim'
class OptimizerType(Enum):
adam = 'adam'
adamw = 'adamw'
class Activation(Enum):
none = 'none'
relu = 'relu'
lrelu = 'lrelu'
silu = 'silu'
tanh = 'tanh'
def get_act(self):
if self == Activation.none:
return nn.Identity()
elif self == Activation.relu:
return nn.ReLU()
elif self == Activation.lrelu:
return nn.LeakyReLU(negative_slope=0.2)
elif self == Activation.silu:
return nn.SiLU()
elif self == Activation.tanh:
return nn.Tanh()
else:
raise NotImplementedError()
class ManipulateLossType(Enum):
bce = 'bce'
mse = 'mse'