Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""ResNe(X)t Head helper.""" | |
import torch | |
import torch.nn as nn | |
from .batchnorm_helper import ( | |
NaiveSyncBatchNorm1d as NaiveSyncBatchNorm1d, | |
) | |
class MLPHead(nn.Module): | |
def __init__( | |
self, | |
dim_in, | |
dim_out, | |
mlp_dim, | |
num_layers, | |
bn_on=False, | |
bias=True, | |
flatten=False, | |
xavier_init=True, | |
bn_sync_num=1, | |
global_sync=False, | |
): | |
super(MLPHead, self).__init__() | |
self.flatten = flatten | |
b = False if bn_on else bias | |
# assert bn_on or bn_sync_num=1 | |
mlp_layers = [nn.Linear(dim_in, mlp_dim, bias=b)] | |
mlp_layers[-1].xavier_init = xavier_init | |
for i in range(1, num_layers): | |
if bn_on: | |
if global_sync or bn_sync_num > 1: | |
mlp_layers.append( | |
NaiveSyncBatchNorm1d( | |
num_sync_devices=bn_sync_num, | |
global_sync=global_sync, | |
num_features=mlp_dim, | |
) | |
) | |
else: | |
mlp_layers.append(nn.BatchNorm1d(num_features=mlp_dim)) | |
mlp_layers.append(nn.ReLU(inplace=True)) | |
if i == num_layers - 1: | |
d = dim_out | |
b = bias | |
else: | |
d = mlp_dim | |
mlp_layers.append(nn.Linear(mlp_dim, d, bias=b)) | |
mlp_layers[-1].xavier_init = xavier_init | |
self.projection = nn.Sequential(*mlp_layers) | |
def forward(self, x): | |
if x.ndim == 5: | |
x = x.permute((0, 2, 3, 4, 1)) | |
if self.flatten: | |
x = x.reshape(-1, x.shape[-1]) | |
return self.projection(x) | |
class TransformerBasicHead(nn.Module): | |
""" | |
BasicHead. No pool. | |
""" | |
def __init__( | |
self, | |
dim_in, | |
num_classes, | |
dropout_rate=0.0, | |
act_func="softmax", | |
cfg=None, | |
): | |
""" | |
Perform linear projection and activation as head for tranformers. | |
Args: | |
dim_in (int): the channel dimension of the input to the head. | |
num_classes (int): the channel dimensions of the output to the head. | |
dropout_rate (float): dropout rate. If equal to 0.0, perform no | |
dropout. | |
act_func (string): activation function to use. 'softmax': applies | |
softmax on the output. 'sigmoid': applies sigmoid on the output. | |
""" | |
super(TransformerBasicHead, self).__init__() | |
if dropout_rate > 0.0: | |
self.dropout = nn.Dropout(dropout_rate) | |
self.projection = nn.Linear(dim_in, num_classes, bias=True) | |
if cfg.CONTRASTIVE.NUM_MLP_LAYERS == 1: | |
self.projection = nn.Linear(dim_in, num_classes, bias=True) | |
else: | |
self.projection = MLPHead( | |
dim_in, | |
num_classes, | |
cfg.CONTRASTIVE.MLP_DIM, | |
cfg.CONTRASTIVE.NUM_MLP_LAYERS, | |
bn_on=cfg.CONTRASTIVE.BN_MLP, | |
bn_sync_num=cfg.BN.NUM_SYNC_DEVICES | |
if cfg.CONTRASTIVE.BN_SYNC_MLP | |
else 1, | |
global_sync=( | |
cfg.CONTRASTIVE.BN_SYNC_MLP and cfg.BN.GLOBAL_SYNC | |
), | |
) | |
self.detach_final_fc = cfg.MODEL.DETACH_FINAL_FC | |
# Softmax for evaluation and testing. | |
if act_func == "softmax": | |
self.act = nn.Softmax(dim=1) | |
elif act_func == "sigmoid": | |
self.act = nn.Sigmoid() | |
elif act_func == "none": | |
self.act = None | |
else: | |
raise NotImplementedError( | |
"{} is not supported as an activation" | |
"function.".format(act_func) | |
) | |
def forward(self, x): | |
if hasattr(self, "dropout"): | |
x = self.dropout(x) | |
if self.detach_final_fc: | |
x = x.detach() | |
x = self.projection(x) | |
if not self.training: | |
if self.act is not None: | |
x = self.act(x) | |
# Performs fully convolutional inference. | |
if x.ndim == 5 and x.shape[1:4] > torch.Size([1, 1, 1]): | |
x = x.mean([1, 2, 3]) | |
x = x.view(x.shape[0], -1) | |
return x | |