File size: 4,066 Bytes
ad947b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
from enum import Enum
from torch import nn
class TrainMode(Enum):
# manipulate mode = training the classifier
manipulate = 'manipulate'
# default trainin mode!
diffusion = 'diffusion'
# default latent training mode!
# fitting the a DDPM to a given latent
latent_diffusion = 'latentdiffusion'
def is_manipulate(self):
return self in [
TrainMode.manipulate,
]
def is_diffusion(self):
return self in [
TrainMode.diffusion,
TrainMode.latent_diffusion,
]
def is_autoenc(self):
# the network possibly does autoencoding
return self in [
TrainMode.diffusion,
]
def is_latent_diffusion(self):
return self in [
TrainMode.latent_diffusion,
]
def use_latent_net(self):
return self.is_latent_diffusion()
def require_dataset_infer(self):
"""
whether training in this mode requires the latent variables to be available?
"""
# this will precalculate all the latents before hand
# and the dataset will be all the predicted latents
return self in [
TrainMode.latent_diffusion,
TrainMode.manipulate,
]
class ManipulateMode(Enum):
"""
how to train the classifier to manipulate
"""
# train on whole celeba attr dataset
celebahq_all = 'celebahq_all'
# celeba with D2C's crop
d2c_fewshot = 'd2cfewshot'
d2c_fewshot_allneg = 'd2cfewshotallneg'
def is_celeba_attr(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
ManipulateMode.celebahq_all,
]
def is_single_class(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
]
def is_fewshot(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
]
def is_fewshot_allneg(self):
return self in [
ManipulateMode.d2c_fewshot_allneg,
]
class ModelType(Enum):
"""
Kinds of the backbone models
"""
# unconditional ddpm
ddpm = 'ddpm'
# autoencoding ddpm cannot do unconditional generation
autoencoder = 'autoencoder'
def has_autoenc(self):
return self in [
ModelType.autoencoder,
]
def can_sample(self):
return self in [ModelType.ddpm]
class ModelName(Enum):
"""
List of all supported model classes
"""
beatgans_ddpm = 'beatgans_ddpm'
beatgans_autoenc = 'beatgans_autoenc'
class ModelMeanType(Enum):
"""
Which type of output the model predicts.
"""
eps = 'eps' # the model predicts epsilon
class ModelVarType(Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
# posterior beta_t
fixed_small = 'fixed_small'
# beta_t
fixed_large = 'fixed_large'
class LossType(Enum):
mse = 'mse' # use raw MSE loss (and KL when learning variances)
l1 = 'l1'
class GenerativeType(Enum):
"""
How's a sample generated
"""
ddpm = 'ddpm'
ddim = 'ddim'
class OptimizerType(Enum):
adam = 'adam'
adamw = 'adamw'
class Activation(Enum):
none = 'none'
relu = 'relu'
lrelu = 'lrelu'
silu = 'silu'
tanh = 'tanh'
def get_act(self):
if self == Activation.none:
return nn.Identity()
elif self == Activation.relu:
return nn.ReLU()
elif self == Activation.lrelu:
return nn.LeakyReLU(negative_slope=0.2)
elif self == Activation.silu:
return nn.SiLU()
elif self == Activation.tanh:
return nn.Tanh()
else:
raise NotImplementedError()
class ManipulateLossType(Enum):
bce = 'bce'
mse = 'mse' |