|
import torch.nn as nn
|
|
from transformers import PretrainedConfig
|
|
|
|
class cceVAEConfig(PretrainedConfig):
|
|
model_type = "cceVAE"
|
|
def __init__(
|
|
self,
|
|
d=2,
|
|
input_size=(1, 256, 256),
|
|
z_dim=1024,
|
|
fmap_sizes=(16, 64, 256, 1024),
|
|
to_1x1=True,
|
|
conv_params=None,
|
|
tconv_params=None,
|
|
normalization_op=None,
|
|
normalization_params=None,
|
|
activation_op="prelu",
|
|
activation_params=None,
|
|
block_op=None,
|
|
block_params=None,
|
|
**kwargs):
|
|
self.d = d
|
|
self.input_size = input_size
|
|
self.z_dim = z_dim
|
|
self.fmap_sizes = fmap_sizes
|
|
self.to_1x1 = to_1x1
|
|
self.conv_params = conv_params
|
|
self.tconv_params = tconv_params
|
|
self.normalization_op = normalization_op
|
|
self.normalization_params = normalization_params
|
|
self.activation_op = activation_op
|
|
self.activation_params = activation_params
|
|
self.block_op = block_op
|
|
self.block_params = block_params
|
|
super().__init__(**kwargs) |