File size: 2,096 Bytes
c2ced9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from transformers import PretrainedConfig

class DiffAEConfig(PretrainedConfig):
    model_type = "DiffAE"
    def __init__(self, 
                 is3D=True,
                 in_channels=1,
                 out_channels=1,
                 latent_dim=128,
                 net_ch=32,
                 sample_every_batches=1000, #log samples during training. Set it to 0 to disable
                 sample_size=4, #Number of samples in the buffer for consistent sampling (batch size of x_T)
                 test_with_TEval=True,
                 ampmode="16-mixed",
                 grey2RGB=-1,
                 test_emb_only=True,
                 test_ema=True,
                 batch_size=9,
                #  beta_scheduler='linear',
                #  latent_beta_scheduler='linear',
                 data_name="ukbb",
                 diffusion_type = 'beatgans',
                #  eval_ema_every_samples = 200_000,
                #  eval_every_samples = 200_000,
                 lr=0.0001,
                #  net_beatgans_attn_head = 1,
                #  net_beatgans_embed_channels = 128,
                #  net_ch_mult = (1, 1, 2, 3, 4),
                #  T_eval = 20,
                #  latent_T_eval=1000,
                #  group_norm_limit=32,
                 seed=1701,
                 input_shape=(50, 128, 128),
                #  dropout=0.1,
                 **kwargs):
        self.is3D = is3D
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.latent_dim = latent_dim
        self.net_ch = net_ch
        self.sample_every_batches = sample_every_batches
        self.sample_size = sample_size
        self.test_with_TEval = test_with_TEval
        self.ampmode = ampmode
        self.grey2RGB = grey2RGB
        self.test_emb_only = test_emb_only
        self.test_ema = test_ema
        self.batch_size = batch_size
        self.data_name = data_name
        self.diffusion_type = diffusion_type
        self.lr = lr      
        self.seed = seed
        self.input_shape = input_shape  
        super().__init__(**kwargs)