Spaces:
Paused
Paused
from collections import OrderedDict | |
import torch.nn as nn | |
from .bn import ABN, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE | |
import torch.nn.functional as functional | |
class ResidualBlock(nn.Module): | |
"""Configurable residual block | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
channels : list of int | |
Number of channels in the internal feature maps. Can either have two or three elements: if three construct | |
a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then | |
`3 x 3` then `1 x 1` convolutions. | |
stride : int | |
Stride of the first `3 x 3` convolution | |
dilation : int | |
Dilation to apply to the `3 x 3` convolutions. | |
groups : int | |
Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with | |
bottleneck blocks. | |
norm_act : callable | |
Function to create normalization / activation Module. | |
dropout: callable | |
Function to create Dropout Module. | |
""" | |
def __init__(self, | |
in_channels, | |
channels, | |
stride=1, | |
dilation=1, | |
groups=1, | |
norm_act=ABN, | |
dropout=None): | |
super(ResidualBlock, self).__init__() | |
# Check parameters for inconsistencies | |
if len(channels) != 2 and len(channels) != 3: | |
raise ValueError("channels must contain either two or three values") | |
if len(channels) == 2 and groups != 1: | |
raise ValueError("groups > 1 are only valid if len(channels) == 3") | |
is_bottleneck = len(channels) == 3 | |
need_proj_conv = stride != 1 or in_channels != channels[-1] | |
if not is_bottleneck: | |
bn2 = norm_act(channels[1]) | |
bn2.activation = ACT_NONE | |
layers = [ | |
("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, | |
dilation=dilation)), | |
("bn1", norm_act(channels[0])), | |
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, | |
dilation=dilation)), | |
("bn2", bn2) | |
] | |
if dropout is not None: | |
layers = layers[0:2] + [("dropout", dropout())] + layers[2:] | |
else: | |
bn3 = norm_act(channels[2]) | |
bn3.activation = ACT_NONE | |
layers = [ | |
("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=1, padding=0, bias=False)), | |
("bn1", norm_act(channels[0])), | |
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=stride, padding=dilation, bias=False, | |
groups=groups, dilation=dilation)), | |
("bn2", norm_act(channels[1])), | |
("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)), | |
("bn3", bn3) | |
] | |
if dropout is not None: | |
layers = layers[0:4] + [("dropout", dropout())] + layers[4:] | |
self.convs = nn.Sequential(OrderedDict(layers)) | |
if need_proj_conv: | |
self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) | |
self.proj_bn = norm_act(channels[-1]) | |
self.proj_bn.activation = ACT_NONE | |
def forward(self, x): | |
if hasattr(self, "proj_conv"): | |
residual = self.proj_conv(x) | |
residual = self.proj_bn(residual) | |
else: | |
residual = x | |
x = self.convs(x) + residual | |
if self.convs.bn1.activation == ACT_LEAKY_RELU: | |
return functional.leaky_relu(x, negative_slope=self.convs.bn1.slope, inplace=True) | |
elif self.convs.bn1.activation == ACT_ELU: | |
return functional.elu(x, inplace=True) | |
else: | |
return x | |
class IdentityResidualBlock(nn.Module): | |
def __init__(self, | |
in_channels, | |
channels, | |
stride=1, | |
dilation=1, | |
groups=1, | |
norm_act=ABN, | |
dropout=None): | |
"""Configurable identity-mapping residual block | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
channels : list of int | |
Number of channels in the internal feature maps. Can either have two or three elements: if three construct | |
a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then | |
`3 x 3` then `1 x 1` convolutions. | |
stride : int | |
Stride of the first `3 x 3` convolution | |
dilation : int | |
Dilation to apply to the `3 x 3` convolutions. | |
groups : int | |
Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with | |
bottleneck blocks. | |
norm_act : callable | |
Function to create normalization / activation Module. | |
dropout: callable | |
Function to create Dropout Module. | |
""" | |
super(IdentityResidualBlock, self).__init__() | |
# Check parameters for inconsistencies | |
if len(channels) != 2 and len(channels) != 3: | |
raise ValueError("channels must contain either two or three values") | |
if len(channels) == 2 and groups != 1: | |
raise ValueError("groups > 1 are only valid if len(channels) == 3") | |
is_bottleneck = len(channels) == 3 | |
need_proj_conv = stride != 1 or in_channels != channels[-1] | |
self.bn1 = norm_act(in_channels) | |
if not is_bottleneck: | |
layers = [ | |
("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, | |
dilation=dilation)), | |
("bn2", norm_act(channels[0])), | |
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, | |
dilation=dilation)) | |
] | |
if dropout is not None: | |
layers = layers[0:2] + [("dropout", dropout())] + layers[2:] | |
else: | |
layers = [ | |
("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)), | |
("bn2", norm_act(channels[0])), | |
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, | |
groups=groups, dilation=dilation)), | |
("bn3", norm_act(channels[1])), | |
("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)) | |
] | |
if dropout is not None: | |
layers = layers[0:4] + [("dropout", dropout())] + layers[4:] | |
self.convs = nn.Sequential(OrderedDict(layers)) | |
if need_proj_conv: | |
self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) | |
def forward(self, x): | |
if hasattr(self, "proj_conv"): | |
bn1 = self.bn1(x) | |
shortcut = self.proj_conv(bn1) | |
else: | |
shortcut = x.clone() | |
bn1 = self.bn1(x) | |
out = self.convs(bn1) | |
out.add_(shortcut) | |
return out | |