#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import numpy import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.init import trunc_normal_ from .common import DropPath, Mlp def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): if pool is None: return tensor, thw_shape tensor_dim = tensor.ndim if tensor_dim == 4: pass elif tensor_dim == 3: tensor = tensor.unsqueeze(1) else: raise NotImplementedError(f"Unsupported input dimension {tensor.shape}") if has_cls_embed: cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :] B, N, L, C = tensor.shape T, H, W = thw_shape tensor = ( tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() ) tensor = pool(tensor) thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]] L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4] tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3) if has_cls_embed: tensor = torch.cat((cls_tok, tensor), dim=2) if norm is not None: tensor = norm(tensor) # Assert tensor_dim in [3, 4] if tensor_dim == 4: pass else: # tensor_dim == 3: tensor = tensor.squeeze(1) return tensor, thw_shape def get_rel_pos(rel_pos, d): if isinstance(d, int): ori_d = rel_pos.shape[0] if ori_d == d: return rel_pos else: # Interpolate rel pos. new_pos_embed = F.interpolate( rel_pos.reshape(1, ori_d, -1).permute(0, 2, 1), size=d, mode="linear", ) return new_pos_embed.reshape(-1, d).permute(1, 0) def cal_rel_pos_spatial( attn, q, k, has_cls_embed, q_shape, k_shape, rel_pos_h, rel_pos_w ): """ Decomposed Spatial Relative Positional Embeddings. """ sp_idx = 1 if has_cls_embed else 0 q_t, q_h, q_w = q_shape k_t, k_h, k_w = k_shape dh = int(2 * max(q_h, k_h) - 1) dw = int(2 * max(q_w, k_w) - 1) # Scale up rel pos if shapes for q and k are different. q_h_ratio = max(k_h / q_h, 1.0) k_h_ratio = max(q_h / k_h, 1.0) dist_h = ( torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio ) dist_h += (k_h - 1) * k_h_ratio q_w_ratio = max(k_w / q_w, 1.0) k_w_ratio = max(q_w / k_w, 1.0) dist_w = ( torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio ) dist_w += (k_w - 1) * k_w_ratio # Intepolate rel pos if needed. rel_pos_h = get_rel_pos(rel_pos_h, dh) rel_pos_w = get_rel_pos(rel_pos_w, dw) Rh = rel_pos_h[dist_h.long()] Rw = rel_pos_w[dist_w.long()] B, n_head, q_N, dim = q.shape r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim) rel_h_q = torch.einsum( "bythwc,hkc->bythwk", r_q, Rh ) # [B, H, q_t, qh, qw, k_h] rel_w_q = torch.einsum( "bythwc,wkc->bythwk", r_q, Rw ) # [B, H, q_t, qh, qw, k_w] attn[:, :, sp_idx:, sp_idx:] = ( attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) + rel_h_q[:, :, :, :, :, None, :, None] + rel_w_q[:, :, :, :, :, None, None, :] ).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) return attn def cal_rel_pos_temporal(attn, q, has_cls_embed, q_shape, k_shape, rel_pos_t): """ Temporal Relative Positional Embeddings. """ sp_idx = 1 if has_cls_embed else 0 q_t, q_h, q_w = q_shape k_t, k_h, k_w = k_shape dt = int(2 * max(q_t, k_t) - 1) # Intepolate rel pos if needed. rel_pos_t = get_rel_pos(rel_pos_t, dt) # Scale up rel pos if shapes for q and k are different. q_t_ratio = max(k_t / q_t, 1.0) k_t_ratio = max(q_t / k_t, 1.0) dist_t = ( torch.arange(q_t)[:, None] * q_t_ratio - torch.arange(k_t)[None, :] * k_t_ratio ) dist_t += (k_t - 1) * k_t_ratio Rt = rel_pos_t[dist_t.long()] B, n_head, q_N, dim = q.shape r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim) # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape( q_t, B * n_head * q_h * q_w, dim ) # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) attn[:, :, sp_idx:, sp_idx:] = ( attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) + rel[:, :, :, :, :, :, None, None] ).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) return attn class MultiScaleAttention(nn.Module): def __init__( self, dim, dim_out, input_size, num_heads=8, qkv_bias=False, drop_rate=0.0, kernel_q=(1, 1, 1), kernel_kv=(1, 1, 1), stride_q=(1, 1, 1), stride_kv=(1, 1, 1), norm_layer=nn.LayerNorm, has_cls_embed=True, # Options include `conv`, `avg`, and `max`. mode="conv", # If True, perform pool before projection. pool_first=False, rel_pos_spatial=False, rel_pos_temporal=False, rel_pos_zero_init=False, residual_pooling=False, separate_qkv=False, ): super().__init__() self.pool_first = pool_first self.separate_qkv = separate_qkv self.drop_rate = drop_rate self.num_heads = num_heads self.dim_out = dim_out head_dim = dim_out // num_heads self.scale = head_dim**-0.5 self.has_cls_embed = has_cls_embed self.mode = mode padding_q = [int(q // 2) for q in kernel_q] padding_kv = [int(kv // 2) for kv in kernel_kv] if pool_first or separate_qkv: self.q = nn.Linear(dim, dim_out, bias=qkv_bias) self.k = nn.Linear(dim, dim_out, bias=qkv_bias) self.v = nn.Linear(dim, dim_out, bias=qkv_bias) else: self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) self.proj = nn.Linear(dim_out, dim_out) if drop_rate > 0.0: self.proj_drop = nn.Dropout(drop_rate) # Skip pooling with kernel and stride size of (1, 1, 1). if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1: kernel_q = () if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1: kernel_kv = () if mode in ("avg", "max"): pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d self.pool_q = ( pool_op(kernel_q, stride_q, padding_q, ceil_mode=False) if len(kernel_q) > 0 else None ) self.pool_k = ( pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) if len(kernel_kv) > 0 else None ) self.pool_v = ( pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) if len(kernel_kv) > 0 else None ) elif mode == "conv" or mode == "conv_unshared": if pool_first: dim_conv = dim // num_heads if mode == "conv" else dim else: dim_conv = dim_out // num_heads if mode == "conv" else dim_out self.pool_q = ( nn.Conv3d( dim_conv, dim_conv, kernel_q, stride=stride_q, padding=padding_q, groups=dim_conv, bias=False, ) if len(kernel_q) > 0 else None ) self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None self.pool_k = ( nn.Conv3d( dim_conv, dim_conv, kernel_kv, stride=stride_kv, padding=padding_kv, groups=dim_conv, bias=False, ) if len(kernel_kv) > 0 else None ) self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None self.pool_v = ( nn.Conv3d( dim_conv, dim_conv, kernel_kv, stride=stride_kv, padding=padding_kv, groups=dim_conv, bias=False, ) if len(kernel_kv) > 0 else None ) self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None else: raise NotImplementedError(f"Unsupported model {mode}") self.rel_pos_spatial = rel_pos_spatial self.rel_pos_temporal = rel_pos_temporal if self.rel_pos_spatial: assert input_size[1] == input_size[2] size = input_size[1] q_size = size // stride_q[1] if len(stride_q) > 0 else size kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size rel_sp_dim = 2 * max(q_size, kv_size) - 1 self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) if not rel_pos_zero_init: trunc_normal_(self.rel_pos_h, std=0.02) trunc_normal_(self.rel_pos_w, std=0.02) if self.rel_pos_temporal: self.rel_pos_t = nn.Parameter( torch.zeros(2 * input_size[0] - 1, head_dim) ) if not rel_pos_zero_init: trunc_normal_(self.rel_pos_t, std=0.02) self.residual_pooling = residual_pooling def forward(self, x, thw_shape): B, N, _ = x.shape if self.pool_first: if self.mode == "conv_unshared": fold_dim = 1 else: fold_dim = self.num_heads x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3) q = k = v = x else: assert self.mode != "conv_unshared" if not self.separate_qkv: qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, -1) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] else: q = k = v = x q = ( self.q(q) .reshape(B, N, self.num_heads, -1) .permute(0, 2, 1, 3) ) k = ( self.k(k) .reshape(B, N, self.num_heads, -1) .permute(0, 2, 1, 3) ) v = ( self.v(v) .reshape(B, N, self.num_heads, -1) .permute(0, 2, 1, 3) ) q, q_shape = attention_pool( q, self.pool_q, thw_shape, has_cls_embed=self.has_cls_embed, norm=self.norm_q if hasattr(self, "norm_q") else None, ) k, k_shape = attention_pool( k, self.pool_k, thw_shape, has_cls_embed=self.has_cls_embed, norm=self.norm_k if hasattr(self, "norm_k") else None, ) v, v_shape = attention_pool( v, self.pool_v, thw_shape, has_cls_embed=self.has_cls_embed, norm=self.norm_v if hasattr(self, "norm_v") else None, ) if self.pool_first: q_N = ( numpy.prod(q_shape) + 1 if self.has_cls_embed else numpy.prod(q_shape) ) k_N = ( numpy.prod(k_shape) + 1 if self.has_cls_embed else numpy.prod(k_shape) ) v_N = ( numpy.prod(v_shape) + 1 if self.has_cls_embed else numpy.prod(v_shape) ) q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1) q = ( self.q(q) .reshape(B, q_N, self.num_heads, -1) .permute(0, 2, 1, 3) ) v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1) v = ( self.v(v) .reshape(B, v_N, self.num_heads, -1) .permute(0, 2, 1, 3) ) k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1) k = ( self.k(k) .reshape(B, k_N, self.num_heads, -1) .permute(0, 2, 1, 3) ) N = q.shape[2] attn = (q * self.scale) @ k.transpose(-2, -1) if self.rel_pos_spatial: attn = cal_rel_pos_spatial( attn, q, k, self.has_cls_embed, q_shape, k_shape, self.rel_pos_h, self.rel_pos_w, ) if self.rel_pos_temporal: attn = cal_rel_pos_temporal( attn, q, self.has_cls_embed, q_shape, k_shape, self.rel_pos_t, ) attn = attn.softmax(dim=-1) x = attn @ v if self.residual_pooling: if self.has_cls_embed: x[:, :, 1:, :] += q[:, :, 1:, :] else: x = x + q x = x.transpose(1, 2).reshape(B, -1, self.dim_out) x = self.proj(x) if self.drop_rate > 0.0: x = self.proj_drop(x) return x, q_shape class MultiScaleBlock(nn.Module): def __init__( self, dim, dim_out, num_heads, input_size, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, drop_path=0.0, layer_scale_init_value=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, up_rate=None, kernel_q=(1, 1, 1), kernel_kv=(1, 1, 1), stride_q=(1, 1, 1), stride_kv=(1, 1, 1), mode="conv", has_cls_embed=True, pool_first=False, rel_pos_spatial=False, rel_pos_temporal=False, rel_pos_zero_init=False, residual_pooling=False, dim_mul_in_att=False, separate_qkv=False, ): super().__init__() self.dim = dim self.dim_out = dim_out self.norm1 = norm_layer(dim) self.dim_mul_in_att = dim_mul_in_att kernel_skip = [s + 1 if s > 1 else s for s in stride_q] stride_skip = stride_q padding_skip = [int(skip // 2) for skip in kernel_skip] att_dim = dim_out if dim_mul_in_att else dim self.attn = MultiScaleAttention( dim, att_dim, num_heads=num_heads, input_size=input_size, qkv_bias=qkv_bias, drop_rate=drop_rate, kernel_q=kernel_q, kernel_kv=kernel_kv, stride_q=stride_q, stride_kv=stride_kv, norm_layer=norm_layer, has_cls_embed=has_cls_embed, mode=mode, pool_first=pool_first, rel_pos_spatial=rel_pos_spatial, rel_pos_temporal=rel_pos_temporal, rel_pos_zero_init=rel_pos_zero_init, residual_pooling=residual_pooling, separate_qkv=separate_qkv, ) self.drop_path = ( DropPath(drop_path) if drop_path > 0.0 else nn.Identity() ) self.norm2 = norm_layer(att_dim) mlp_hidden_dim = int(att_dim * mlp_ratio) self.has_cls_embed = has_cls_embed # TODO: check the use case for up_rate, and merge the following lines if up_rate is not None and up_rate > 1: mlp_dim_out = dim * up_rate else: mlp_dim_out = dim_out self.mlp = Mlp( in_features=att_dim, hidden_features=mlp_hidden_dim, out_features=mlp_dim_out, act_layer=act_layer, drop_rate=drop_rate, ) if layer_scale_init_value > 0: self.gamma_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True ) self.gamma_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim_out)), requires_grad=True, ) else: self.gamma_1, self.gamma_2 = None, None if dim != dim_out: self.proj = nn.Linear(dim, dim_out) self.pool_skip = ( nn.MaxPool3d( kernel_skip, stride_skip, padding_skip, ceil_mode=False ) if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1 else None ) def forward(self, x, thw_shape=None): x_norm = self.norm1(x) x_block, thw_shape_new = self.attn(x_norm, thw_shape) if self.dim_mul_in_att and self.dim != self.dim_out: x = self.proj(x_norm) x_res, _ = attention_pool( x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed ) if self.gamma_1 is not None: x = x_res + self.drop_path(self.gamma_1 * x_block) else: x = x_res + self.drop_path(x_block) x_norm = self.norm2(x) x_mlp = self.mlp(x_norm) if not self.dim_mul_in_att and self.dim != self.dim_out: x = self.proj(x_norm) if self.gamma_2 is not None: x = x + self.drop_path(self.gamma_2 * x_mlp) else: x = x + self.drop_path(x_mlp) if thw_shape: return x, thw_shape_new else: return x