import torch import torch.nn as nn from transformers import PreTrainedModel from .aes import VAE from .cceVAEConfig import cceVAEConfig class cceVAE(PreTrainedModel): config_class = cceVAEConfig def __init__(self, config): super().__init__(config) self.model = VAE( d=config.d, input_size=config.input_size, z_dim=config.z_dim, fmap_sizes=config.fmap_sizes, to_1x1=config.to_1x1, conv_params=config.conv_params, tconv_params=config.tconv_params, normalization_op=config.normalization_op, normalization_params=config.normalization_params, activation_op=config.activation_op, activation_params=config.activation_params, block_op=config.block_op, block_params=config.block_params) def forward(self, x): return self.model(x)