File size: 576 Bytes
8ff4a33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from transformers import PretrainedConfig
from typing import List
class VQVAEConfig(PretrainedConfig):
model_type = "VQVAE"
def __init__(
self,
embedding_dim: int = 256,
n_codes: int = 2048,
n_hiddens: int = 240,
n_res_layers: int = 4,
downsample: List[int] = [2, 4, 4],
**kwargs,
):
self.embedding_dim = embedding_dim
self.n_codes = n_codes
self.n_hiddens = n_hiddens
self.n_res_layers = n_res_layers
self.downsample = downsample
super().__init__(**kwargs) |