|
import torch
|
|
import torch.nn as nn
|
|
from transformers import PreTrainedModel
|
|
from .aes import VAE
|
|
from .cceVAEConfig import cceVAEConfig
|
|
|
|
class cceVAE(PreTrainedModel):
|
|
config_class = cceVAEConfig
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = VAE(
|
|
d=config.d,
|
|
input_size=config.input_size,
|
|
z_dim=config.z_dim,
|
|
fmap_sizes=config.fmap_sizes,
|
|
to_1x1=config.to_1x1,
|
|
conv_params=config.conv_params,
|
|
tconv_params=config.tconv_params,
|
|
normalization_op=config.normalization_op,
|
|
normalization_params=config.normalization_params,
|
|
activation_op=config.activation_op,
|
|
activation_params=config.activation_params,
|
|
block_op=config.block_op,
|
|
block_params=config.block_params)
|
|
def forward(self, x):
|
|
return self.model(x) |