|
import torch |
|
import torch.nn as nn |
|
import math |
|
import random |
|
|
|
|
|
class Pooler(nn.Module): |
|
def __init__(self, dim_in, dim_out, pool_out_size): |
|
super().__init__() |
|
assert isinstance(pool_out_size, str) |
|
self.pool_out_size = pool_out_size.split(",") |
|
print("pool_out_size: {}".format(self.pool_out_size)) |
|
|
|
self.mlp = nn.Sequential( |
|
nn.Linear(dim_in, dim_out), |
|
nn.GELU(), |
|
nn.Linear(dim_out, dim_out) |
|
) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x (torch.Tensor): image features |
|
shape (b, v, D) |
|
Returns: |
|
shape (b, n, D) where n is self.num_latents |
|
""" |
|
b, v, d = x.shape |
|
s = int(math.sqrt(v -1)) |
|
x = x[:, 1:, :] |
|
x_in = x.reshape(b, s, s, d) |
|
|
|
pool_out_size = random.choice(self.pool_out_size) |
|
|
|
if '+' in pool_out_size: |
|
pool_out_size_list = [int(p) for p in pool_out_size.split('+')] |
|
else: |
|
pool_out_size_list = [int(pool_out_size)] |
|
pool_out_size_list.sort(reverse=True) |
|
|
|
x_out = [] |
|
for pool_out_size in pool_out_size_list: |
|
assert s % pool_out_size == 0 |
|
x = x_in.reshape(b, pool_out_size, s//pool_out_size, pool_out_size, s//pool_out_size, d) |
|
x = x.permute([0, 1, 3, 5, 2, 4]).reshape(b, pool_out_size * pool_out_size, d, -1).mean(-1) |
|
x = self.mlp(x) |
|
x_out.append(x) |
|
x_out = torch.cat(x_out, dim=-2) |
|
return x_out |