File size: 1,352 Bytes
0fd2f06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

import torch


ZERO_VAE_CACHE = [
    torch.zeros(1, 16, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 192, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832)
]

feat_names = [f"vae_cache_{i}" for i in range(len(ZERO_VAE_CACHE))]
ALL_INPUTS_NAMES = ["z", "use_cache"] + feat_names