# Copyright (c) OpenMMLab. All rights reserved. from copy import deepcopy from typing import Sequence import numpy as np import torch import torch.nn as nn from mmcv.cnn.bricks.transformer import FFN from mmengine.model import BaseModule, ModuleList from mmengine.model.weight_init import trunc_normal_ from mmpretrain.registry import MODELS from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed, to_2tuple) from .base_backbone import BaseBackbone class T2TTransformerLayer(BaseModule): """Transformer Layer for T2T_ViT. Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports different ``input_dims`` and ``embed_dims``. Args: embed_dims (int): The feature dimension. num_heads (int): Parallel attention heads. feedforward_channels (int): The hidden dimension for FFNs input_dims (int, optional): The input token dimension. Defaults to None. drop_rate (float): Probability of an element to be zeroed after the feed forward layer. Defaults to 0. attn_drop_rate (float): The drop out rate for attention output weights. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0. num_fcs (int): The number of fully-connected layers for FFNs. Defaults to 2. qkv_bias (bool): enable bias for qkv if True. Defaults to True. qk_scale (float, optional): Override default qk scale of ``(input_dims // num_heads) ** -0.5`` if set. Defaults to None. act_cfg (dict): The activation config for FFNs. Defaluts to ``dict(type='GELU')``. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. init_cfg (dict, optional): Initialization config dict. Defaults to None. Notes: In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e. ``(embed_dims // num_heads) ** -0.5``. However, in the official code, it uses ``(input_dims // num_heads) ** -0.5``, so here we keep the same with the official implementation. """ def __init__(self, embed_dims, num_heads, feedforward_channels, input_dims=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., num_fcs=2, qkv_bias=False, qk_scale=None, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), init_cfg=None): super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg) self.v_shortcut = True if input_dims is not None else False input_dims = input_dims or embed_dims self.ln1 = build_norm_layer(norm_cfg, input_dims) self.attn = MultiheadAttention( input_dims=input_dims, embed_dims=embed_dims, num_heads=num_heads, attn_drop=attn_drop_rate, proj_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), qkv_bias=qkv_bias, qk_scale=qk_scale or (input_dims // num_heads)**-0.5, v_shortcut=self.v_shortcut) self.ln2 = build_norm_layer(norm_cfg, embed_dims) self.ffn = FFN( embed_dims=embed_dims, feedforward_channels=feedforward_channels, num_fcs=num_fcs, ffn_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), act_cfg=act_cfg) def forward(self, x): if self.v_shortcut: x = self.attn(self.ln1(x)) else: x = x + self.attn(self.ln1(x)) x = self.ffn(self.ln2(x), identity=x) return x class T2TModule(BaseModule): """Tokens-to-Token module. "Tokens-to-Token module" (T2T Module) can model the local structure information of images and reduce the length of tokens progressively. Args: img_size (int): Input image size in_channels (int): Number of input channels embed_dims (int): Embedding dimension token_dims (int): Tokens dimension in T2TModuleAttention. use_performer (bool): If True, use Performer version self-attention to adopt regular self-attention. Defaults to False. init_cfg (dict, optional): The extra config for initialization. Default: None. Notes: Usually, ``token_dim`` is set as a small value (32 or 64) to reduce MACs """ def __init__( self, img_size=224, in_channels=3, embed_dims=384, token_dims=64, use_performer=False, init_cfg=None, ): super(T2TModule, self).__init__(init_cfg) self.embed_dims = embed_dims self.soft_split0 = nn.Unfold( kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) self.soft_split1 = nn.Unfold( kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) self.soft_split2 = nn.Unfold( kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) if not use_performer: self.attention1 = T2TTransformerLayer( input_dims=in_channels * 7 * 7, embed_dims=token_dims, num_heads=1, feedforward_channels=token_dims) self.attention2 = T2TTransformerLayer( input_dims=token_dims * 3 * 3, embed_dims=token_dims, num_heads=1, feedforward_channels=token_dims) self.project = nn.Linear(token_dims * 3 * 3, embed_dims) else: raise NotImplementedError("Performer hasn't been implemented.") # there are 3 soft split, stride are 4,2,2 separately out_side = img_size // (4 * 2 * 2) self.init_out_size = [out_side, out_side] self.num_patches = out_side**2 @staticmethod def _get_unfold_size(unfold: nn.Unfold, input_size): h, w = input_size kernel_size = to_2tuple(unfold.kernel_size) stride = to_2tuple(unfold.stride) padding = to_2tuple(unfold.padding) dilation = to_2tuple(unfold.dilation) h_out = (h + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1 w_out = (w + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 return (h_out, w_out) def forward(self, x): # step0: soft split hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:]) x = self.soft_split0(x).transpose(1, 2) for step in [1, 2]: # re-structurization/reconstruction attn = getattr(self, f'attention{step}') x = attn(x).transpose(1, 2) B, C, _ = x.shape x = x.reshape(B, C, hw_shape[0], hw_shape[1]) # soft split soft_split = getattr(self, f'soft_split{step}') hw_shape = self._get_unfold_size(soft_split, hw_shape) x = soft_split(x).transpose(1, 2) # final tokens x = self.project(x) return x, hw_shape def get_sinusoid_encoding(n_position, embed_dims): """Generate sinusoid encoding table. Sinusoid encoding is a kind of relative position encoding method came from `Attention Is All You Need`_. Args: n_position (int): The length of the input token. embed_dims (int): The position embedding dimension. Returns: :obj:`torch.FloatTensor`: The sinusoid encoding table. """ def get_position_angle_vec(position): return [ position / np.power(10000, 2 * (i // 2) / embed_dims) for i in range(embed_dims) ] sinusoid_table = np.array( [get_position_angle_vec(pos) for pos in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.FloatTensor(sinusoid_table).unsqueeze(0) @MODELS.register_module() class T2T_ViT(BaseBackbone): """Tokens-to-Token Vision Transformer (T2T-ViT) A PyTorch implementation of `Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet `_ Args: img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to the most common input image shape. Defaults to 224. in_channels (int): Number of input channels. embed_dims (int): Embedding dimension. num_layers (int): Num of transformer layers in encoder. Defaults to 14. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Dropout rate after position embedding. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. out_type (str): The type of output features. Please choose from - ``"cls_token"``: The class token tensor with shape (B, C). - ``"featmap"``: The feature map tensor from the patch tokens with shape (B, C, H, W). - ``"avg_featmap"``: The global averaged feature map tensor with shape (B, C). - ``"raw"``: The raw feature tensor includes patch tokens and class tokens with shape (B, L, C). Defaults to ``"cls_token"``. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". t2t_cfg (dict): Extra config of Tokens-to-Token module. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. init_cfg (dict, optional): The Config for initialization. Defaults to None. """ OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} def __init__(self, img_size=224, in_channels=3, embed_dims=384, num_layers=14, out_indices=-1, drop_rate=0., drop_path_rate=0., norm_cfg=dict(type='LN'), final_norm=True, out_type='cls_token', with_cls_token=True, interpolate_mode='bicubic', t2t_cfg=dict(), layer_cfgs=dict(), init_cfg=None): super().__init__(init_cfg) # Token-to-Token Module self.tokens_to_token = T2TModule( img_size=img_size, in_channels=in_channels, embed_dims=embed_dims, **t2t_cfg) self.patch_resolution = self.tokens_to_token.init_out_size num_patches = self.patch_resolution[0] * self.patch_resolution[1] # Set out type if out_type not in self.OUT_TYPES: raise ValueError(f'Unsupported `out_type` {out_type}, please ' f'choose from {self.OUT_TYPES}') self.out_type = out_type # Set cls token if with_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) self.num_extra_tokens = 1 elif out_type != 'cls_token': self.cls_token = None self.num_extra_tokens = 0 else: raise ValueError( 'with_cls_token must be True when `out_type="cls_token"`.') # Set position embedding self.interpolate_mode = interpolate_mode sinusoid_table = get_sinusoid_encoding( num_patches + self.num_extra_tokens, embed_dims) self.register_buffer('pos_embed', sinusoid_table) self._register_load_state_dict_pre_hook(self._prepare_pos_embed) self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must be a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = num_layers + index assert 0 <= out_indices[i] <= num_layers, \ f'Invalid out_indices {index}' self.out_indices = out_indices # stochastic depth decay rule dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)] self.encoder = ModuleList() for i in range(num_layers): if isinstance(layer_cfgs, Sequence): layer_cfg = layer_cfgs[i] else: layer_cfg = deepcopy(layer_cfgs) layer_cfg = { 'embed_dims': embed_dims, 'num_heads': 6, 'feedforward_channels': 3 * embed_dims, 'drop_path_rate': dpr[i], 'qkv_bias': False, 'norm_cfg': norm_cfg, **layer_cfg } layer = T2TTransformerLayer(**layer_cfg) self.encoder.append(layer) self.final_norm = final_norm if final_norm: self.norm = build_norm_layer(norm_cfg, embed_dims) else: self.norm = nn.Identity() def init_weights(self): super().init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress custom init if use pretrained model. return trunc_normal_(self.cls_token, std=.02) def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): name = prefix + 'pos_embed' if name not in state_dict.keys(): return ckpt_pos_embed_shape = state_dict[name].shape if self.pos_embed.shape != ckpt_pos_embed_shape: from mmengine.logging import MMLogger logger = MMLogger.get_current_instance() logger.info( f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' f'to {self.pos_embed.shape}.') ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) pos_embed_shape = self.tokens_to_token.init_out_size state_dict[name] = resize_pos_embed(state_dict[name], ckpt_pos_embed_shape, pos_embed_shape, self.interpolate_mode, self.num_extra_tokens) def forward(self, x): B = x.shape[0] x, patch_resolution = self.tokens_to_token(x) if self.cls_token is not None: # stole cls_tokens impl from Phil Wang, thanks cls_token = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_token, x), dim=1) x = x + resize_pos_embed( self.pos_embed, self.patch_resolution, patch_resolution, mode=self.interpolate_mode, num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) outs = [] for i, layer in enumerate(self.encoder): x = layer(x) if i == len(self.encoder) - 1 and self.final_norm: x = self.norm(x) if i in self.out_indices: outs.append(self._format_output(x, patch_resolution)) return tuple(outs) def _format_output(self, x, hw): if self.out_type == 'raw': return x if self.out_type == 'cls_token': return x[:, 0] patch_token = x[:, self.num_extra_tokens:] if self.out_type == 'featmap': B = x.size(0) # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) if self.out_type == 'avg_featmap': return patch_token.mean(dim=1)