|
"""Various modules used in the decoder of the model. |
|
|
|
Adapted from https://github.com/jinlinyi/PerspectiveFields |
|
""" |
|
|
|
import logging |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): |
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" |
|
if drop_prob == 0.0 or not training: |
|
return x |
|
keep_prob = 1 - drop_prob |
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
if keep_prob > 0.0 and scale_by_keep: |
|
random_tensor.div_(keep_prob) |
|
return x * random_tensor |
|
|
|
|
|
class DropPath(nn.Module): |
|
"""DropBlock, DropPath |
|
|
|
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. |
|
|
|
Papers: |
|
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) |
|
|
|
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) |
|
|
|
Code: |
|
DropBlock impl inspired by two Tensorflow impl: |
|
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 |
|
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
|
|
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): |
|
super(DropPath, self).__init__() |
|
self.drop_prob = drop_prob |
|
self.scale_by_keep = scale_by_keep |
|
|
|
def forward(self, x): |
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) |
|
|
|
def extra_repr(self): |
|
return f"drop_prob={round(self.drop_prob,3):0.3f}" |
|
|
|
|
|
class DWConv(nn.Module): |
|
def __init__(self, dim=768): |
|
super(DWConv, self).__init__() |
|
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) |
|
|
|
def forward(self, x): |
|
x = self.dwconv(x) |
|
return x |
|
|
|
|
|
class MLP(nn.Module): |
|
"""Linear Embedding.""" |
|
|
|
def __init__(self, input_dim=2048, embed_dim=768): |
|
super().__init__() |
|
self.proj = nn.Linear(input_dim, embed_dim) |
|
|
|
def forward(self, x): |
|
x = x.flatten(2).transpose(1, 2) |
|
x = self.proj(x) |
|
return x |
|
|
|
|
|
class ConvModule(nn.Module): |
|
"""Replacement for mmcv.cnn.ConvModule to avoid mmcv dependency.""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int, |
|
padding: int = 0, |
|
use_norm: bool = False, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) |
|
self.bn = nn.BatchNorm2d(out_channels) if use_norm else nn.Identity() |
|
self.activate = nn.ReLU(inplace=True) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = self.bn(x) |
|
return self.activate(x) |
|
|
|
|
|
class ResidualConvUnit(nn.Module): |
|
"""Residual convolution module.""" |
|
|
|
def __init__(self, features): |
|
"""Init. |
|
Args: |
|
features (int): number of features |
|
""" |
|
super().__init__() |
|
|
|
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) |
|
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) |
|
|
|
self.relu = torch.nn.ReLU(inplace=True) |
|
|
|
def forward(self, x): |
|
"""Forward pass. |
|
Args: |
|
x (tensor): input |
|
Returns: |
|
tensor: output |
|
""" |
|
out = self.relu(x) |
|
out = self.conv1(out) |
|
out = self.relu(out) |
|
out = self.conv2(out) |
|
return out + x |
|
|
|
|
|
class FeatureFusionBlock(nn.Module): |
|
"""Feature fusion block.""" |
|
|
|
def __init__(self, features, unit2only=False, upsample=True): |
|
"""Init. |
|
Args: |
|
features (int): number of features |
|
""" |
|
super().__init__() |
|
self.upsample = upsample |
|
|
|
if not unit2only: |
|
self.resConfUnit1 = ResidualConvUnit(features) |
|
self.resConfUnit2 = ResidualConvUnit(features) |
|
|
|
def forward(self, *xs): |
|
"""Forward pass.""" |
|
output = xs[0] |
|
|
|
if len(xs) == 2: |
|
output = output + self.resConfUnit1(xs[1]) |
|
|
|
output = self.resConfUnit2(output) |
|
|
|
if self.upsample: |
|
output = F.interpolate(output, scale_factor=2, mode="bilinear", align_corners=False) |
|
|
|
return output |
|
|
|
|
|
class _DenseLayer(nn.Module): |
|
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient): |
|
super().__init__() |
|
self.norm1 = nn.BatchNorm2d(num_input_features) |
|
self.relu1 = nn.ReLU(inplace=True) |
|
self.conv1 = nn.Conv2d( |
|
num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False |
|
) |
|
|
|
self.norm2 = nn.BatchNorm2d(bn_size * growth_rate) |
|
self.relu2 = nn.ReLU(inplace=True) |
|
self.conv2 = nn.Conv2d( |
|
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False |
|
) |
|
|
|
self.drop_rate = float(drop_rate) |
|
self.memory_efficient = memory_efficient |
|
|
|
def bn_function(self, inputs): |
|
concated_features = torch.cat(inputs, 1) |
|
return self.conv1(self.relu1(self.norm1(concated_features))) |
|
|
|
def any_requires_grad(self, inp): |
|
return any(tensor.requires_grad for tensor in inp) |
|
|
|
@torch.jit.unused |
|
def call_checkpoint_bottleneck(self, inp): |
|
def closure(*inputs): |
|
return self.bn_function(inputs) |
|
|
|
return cp.checkpoint(closure, *inp) |
|
|
|
@torch.jit._overload_method |
|
def forward(self, inp) -> Tensor: |
|
pass |
|
|
|
@torch.jit._overload_method |
|
def forward(self, inp): |
|
pass |
|
|
|
|
|
|
|
def forward(self, inp): |
|
prev_features = [inp] if isinstance(inp, Tensor) else inp |
|
if self.memory_efficient and self.any_requires_grad(prev_features): |
|
if torch.jit.is_scripting(): |
|
raise Exception("Memory Efficient not supported in JIT") |
|
|
|
bottleneck_output = self.call_checkpoint_bottleneck(prev_features) |
|
else: |
|
bottleneck_output = self.bn_function(prev_features) |
|
|
|
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) |
|
if self.drop_rate > 0: |
|
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) |
|
return new_features |
|
|
|
|
|
class _DenseBlock(nn.ModuleDict): |
|
_version = 2 |
|
|
|
def __init__( |
|
self, |
|
num_layers, |
|
num_input_features, |
|
bn_size, |
|
growth_rate, |
|
drop_rate, |
|
memory_efficient=False, |
|
): |
|
super().__init__() |
|
for i in range(num_layers): |
|
layer = _DenseLayer( |
|
num_input_features + i * growth_rate, |
|
growth_rate=growth_rate, |
|
bn_size=bn_size, |
|
drop_rate=drop_rate, |
|
memory_efficient=memory_efficient, |
|
) |
|
self.add_module("denselayer%d" % (i + 1), layer) |
|
|
|
def forward(self, init_features): |
|
features = [init_features] |
|
for name, layer in self.items(): |
|
new_features = layer(features) |
|
features.append(new_features) |
|
return torch.cat(features, 1) |
|
|
|
|
|
class _Transition(nn.Sequential): |
|
def __init__(self, num_input_features, num_output_features): |
|
super().__init__() |
|
self.norm = nn.BatchNorm2d(num_input_features) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.conv = nn.Conv2d( |
|
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False |
|
) |
|
self.pool = nn.AvgPool2d(kernel_size=2, stride=2) |
|
|