vqvae / configuration_vqvae.py
frankleeeee's picture
Upload VQVAE
8ff4a33 verified
raw
history blame
576 Bytes
from transformers import PretrainedConfig
from typing import List
class VQVAEConfig(PretrainedConfig):
model_type = "VQVAE"
def __init__(
self,
embedding_dim: int = 256,
n_codes: int = 2048,
n_hiddens: int = 240,
n_res_layers: int = 4,
downsample: List[int] = [2, 4, 4],
**kwargs,
):
self.embedding_dim = embedding_dim
self.n_codes = n_codes
self.n_hiddens = n_hiddens
self.n_res_layers = n_res_layers
self.downsample = downsample
super().__init__(**kwargs)