|
from enum import Enum |
|
from torch import nn |
|
|
|
|
|
class TrainMode(Enum): |
|
|
|
manipulate = 'manipulate' |
|
|
|
diffusion = 'diffusion' |
|
|
|
|
|
latent_diffusion = 'latentdiffusion' |
|
|
|
def is_manipulate(self): |
|
return self in [ |
|
TrainMode.manipulate, |
|
] |
|
|
|
def is_diffusion(self): |
|
return self in [ |
|
TrainMode.diffusion, |
|
TrainMode.latent_diffusion, |
|
] |
|
|
|
def is_autoenc(self): |
|
|
|
return self in [ |
|
TrainMode.diffusion, |
|
] |
|
|
|
def is_latent_diffusion(self): |
|
return self in [ |
|
TrainMode.latent_diffusion, |
|
] |
|
|
|
def use_latent_net(self): |
|
return self.is_latent_diffusion() |
|
|
|
def require_dataset_infer(self): |
|
""" |
|
whether training in this mode requires the latent variables to be available? |
|
""" |
|
|
|
|
|
return self in [ |
|
TrainMode.latent_diffusion, |
|
TrainMode.manipulate, |
|
] |
|
|
|
|
|
class ManipulateMode(Enum): |
|
""" |
|
how to train the classifier to manipulate |
|
""" |
|
|
|
celebahq_all = 'celebahq_all' |
|
|
|
d2c_fewshot = 'd2cfewshot' |
|
d2c_fewshot_allneg = 'd2cfewshotallneg' |
|
|
|
def is_celeba_attr(self): |
|
return self in [ |
|
ManipulateMode.d2c_fewshot, |
|
ManipulateMode.d2c_fewshot_allneg, |
|
ManipulateMode.celebahq_all, |
|
] |
|
|
|
def is_single_class(self): |
|
return self in [ |
|
ManipulateMode.d2c_fewshot, |
|
ManipulateMode.d2c_fewshot_allneg, |
|
] |
|
|
|
def is_fewshot(self): |
|
return self in [ |
|
ManipulateMode.d2c_fewshot, |
|
ManipulateMode.d2c_fewshot_allneg, |
|
] |
|
|
|
def is_fewshot_allneg(self): |
|
return self in [ |
|
ManipulateMode.d2c_fewshot_allneg, |
|
] |
|
|
|
|
|
class ModelType(Enum): |
|
""" |
|
Kinds of the backbone models |
|
""" |
|
|
|
|
|
ddpm = 'ddpm' |
|
|
|
autoencoder = 'autoencoder' |
|
|
|
def has_autoenc(self): |
|
return self in [ |
|
ModelType.autoencoder, |
|
] |
|
|
|
def can_sample(self): |
|
return self in [ModelType.ddpm] |
|
|
|
|
|
class ModelName(Enum): |
|
""" |
|
List of all supported model classes |
|
""" |
|
|
|
beatgans_ddpm = 'beatgans_ddpm' |
|
beatgans_autoenc = 'beatgans_autoenc' |
|
|
|
|
|
class ModelMeanType(Enum): |
|
""" |
|
Which type of output the model predicts. |
|
""" |
|
|
|
eps = 'eps' |
|
|
|
|
|
class ModelVarType(Enum): |
|
""" |
|
What is used as the model's output variance. |
|
|
|
The LEARNED_RANGE option has been added to allow the model to predict |
|
values between FIXED_SMALL and FIXED_LARGE, making its job easier. |
|
""" |
|
|
|
|
|
fixed_small = 'fixed_small' |
|
|
|
fixed_large = 'fixed_large' |
|
|
|
|
|
class LossType(Enum): |
|
mse = 'mse' |
|
l1 = 'l1' |
|
|
|
|
|
class GenerativeType(Enum): |
|
""" |
|
How's a sample generated |
|
""" |
|
|
|
ddpm = 'ddpm' |
|
ddim = 'ddim' |
|
|
|
|
|
class OptimizerType(Enum): |
|
adam = 'adam' |
|
adamw = 'adamw' |
|
|
|
|
|
class Activation(Enum): |
|
none = 'none' |
|
relu = 'relu' |
|
lrelu = 'lrelu' |
|
silu = 'silu' |
|
tanh = 'tanh' |
|
|
|
def get_act(self): |
|
if self == Activation.none: |
|
return nn.Identity() |
|
elif self == Activation.relu: |
|
return nn.ReLU() |
|
elif self == Activation.lrelu: |
|
return nn.LeakyReLU(negative_slope=0.2) |
|
elif self == Activation.silu: |
|
return nn.SiLU() |
|
elif self == Activation.tanh: |
|
return nn.Tanh() |
|
else: |
|
raise NotImplementedError() |
|
|
|
|
|
class ManipulateLossType(Enum): |
|
bce = 'bce' |
|
mse = 'mse' |