LinB203
m
a220803
from ..configuration_videobase import VideoBaseConfiguration
from typing import Union, Tuple
class VQVAEConfiguration(VideoBaseConfiguration):
def __init__(
self,
embedding_dim: int = 256,
n_codes: int = 2048,
n_hiddens: int = 240,
n_res_layers: int = 4,
resolution: int = 128,
sequence_length: int = 16,
downsample: Union[Tuple[int, int, int], str] = (4, 4, 4),
no_pos_embd: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.n_codes = n_codes
self.n_hiddens = n_hiddens
self.n_res_layers = n_res_layers
self.resolution = resolution
self.sequence_length = sequence_length
if isinstance(downsample, str):
self.downsample = tuple(map(int, downsample.split(",")))
else:
self.downsample = downsample
self.no_pos_embd = no_pos_embd
self.hidden_size = n_hiddens