ianpan's picture
Initial commit
231edce
import sys
from functools import partial
import torch
from torch import nn
from torch.autograd import Function as Function
from .attention import MultiScaleAttention, attention_pool
from .common import Mlp, TwoStreamFusion, drop_path
from .utils import round_width
class ReversibleMViT(nn.Module):
"""
Reversible model builder. This builds the reversible transformer encoder
and allows reversible training.
Karttikeya Mangalam, Haoqi Fan, Yanghao Li, Chao-Yuan Wu, Bo Xiong,
Christoph Feichtenhofer, Jitendra Malik
"Reversible Vision Transformers"
https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf
"""
def __init__(self, config, model):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
model (nn.Module): parent MViT module this module forms
a reversible encoder in.
"""
super().__init__()
self.cfg = config
embed_dim = self.cfg.MVIT.EMBED_DIM
depth = self.cfg.MVIT.DEPTH
num_heads = self.cfg.MVIT.NUM_HEADS
mlp_ratio = self.cfg.MVIT.MLP_RATIO
qkv_bias = self.cfg.MVIT.QKV_BIAS
drop_path_rate = self.cfg.MVIT.DROPPATH_RATE
self.dropout = config.MVIT.DROPOUT_RATE
self.pre_q_fusion = self.cfg.MVIT.REV.PRE_Q_FUSION
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
input_size = model.patch_dims
self.layers = nn.ModuleList([])
self.no_custom_backward = False
if self.cfg.MVIT.NORM == "layernorm":
norm_layer = partial(nn.LayerNorm, eps=1e-6)
else:
raise NotImplementedError("Only supports layernorm.")
dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1)
for i in range(len(self.cfg.MVIT.DIM_MUL)):
dim_mul[self.cfg.MVIT.DIM_MUL[i][0]] = self.cfg.MVIT.DIM_MUL[i][1]
for i in range(len(self.cfg.MVIT.HEAD_MUL)):
head_mul[self.cfg.MVIT.HEAD_MUL[i][0]] = self.cfg.MVIT.HEAD_MUL[i][
1
]
pool_q = model.pool_q
pool_kv = model.pool_kv
stride_q = model.stride_q
stride_kv = model.stride_kv
for i in range(depth):
num_heads = round_width(num_heads, head_mul[i])
# Upsampling inside the MHPA, input to the Q-pooling block is lower C dimension
# This localizes the feature changes in a single block, making more computation reversible.
embed_dim = round_width(
embed_dim, dim_mul[i - 1] if i > 0 else 1.0, divisor=num_heads
)
dim_out = round_width(
embed_dim,
dim_mul[i],
divisor=round_width(num_heads, head_mul[i + 1]),
)
if i in self.cfg.MVIT.REV.BUFFER_LAYERS:
layer_type = StageTransitionBlock
input_mult = 2 if "concat" in self.pre_q_fusion else 1
else:
layer_type = ReversibleBlock
input_mult = 1
dimout_correction = (
2 if (input_mult == 2 and "concat" in self.pre_q_fusion) else 1
)
self.layers.append(
layer_type(
dim=embed_dim
* input_mult, # added only for concat fusion before Qpooling layers
input_size=input_size,
dim_out=dim_out * input_mult // dimout_correction,
num_heads=num_heads,
cfg=self.cfg,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
kernel_q=pool_q[i] if len(pool_q) > i else [],
kernel_kv=pool_kv[i] if len(pool_kv) > i else [],
stride_q=stride_q[i] if len(stride_q) > i else [],
stride_kv=stride_kv[i] if len(stride_kv) > i else [],
layer_id=i,
pre_q_fusion=self.pre_q_fusion,
)
)
# F is the attention block
self.layers[-1].F.thw = input_size
if len(stride_q[i]) > 0:
input_size = [
size // stride
for size, stride in zip(input_size, stride_q[i])
]
embed_dim = dim_out
@staticmethod
def vanilla_backward(h, layers, buffer):
"""
Using rev layers without rev backpropagation. Debugging purposes only.
Activated with self.no_custom_backward.
"""
# split into hidden states (h) and attention_output (a)
h, a = torch.chunk(h, 2, dim=-1)
for _, layer in enumerate(layers):
a, h = layer(a, h)
return torch.cat([a, h], dim=-1)
def forward(self, x):
# process the layers in a reversible stack and an irreversible stack.
stack = []
for l_i in range(len(self.layers)):
if isinstance(self.layers[l_i], StageTransitionBlock):
stack.append(("StageTransition", l_i))
else:
if len(stack) == 0 or stack[-1][0] == "StageTransition":
stack.append(("Reversible", []))
stack[-1][1].append(l_i)
for layer_seq in stack:
if layer_seq[0] == "StageTransition":
x = self.layers[layer_seq[1]](x)
else:
x = torch.cat([x, x], dim=-1)
# no need for custom backprop in eval/model stat log
if not self.training or self.no_custom_backward:
executing_fn = ReversibleMViT.vanilla_backward
else:
executing_fn = RevBackProp.apply
x = executing_fn(
x,
self.layers[layer_seq[1][0] : layer_seq[1][-1] + 1],
[], # buffer activations
)
# Apply dropout
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
return x
class RevBackProp(Function):
"""
Custom Backpropagation function to allow (A) flusing memory in foward
and (B) activation recomputation reversibly in backward for gradient calculation.
Inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
"""
@staticmethod
def forward(
ctx,
x,
layers,
buffer_layers, # List of layer ids for int activation to buffer
):
"""
Reversible Forward pass. Any intermediate activations from `buffer_layers` are
cached in ctx for forward pass. This is not necessary for standard usecases.
Each reversible layer implements its own forward pass logic.
"""
buffer_layers.sort()
X_1, X_2 = torch.chunk(x, 2, dim=-1)
intermediate = []
for layer in layers:
X_1, X_2 = layer(X_1, X_2)
if layer.layer_id in buffer_layers:
intermediate.extend([X_1.detach(), X_2.detach()])
if len(buffer_layers) == 0:
all_tensors = [X_1.detach(), X_2.detach()]
else:
intermediate = [torch.LongTensor(buffer_layers), *intermediate]
all_tensors = [X_1.detach(), X_2.detach(), *intermediate]
ctx.save_for_backward(*all_tensors)
ctx.layers = layers
return torch.cat([X_1, X_2], dim=-1)
@staticmethod
def backward(ctx, dx):
"""
Reversible Backward pass. Any intermediate activations from `buffer_layers` are
recovered from ctx. Each layer implements its own loic for backward pass (both
activation recomputation and grad calculation).
"""
dX_1, dX_2 = torch.chunk(dx, 2, dim=-1)
# retrieve params from ctx for backward
X_1, X_2, *int_tensors = ctx.saved_tensors
# no buffering
if len(int_tensors) != 0:
buffer_layers = int_tensors[0].tolist()
else:
buffer_layers = []
layers = ctx.layers
for _, layer in enumerate(layers[::-1]):
if layer.layer_id in buffer_layers:
X_1, X_2, dX_1, dX_2 = layer.backward_pass(
Y_1=int_tensors[
buffer_layers.index(layer.layer_id) * 2 + 1
],
Y_2=int_tensors[
buffer_layers.index(layer.layer_id) * 2 + 2
],
dY_1=dX_1,
dY_2=dX_2,
)
else:
X_1, X_2, dX_1, dX_2 = layer.backward_pass(
Y_1=X_1,
Y_2=X_2,
dY_1=dX_1,
dY_2=dX_2,
)
dx = torch.cat([dX_1, dX_2], dim=-1)
del int_tensors
del dX_1, dX_2, X_1, X_2
return dx, None, None
class StageTransitionBlock(nn.Module):
"""
Blocks for changing the feature dimensions in MViT (using Q-pooling).
See Section 3.3.1 in paper for details.
"""
def __init__(
self,
dim,
input_size,
dim_out,
num_heads,
mlp_ratio,
qkv_bias,
drop_path,
kernel_q,
kernel_kv,
stride_q,
stride_kv,
cfg,
norm_layer=nn.LayerNorm,
pre_q_fusion=None,
layer_id=0,
):
"""
Uses the same structure of F and G functions as Reversible Block except
without using reversible forward (and backward) pass.
"""
super().__init__()
self.drop_path_rate = drop_path
embed_dim = dim
self.F = AttentionSubBlock(
dim=embed_dim,
input_size=input_size,
num_heads=num_heads,
cfg=cfg,
dim_out=dim_out,
kernel_q=kernel_q,
kernel_kv=kernel_kv,
stride_q=stride_q,
stride_kv=stride_kv,
norm_layer=norm_layer,
)
self.G = MLPSubblock(
dim=dim_out,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
)
self.layer_id = layer_id
self.is_proj = False
self.has_cls_embed = cfg.MVIT.CLS_EMBED_ON
self.is_conv = False
self.pool_first = cfg.MVIT.POOL_FIRST
self.mode = cfg.MVIT.MODE
self.pre_q_fuse = TwoStreamFusion(pre_q_fusion, dim=dim)
if cfg.MVIT.REV.RES_PATH == "max":
self.res_conv = False
self.pool_skip = nn.MaxPool3d(
# self.attention.attn.pool_q.kernel_size,
[s + 1 if s > 1 else s for s in self.F.attn.pool_q.stride],
self.F.attn.pool_q.stride,
[int(k // 2) for k in self.F.attn.pool_q.stride],
# self.attention.attn.pool_q.padding,
ceil_mode=False,
)
elif cfg.MVIT.REV.RES_PATH == "conv":
self.res_conv = True
else:
raise NotImplementedError
# Add a linear projection in residual branch
if embed_dim != dim_out:
self.is_proj = True
self.res_proj = nn.Linear(embed_dim, dim_out, bias=True)
def forward(
self,
x,
):
"""
Forward logic is similar to MultiScaleBlock with Q-pooling.
"""
x = self.pre_q_fuse(x)
# fork tensor for residual connections
x_res = x
# This uses conv to pool the residual hidden features
# but done before pooling only if not pool_first
if self.is_proj and not self.pool_first:
x_res = self.res_proj(x_res)
if self.res_conv:
# Pooling the hidden features with the same conv as Q
N, L, C = x_res.shape
# This handling is the same as that of q in MultiScaleAttention
if self.mode == "conv_unshared":
fold_dim = 1
else:
fold_dim = self.F.attn.num_heads
# Output is (B, N, L, C)
x_res = x_res.reshape(N, L, fold_dim, C // fold_dim).permute(
0, 2, 1, 3
)
x_res, _ = attention_pool(
x_res,
self.F.attn.pool_q,
# thw_shape = self.attention.attn.thw,
thw_shape=self.F.thw,
has_cls_embed=self.has_cls_embed,
norm=self.F.attn.norm_q
if hasattr(self.F.attn, "norm_q")
else None,
)
x_res = x_res.permute(0, 2, 1, 3).reshape(N, x_res.shape[2], C)
else:
# Pooling the hidden features with max op
x_res, _ = attention_pool(
x_res,
self.pool_skip,
thw_shape=self.F.attn.thw,
has_cls_embed=self.has_cls_embed,
)
# If pool_first then project to higher dim now
if self.is_proj and self.pool_first:
x_res = self.res_proj(x_res)
x = self.F(x)
x = x_res + x
x = x + self.G(x)
x = drop_path(x, drop_prob=self.drop_path_rate, training=self.training)
return x
class ReversibleBlock(nn.Module):
"""
Reversible Blocks for Reversible Vision Transformer and also
for state-preserving blocks in Reversible MViT. See Section
3.3.2 in paper for details.
"""
def __init__(
self,
dim,
input_size,
dim_out,
num_heads,
mlp_ratio,
qkv_bias,
drop_path,
kernel_q,
kernel_kv,
stride_q,
stride_kv,
cfg,
norm_layer=nn.LayerNorm,
layer_id=0,
**kwargs
):
"""
Block is composed entirely of function F (Attention
sub-block) and G (MLP sub-block) including layernorm.
"""
super().__init__()
self.drop_path_rate = drop_path
self.F = AttentionSubBlock(
dim=dim,
input_size=input_size,
num_heads=num_heads,
cfg=cfg,
dim_out=dim_out,
kernel_q=kernel_q,
kernel_kv=kernel_kv,
stride_q=stride_q,
stride_kv=stride_kv,
norm_layer=norm_layer,
)
self.G = MLPSubblock(
dim=dim,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
)
self.layer_id = layer_id
self.seeds = {}
def seed_cuda(self, key):
"""
Fix seeds to allow for stochastic elements such as
dropout to be reproduced exactly in activation
recomputation in the backward pass.
"""
# randomize seeds
# use cuda generator if available
if (
hasattr(torch.cuda, "default_generators")
and len(torch.cuda.default_generators) > 0
):
# GPU
device_idx = torch.cuda.current_device()
seed = torch.cuda.default_generators[device_idx].seed()
else:
# CPU
seed = int(torch.seed() % sys.maxsize)
self.seeds[key] = seed
torch.manual_seed(self.seeds[key])
def forward(self, X_1, X_2):
"""
forward pass equations:
Y_1 = X_1 + Attention(X_2), F = Attention
Y_2 = X_2 + MLP(Y_1), G = MLP
"""
self.seed_cuda("attn")
# Y_1 : attn_output
f_X_2 = self.F(X_2)
self.seed_cuda("droppath")
f_X_2_dropped = drop_path(
f_X_2, drop_prob=self.drop_path_rate, training=self.training
)
# Y_1 = X_1 + f(X_2)
Y_1 = X_1 + f_X_2_dropped
# free memory
del X_1
self.seed_cuda("FFN")
g_Y_1 = self.G(Y_1)
torch.manual_seed(self.seeds["droppath"])
g_Y_1_dropped = drop_path(
g_Y_1, drop_prob=self.drop_path_rate, training=self.training
)
# Y_2 = X_2 + g(Y_1)
Y_2 = X_2 + g_Y_1_dropped
del X_2
return Y_1, Y_2
def backward_pass(
self,
Y_1,
Y_2,
dY_1,
dY_2,
):
"""
equation for activation recomputation:
X_2 = Y_2 - G(Y_1), G = MLP
X_1 = Y_1 - F(X_2), F = Attention
"""
# temporarily record intermediate activation for G
# and use them for gradient calculcation of G
with torch.enable_grad():
Y_1.requires_grad = True
torch.manual_seed(self.seeds["FFN"])
g_Y_1 = self.G(Y_1)
torch.manual_seed(self.seeds["droppath"])
g_Y_1 = drop_path(
g_Y_1, drop_prob=self.drop_path_rate, training=self.training
)
g_Y_1.backward(dY_2, retain_graph=True)
# activation recomputation is by design and not part of
# the computation graph in forward pass.
with torch.no_grad():
X_2 = Y_2 - g_Y_1
del g_Y_1
dY_1 = dY_1 + Y_1.grad
Y_1.grad = None
# record F activations and calc gradients on F
with torch.enable_grad():
X_2.requires_grad = True
torch.manual_seed(self.seeds["attn"])
f_X_2 = self.F(X_2)
torch.manual_seed(self.seeds["droppath"])
f_X_2 = drop_path(
f_X_2, drop_prob=self.drop_path_rate, training=self.training
)
f_X_2.backward(dY_1, retain_graph=True)
# propagate reverse computed acitvations at the start of
# the previou block for backprop.s
with torch.no_grad():
X_1 = Y_1 - f_X_2
del f_X_2, Y_1
dY_2 = dY_2 + X_2.grad
X_2.grad = None
X_2 = X_2.detach()
return X_1, X_2, dY_1, dY_2
class MLPSubblock(nn.Module):
"""
This creates the function G such that the entire block can be
expressed as F(G(X)). Includes pre-LayerNorm.
"""
def __init__(
self,
dim,
mlp_ratio,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=nn.GELU,
)
def forward(self, x):
return self.mlp(self.norm(x))
class AttentionSubBlock(nn.Module):
"""
This creates the function F such that the entire block can be
expressed as F(G(X)). Includes pre-LayerNorm.
"""
def __init__(
self,
dim,
input_size,
num_heads,
cfg,
dim_out=None,
kernel_q=(1, 1, 1),
kernel_kv=(1, 1, 1),
stride_q=(1, 1, 1),
stride_kv=(1, 1, 1),
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True)
# This will be set externally during init
self.thw = None
# the actual attention details are the same as Multiscale
# attention for MViTv2 (with channel up=projection inside block)
# can also implement no upprojection attention for vanilla ViT
self.attn = MultiScaleAttention(
dim,
dim_out,
input_size=input_size,
num_heads=num_heads,
kernel_q=kernel_q,
kernel_kv=kernel_kv,
stride_q=stride_q,
stride_kv=stride_kv,
norm_layer=norm_layer,
drop_rate=cfg.MVIT.DROPOUT_RATE,
qkv_bias=cfg.MVIT.QKV_BIAS,
has_cls_embed=cfg.MVIT.CLS_EMBED_ON,
mode=cfg.MVIT.MODE,
pool_first=cfg.MVIT.POOL_FIRST,
rel_pos_spatial=cfg.MVIT.REL_POS_SPATIAL,
rel_pos_temporal=cfg.MVIT.REL_POS_TEMPORAL,
rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT,
residual_pooling=cfg.MVIT.RESIDUAL_POOLING,
separate_qkv=cfg.MVIT.SEPARATE_QKV,
)
def forward(self, x):
out, _ = self.attn(self.norm(x), self.thw)
return out