soumickmj's picture
Upload cceVAE
c3bcb92 verified
import torch.nn as nn
from transformers import PretrainedConfig
class cceVAEConfig(PretrainedConfig):
model_type = "cceVAE"
def __init__(
self,
d=2,
input_size=(1, 256, 256),
z_dim=1024,
fmap_sizes=(16, 64, 256, 1024),
to_1x1=True,
conv_params=None,
tconv_params=None,
normalization_op=None,
normalization_params=None,
activation_op="prelu",
activation_params=None,
block_op=None,
block_params=None,
**kwargs):
self.d = d
self.input_size = input_size
self.z_dim = z_dim
self.fmap_sizes = fmap_sizes
self.to_1x1 = to_1x1
self.conv_params = conv_params
self.tconv_params = tconv_params
self.normalization_op = normalization_op
self.normalization_params = normalization_params
self.activation_op = activation_op
self.activation_params = activation_params
self.block_op = block_op
self.block_params = block_params
super().__init__(**kwargs)