soumickmj's picture
Upload DiffAE
485c2ee verified
raw
history blame
2.1 kB
from transformers import PretrainedConfig
class DiffAEConfig(PretrainedConfig):
model_type = "DiffAE"
def __init__(self,
is3D=True,
in_channels=1,
out_channels=1,
latent_dim=128,
net_ch=32,
sample_every_batches=1000, #log samples during training. Set it to 0 to disable
sample_size=4, #Number of samples in the buffer for consistent sampling (batch size of x_T)
test_with_TEval=True,
ampmode="16-mixed",
grey2RGB=-1,
test_emb_only=True,
test_ema=True,
batch_size=9,
# beta_scheduler='linear',
# latent_beta_scheduler='linear',
data_name="ukbb",
diffusion_type = 'beatgans',
# eval_ema_every_samples = 200_000,
# eval_every_samples = 200_000,
lr=0.0001,
# net_beatgans_attn_head = 1,
# net_beatgans_embed_channels = 128,
# net_ch_mult = (1, 1, 2, 3, 4),
# T_eval = 20,
# latent_T_eval=1000,
# group_norm_limit=32,
seed=1701,
input_shape=(50, 128, 128),
# dropout=0.1,
**kwargs):
self.is3D = is3D
self.in_channels = in_channels
self.out_channels = out_channels
self.latent_dim = latent_dim
self.net_ch = net_ch
self.sample_every_batches = sample_every_batches
self.sample_size = sample_size
self.test_with_TEval = test_with_TEval
self.ampmode = ampmode
self.grey2RGB = grey2RGB
self.test_emb_only = test_emb_only
self.test_ema = test_ema
self.batch_size = batch_size
self.data_name = data_name
self.diffusion_type = diffusion_type
self.lr = lr
self.seed = seed
self.input_shape = input_shape
super().__init__(**kwargs)