veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
8.13 kB
"""Various modules used in the decoder of the model.
Adapted from https://github.com/jinlinyi/PerspectiveFields
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""DropBlock, DropPath
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
Papers:
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
Code:
DropBlock impl inspired by two Tensorflow impl:
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
Hacked together by / Copyright 2020 Ross Wightman
"""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x):
x = self.dwconv(x)
return x
class MLP(nn.Module):
"""Linear Embedding."""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class ConvModule(nn.Module):
"""Replacement for mmcv.cnn.ConvModule to avoid mmcv dependency."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: int = 0,
use_norm: bool = False,
bias: bool = True,
):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
self.bn = nn.BatchNorm2d(out_channels) if use_norm else nn.Identity()
self.activate = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.activate(x)
class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(self, features, unit2only=False, upsample=True):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.upsample = upsample
if not unit2only:
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass."""
output = xs[0]
if len(xs) == 2:
output = output + self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
if self.upsample:
output = F.interpolate(output, scale_factor=2, mode="bilinear", align_corners=False)
return output
class _DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient):
super().__init__()
self.norm1 = nn.BatchNorm2d(num_input_features)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(
num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False
)
self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False
)
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
def bn_function(self, inputs):
concated_features = torch.cat(inputs, 1)
return self.conv1(self.relu1(self.norm1(concated_features)))
def any_requires_grad(self, inp):
return any(tensor.requires_grad for tensor in inp)
@torch.jit.unused # noqa: T484
def call_checkpoint_bottleneck(self, inp):
def closure(*inputs):
return self.bn_function(inputs)
return cp.checkpoint(closure, *inp)
@torch.jit._overload_method # noqa: F811
def forward(self, inp) -> Tensor: # noqa: F811
pass
@torch.jit._overload_method # noqa: F811
def forward(self, inp): # noqa: F811
pass
# torchscript does not yet support *args, so we overload method
# allowing it to take either a List[Tensor] or single Tensor
def forward(self, inp): # noqa: F811
prev_features = [inp] if isinstance(inp, Tensor) else inp
if self.memory_efficient and self.any_requires_grad(prev_features):
if torch.jit.is_scripting():
raise Exception("Memory Efficient not supported in JIT")
bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
else:
bottleneck_output = self.bn_function(prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return new_features
class _DenseBlock(nn.ModuleDict):
_version = 2
def __init__(
self,
num_layers,
num_input_features,
bn_size,
growth_rate,
drop_rate,
memory_efficient=False,
):
super().__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.add_module("denselayer%d" % (i + 1), layer)
def forward(self, init_features):
features = [init_features]
for name, layer in self.items():
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super().__init__()
self.norm = nn.BatchNorm2d(num_input_features)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False
)
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)