self-forcing / demo_utils /constant.py
multimodalart's picture
Upload 80 files
0fd2f06 verified
raw
history blame
1.35 kB
import torch
ZERO_VAE_CACHE = [
torch.zeros(1, 16, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 192, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832)
]
feat_names = [f"vae_cache_{i}" for i in range(len(ZERO_VAE_CACHE))]
ALL_INPUTS_NAMES = ["z", "use_cache"] + feat_names