|
|
|
|
|
from copy import deepcopy |
|
import fvcore.nn.weight_init as weight_init |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from .batch_norm import get_norm |
|
from .blocks import DepthwiseSeparableConv2d |
|
from .wrappers import Conv2d |
|
|
|
|
|
class ASPP(nn.Module): |
|
""" |
|
Atrous Spatial Pyramid Pooling (ASPP). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
dilations, |
|
*, |
|
norm, |
|
activation, |
|
pool_kernel_size=None, |
|
dropout: float = 0.0, |
|
use_depthwise_separable_conv=False, |
|
): |
|
""" |
|
Args: |
|
in_channels (int): number of input channels for ASPP. |
|
out_channels (int): number of output channels. |
|
dilations (list): a list of 3 dilations in ASPP. |
|
norm (str or callable): normalization for all conv layers. |
|
See :func:`layers.get_norm` for supported format. norm is |
|
applied to all conv layers except the conv following |
|
global average pooling. |
|
activation (callable): activation function. |
|
pool_kernel_size (tuple, list): the average pooling size (kh, kw) |
|
for image pooling layer in ASPP. If set to None, it always |
|
performs global average pooling. If not None, it must be |
|
divisible by the shape of inputs in forward(). It is recommended |
|
to use a fixed input feature size in training, and set this |
|
option to match this size, so that it performs global average |
|
pooling in training, and the size of the pooling window stays |
|
consistent in inference. |
|
dropout (float): apply dropout on the output of ASPP. It is used in |
|
the official DeepLab implementation with a rate of 0.1: |
|
https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa |
|
use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d |
|
for 3x3 convs in ASPP, proposed in :paper:`DeepLabV3+`. |
|
""" |
|
super(ASPP, self).__init__() |
|
assert len(dilations) == 3, "ASPP expects 3 dilations, got {}".format(len(dilations)) |
|
self.pool_kernel_size = pool_kernel_size |
|
self.dropout = dropout |
|
use_bias = norm == "" |
|
self.convs = nn.ModuleList() |
|
|
|
self.convs.append( |
|
Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
bias=use_bias, |
|
norm=get_norm(norm, out_channels), |
|
activation=deepcopy(activation), |
|
) |
|
) |
|
weight_init.c2_xavier_fill(self.convs[-1]) |
|
|
|
for dilation in dilations: |
|
if use_depthwise_separable_conv: |
|
self.convs.append( |
|
DepthwiseSeparableConv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
padding=dilation, |
|
dilation=dilation, |
|
norm1=norm, |
|
activation1=deepcopy(activation), |
|
norm2=norm, |
|
activation2=deepcopy(activation), |
|
) |
|
) |
|
else: |
|
self.convs.append( |
|
Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
padding=dilation, |
|
dilation=dilation, |
|
bias=use_bias, |
|
norm=get_norm(norm, out_channels), |
|
activation=deepcopy(activation), |
|
) |
|
) |
|
weight_init.c2_xavier_fill(self.convs[-1]) |
|
|
|
|
|
|
|
if pool_kernel_size is None: |
|
image_pooling = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), |
|
) |
|
else: |
|
image_pooling = nn.Sequential( |
|
nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1), |
|
Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), |
|
) |
|
weight_init.c2_xavier_fill(image_pooling[1]) |
|
self.convs.append(image_pooling) |
|
|
|
self.project = Conv2d( |
|
5 * out_channels, |
|
out_channels, |
|
kernel_size=1, |
|
bias=use_bias, |
|
norm=get_norm(norm, out_channels), |
|
activation=deepcopy(activation), |
|
) |
|
weight_init.c2_xavier_fill(self.project) |
|
|
|
def forward(self, x): |
|
size = x.shape[-2:] |
|
if self.pool_kernel_size is not None: |
|
if size[0] % self.pool_kernel_size[0] or size[1] % self.pool_kernel_size[1]: |
|
raise ValueError( |
|
"`pool_kernel_size` must be divisible by the shape of inputs. " |
|
"Input size: {} `pool_kernel_size`: {}".format(size, self.pool_kernel_size) |
|
) |
|
res = [] |
|
for conv in self.convs: |
|
res.append(conv(x)) |
|
res[-1] = F.interpolate(res[-1], size=size, mode="bilinear", align_corners=False) |
|
res = torch.cat(res, dim=1) |
|
res = self.project(res) |
|
res = F.dropout(res, self.dropout, training=self.training) if self.dropout > 0 else res |
|
return res |
|
|