Spaces:
Runtime error
Runtime error
import sys | |
from functools import partial | |
import torch | |
from torch import nn | |
from torch.autograd import Function as Function | |
from .attention import MultiScaleAttention, attention_pool | |
from .common import Mlp, TwoStreamFusion, drop_path | |
from .utils import round_width | |
class ReversibleMViT(nn.Module): | |
""" | |
Reversible model builder. This builds the reversible transformer encoder | |
and allows reversible training. | |
Karttikeya Mangalam, Haoqi Fan, Yanghao Li, Chao-Yuan Wu, Bo Xiong, | |
Christoph Feichtenhofer, Jitendra Malik | |
"Reversible Vision Transformers" | |
https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf | |
""" | |
def __init__(self, config, model): | |
""" | |
The `__init__` method of any subclass should also contain these | |
arguments. | |
Args: | |
cfg (CfgNode): model building configs, details are in the | |
comments of the config file. | |
model (nn.Module): parent MViT module this module forms | |
a reversible encoder in. | |
""" | |
super().__init__() | |
self.cfg = config | |
embed_dim = self.cfg.MVIT.EMBED_DIM | |
depth = self.cfg.MVIT.DEPTH | |
num_heads = self.cfg.MVIT.NUM_HEADS | |
mlp_ratio = self.cfg.MVIT.MLP_RATIO | |
qkv_bias = self.cfg.MVIT.QKV_BIAS | |
drop_path_rate = self.cfg.MVIT.DROPPATH_RATE | |
self.dropout = config.MVIT.DROPOUT_RATE | |
self.pre_q_fusion = self.cfg.MVIT.REV.PRE_Q_FUSION | |
dpr = [ | |
x.item() for x in torch.linspace(0, drop_path_rate, depth) | |
] # stochastic depth decay rule | |
input_size = model.patch_dims | |
self.layers = nn.ModuleList([]) | |
self.no_custom_backward = False | |
if self.cfg.MVIT.NORM == "layernorm": | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
else: | |
raise NotImplementedError("Only supports layernorm.") | |
dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) | |
for i in range(len(self.cfg.MVIT.DIM_MUL)): | |
dim_mul[self.cfg.MVIT.DIM_MUL[i][0]] = self.cfg.MVIT.DIM_MUL[i][1] | |
for i in range(len(self.cfg.MVIT.HEAD_MUL)): | |
head_mul[self.cfg.MVIT.HEAD_MUL[i][0]] = self.cfg.MVIT.HEAD_MUL[i][ | |
1 | |
] | |
pool_q = model.pool_q | |
pool_kv = model.pool_kv | |
stride_q = model.stride_q | |
stride_kv = model.stride_kv | |
for i in range(depth): | |
num_heads = round_width(num_heads, head_mul[i]) | |
# Upsampling inside the MHPA, input to the Q-pooling block is lower C dimension | |
# This localizes the feature changes in a single block, making more computation reversible. | |
embed_dim = round_width( | |
embed_dim, dim_mul[i - 1] if i > 0 else 1.0, divisor=num_heads | |
) | |
dim_out = round_width( | |
embed_dim, | |
dim_mul[i], | |
divisor=round_width(num_heads, head_mul[i + 1]), | |
) | |
if i in self.cfg.MVIT.REV.BUFFER_LAYERS: | |
layer_type = StageTransitionBlock | |
input_mult = 2 if "concat" in self.pre_q_fusion else 1 | |
else: | |
layer_type = ReversibleBlock | |
input_mult = 1 | |
dimout_correction = ( | |
2 if (input_mult == 2 and "concat" in self.pre_q_fusion) else 1 | |
) | |
self.layers.append( | |
layer_type( | |
dim=embed_dim | |
* input_mult, # added only for concat fusion before Qpooling layers | |
input_size=input_size, | |
dim_out=dim_out * input_mult // dimout_correction, | |
num_heads=num_heads, | |
cfg=self.cfg, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
drop_path=dpr[i], | |
norm_layer=norm_layer, | |
kernel_q=pool_q[i] if len(pool_q) > i else [], | |
kernel_kv=pool_kv[i] if len(pool_kv) > i else [], | |
stride_q=stride_q[i] if len(stride_q) > i else [], | |
stride_kv=stride_kv[i] if len(stride_kv) > i else [], | |
layer_id=i, | |
pre_q_fusion=self.pre_q_fusion, | |
) | |
) | |
# F is the attention block | |
self.layers[-1].F.thw = input_size | |
if len(stride_q[i]) > 0: | |
input_size = [ | |
size // stride | |
for size, stride in zip(input_size, stride_q[i]) | |
] | |
embed_dim = dim_out | |
def vanilla_backward(h, layers, buffer): | |
""" | |
Using rev layers without rev backpropagation. Debugging purposes only. | |
Activated with self.no_custom_backward. | |
""" | |
# split into hidden states (h) and attention_output (a) | |
h, a = torch.chunk(h, 2, dim=-1) | |
for _, layer in enumerate(layers): | |
a, h = layer(a, h) | |
return torch.cat([a, h], dim=-1) | |
def forward(self, x): | |
# process the layers in a reversible stack and an irreversible stack. | |
stack = [] | |
for l_i in range(len(self.layers)): | |
if isinstance(self.layers[l_i], StageTransitionBlock): | |
stack.append(("StageTransition", l_i)) | |
else: | |
if len(stack) == 0 or stack[-1][0] == "StageTransition": | |
stack.append(("Reversible", [])) | |
stack[-1][1].append(l_i) | |
for layer_seq in stack: | |
if layer_seq[0] == "StageTransition": | |
x = self.layers[layer_seq[1]](x) | |
else: | |
x = torch.cat([x, x], dim=-1) | |
# no need for custom backprop in eval/model stat log | |
if not self.training or self.no_custom_backward: | |
executing_fn = ReversibleMViT.vanilla_backward | |
else: | |
executing_fn = RevBackProp.apply | |
x = executing_fn( | |
x, | |
self.layers[layer_seq[1][0] : layer_seq[1][-1] + 1], | |
[], # buffer activations | |
) | |
# Apply dropout | |
x = nn.functional.dropout(x, p=self.dropout, training=self.training) | |
return x | |
class RevBackProp(Function): | |
""" | |
Custom Backpropagation function to allow (A) flusing memory in foward | |
and (B) activation recomputation reversibly in backward for gradient calculation. | |
Inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py | |
""" | |
def forward( | |
ctx, | |
x, | |
layers, | |
buffer_layers, # List of layer ids for int activation to buffer | |
): | |
""" | |
Reversible Forward pass. Any intermediate activations from `buffer_layers` are | |
cached in ctx for forward pass. This is not necessary for standard usecases. | |
Each reversible layer implements its own forward pass logic. | |
""" | |
buffer_layers.sort() | |
X_1, X_2 = torch.chunk(x, 2, dim=-1) | |
intermediate = [] | |
for layer in layers: | |
X_1, X_2 = layer(X_1, X_2) | |
if layer.layer_id in buffer_layers: | |
intermediate.extend([X_1.detach(), X_2.detach()]) | |
if len(buffer_layers) == 0: | |
all_tensors = [X_1.detach(), X_2.detach()] | |
else: | |
intermediate = [torch.LongTensor(buffer_layers), *intermediate] | |
all_tensors = [X_1.detach(), X_2.detach(), *intermediate] | |
ctx.save_for_backward(*all_tensors) | |
ctx.layers = layers | |
return torch.cat([X_1, X_2], dim=-1) | |
def backward(ctx, dx): | |
""" | |
Reversible Backward pass. Any intermediate activations from `buffer_layers` are | |
recovered from ctx. Each layer implements its own loic for backward pass (both | |
activation recomputation and grad calculation). | |
""" | |
dX_1, dX_2 = torch.chunk(dx, 2, dim=-1) | |
# retrieve params from ctx for backward | |
X_1, X_2, *int_tensors = ctx.saved_tensors | |
# no buffering | |
if len(int_tensors) != 0: | |
buffer_layers = int_tensors[0].tolist() | |
else: | |
buffer_layers = [] | |
layers = ctx.layers | |
for _, layer in enumerate(layers[::-1]): | |
if layer.layer_id in buffer_layers: | |
X_1, X_2, dX_1, dX_2 = layer.backward_pass( | |
Y_1=int_tensors[ | |
buffer_layers.index(layer.layer_id) * 2 + 1 | |
], | |
Y_2=int_tensors[ | |
buffer_layers.index(layer.layer_id) * 2 + 2 | |
], | |
dY_1=dX_1, | |
dY_2=dX_2, | |
) | |
else: | |
X_1, X_2, dX_1, dX_2 = layer.backward_pass( | |
Y_1=X_1, | |
Y_2=X_2, | |
dY_1=dX_1, | |
dY_2=dX_2, | |
) | |
dx = torch.cat([dX_1, dX_2], dim=-1) | |
del int_tensors | |
del dX_1, dX_2, X_1, X_2 | |
return dx, None, None | |
class StageTransitionBlock(nn.Module): | |
""" | |
Blocks for changing the feature dimensions in MViT (using Q-pooling). | |
See Section 3.3.1 in paper for details. | |
""" | |
def __init__( | |
self, | |
dim, | |
input_size, | |
dim_out, | |
num_heads, | |
mlp_ratio, | |
qkv_bias, | |
drop_path, | |
kernel_q, | |
kernel_kv, | |
stride_q, | |
stride_kv, | |
cfg, | |
norm_layer=nn.LayerNorm, | |
pre_q_fusion=None, | |
layer_id=0, | |
): | |
""" | |
Uses the same structure of F and G functions as Reversible Block except | |
without using reversible forward (and backward) pass. | |
""" | |
super().__init__() | |
self.drop_path_rate = drop_path | |
embed_dim = dim | |
self.F = AttentionSubBlock( | |
dim=embed_dim, | |
input_size=input_size, | |
num_heads=num_heads, | |
cfg=cfg, | |
dim_out=dim_out, | |
kernel_q=kernel_q, | |
kernel_kv=kernel_kv, | |
stride_q=stride_q, | |
stride_kv=stride_kv, | |
norm_layer=norm_layer, | |
) | |
self.G = MLPSubblock( | |
dim=dim_out, | |
mlp_ratio=mlp_ratio, | |
norm_layer=norm_layer, | |
) | |
self.layer_id = layer_id | |
self.is_proj = False | |
self.has_cls_embed = cfg.MVIT.CLS_EMBED_ON | |
self.is_conv = False | |
self.pool_first = cfg.MVIT.POOL_FIRST | |
self.mode = cfg.MVIT.MODE | |
self.pre_q_fuse = TwoStreamFusion(pre_q_fusion, dim=dim) | |
if cfg.MVIT.REV.RES_PATH == "max": | |
self.res_conv = False | |
self.pool_skip = nn.MaxPool3d( | |
# self.attention.attn.pool_q.kernel_size, | |
[s + 1 if s > 1 else s for s in self.F.attn.pool_q.stride], | |
self.F.attn.pool_q.stride, | |
[int(k // 2) for k in self.F.attn.pool_q.stride], | |
# self.attention.attn.pool_q.padding, | |
ceil_mode=False, | |
) | |
elif cfg.MVIT.REV.RES_PATH == "conv": | |
self.res_conv = True | |
else: | |
raise NotImplementedError | |
# Add a linear projection in residual branch | |
if embed_dim != dim_out: | |
self.is_proj = True | |
self.res_proj = nn.Linear(embed_dim, dim_out, bias=True) | |
def forward( | |
self, | |
x, | |
): | |
""" | |
Forward logic is similar to MultiScaleBlock with Q-pooling. | |
""" | |
x = self.pre_q_fuse(x) | |
# fork tensor for residual connections | |
x_res = x | |
# This uses conv to pool the residual hidden features | |
# but done before pooling only if not pool_first | |
if self.is_proj and not self.pool_first: | |
x_res = self.res_proj(x_res) | |
if self.res_conv: | |
# Pooling the hidden features with the same conv as Q | |
N, L, C = x_res.shape | |
# This handling is the same as that of q in MultiScaleAttention | |
if self.mode == "conv_unshared": | |
fold_dim = 1 | |
else: | |
fold_dim = self.F.attn.num_heads | |
# Output is (B, N, L, C) | |
x_res = x_res.reshape(N, L, fold_dim, C // fold_dim).permute( | |
0, 2, 1, 3 | |
) | |
x_res, _ = attention_pool( | |
x_res, | |
self.F.attn.pool_q, | |
# thw_shape = self.attention.attn.thw, | |
thw_shape=self.F.thw, | |
has_cls_embed=self.has_cls_embed, | |
norm=self.F.attn.norm_q | |
if hasattr(self.F.attn, "norm_q") | |
else None, | |
) | |
x_res = x_res.permute(0, 2, 1, 3).reshape(N, x_res.shape[2], C) | |
else: | |
# Pooling the hidden features with max op | |
x_res, _ = attention_pool( | |
x_res, | |
self.pool_skip, | |
thw_shape=self.F.attn.thw, | |
has_cls_embed=self.has_cls_embed, | |
) | |
# If pool_first then project to higher dim now | |
if self.is_proj and self.pool_first: | |
x_res = self.res_proj(x_res) | |
x = self.F(x) | |
x = x_res + x | |
x = x + self.G(x) | |
x = drop_path(x, drop_prob=self.drop_path_rate, training=self.training) | |
return x | |
class ReversibleBlock(nn.Module): | |
""" | |
Reversible Blocks for Reversible Vision Transformer and also | |
for state-preserving blocks in Reversible MViT. See Section | |
3.3.2 in paper for details. | |
""" | |
def __init__( | |
self, | |
dim, | |
input_size, | |
dim_out, | |
num_heads, | |
mlp_ratio, | |
qkv_bias, | |
drop_path, | |
kernel_q, | |
kernel_kv, | |
stride_q, | |
stride_kv, | |
cfg, | |
norm_layer=nn.LayerNorm, | |
layer_id=0, | |
**kwargs | |
): | |
""" | |
Block is composed entirely of function F (Attention | |
sub-block) and G (MLP sub-block) including layernorm. | |
""" | |
super().__init__() | |
self.drop_path_rate = drop_path | |
self.F = AttentionSubBlock( | |
dim=dim, | |
input_size=input_size, | |
num_heads=num_heads, | |
cfg=cfg, | |
dim_out=dim_out, | |
kernel_q=kernel_q, | |
kernel_kv=kernel_kv, | |
stride_q=stride_q, | |
stride_kv=stride_kv, | |
norm_layer=norm_layer, | |
) | |
self.G = MLPSubblock( | |
dim=dim, | |
mlp_ratio=mlp_ratio, | |
norm_layer=norm_layer, | |
) | |
self.layer_id = layer_id | |
self.seeds = {} | |
def seed_cuda(self, key): | |
""" | |
Fix seeds to allow for stochastic elements such as | |
dropout to be reproduced exactly in activation | |
recomputation in the backward pass. | |
""" | |
# randomize seeds | |
# use cuda generator if available | |
if ( | |
hasattr(torch.cuda, "default_generators") | |
and len(torch.cuda.default_generators) > 0 | |
): | |
# GPU | |
device_idx = torch.cuda.current_device() | |
seed = torch.cuda.default_generators[device_idx].seed() | |
else: | |
# CPU | |
seed = int(torch.seed() % sys.maxsize) | |
self.seeds[key] = seed | |
torch.manual_seed(self.seeds[key]) | |
def forward(self, X_1, X_2): | |
""" | |
forward pass equations: | |
Y_1 = X_1 + Attention(X_2), F = Attention | |
Y_2 = X_2 + MLP(Y_1), G = MLP | |
""" | |
self.seed_cuda("attn") | |
# Y_1 : attn_output | |
f_X_2 = self.F(X_2) | |
self.seed_cuda("droppath") | |
f_X_2_dropped = drop_path( | |
f_X_2, drop_prob=self.drop_path_rate, training=self.training | |
) | |
# Y_1 = X_1 + f(X_2) | |
Y_1 = X_1 + f_X_2_dropped | |
# free memory | |
del X_1 | |
self.seed_cuda("FFN") | |
g_Y_1 = self.G(Y_1) | |
torch.manual_seed(self.seeds["droppath"]) | |
g_Y_1_dropped = drop_path( | |
g_Y_1, drop_prob=self.drop_path_rate, training=self.training | |
) | |
# Y_2 = X_2 + g(Y_1) | |
Y_2 = X_2 + g_Y_1_dropped | |
del X_2 | |
return Y_1, Y_2 | |
def backward_pass( | |
self, | |
Y_1, | |
Y_2, | |
dY_1, | |
dY_2, | |
): | |
""" | |
equation for activation recomputation: | |
X_2 = Y_2 - G(Y_1), G = MLP | |
X_1 = Y_1 - F(X_2), F = Attention | |
""" | |
# temporarily record intermediate activation for G | |
# and use them for gradient calculcation of G | |
with torch.enable_grad(): | |
Y_1.requires_grad = True | |
torch.manual_seed(self.seeds["FFN"]) | |
g_Y_1 = self.G(Y_1) | |
torch.manual_seed(self.seeds["droppath"]) | |
g_Y_1 = drop_path( | |
g_Y_1, drop_prob=self.drop_path_rate, training=self.training | |
) | |
g_Y_1.backward(dY_2, retain_graph=True) | |
# activation recomputation is by design and not part of | |
# the computation graph in forward pass. | |
with torch.no_grad(): | |
X_2 = Y_2 - g_Y_1 | |
del g_Y_1 | |
dY_1 = dY_1 + Y_1.grad | |
Y_1.grad = None | |
# record F activations and calc gradients on F | |
with torch.enable_grad(): | |
X_2.requires_grad = True | |
torch.manual_seed(self.seeds["attn"]) | |
f_X_2 = self.F(X_2) | |
torch.manual_seed(self.seeds["droppath"]) | |
f_X_2 = drop_path( | |
f_X_2, drop_prob=self.drop_path_rate, training=self.training | |
) | |
f_X_2.backward(dY_1, retain_graph=True) | |
# propagate reverse computed acitvations at the start of | |
# the previou block for backprop.s | |
with torch.no_grad(): | |
X_1 = Y_1 - f_X_2 | |
del f_X_2, Y_1 | |
dY_2 = dY_2 + X_2.grad | |
X_2.grad = None | |
X_2 = X_2.detach() | |
return X_1, X_2, dY_1, dY_2 | |
class MLPSubblock(nn.Module): | |
""" | |
This creates the function G such that the entire block can be | |
expressed as F(G(X)). Includes pre-LayerNorm. | |
""" | |
def __init__( | |
self, | |
dim, | |
mlp_ratio, | |
norm_layer=nn.LayerNorm, | |
): | |
super().__init__() | |
self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=nn.GELU, | |
) | |
def forward(self, x): | |
return self.mlp(self.norm(x)) | |
class AttentionSubBlock(nn.Module): | |
""" | |
This creates the function F such that the entire block can be | |
expressed as F(G(X)). Includes pre-LayerNorm. | |
""" | |
def __init__( | |
self, | |
dim, | |
input_size, | |
num_heads, | |
cfg, | |
dim_out=None, | |
kernel_q=(1, 1, 1), | |
kernel_kv=(1, 1, 1), | |
stride_q=(1, 1, 1), | |
stride_kv=(1, 1, 1), | |
norm_layer=nn.LayerNorm, | |
): | |
super().__init__() | |
self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True) | |
# This will be set externally during init | |
self.thw = None | |
# the actual attention details are the same as Multiscale | |
# attention for MViTv2 (with channel up=projection inside block) | |
# can also implement no upprojection attention for vanilla ViT | |
self.attn = MultiScaleAttention( | |
dim, | |
dim_out, | |
input_size=input_size, | |
num_heads=num_heads, | |
kernel_q=kernel_q, | |
kernel_kv=kernel_kv, | |
stride_q=stride_q, | |
stride_kv=stride_kv, | |
norm_layer=norm_layer, | |
drop_rate=cfg.MVIT.DROPOUT_RATE, | |
qkv_bias=cfg.MVIT.QKV_BIAS, | |
has_cls_embed=cfg.MVIT.CLS_EMBED_ON, | |
mode=cfg.MVIT.MODE, | |
pool_first=cfg.MVIT.POOL_FIRST, | |
rel_pos_spatial=cfg.MVIT.REL_POS_SPATIAL, | |
rel_pos_temporal=cfg.MVIT.REL_POS_TEMPORAL, | |
rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT, | |
residual_pooling=cfg.MVIT.RESIDUAL_POOLING, | |
separate_qkv=cfg.MVIT.SEPARATE_QKV, | |
) | |
def forward(self, x): | |
out, _ = self.attn(self.norm(x), self.thw) | |
return out |