File size: 1,801 Bytes
8cf6c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import random
import math
import torch 
import torch.nn as nn


class Pooler(nn.Module):
    def __init__(self, dim_in, dim_out, pool_out_size):
        super().__init__()
        if not isinstance(pool_out_size, (list, tuple)):
            pool_out_size = [pool_out_size]

        self.pool_out_size = pool_out_size
        print("pool_out_size: {}".format(self.pool_out_size))

        self.mlp = nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.GELU(),
            nn.Linear(dim_out, dim_out)
        )
                
    def forward(self, x):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, F, v, D)
        Returns:
            shape (b, T, n, D) where n is self.num_latents
        """
        b, t, f, v, d = x.shape
        s = int(math.sqrt(v -1))
        assert t == 1 and f == 1 
        x = x[:, :, :, 1:, :]           # remove cls_token
        x_in = x.reshape(b, t, f, s, s, d)

        pool_out_size = random.choice(self.pool_out_size)
        if '+' in pool_out_size:        # "16+32" means ensemble the pool size of 16 and 32
            pool_out_size_list = [int(p) for p in pool_out_size.split('+')]
        else:
            pool_out_size_list = [int(pool_out_size)]
        pool_out_size_list.sort(reverse=True)

        x_out = []
        for pool_out_size in pool_out_size_list:
            x = x_in.reshape(b, t, f, pool_out_size, s//pool_out_size, pool_out_size, s//pool_out_size, d)
            x = x.permute([0, 1, 2, 3, 5, 7, 4, 6]).reshape(b, t, f, pool_out_size * pool_out_size, d, -1).mean(-1)
            x = self.mlp(x)                 # [b, t, f, h*w, d]
            x = x.flatten(0, 2)
            x_out.append(x)
        x_out = torch.cat(x_out, dim=-2)

        return x_out.unsqueeze(1)