File size: 950 Bytes
c3bcb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .aes import VAE
from .cceVAEConfig import cceVAEConfig

class cceVAE(PreTrainedModel):
    config_class = cceVAEConfig
    def __init__(self, config):
        super().__init__(config)            
        self.model = VAE(
            d=config.d,
            input_size=config.input_size,
            z_dim=config.z_dim,
            fmap_sizes=config.fmap_sizes,
            to_1x1=config.to_1x1,
            conv_params=config.conv_params,
            tconv_params=config.tconv_params,
            normalization_op=config.normalization_op,
            normalization_params=config.normalization_params,
            activation_op=config.activation_op,
            activation_params=config.activation_params,
            block_op=config.block_op,
            block_params=config.block_params)
    def forward(self, x):
        return self.model(x)