|
from transformers import PretrainedConfig |
|
|
|
class DiffAEConfig(PretrainedConfig): |
|
model_type = "DiffAE" |
|
def __init__(self, |
|
is3D=True, |
|
in_channels=1, |
|
out_channels=1, |
|
latent_dim=128, |
|
net_ch=32, |
|
sample_every_batches=1000, |
|
sample_size=4, |
|
test_with_TEval=True, |
|
ampmode="16-mixed", |
|
grey2RGB=-1, |
|
test_emb_only=True, |
|
test_ema=True, |
|
batch_size=9, |
|
|
|
|
|
data_name="ukbb", |
|
diffusion_type = 'beatgans', |
|
|
|
|
|
lr=0.0001, |
|
|
|
|
|
|
|
|
|
|
|
|
|
seed=1701, |
|
input_shape=(50, 128, 128), |
|
|
|
**kwargs): |
|
self.is3D = is3D |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.latent_dim = latent_dim |
|
self.net_ch = net_ch |
|
self.sample_every_batches = sample_every_batches |
|
self.sample_size = sample_size |
|
self.test_with_TEval = test_with_TEval |
|
self.ampmode = ampmode |
|
self.grey2RGB = grey2RGB |
|
self.test_emb_only = test_emb_only |
|
self.test_ema = test_ema |
|
self.batch_size = batch_size |
|
self.data_name = data_name |
|
self.diffusion_type = diffusion_type |
|
self.lr = lr |
|
self.seed = seed |
|
self.input_shape = input_shape |
|
super().__init__(**kwargs) |