|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule |
|
|
|
from mmseg.ops import resize |
|
from ..builder import HEADS |
|
from .decode_head import BaseDecodeHead |
|
|
|
|
|
class ASPPModule(nn.ModuleList): |
|
"""Atrous Spatial Pyramid Pooling (ASPP) Module. |
|
|
|
Args: |
|
dilations (tuple[int]): Dilation rate of each layer. |
|
in_channels (int): Input channels. |
|
channels (int): Channels after modules, before conv_seg. |
|
conv_cfg (dict|None): Config of conv layers. |
|
norm_cfg (dict|None): Config of norm layers. |
|
act_cfg (dict): Config of activation layers. |
|
""" |
|
|
|
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, |
|
act_cfg): |
|
super(ASPPModule, self).__init__() |
|
self.dilations = dilations |
|
self.in_channels = in_channels |
|
self.channels = channels |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.act_cfg = act_cfg |
|
for dilation in dilations: |
|
self.append( |
|
ConvModule( |
|
self.in_channels, |
|
self.channels, |
|
1 if dilation == 1 else 3, |
|
dilation=dilation, |
|
padding=0 if dilation == 1 else dilation, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
aspp_outs = [] |
|
for aspp_module in self: |
|
aspp_outs.append(aspp_module(x)) |
|
|
|
return aspp_outs |
|
|
|
|
|
@HEADS.register_module() |
|
class ASPPHead(BaseDecodeHead): |
|
"""Rethinking Atrous Convolution for Semantic Image Segmentation. |
|
|
|
This head is the implementation of `DeepLabV3 |
|
<https://arxiv.org/abs/1706.05587>`_. |
|
|
|
Args: |
|
dilations (tuple[int]): Dilation rates for ASPP module. |
|
Default: (1, 6, 12, 18). |
|
""" |
|
|
|
def __init__(self, dilations=(1, 6, 12, 18), **kwargs): |
|
super(ASPPHead, self).__init__(**kwargs) |
|
assert isinstance(dilations, (list, tuple)) |
|
self.dilations = dilations |
|
self.image_pool = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
ConvModule( |
|
self.in_channels, |
|
self.channels, |
|
1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
self.aspp_modules = ASPPModule( |
|
dilations, |
|
self.in_channels, |
|
self.channels, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
self.bottleneck = ConvModule( |
|
(len(dilations) + 1) * self.channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
x = self._transform_inputs(inputs) |
|
aspp_outs = [ |
|
resize( |
|
self.image_pool(x), |
|
size=x.size()[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
] |
|
aspp_outs.extend(self.aspp_modules(x)) |
|
aspp_outs = torch.cat(aspp_outs, dim=1) |
|
output = self.bottleneck(aspp_outs) |
|
output = self.cls_seg(output) |
|
return output |
|
|