soumickmj's picture
Upload cceVAE
c3bcb92 verified
raw
history blame contribute delete
950 Bytes
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .aes import VAE
from .cceVAEConfig import cceVAEConfig
class cceVAE(PreTrainedModel):
config_class = cceVAEConfig
def __init__(self, config):
super().__init__(config)
self.model = VAE(
d=config.d,
input_size=config.input_size,
z_dim=config.z_dim,
fmap_sizes=config.fmap_sizes,
to_1x1=config.to_1x1,
conv_params=config.conv_params,
tconv_params=config.tconv_params,
normalization_op=config.normalization_op,
normalization_params=config.normalization_params,
activation_op=config.activation_op,
activation_params=config.activation_params,
block_op=config.block_op,
block_params=config.block_params)
def forward(self, x):
return self.model(x)