# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Non-local helper""" import torch import torch.nn as nn class Nonlocal(nn.Module): """ Builds Non-local Neural Networks as a generic family of building blocks for capturing long-range dependencies. Non-local Network computes the response at a position as a weighted sum of the features at all positions. This building block can be plugged into many computer vision architectures. More details in the paper: https://arxiv.org/pdf/1711.07971.pdf """ def __init__( self, dim, dim_inner, pool_size=None, instantiation="softmax", zero_init_final_conv=False, zero_init_final_norm=True, norm_eps=1e-5, norm_momentum=0.1, norm_module=nn.BatchNorm3d, ): """ Args: dim (int): number of dimension for the input. dim_inner (int): number of dimension inside of the Non-local block. pool_size (list): the kernel size of spatial temporal pooling, temporal pool kernel size, spatial pool kernel size, spatial pool kernel size in order. By default pool_size is None, then there would be no pooling used. instantiation (string): supports two different instantiation method: "dot_product": normalizing correlation matrix with L2. "softmax": normalizing correlation matrix with Softmax. zero_init_final_conv (bool): If true, zero initializing the final convolution of the Non-local block. zero_init_final_norm (bool): If true, zero initializing the final batch norm of the Non-local block. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. """ super(Nonlocal, self).__init__() self.dim = dim self.dim_inner = dim_inner self.pool_size = pool_size self.instantiation = instantiation self.use_pool = ( False if pool_size is None else any((size > 1 for size in pool_size)) ) self.norm_eps = norm_eps self.norm_momentum = norm_momentum self._construct_nonlocal( zero_init_final_conv, zero_init_final_norm, norm_module ) def _construct_nonlocal( self, zero_init_final_conv, zero_init_final_norm, norm_module ): # Three convolution heads: theta, phi, and g. self.conv_theta = nn.Conv3d( self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 ) self.conv_phi = nn.Conv3d( self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 ) self.conv_g = nn.Conv3d( self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 ) # Final convolution output. self.conv_out = nn.Conv3d( self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0 ) # Zero initializing the final convolution output. self.conv_out.zero_init = zero_init_final_conv # TODO: change the name to `norm` self.bn = norm_module( num_features=self.dim, eps=self.norm_eps, momentum=self.norm_momentum, ) # Zero initializing the final bn. self.bn.transform_final_bn = zero_init_final_norm # Optional to add the spatial-temporal pooling. if self.use_pool: self.pool = nn.MaxPool3d( kernel_size=self.pool_size, stride=self.pool_size, padding=[0, 0, 0], ) def forward(self, x): x_identity = x N, C, T, H, W = x.size() theta = self.conv_theta(x) # Perform temporal-spatial pooling to reduce the computation. if self.use_pool: x = self.pool(x) phi = self.conv_phi(x) g = self.conv_g(x) theta = theta.view(N, self.dim_inner, -1) phi = phi.view(N, self.dim_inner, -1) g = g.view(N, self.dim_inner, -1) # (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW). theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi)) # For original Non-local paper, there are two main ways to normalize # the affinity tensor: # 1) Softmax normalization (norm on exp). # 2) dot_product normalization. if self.instantiation == "softmax": # Normalizing the affinity tensor theta_phi before softmax. theta_phi = theta_phi * (self.dim_inner ** -0.5) theta_phi = nn.functional.softmax(theta_phi, dim=2) elif self.instantiation == "dot_product": spatial_temporal_dim = theta_phi.shape[2] theta_phi = theta_phi / spatial_temporal_dim else: raise NotImplementedError( "Unknown norm type {}".format(self.instantiation) ) # (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW). theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g)) # (N, C, TxHxW) => (N, C, T, H, W). theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W) p = self.conv_out(theta_phi_g) p = self.bn(p) return x_identity + p