#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ResNe(X)t Head helper.""" import torch import torch.nn as nn from .batchnorm_helper import ( NaiveSyncBatchNorm1d as NaiveSyncBatchNorm1d, ) class MLPHead(nn.Module): def __init__( self, dim_in, dim_out, mlp_dim, num_layers, bn_on=False, bias=True, flatten=False, xavier_init=True, bn_sync_num=1, global_sync=False, ): super(MLPHead, self).__init__() self.flatten = flatten b = False if bn_on else bias # assert bn_on or bn_sync_num=1 mlp_layers = [nn.Linear(dim_in, mlp_dim, bias=b)] mlp_layers[-1].xavier_init = xavier_init for i in range(1, num_layers): if bn_on: if global_sync or bn_sync_num > 1: mlp_layers.append( NaiveSyncBatchNorm1d( num_sync_devices=bn_sync_num, global_sync=global_sync, num_features=mlp_dim, ) ) else: mlp_layers.append(nn.BatchNorm1d(num_features=mlp_dim)) mlp_layers.append(nn.ReLU(inplace=True)) if i == num_layers - 1: d = dim_out b = bias else: d = mlp_dim mlp_layers.append(nn.Linear(mlp_dim, d, bias=b)) mlp_layers[-1].xavier_init = xavier_init self.projection = nn.Sequential(*mlp_layers) def forward(self, x): if x.ndim == 5: x = x.permute((0, 2, 3, 4, 1)) if self.flatten: x = x.reshape(-1, x.shape[-1]) return self.projection(x) class TransformerBasicHead(nn.Module): """ BasicHead. No pool. """ def __init__( self, dim_in, num_classes, dropout_rate=0.0, act_func="softmax", cfg=None, ): """ Perform linear projection and activation as head for tranformers. Args: dim_in (int): the channel dimension of the input to the head. num_classes (int): the channel dimensions of the output to the head. dropout_rate (float): dropout rate. If equal to 0.0, perform no dropout. act_func (string): activation function to use. 'softmax': applies softmax on the output. 'sigmoid': applies sigmoid on the output. """ super(TransformerBasicHead, self).__init__() if dropout_rate > 0.0: self.dropout = nn.Dropout(dropout_rate) self.projection = nn.Linear(dim_in, num_classes, bias=True) if cfg.CONTRASTIVE.NUM_MLP_LAYERS == 1: self.projection = nn.Linear(dim_in, num_classes, bias=True) else: self.projection = MLPHead( dim_in, num_classes, cfg.CONTRASTIVE.MLP_DIM, cfg.CONTRASTIVE.NUM_MLP_LAYERS, bn_on=cfg.CONTRASTIVE.BN_MLP, bn_sync_num=cfg.BN.NUM_SYNC_DEVICES if cfg.CONTRASTIVE.BN_SYNC_MLP else 1, global_sync=( cfg.CONTRASTIVE.BN_SYNC_MLP and cfg.BN.GLOBAL_SYNC ), ) self.detach_final_fc = cfg.MODEL.DETACH_FINAL_FC # Softmax for evaluation and testing. if act_func == "softmax": self.act = nn.Softmax(dim=1) elif act_func == "sigmoid": self.act = nn.Sigmoid() elif act_func == "none": self.act = None else: raise NotImplementedError( "{} is not supported as an activation" "function.".format(act_func) ) def forward(self, x): if hasattr(self, "dropout"): x = self.dropout(x) if self.detach_final_fc: x = x.detach() x = self.projection(x) if not self.training: if self.act is not None: x = self.act(x) # Performs fully convolutional inference. if x.ndim == 5 and x.shape[1:4] > torch.Size([1, 1, 1]): x = x.mean([1, 2, 3]) x = x.view(x.shape[0], -1) return x