# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Video models.""" import math from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.init import trunc_normal_ # import slowfast.utils.weight_init_helper as init_helper from .attention import MultiScaleBlock # from slowfast.models.batchnorm_helper import get_norm from .common import TwoStreamFusion from .reversible_mvit import ReversibleMViT from .utils import ( calc_mvit_feature_geometry, get_3d_sincos_pos_embed, round_width, validate_checkpoint_wrapper_import, ) from . import head_helper, stem_helper # noqae class MViT(nn.Module): """ Model builder for MViTv1 and MViTv2. "MViTv2: Improved Multiscale Vision Transformers for Classification and Detection" Yanghao Li, Chao-Yuan Wu, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik, Christoph Feichtenhofer https://arxiv.org/abs/2112.01526 "Multiscale Vision Transformers" Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra Malik, Christoph Feichtenhofer https://arxiv.org/abs/2104.11227 """ def __init__(self, cfg): super().__init__() # Get parameters. assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE self.cfg = cfg pool_first = cfg.MVIT.POOL_FIRST # Prepare input. spatial_size = cfg.DATA.TRAIN_CROP_SIZE temporal_size = cfg.DATA.NUM_FRAMES in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0] self.use_2d_patch = cfg.MVIT.PATCH_2D self.enable_detection = cfg.DETECTION.ENABLE self.enable_rev = cfg.MVIT.REV.ENABLE self.patch_stride = cfg.MVIT.PATCH_STRIDE if self.use_2d_patch: self.patch_stride = [1] + self.patch_stride self.T = cfg.DATA.NUM_FRAMES // self.patch_stride[0] self.H = cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[1] self.W = cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[2] # Prepare output. num_classes = cfg.MODEL.NUM_CLASSES embed_dim = cfg.MVIT.EMBED_DIM # Prepare backbone num_heads = cfg.MVIT.NUM_HEADS mlp_ratio = cfg.MVIT.MLP_RATIO qkv_bias = cfg.MVIT.QKV_BIAS self.drop_rate = cfg.MVIT.DROPOUT_RATE depth = cfg.MVIT.DEPTH drop_path_rate = cfg.MVIT.DROPPATH_RATE layer_scale_init_value = cfg.MVIT.LAYER_SCALE_INIT_VALUE head_init_scale = cfg.MVIT.HEAD_INIT_SCALE mode = cfg.MVIT.MODE self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON self.use_mean_pooling = cfg.MVIT.USE_MEAN_POOLING # Params for positional embedding self.use_abs_pos = cfg.MVIT.USE_ABS_POS self.use_fixed_sincos_pos = cfg.MVIT.USE_FIXED_SINCOS_POS self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED self.rel_pos_spatial = cfg.MVIT.REL_POS_SPATIAL self.rel_pos_temporal = cfg.MVIT.REL_POS_TEMPORAL if cfg.MVIT.NORM == "layernorm": norm_layer = partial(nn.LayerNorm, eps=1e-6) else: raise NotImplementedError("Only supports layernorm.") self.num_classes = num_classes self.patch_embed = stem_helper.PatchEmbed( dim_in=in_chans, dim_out=embed_dim, kernel=cfg.MVIT.PATCH_KERNEL, stride=cfg.MVIT.PATCH_STRIDE, padding=cfg.MVIT.PATCH_PADDING, conv_2d=self.use_2d_patch, ) self.input_dims = [temporal_size, spatial_size, spatial_size] assert self.input_dims[1] == self.input_dims[2] self.patch_dims = [ self.input_dims[i] // self.patch_stride[i] for i in range(len(self.input_dims)) ] num_patches = math.prod(self.patch_dims) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule if self.cls_embed_on: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) pos_embed_dim = num_patches + 1 else: pos_embed_dim = num_patches if self.use_abs_pos: if self.sep_pos_embed: self.pos_embed_spatial = nn.Parameter( torch.zeros( 1, self.patch_dims[1] * self.patch_dims[2], embed_dim ) ) self.pos_embed_temporal = nn.Parameter( torch.zeros(1, self.patch_dims[0], embed_dim) ) if self.cls_embed_on: self.pos_embed_class = nn.Parameter( torch.zeros(1, 1, embed_dim) ) else: self.pos_embed = nn.Parameter( torch.zeros( 1, pos_embed_dim, embed_dim, ), requires_grad=not self.use_fixed_sincos_pos, ) if self.drop_rate > 0.0: self.pos_drop = nn.Dropout(p=self.drop_rate) dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) for i in range(len(cfg.MVIT.DIM_MUL)): dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] for i in range(len(cfg.MVIT.HEAD_MUL)): head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] pool_q = [[] for i in range(cfg.MVIT.DEPTH)] pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] stride_q = [[] for i in range(cfg.MVIT.DEPTH)] stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): stride_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_Q_STRIDE[i][ 1: ] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] ] # If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE. if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None: _stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE cfg.MVIT.POOL_KV_STRIDE = [] for i in range(cfg.MVIT.DEPTH): if len(stride_q[i]) > 0: _stride_kv = [ max(_stride_kv[d] // stride_q[i][d], 1) for d in range(len(_stride_kv)) ] cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv) for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): stride_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = cfg.MVIT.POOL_KV_STRIDE[ i ][1:] if cfg.MVIT.POOL_KVQ_KERNEL is not None: pool_kv[ cfg.MVIT.POOL_KV_STRIDE[i][0] ] = cfg.MVIT.POOL_KVQ_KERNEL else: pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ s + 1 if s > 1 else s for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] ] self.pool_q = pool_q self.pool_kv = pool_kv self.stride_q = stride_q self.stride_kv = stride_kv self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None input_size = self.patch_dims if self.enable_rev: # rev does not allow cls token assert not self.cls_embed_on self.rev_backbone = ReversibleMViT(cfg, self) embed_dim = round_width( embed_dim, dim_mul.prod(), divisor=num_heads ) self.fuse = TwoStreamFusion( cfg.MVIT.REV.RESPATH_FUSE, dim=2 * embed_dim ) if "concat" in self.cfg.MVIT.REV.RESPATH_FUSE: self.norm = norm_layer(2 * embed_dim) else: self.norm = norm_layer(embed_dim) else: self.blocks = nn.ModuleList() for i in range(depth): num_heads = round_width(num_heads, head_mul[i]) if cfg.MVIT.DIM_MUL_IN_ATT: dim_out = round_width( embed_dim, dim_mul[i], divisor=round_width(num_heads, head_mul[i]), ) else: dim_out = round_width( embed_dim, dim_mul[i + 1], divisor=round_width(num_heads, head_mul[i + 1]), ) attention_block = MultiScaleBlock( dim=embed_dim, dim_out=dim_out, num_heads=num_heads, input_size=input_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=self.drop_rate, drop_path=dpr[i], norm_layer=norm_layer, kernel_q=pool_q[i] if len(pool_q) > i else [], kernel_kv=pool_kv[i] if len(pool_kv) > i else [], stride_q=stride_q[i] if len(stride_q) > i else [], stride_kv=stride_kv[i] if len(stride_kv) > i else [], mode=mode, has_cls_embed=self.cls_embed_on, pool_first=pool_first, rel_pos_spatial=self.rel_pos_spatial, rel_pos_temporal=self.rel_pos_temporal, rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT, residual_pooling=cfg.MVIT.RESIDUAL_POOLING, dim_mul_in_att=cfg.MVIT.DIM_MUL_IN_ATT, separate_qkv=cfg.MVIT.SEPARATE_QKV, ) self.blocks.append(attention_block) if len(stride_q[i]) > 0: input_size = [ size // stride for size, stride in zip(input_size, stride_q[i]) ] embed_dim = dim_out self.norm = norm_layer(embed_dim) if self.enable_detection: raise Exception("Detection is not supported") else: self.head = head_helper.TransformerBasicHead( 2 * embed_dim if ("concat" in cfg.MVIT.REV.RESPATH_FUSE and self.enable_rev) else embed_dim, num_classes, dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, cfg=cfg, ) if self.use_abs_pos: if self.sep_pos_embed: trunc_normal_(self.pos_embed_spatial, std=0.02) trunc_normal_(self.pos_embed_temporal, std=0.02) if self.cls_embed_on: trunc_normal_(self.pos_embed_class, std=0.02) else: trunc_normal_(self.pos_embed, std=0.02) if self.use_fixed_sincos_pos: pos_embed = get_3d_sincos_pos_embed( self.pos_embed.shape[-1], self.H, self.T, cls_token=self.cls_embed_on, ) self.pos_embed.data.copy_( torch.from_numpy(pos_embed).float().unsqueeze(0) ) if self.cls_embed_on: trunc_normal_(self.cls_token, std=0.02) self.apply(self._init_weights) self.head.projection.weight.data.mul_(head_init_scale) self.head.projection.bias.data.mul_(head_init_scale) self.feat_size, self.feat_stride = calc_mvit_feature_geometry(cfg) def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): nn.init.trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0.02) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0.02) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): names = [] if self.cfg.MVIT.ZERO_DECAY_POS_CLS: if self.use_abs_pos: if self.sep_pos_embed: names.extend( [ "pos_embed_spatial", "pos_embed_temporal", "pos_embed_class", ] ) else: names.append("pos_embed") if self.rel_pos_spatial: names.extend(["rel_pos_h", "rel_pos_w", "rel_pos_hw"]) if self.rel_pos_temporal: names.extend(["rel_pos_t"]) if self.cls_embed_on: names.append("cls_token") return names def _get_pos_embed(self, pos_embed, bcthw): if len(bcthw) == 4: t, h, w = 1, bcthw[-2], bcthw[-1] else: t, h, w = bcthw[-3], bcthw[-2], bcthw[-1] if self.cls_embed_on: cls_pos_embed = pos_embed[:, 0:1, :] pos_embed = pos_embed[:, 1:] txy_num = pos_embed.shape[1] p_t, p_h, p_w = self.patch_dims assert p_t * p_h * p_w == txy_num if (p_t, p_h, p_w) != (t, h, w): new_pos_embed = F.interpolate( pos_embed[:, :, :] .reshape(1, p_t, p_h, p_w, -1) .permute(0, 4, 1, 2, 3), size=(t, h, w), mode="trilinear", ) pos_embed = new_pos_embed.reshape(1, -1, t * h * w).permute(0, 2, 1) if self.cls_embed_on: pos_embed = torch.cat((cls_pos_embed, pos_embed), dim=1) return pos_embed def _forward_reversible(self, x): """ Reversible specific code for forward computation. """ # rev does not support cls token or detection assert not self.cls_embed_on assert not self.enable_detection x = self.rev_backbone(x) if self.use_mean_pooling: x = self.fuse(x) x = x.mean(1) x = self.norm(x) else: x = self.norm(x) x = self.fuse(x) x = x.mean(1) x = self.head(x) return x def forward(self, x, bboxes=None, return_attn=False): x = x[0] x, bcthw = self.patch_embed(x) bcthw = list(bcthw) if len(bcthw) == 4: # Fix bcthw in case of 4D tensor bcthw.insert(2, torch.tensor(self.T)) T, H, W = bcthw[-3], bcthw[-2], bcthw[-1] assert len(bcthw) == 5 and (T, H, W) == (self.T, self.H, self.W), bcthw B, N, C = x.shape s = 1 if self.cls_embed_on else 0 if self.use_fixed_sincos_pos: x += self.pos_embed[:, s:, :] # s: on/off cls token if self.cls_embed_on: cls_tokens = self.cls_token.expand( B, -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks if self.use_fixed_sincos_pos: cls_tokens = cls_tokens + self.pos_embed[:, :s, :] x = torch.cat((cls_tokens, x), dim=1) if self.use_abs_pos: if self.sep_pos_embed: pos_embed = self.pos_embed_spatial.repeat( 1, self.patch_dims[0], 1 ) + torch.repeat_interleave( self.pos_embed_temporal, self.patch_dims[1] * self.patch_dims[2], dim=1, ) if self.cls_embed_on: pos_embed = torch.cat([self.pos_embed_class, pos_embed], 1) x += self._get_pos_embed(pos_embed, bcthw) else: x += self._get_pos_embed(self.pos_embed, bcthw) if self.drop_rate: x = self.pos_drop(x) if self.norm_stem: x = self.norm_stem(x) thw = [T, H, W] if self.enable_rev: x = self._forward_reversible(x) else: for blk in self.blocks: x, thw = blk(x, thw) if self.enable_detection: assert not self.enable_rev x = self.norm(x) if self.cls_embed_on: x = x[:, 1:] B, _, C = x.shape x = x.transpose(1, 2).reshape(B, C, thw[0], thw[1], thw[2]) x = self.head([x], bboxes) else: if self.use_mean_pooling: if self.cls_embed_on: x = x[:, 1:] x = x.mean(1) x = self.norm(x) elif self.cls_embed_on: x = self.norm(x) x = x[:, 0] else: # this is default, [norm->mean] x = self.norm(x) x = x.mean(1) x = self.head(x) return x