# Copyright (c) OpenMMLab. All rights reserved. import math import torch import torch.nn as nn from mmcv.cnn import (build_norm_layer, constant_init, normal_init, trunc_normal_init) from mmcv.runner import _load_checkpoint, load_state_dict from ...utils import get_root_logger from ..builder import BACKBONES from ..utils import (PatchEmbed, TCFormerDynamicBlock, TCFormerRegularBlock, TokenConv, cluster_dpc_knn, merge_tokens, tcformer_convert, token2map) class CTM(nn.Module): """Clustering-based Token Merging module in TCFormer. Args: sample_ratio (float): The sample ratio of tokens. embed_dim (int): Input token feature dimension. dim_out (int): Output token feature dimension. k (int): number of the nearest neighbor used i DPC-knn algorithm. """ def __init__(self, sample_ratio, embed_dim, dim_out, k=5): super().__init__() self.sample_ratio = sample_ratio self.dim_out = dim_out self.conv = TokenConv( in_channels=embed_dim, out_channels=dim_out, kernel_size=3, stride=2, padding=1) self.norm = nn.LayerNorm(self.dim_out) self.score = nn.Linear(self.dim_out, 1) self.k = k def forward(self, token_dict): token_dict = token_dict.copy() x = self.conv(token_dict) x = self.norm(x) token_score = self.score(x) token_weight = token_score.exp() token_dict['x'] = x B, N, C = x.shape token_dict['token_score'] = token_score cluster_num = max(math.ceil(N * self.sample_ratio), 1) idx_cluster, cluster_num = cluster_dpc_knn(token_dict, cluster_num, self.k) down_dict = merge_tokens(token_dict, idx_cluster, cluster_num, token_weight) H, W = token_dict['map_size'] H = math.floor((H - 1) / 2 + 1) W = math.floor((W - 1) / 2 + 1) down_dict['map_size'] = [H, W] return down_dict, token_dict @BACKBONES.register_module() class TCFormer(nn.Module): """Token Clustering Transformer (TCFormer) Implementation of `Not All Tokens Are Equal: Human-centric Visual Analysis via Token Clustering Transformer ` Args: in_channels (int): Number of input channels. Default: 3. embed_dims (list[int]): Embedding dimension. Default: [64, 128, 256, 512]. num_heads (Sequence[int]): The attention heads of each transformer encode layer. Default: [1, 2, 5, 8]. mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the embedding dim of each transformer block. qkv_bias (bool): Enable bias for qkv if True. Default: True. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. Default: None. drop_rate (float): Probability of an element to be zeroed. Default 0.0. attn_drop_rate (float): The drop out rate for attention layer. Default 0.0. drop_path_rate (float): stochastic depth rate. Default 0. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN', eps=1e-6). num_layers (Sequence[int]): The layer number of each transformer encode layer. Default: [3, 4, 6, 3]. sr_ratios (Sequence[int]): The spatial reduction rate of each transformer block. Default: [8, 4, 2, 1]. num_stages (int): The num of stages. Default: 4. pretrained (str, optional): model pretrained path. Default: None. k (int): number of the nearest neighbor used for local density. sample_ratios (list[float]): The sample ratios of CTM modules. Default: [0.25, 0.25, 0.25] return_map (bool): If True, transfer dynamic tokens to feature map at last. Default: False convert_weights (bool): The flag indicates whether the pre-trained model is from the original repo. We may need to convert some keys to make it compatible. Default: True. """ def __init__(self, in_channels=3, embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_cfg=dict(type='LN', eps=1e-6), num_layers=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, pretrained=None, k=5, sample_ratios=[0.25, 0.25, 0.25], return_map=False, convert_weights=True): super().__init__() self.num_layers = num_layers self.num_stages = num_stages self.grid_stride = sr_ratios[0] self.embed_dims = embed_dims self.sr_ratios = sr_ratios self.mlp_ratios = mlp_ratios self.sample_ratios = sample_ratios self.return_map = return_map self.convert_weights = convert_weights # stochastic depth decay rule dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers)) ] cur = 0 # In stage 1, use the standard transformer blocks for i in range(1): patch_embed = PatchEmbed( in_channels=in_channels if i == 0 else embed_dims[i - 1], embed_dims=embed_dims[i], kernel_size=7, stride=4, padding=3, bias=True, norm_cfg=dict(type='LN', eps=1e-6)) block = nn.ModuleList([ TCFormerRegularBlock( dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_cfg=norm_cfg, sr_ratio=sr_ratios[i]) for j in range(num_layers[i]) ]) norm = build_norm_layer(norm_cfg, embed_dims[i])[1] cur += num_layers[i] setattr(self, f'patch_embed{i + 1}', patch_embed) setattr(self, f'block{i + 1}', block) setattr(self, f'norm{i + 1}', norm) # In stage 2~4, use TCFormerDynamicBlock for dynamic tokens for i in range(1, num_stages): ctm = CTM(sample_ratios[i - 1], embed_dims[i - 1], embed_dims[i], k) block = nn.ModuleList([ TCFormerDynamicBlock( dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_cfg=norm_cfg, sr_ratio=sr_ratios[i]) for j in range(num_layers[i]) ]) norm = build_norm_layer(norm_cfg, embed_dims[i])[1] cur += num_layers[i] setattr(self, f'ctm{i}', ctm) setattr(self, f'block{i + 1}', block) setattr(self, f'norm{i + 1}', norm) self.init_weights(pretrained) def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint( pretrained, logger=logger, map_location='cpu') logger.warning(f'Load pre-trained model for ' f'{self.__class__.__name__} from original repo') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.convert_weights: # We need to convert pre-trained weights to match this # implementation. state_dict = tcformer_convert(state_dict) load_state_dict(self, state_dict, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m, 0, math.sqrt(2.0 / fan_out)) else: raise TypeError('pretrained must be a str or None') def forward(self, x): outs = [] i = 0 patch_embed = getattr(self, f'patch_embed{i + 1}') block = getattr(self, f'block{i + 1}') norm = getattr(self, f'norm{i + 1}') x, (H, W) = patch_embed(x) for blk in block: x = blk(x, H, W) x = norm(x) # init token dict B, N, _ = x.shape device = x.device idx_token = torch.arange(N)[None, :].repeat(B, 1).to(device) agg_weight = x.new_ones(B, N, 1) token_dict = { 'x': x, 'token_num': N, 'map_size': [H, W], 'init_grid_size': [H, W], 'idx_token': idx_token, 'agg_weight': agg_weight } outs.append(token_dict.copy()) # stage 2~4 for i in range(1, self.num_stages): ctm = getattr(self, f'ctm{i}') block = getattr(self, f'block{i + 1}') norm = getattr(self, f'norm{i + 1}') token_dict = ctm(token_dict) # down sample for j, blk in enumerate(block): token_dict = blk(token_dict) token_dict['x'] = norm(token_dict['x']) outs.append(token_dict) if self.return_map: outs = [token2map(token_dict) for token_dict in outs] return outs