Spaces:
Paused
Paused
from enum import Enum | |
from torch import nn | |
class TrainMode(Enum): | |
# manipulate mode = training the classifier | |
manipulate = 'manipulate' | |
# default trainin mode! | |
diffusion = 'diffusion' | |
# default latent training mode! | |
# fitting the a DDPM to a given latent | |
latent_diffusion = 'latentdiffusion' | |
def is_manipulate(self): | |
return self in [ | |
TrainMode.manipulate, | |
] | |
def is_diffusion(self): | |
return self in [ | |
TrainMode.diffusion, | |
TrainMode.latent_diffusion, | |
] | |
def is_autoenc(self): | |
# the network possibly does autoencoding | |
return self in [ | |
TrainMode.diffusion, | |
] | |
def is_latent_diffusion(self): | |
return self in [ | |
TrainMode.latent_diffusion, | |
] | |
def use_latent_net(self): | |
return self.is_latent_diffusion() | |
def require_dataset_infer(self): | |
""" | |
whether training in this mode requires the latent variables to be available? | |
""" | |
# this will precalculate all the latents before hand | |
# and the dataset will be all the predicted latents | |
return self in [ | |
TrainMode.latent_diffusion, | |
TrainMode.manipulate, | |
] | |
class ManipulateMode(Enum): | |
""" | |
how to train the classifier to manipulate | |
""" | |
# train on whole celeba attr dataset | |
celebahq_all = 'celebahq_all' | |
# celeba with D2C's crop | |
d2c_fewshot = 'd2cfewshot' | |
d2c_fewshot_allneg = 'd2cfewshotallneg' | |
def is_celeba_attr(self): | |
return self in [ | |
ManipulateMode.d2c_fewshot, | |
ManipulateMode.d2c_fewshot_allneg, | |
ManipulateMode.celebahq_all, | |
] | |
def is_single_class(self): | |
return self in [ | |
ManipulateMode.d2c_fewshot, | |
ManipulateMode.d2c_fewshot_allneg, | |
] | |
def is_fewshot(self): | |
return self in [ | |
ManipulateMode.d2c_fewshot, | |
ManipulateMode.d2c_fewshot_allneg, | |
] | |
def is_fewshot_allneg(self): | |
return self in [ | |
ManipulateMode.d2c_fewshot_allneg, | |
] | |
class ModelType(Enum): | |
""" | |
Kinds of the backbone models | |
""" | |
# unconditional ddpm | |
ddpm = 'ddpm' | |
# autoencoding ddpm cannot do unconditional generation | |
autoencoder = 'autoencoder' | |
def has_autoenc(self): | |
return self in [ | |
ModelType.autoencoder, | |
] | |
def can_sample(self): | |
return self in [ModelType.ddpm] | |
class ModelName(Enum): | |
""" | |
List of all supported model classes | |
""" | |
beatgans_ddpm = 'beatgans_ddpm' | |
beatgans_autoenc = 'beatgans_autoenc' | |
class ModelMeanType(Enum): | |
""" | |
Which type of output the model predicts. | |
""" | |
eps = 'eps' # the model predicts epsilon | |
class ModelVarType(Enum): | |
""" | |
What is used as the model's output variance. | |
The LEARNED_RANGE option has been added to allow the model to predict | |
values between FIXED_SMALL and FIXED_LARGE, making its job easier. | |
""" | |
# posterior beta_t | |
fixed_small = 'fixed_small' | |
# beta_t | |
fixed_large = 'fixed_large' | |
class LossType(Enum): | |
mse = 'mse' # use raw MSE loss (and KL when learning variances) | |
l1 = 'l1' | |
class GenerativeType(Enum): | |
""" | |
How's a sample generated | |
""" | |
ddpm = 'ddpm' | |
ddim = 'ddim' | |
class OptimizerType(Enum): | |
adam = 'adam' | |
adamw = 'adamw' | |
class Activation(Enum): | |
none = 'none' | |
relu = 'relu' | |
lrelu = 'lrelu' | |
silu = 'silu' | |
tanh = 'tanh' | |
def get_act(self): | |
if self == Activation.none: | |
return nn.Identity() | |
elif self == Activation.relu: | |
return nn.ReLU() | |
elif self == Activation.lrelu: | |
return nn.LeakyReLU(negative_slope=0.2) | |
elif self == Activation.silu: | |
return nn.SiLU() | |
elif self == Activation.tanh: | |
return nn.Tanh() | |
else: | |
raise NotImplementedError() | |
class ManipulateLossType(Enum): | |
bce = 'bce' | |
mse = 'mse' |