herrius's picture
Upload 259 files
32b542e
import torch
from torch import nn
from uniperceiver.config import configurable
from ..layers.create_act import get_act_layer
from .build import EMBEDDING_REGISTRY
from .position_embedding import NNEmbeddingEncoding
from einops import rearrange, repeat
from uniperceiver.modeling.layers import FP16LayerNorm
__all__ = ["VideoBaseEmbedding"]
@EMBEDDING_REGISTRY.register()
class VideoBaseEmbedding(nn.Module):
@configurable
def __init__(
self,
*,
cfg: dict,
in_dim: int,
out_dim: int,
patch_size: int,
time_span: int,
max_time_len: int,
max_spatial_size = 196,
**kwargs
):
super(VideoBaseEmbedding, self).__init__()
self.cfg = cfg
self.embeddings = nn.Linear(in_dim, out_dim)
self.embeddings_act = kwargs.pop("embeddings_act", None)
self.embeddings_norm = kwargs.pop("embeddings_norm", None)
self.embeddings_dropout = kwargs.pop("embeddings_dropout", None)
self.embeddings_pos = kwargs.pop('embeddings_pos', None)
self.embeddings_type = kwargs.pop("embeddings_token_type", None)
self.random_temporal_pos = kwargs.pop("random_temporal_pos", True)
self.patch_size = patch_size
self.time_span = time_span
self.pos_before = kwargs.pop('pos_before', True)
self.add_type_embedding = cfg.MODEL.VIDEO_EMBED.ADD_TYPE_EMBED
if self.add_type_embedding:
assert self.embeddings_type is not None
self.embeddings_st_pos = None
self.max_spatial_size = max_spatial_size
if isinstance(self.embeddings_pos, str):
if self.embeddings_pos == 'divide_st_pos':
self.embeddings_st_pos = Divide_ST_POS(
max_spatial_size, max_time_len, out_dim,
self.random_temporal_pos)
self.embeddings_pos = None
del self.embeddings
self.embeddings = nn.Conv2d(in_dim//(self.patch_size**2), out_dim, kernel_size=self.patch_size, stride=self.patch_size)
pass
def replace_weight(self, visual_embed):
if visual_embed is not None:
del self.embeddings
self.embeddings = visual_embed.patch_embed.proj
def share_spatial_pos(self, visual_embed):
if self.embeddings_st_pos is not None and visual_embed is not None:
if self.embeddings_st_pos.spatial_pos_embed.weight.shape[0] == visual_embed.patch_embed.pos_embed.weight.shape[0]:
self.embeddings_st_pos.spatial_pos_embed_index = 0
else:
# cls token in image patch tokenizer
self.embeddings_st_pos.spatial_pos_embed_index = 1
del self.embeddings_st_pos.spatial_pos_embed
self.embeddings_st_pos.spatial_pos_embed = visual_embed.patch_embed.pos_embed
pass
@classmethod
def from_config(cls, cfg):
kwargs = {
"in_dim": cfg.MODEL.VIDEO_EMBED.IN_DIM,
"out_dim": cfg.MODEL.VIDEO_EMBED.OUT_DIM,
"patch_size": cfg.MODEL.PATCH_SIZE,
"time_span": cfg.MODEL.VIDEO_EMBED.PATCH_SIZE_T,
"max_time_len": cfg.MODEL.VIDEO_EMBED.MAX_FRAMES,
}
max_spatial_size = int(cfg.MODEL.IMG_INPUT_SIZE/cfg.MODEL.PATCH_SIZE)**2
kwargs['max_spatial_size'] = max_spatial_size
activation_name = (cfg.MODEL.VIDEO_EMBED.ACTIVATION).lower()
if activation_name != "none":
activation = get_act_layer(activation_name)
assert activation is not None
act_kwargs = {}
if activation_name in { "elu", "celu" }:
act_kwargs["alpha"] = cfg.MODEL.VIDEO_EMBED.ELU_ALPHA
embeddings_act = activation(**act_kwargs)
kwargs['embeddings_act'] = embeddings_act
if cfg.MODEL.VIDEO_EMBED.DROPOUT > 0:
embeddings_dropout = nn.Dropout(cfg.MODEL.VIDEO_EMBED.DROPOUT)
kwargs['embeddings_dropout'] = embeddings_dropout
if cfg.MODEL.VIDEO_EMBED.USE_NORM:
if cfg.SOLVER.FORCE_LN_FP16:
embeddings_norm = FP16LayerNorm(cfg.MODEL.VIDEO_EMBED.OUT_DIM)
else:
embeddings_norm = nn.LayerNorm(cfg.MODEL.VIDEO_EMBED.OUT_DIM)
kwargs['embeddings_norm'] = embeddings_norm
if cfg.MODEL.VIDEO_EMBED.DIVIDE_ST_POS:
kwargs['embeddings_pos'] = "divide_st_pos"
elif cfg.MODEL.VIDEO_EMBED.POSITION.lower() != 'none':
embeddings_pos = NNEmbeddingEncoding(cfg.MODEL.VIDEO_EMBED.OUT_DIM, cfg.MODEL.VIDEO_EMBED.MAX_LENGTH)
kwargs['embeddings_pos'] = embeddings_pos
if cfg.MODEL.VIDEO_EMBED.TYPE_SIZE > 0:
embeddings_token_type = nn.Embedding(
cfg.MODEL.VIDEO_EMBED.TYPE_SIZE, cfg.MODEL.VIDEO_EMBED.OUT_DIM)
kwargs['embeddings_token_type'] = embeddings_token_type
kwargs['random_temporal_pos'] = cfg.MODEL.VIDEO_EMBED.POS_RANDOM
kwargs['pos_before'] = cfg.MODEL.POS_BEFORE
kwargs['cfg'] = cfg
return kwargs
def forward(self, data, **kwargs):
if data.dim() == 4:
#images
data = data.unsqueeze(1) # BS, 3, 224, 224
if self.embeddings_st_pos is not None:
bs = data.size(0)
x = self.embeddings(data.flatten(0, 1)) # b*t, dim, 14, 14
x = x.flatten(2) # .flatten(2)
embeddings = rearrange(x, '(b t s) c hw -> b t hw (s c)', b=bs, s = self.time_span)
embeddings_pos = self.embeddings_st_pos(embeddings).unsqueeze(
0).flatten(1, 2)
embeddings = embeddings.flatten(1, 2)
if self.pos_before:
embeddings = embeddings + embeddings_pos.to(embeddings.dtype)
if self.embeddings_pos is not None:
x = rearrange(data, 'b (t s) c (h p1) (w p2) -> b (t h w) (s c p1 p2)', s = self.time_span, p1 = self.patch_size, p2 = self.patch_size)
embeddings = self.embeddings(x)
embeddings_pos = self.embeddings_pos(x).unsqueeze(0)
if self.pos_before:
embeddings = embeddings + embeddings_pos.to(embeddings.dtype)
if self.add_type_embedding:
embeddings = embeddings + self.embeddings_type.weight[0].unsqueeze(0).unsqueeze(1).to(embeddings.dtype)
if self.embeddings_act is not None:
embeddings = self.embeddings_act(embeddings)
if self.embeddings_norm is not None:
embeddings = self.embeddings_norm(embeddings)
if not self.pos_before:
embeddings = embeddings + embeddings_pos
if self.embeddings_dropout is not None:
embeddings = self.embeddings_dropout(embeddings)
return embeddings
class Divide_ST_POS(nn.Module):
def __init__(self, num_patches, max_time_len, out_dim,
random_temporal_pos):
super(Divide_ST_POS, self).__init__()
self.spatial_pos_embed = nn.Embedding(num_patches, out_dim)
self.temporal_pos_embed = nn.Embedding(max_time_len, out_dim)
self.spatial_pos_embed_index = 0 # sometimes image has cls_token
self.max_frames = max_time_len
self.random_temporal_pos = random_temporal_pos
def forward(self, x):
dtype = x.dtype
temp_len, spatial_size = x.size(1), x.size(2)
if self.training and self.random_temporal_pos:
temporal_pos_ids = torch.arange(temp_len, dtype=torch.long, device=x.device) + \
torch.randint(0, self.max_frames - temp_len + 1, size=(1,), dtype=torch.long, device=x.device)
else:
temporal_pos_ids = torch.arange(temp_len, dtype=torch.long, device=x.device)
pos_embed = self.temporal_pos_embed(temporal_pos_ids).unsqueeze(1).to(dtype=dtype) + \
self.spatial_pos_embed( torch.arange(start= self.spatial_pos_embed_index, end=spatial_size + self.spatial_pos_embed_index , dtype=torch.long, device=x.device)).unsqueeze(0).to(dtype=dtype)
return pos_embed