ianpan's picture
Initial commit
231edce
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Video models."""
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
# import slowfast.utils.weight_init_helper as init_helper
from .attention import MultiScaleBlock
# from slowfast.models.batchnorm_helper import get_norm
from .common import TwoStreamFusion
from .reversible_mvit import ReversibleMViT
from .utils import (
calc_mvit_feature_geometry,
get_3d_sincos_pos_embed,
round_width,
validate_checkpoint_wrapper_import,
)
from . import head_helper, stem_helper # noqae
class MViT(nn.Module):
"""
Model builder for MViTv1 and MViTv2.
"MViTv2: Improved Multiscale Vision Transformers for Classification and Detection"
Yanghao Li, Chao-Yuan Wu, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik, Christoph Feichtenhofer
https://arxiv.org/abs/2112.01526
"Multiscale Vision Transformers"
Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra Malik, Christoph Feichtenhofer
https://arxiv.org/abs/2104.11227
"""
def __init__(self, cfg):
super().__init__()
# Get parameters.
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE
self.cfg = cfg
pool_first = cfg.MVIT.POOL_FIRST
# Prepare input.
spatial_size = cfg.DATA.TRAIN_CROP_SIZE
temporal_size = cfg.DATA.NUM_FRAMES
in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0]
self.use_2d_patch = cfg.MVIT.PATCH_2D
self.enable_detection = cfg.DETECTION.ENABLE
self.enable_rev = cfg.MVIT.REV.ENABLE
self.patch_stride = cfg.MVIT.PATCH_STRIDE
if self.use_2d_patch:
self.patch_stride = [1] + self.patch_stride
self.T = cfg.DATA.NUM_FRAMES // self.patch_stride[0]
self.H = cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[1]
self.W = cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[2]
# Prepare output.
num_classes = cfg.MODEL.NUM_CLASSES
embed_dim = cfg.MVIT.EMBED_DIM
# Prepare backbone
num_heads = cfg.MVIT.NUM_HEADS
mlp_ratio = cfg.MVIT.MLP_RATIO
qkv_bias = cfg.MVIT.QKV_BIAS
self.drop_rate = cfg.MVIT.DROPOUT_RATE
depth = cfg.MVIT.DEPTH
drop_path_rate = cfg.MVIT.DROPPATH_RATE
layer_scale_init_value = cfg.MVIT.LAYER_SCALE_INIT_VALUE
head_init_scale = cfg.MVIT.HEAD_INIT_SCALE
mode = cfg.MVIT.MODE
self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON
self.use_mean_pooling = cfg.MVIT.USE_MEAN_POOLING
# Params for positional embedding
self.use_abs_pos = cfg.MVIT.USE_ABS_POS
self.use_fixed_sincos_pos = cfg.MVIT.USE_FIXED_SINCOS_POS
self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED
self.rel_pos_spatial = cfg.MVIT.REL_POS_SPATIAL
self.rel_pos_temporal = cfg.MVIT.REL_POS_TEMPORAL
if cfg.MVIT.NORM == "layernorm":
norm_layer = partial(nn.LayerNorm, eps=1e-6)
else:
raise NotImplementedError("Only supports layernorm.")
self.num_classes = num_classes
self.patch_embed = stem_helper.PatchEmbed(
dim_in=in_chans,
dim_out=embed_dim,
kernel=cfg.MVIT.PATCH_KERNEL,
stride=cfg.MVIT.PATCH_STRIDE,
padding=cfg.MVIT.PATCH_PADDING,
conv_2d=self.use_2d_patch,
)
self.input_dims = [temporal_size, spatial_size, spatial_size]
assert self.input_dims[1] == self.input_dims[2]
self.patch_dims = [
self.input_dims[i] // self.patch_stride[i]
for i in range(len(self.input_dims))
]
num_patches = math.prod(self.patch_dims)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
if self.cls_embed_on:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
pos_embed_dim = num_patches + 1
else:
pos_embed_dim = num_patches
if self.use_abs_pos:
if self.sep_pos_embed:
self.pos_embed_spatial = nn.Parameter(
torch.zeros(
1, self.patch_dims[1] * self.patch_dims[2], embed_dim
)
)
self.pos_embed_temporal = nn.Parameter(
torch.zeros(1, self.patch_dims[0], embed_dim)
)
if self.cls_embed_on:
self.pos_embed_class = nn.Parameter(
torch.zeros(1, 1, embed_dim)
)
else:
self.pos_embed = nn.Parameter(
torch.zeros(
1,
pos_embed_dim,
embed_dim,
),
requires_grad=not self.use_fixed_sincos_pos,
)
if self.drop_rate > 0.0:
self.pos_drop = nn.Dropout(p=self.drop_rate)
dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1)
for i in range(len(cfg.MVIT.DIM_MUL)):
dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1]
for i in range(len(cfg.MVIT.HEAD_MUL)):
head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1]
pool_q = [[] for i in range(cfg.MVIT.DEPTH)]
pool_kv = [[] for i in range(cfg.MVIT.DEPTH)]
stride_q = [[] for i in range(cfg.MVIT.DEPTH)]
stride_kv = [[] for i in range(cfg.MVIT.DEPTH)]
for i in range(len(cfg.MVIT.POOL_Q_STRIDE)):
stride_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_Q_STRIDE[i][
1:
]
if cfg.MVIT.POOL_KVQ_KERNEL is not None:
pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL
else:
pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [
s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:]
]
# If POOL_KV_STRIDE_ADAPTIVE is not None, initialize POOL_KV_STRIDE.
if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None:
_stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE
cfg.MVIT.POOL_KV_STRIDE = []
for i in range(cfg.MVIT.DEPTH):
if len(stride_q[i]) > 0:
_stride_kv = [
max(_stride_kv[d] // stride_q[i][d], 1)
for d in range(len(_stride_kv))
]
cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv)
for i in range(len(cfg.MVIT.POOL_KV_STRIDE)):
stride_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = cfg.MVIT.POOL_KV_STRIDE[
i
][1:]
if cfg.MVIT.POOL_KVQ_KERNEL is not None:
pool_kv[
cfg.MVIT.POOL_KV_STRIDE[i][0]
] = cfg.MVIT.POOL_KVQ_KERNEL
else:
pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [
s + 1 if s > 1 else s
for s in cfg.MVIT.POOL_KV_STRIDE[i][1:]
]
self.pool_q = pool_q
self.pool_kv = pool_kv
self.stride_q = stride_q
self.stride_kv = stride_kv
self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None
input_size = self.patch_dims
if self.enable_rev:
# rev does not allow cls token
assert not self.cls_embed_on
self.rev_backbone = ReversibleMViT(cfg, self)
embed_dim = round_width(
embed_dim, dim_mul.prod(), divisor=num_heads
)
self.fuse = TwoStreamFusion(
cfg.MVIT.REV.RESPATH_FUSE, dim=2 * embed_dim
)
if "concat" in self.cfg.MVIT.REV.RESPATH_FUSE:
self.norm = norm_layer(2 * embed_dim)
else:
self.norm = norm_layer(embed_dim)
else:
self.blocks = nn.ModuleList()
for i in range(depth):
num_heads = round_width(num_heads, head_mul[i])
if cfg.MVIT.DIM_MUL_IN_ATT:
dim_out = round_width(
embed_dim,
dim_mul[i],
divisor=round_width(num_heads, head_mul[i]),
)
else:
dim_out = round_width(
embed_dim,
dim_mul[i + 1],
divisor=round_width(num_heads, head_mul[i + 1]),
)
attention_block = MultiScaleBlock(
dim=embed_dim,
dim_out=dim_out,
num_heads=num_heads,
input_size=input_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_rate=self.drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
kernel_q=pool_q[i] if len(pool_q) > i else [],
kernel_kv=pool_kv[i] if len(pool_kv) > i else [],
stride_q=stride_q[i] if len(stride_q) > i else [],
stride_kv=stride_kv[i] if len(stride_kv) > i else [],
mode=mode,
has_cls_embed=self.cls_embed_on,
pool_first=pool_first,
rel_pos_spatial=self.rel_pos_spatial,
rel_pos_temporal=self.rel_pos_temporal,
rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT,
residual_pooling=cfg.MVIT.RESIDUAL_POOLING,
dim_mul_in_att=cfg.MVIT.DIM_MUL_IN_ATT,
separate_qkv=cfg.MVIT.SEPARATE_QKV,
)
self.blocks.append(attention_block)
if len(stride_q[i]) > 0:
input_size = [
size // stride
for size, stride in zip(input_size, stride_q[i])
]
embed_dim = dim_out
self.norm = norm_layer(embed_dim)
if self.enable_detection:
raise Exception("Detection is not supported")
else:
self.head = head_helper.TransformerBasicHead(
2 * embed_dim
if ("concat" in cfg.MVIT.REV.RESPATH_FUSE and self.enable_rev)
else embed_dim,
num_classes,
dropout_rate=cfg.MODEL.DROPOUT_RATE,
act_func=cfg.MODEL.HEAD_ACT,
cfg=cfg,
)
if self.use_abs_pos:
if self.sep_pos_embed:
trunc_normal_(self.pos_embed_spatial, std=0.02)
trunc_normal_(self.pos_embed_temporal, std=0.02)
if self.cls_embed_on:
trunc_normal_(self.pos_embed_class, std=0.02)
else:
trunc_normal_(self.pos_embed, std=0.02)
if self.use_fixed_sincos_pos:
pos_embed = get_3d_sincos_pos_embed(
self.pos_embed.shape[-1],
self.H,
self.T,
cls_token=self.cls_embed_on,
)
self.pos_embed.data.copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0)
)
if self.cls_embed_on:
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
self.head.projection.weight.data.mul_(head_init_scale)
self.head.projection.bias.data.mul_(head_init_scale)
self.feat_size, self.feat_stride = calc_mvit_feature_geometry(cfg)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0.02)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0.02)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
names = []
if self.cfg.MVIT.ZERO_DECAY_POS_CLS:
if self.use_abs_pos:
if self.sep_pos_embed:
names.extend(
[
"pos_embed_spatial",
"pos_embed_temporal",
"pos_embed_class",
]
)
else:
names.append("pos_embed")
if self.rel_pos_spatial:
names.extend(["rel_pos_h", "rel_pos_w", "rel_pos_hw"])
if self.rel_pos_temporal:
names.extend(["rel_pos_t"])
if self.cls_embed_on:
names.append("cls_token")
return names
def _get_pos_embed(self, pos_embed, bcthw):
if len(bcthw) == 4:
t, h, w = 1, bcthw[-2], bcthw[-1]
else:
t, h, w = bcthw[-3], bcthw[-2], bcthw[-1]
if self.cls_embed_on:
cls_pos_embed = pos_embed[:, 0:1, :]
pos_embed = pos_embed[:, 1:]
txy_num = pos_embed.shape[1]
p_t, p_h, p_w = self.patch_dims
assert p_t * p_h * p_w == txy_num
if (p_t, p_h, p_w) != (t, h, w):
new_pos_embed = F.interpolate(
pos_embed[:, :, :]
.reshape(1, p_t, p_h, p_w, -1)
.permute(0, 4, 1, 2, 3),
size=(t, h, w),
mode="trilinear",
)
pos_embed = new_pos_embed.reshape(1, -1, t * h * w).permute(0, 2, 1)
if self.cls_embed_on:
pos_embed = torch.cat((cls_pos_embed, pos_embed), dim=1)
return pos_embed
def _forward_reversible(self, x):
"""
Reversible specific code for forward computation.
"""
# rev does not support cls token or detection
assert not self.cls_embed_on
assert not self.enable_detection
x = self.rev_backbone(x)
if self.use_mean_pooling:
x = self.fuse(x)
x = x.mean(1)
x = self.norm(x)
else:
x = self.norm(x)
x = self.fuse(x)
x = x.mean(1)
x = self.head(x)
return x
def forward(self, x, bboxes=None, return_attn=False):
x = x[0]
x, bcthw = self.patch_embed(x)
bcthw = list(bcthw)
if len(bcthw) == 4: # Fix bcthw in case of 4D tensor
bcthw.insert(2, torch.tensor(self.T))
T, H, W = bcthw[-3], bcthw[-2], bcthw[-1]
assert len(bcthw) == 5 and (T, H, W) == (self.T, self.H, self.W), bcthw
B, N, C = x.shape
s = 1 if self.cls_embed_on else 0
if self.use_fixed_sincos_pos:
x += self.pos_embed[:, s:, :] # s: on/off cls token
if self.cls_embed_on:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
if self.use_fixed_sincos_pos:
cls_tokens = cls_tokens + self.pos_embed[:, :s, :]
x = torch.cat((cls_tokens, x), dim=1)
if self.use_abs_pos:
if self.sep_pos_embed:
pos_embed = self.pos_embed_spatial.repeat(
1, self.patch_dims[0], 1
) + torch.repeat_interleave(
self.pos_embed_temporal,
self.patch_dims[1] * self.patch_dims[2],
dim=1,
)
if self.cls_embed_on:
pos_embed = torch.cat([self.pos_embed_class, pos_embed], 1)
x += self._get_pos_embed(pos_embed, bcthw)
else:
x += self._get_pos_embed(self.pos_embed, bcthw)
if self.drop_rate:
x = self.pos_drop(x)
if self.norm_stem:
x = self.norm_stem(x)
thw = [T, H, W]
if self.enable_rev:
x = self._forward_reversible(x)
else:
for blk in self.blocks:
x, thw = blk(x, thw)
if self.enable_detection:
assert not self.enable_rev
x = self.norm(x)
if self.cls_embed_on:
x = x[:, 1:]
B, _, C = x.shape
x = x.transpose(1, 2).reshape(B, C, thw[0], thw[1], thw[2])
x = self.head([x], bboxes)
else:
if self.use_mean_pooling:
if self.cls_embed_on:
x = x[:, 1:]
x = x.mean(1)
x = self.norm(x)
elif self.cls_embed_on:
x = self.norm(x)
x = x[:, 0]
else: # this is default, [norm->mean]
x = self.norm(x)
x = x.mean(1)
x = self.head(x)
return x