InfMLLM_7B / pooler.py
mightyzau's picture
Upload folder using huggingface_hub
8cf6c16
import random
import math
import torch
import torch.nn as nn
class Pooler(nn.Module):
def __init__(self, dim_in, dim_out, pool_out_size):
super().__init__()
if not isinstance(pool_out_size, (list, tuple)):
pool_out_size = [pool_out_size]
self.pool_out_size = pool_out_size
print("pool_out_size: {}".format(self.pool_out_size))
self.mlp = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.Linear(dim_out, dim_out)
)
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (b, T, F, v, D)
Returns:
shape (b, T, n, D) where n is self.num_latents
"""
b, t, f, v, d = x.shape
s = int(math.sqrt(v -1))
assert t == 1 and f == 1
x = x[:, :, :, 1:, :] # remove cls_token
x_in = x.reshape(b, t, f, s, s, d)
pool_out_size = random.choice(self.pool_out_size)
if '+' in pool_out_size: # "16+32" means ensemble the pool size of 16 and 32
pool_out_size_list = [int(p) for p in pool_out_size.split('+')]
else:
pool_out_size_list = [int(pool_out_size)]
pool_out_size_list.sort(reverse=True)
x_out = []
for pool_out_size in pool_out_size_list:
x = x_in.reshape(b, t, f, pool_out_size, s//pool_out_size, pool_out_size, s//pool_out_size, d)
x = x.permute([0, 1, 2, 3, 5, 7, 4, 6]).reshape(b, t, f, pool_out_size * pool_out_size, d, -1).mean(-1)
x = self.mlp(x) # [b, t, f, h*w, d]
x = x.flatten(0, 2)
x_out.append(x)
x_out = torch.cat(x_out, dim=-2)
return x_out.unsqueeze(1)