Spaces:
Runtime error
Runtime error
File size: 4,394 Bytes
231edce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""ResNe(X)t Head helper."""
import torch
import torch.nn as nn
from .batchnorm_helper import (
NaiveSyncBatchNorm1d as NaiveSyncBatchNorm1d,
)
class MLPHead(nn.Module):
def __init__(
self,
dim_in,
dim_out,
mlp_dim,
num_layers,
bn_on=False,
bias=True,
flatten=False,
xavier_init=True,
bn_sync_num=1,
global_sync=False,
):
super(MLPHead, self).__init__()
self.flatten = flatten
b = False if bn_on else bias
# assert bn_on or bn_sync_num=1
mlp_layers = [nn.Linear(dim_in, mlp_dim, bias=b)]
mlp_layers[-1].xavier_init = xavier_init
for i in range(1, num_layers):
if bn_on:
if global_sync or bn_sync_num > 1:
mlp_layers.append(
NaiveSyncBatchNorm1d(
num_sync_devices=bn_sync_num,
global_sync=global_sync,
num_features=mlp_dim,
)
)
else:
mlp_layers.append(nn.BatchNorm1d(num_features=mlp_dim))
mlp_layers.append(nn.ReLU(inplace=True))
if i == num_layers - 1:
d = dim_out
b = bias
else:
d = mlp_dim
mlp_layers.append(nn.Linear(mlp_dim, d, bias=b))
mlp_layers[-1].xavier_init = xavier_init
self.projection = nn.Sequential(*mlp_layers)
def forward(self, x):
if x.ndim == 5:
x = x.permute((0, 2, 3, 4, 1))
if self.flatten:
x = x.reshape(-1, x.shape[-1])
return self.projection(x)
class TransformerBasicHead(nn.Module):
"""
BasicHead. No pool.
"""
def __init__(
self,
dim_in,
num_classes,
dropout_rate=0.0,
act_func="softmax",
cfg=None,
):
"""
Perform linear projection and activation as head for tranformers.
Args:
dim_in (int): the channel dimension of the input to the head.
num_classes (int): the channel dimensions of the output to the head.
dropout_rate (float): dropout rate. If equal to 0.0, perform no
dropout.
act_func (string): activation function to use. 'softmax': applies
softmax on the output. 'sigmoid': applies sigmoid on the output.
"""
super(TransformerBasicHead, self).__init__()
if dropout_rate > 0.0:
self.dropout = nn.Dropout(dropout_rate)
self.projection = nn.Linear(dim_in, num_classes, bias=True)
if cfg.CONTRASTIVE.NUM_MLP_LAYERS == 1:
self.projection = nn.Linear(dim_in, num_classes, bias=True)
else:
self.projection = MLPHead(
dim_in,
num_classes,
cfg.CONTRASTIVE.MLP_DIM,
cfg.CONTRASTIVE.NUM_MLP_LAYERS,
bn_on=cfg.CONTRASTIVE.BN_MLP,
bn_sync_num=cfg.BN.NUM_SYNC_DEVICES
if cfg.CONTRASTIVE.BN_SYNC_MLP
else 1,
global_sync=(
cfg.CONTRASTIVE.BN_SYNC_MLP and cfg.BN.GLOBAL_SYNC
),
)
self.detach_final_fc = cfg.MODEL.DETACH_FINAL_FC
# Softmax for evaluation and testing.
if act_func == "softmax":
self.act = nn.Softmax(dim=1)
elif act_func == "sigmoid":
self.act = nn.Sigmoid()
elif act_func == "none":
self.act = None
else:
raise NotImplementedError(
"{} is not supported as an activation"
"function.".format(act_func)
)
def forward(self, x):
if hasattr(self, "dropout"):
x = self.dropout(x)
if self.detach_final_fc:
x = x.detach()
x = self.projection(x)
if not self.training:
if self.act is not None:
x = self.act(x)
# Performs fully convolutional inference.
if x.ndim == 5 and x.shape[1:4] > torch.Size([1, 1, 1]):
x = x.mean([1, 2, 3])
x = x.view(x.shape[0], -1)
return x
|