|
import torch |
|
from torch import nn |
|
|
|
from uniperceiver.config import configurable |
|
from ..layers.create_act import get_act_layer |
|
from .build import EMBEDDING_REGISTRY |
|
from .position_embedding import NNEmbeddingEncoding |
|
from einops import rearrange, repeat |
|
from uniperceiver.modeling.layers import FP16LayerNorm |
|
|
|
|
|
__all__ = ["VideoBaseEmbedding"] |
|
|
|
@EMBEDDING_REGISTRY.register() |
|
class VideoBaseEmbedding(nn.Module): |
|
@configurable |
|
def __init__( |
|
self, |
|
*, |
|
cfg: dict, |
|
in_dim: int, |
|
out_dim: int, |
|
patch_size: int, |
|
time_span: int, |
|
max_time_len: int, |
|
max_spatial_size = 196, |
|
**kwargs |
|
): |
|
super(VideoBaseEmbedding, self).__init__() |
|
self.cfg = cfg |
|
self.embeddings = nn.Linear(in_dim, out_dim) |
|
self.embeddings_act = kwargs.pop("embeddings_act", None) |
|
self.embeddings_norm = kwargs.pop("embeddings_norm", None) |
|
self.embeddings_dropout = kwargs.pop("embeddings_dropout", None) |
|
self.embeddings_pos = kwargs.pop('embeddings_pos', None) |
|
self.embeddings_type = kwargs.pop("embeddings_token_type", None) |
|
self.random_temporal_pos = kwargs.pop("random_temporal_pos", True) |
|
self.patch_size = patch_size |
|
self.time_span = time_span |
|
self.pos_before = kwargs.pop('pos_before', True) |
|
self.add_type_embedding = cfg.MODEL.VIDEO_EMBED.ADD_TYPE_EMBED |
|
if self.add_type_embedding: |
|
assert self.embeddings_type is not None |
|
|
|
self.embeddings_st_pos = None |
|
self.max_spatial_size = max_spatial_size |
|
if isinstance(self.embeddings_pos, str): |
|
if self.embeddings_pos == 'divide_st_pos': |
|
self.embeddings_st_pos = Divide_ST_POS( |
|
max_spatial_size, max_time_len, out_dim, |
|
self.random_temporal_pos) |
|
self.embeddings_pos = None |
|
del self.embeddings |
|
self.embeddings = nn.Conv2d(in_dim//(self.patch_size**2), out_dim, kernel_size=self.patch_size, stride=self.patch_size) |
|
pass |
|
|
|
def replace_weight(self, visual_embed): |
|
if visual_embed is not None: |
|
del self.embeddings |
|
self.embeddings = visual_embed.patch_embed.proj |
|
|
|
def share_spatial_pos(self, visual_embed): |
|
if self.embeddings_st_pos is not None and visual_embed is not None: |
|
|
|
if self.embeddings_st_pos.spatial_pos_embed.weight.shape[0] == visual_embed.patch_embed.pos_embed.weight.shape[0]: |
|
self.embeddings_st_pos.spatial_pos_embed_index = 0 |
|
else: |
|
|
|
self.embeddings_st_pos.spatial_pos_embed_index = 1 |
|
del self.embeddings_st_pos.spatial_pos_embed |
|
self.embeddings_st_pos.spatial_pos_embed = visual_embed.patch_embed.pos_embed |
|
pass |
|
|
|
@classmethod |
|
def from_config(cls, cfg): |
|
kwargs = { |
|
"in_dim": cfg.MODEL.VIDEO_EMBED.IN_DIM, |
|
"out_dim": cfg.MODEL.VIDEO_EMBED.OUT_DIM, |
|
"patch_size": cfg.MODEL.PATCH_SIZE, |
|
"time_span": cfg.MODEL.VIDEO_EMBED.PATCH_SIZE_T, |
|
"max_time_len": cfg.MODEL.VIDEO_EMBED.MAX_FRAMES, |
|
} |
|
max_spatial_size = int(cfg.MODEL.IMG_INPUT_SIZE/cfg.MODEL.PATCH_SIZE)**2 |
|
kwargs['max_spatial_size'] = max_spatial_size |
|
activation_name = (cfg.MODEL.VIDEO_EMBED.ACTIVATION).lower() |
|
if activation_name != "none": |
|
activation = get_act_layer(activation_name) |
|
assert activation is not None |
|
|
|
act_kwargs = {} |
|
if activation_name in { "elu", "celu" }: |
|
act_kwargs["alpha"] = cfg.MODEL.VIDEO_EMBED.ELU_ALPHA |
|
embeddings_act = activation(**act_kwargs) |
|
kwargs['embeddings_act'] = embeddings_act |
|
|
|
if cfg.MODEL.VIDEO_EMBED.DROPOUT > 0: |
|
embeddings_dropout = nn.Dropout(cfg.MODEL.VIDEO_EMBED.DROPOUT) |
|
kwargs['embeddings_dropout'] = embeddings_dropout |
|
|
|
if cfg.MODEL.VIDEO_EMBED.USE_NORM: |
|
if cfg.SOLVER.FORCE_LN_FP16: |
|
embeddings_norm = FP16LayerNorm(cfg.MODEL.VIDEO_EMBED.OUT_DIM) |
|
else: |
|
embeddings_norm = nn.LayerNorm(cfg.MODEL.VIDEO_EMBED.OUT_DIM) |
|
kwargs['embeddings_norm'] = embeddings_norm |
|
|
|
if cfg.MODEL.VIDEO_EMBED.DIVIDE_ST_POS: |
|
kwargs['embeddings_pos'] = "divide_st_pos" |
|
|
|
elif cfg.MODEL.VIDEO_EMBED.POSITION.lower() != 'none': |
|
embeddings_pos = NNEmbeddingEncoding(cfg.MODEL.VIDEO_EMBED.OUT_DIM, cfg.MODEL.VIDEO_EMBED.MAX_LENGTH) |
|
kwargs['embeddings_pos'] = embeddings_pos |
|
|
|
if cfg.MODEL.VIDEO_EMBED.TYPE_SIZE > 0: |
|
embeddings_token_type = nn.Embedding( |
|
cfg.MODEL.VIDEO_EMBED.TYPE_SIZE, cfg.MODEL.VIDEO_EMBED.OUT_DIM) |
|
kwargs['embeddings_token_type'] = embeddings_token_type |
|
kwargs['random_temporal_pos'] = cfg.MODEL.VIDEO_EMBED.POS_RANDOM |
|
kwargs['pos_before'] = cfg.MODEL.POS_BEFORE |
|
kwargs['cfg'] = cfg |
|
return kwargs |
|
|
|
def forward(self, data, **kwargs): |
|
|
|
if data.dim() == 4: |
|
|
|
data = data.unsqueeze(1) |
|
|
|
|
|
if self.embeddings_st_pos is not None: |
|
bs = data.size(0) |
|
x = self.embeddings(data.flatten(0, 1)) |
|
x = x.flatten(2) |
|
embeddings = rearrange(x, '(b t s) c hw -> b t hw (s c)', b=bs, s = self.time_span) |
|
embeddings_pos = self.embeddings_st_pos(embeddings).unsqueeze( |
|
0).flatten(1, 2) |
|
embeddings = embeddings.flatten(1, 2) |
|
if self.pos_before: |
|
embeddings = embeddings + embeddings_pos.to(embeddings.dtype) |
|
|
|
|
|
if self.embeddings_pos is not None: |
|
x = rearrange(data, 'b (t s) c (h p1) (w p2) -> b (t h w) (s c p1 p2)', s = self.time_span, p1 = self.patch_size, p2 = self.patch_size) |
|
embeddings = self.embeddings(x) |
|
embeddings_pos = self.embeddings_pos(x).unsqueeze(0) |
|
if self.pos_before: |
|
embeddings = embeddings + embeddings_pos.to(embeddings.dtype) |
|
|
|
if self.add_type_embedding: |
|
embeddings = embeddings + self.embeddings_type.weight[0].unsqueeze(0).unsqueeze(1).to(embeddings.dtype) |
|
|
|
if self.embeddings_act is not None: |
|
embeddings = self.embeddings_act(embeddings) |
|
|
|
if self.embeddings_norm is not None: |
|
embeddings = self.embeddings_norm(embeddings) |
|
|
|
if not self.pos_before: |
|
embeddings = embeddings + embeddings_pos |
|
|
|
if self.embeddings_dropout is not None: |
|
embeddings = self.embeddings_dropout(embeddings) |
|
|
|
return embeddings |
|
|
|
|
|
|
|
class Divide_ST_POS(nn.Module): |
|
def __init__(self, num_patches, max_time_len, out_dim, |
|
random_temporal_pos): |
|
super(Divide_ST_POS, self).__init__() |
|
self.spatial_pos_embed = nn.Embedding(num_patches, out_dim) |
|
self.temporal_pos_embed = nn.Embedding(max_time_len, out_dim) |
|
self.spatial_pos_embed_index = 0 |
|
self.max_frames = max_time_len |
|
self.random_temporal_pos = random_temporal_pos |
|
|
|
def forward(self, x): |
|
dtype = x.dtype |
|
temp_len, spatial_size = x.size(1), x.size(2) |
|
|
|
if self.training and self.random_temporal_pos: |
|
temporal_pos_ids = torch.arange(temp_len, dtype=torch.long, device=x.device) + \ |
|
torch.randint(0, self.max_frames - temp_len + 1, size=(1,), dtype=torch.long, device=x.device) |
|
else: |
|
temporal_pos_ids = torch.arange(temp_len, dtype=torch.long, device=x.device) |
|
|
|
pos_embed = self.temporal_pos_embed(temporal_pos_ids).unsqueeze(1).to(dtype=dtype) + \ |
|
self.spatial_pos_embed( torch.arange(start= self.spatial_pos_embed_index, end=spatial_size + self.spatial_pos_embed_index , dtype=torch.long, device=x.device)).unsqueeze(0).to(dtype=dtype) |
|
return pos_embed |
|
|