|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule |
|
|
|
from mmseg.ops import resize |
|
from ..builder import HEADS |
|
from .aspp_head import ASPPHead, ASPPModule |
|
|
|
|
|
class DepthwiseSeparableASPPModule(ASPPModule): |
|
"""Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable |
|
conv.""" |
|
|
|
def __init__(self, **kwargs): |
|
super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) |
|
for i, dilation in enumerate(self.dilations): |
|
if dilation > 1: |
|
self[i] = DepthwiseSeparableConvModule( |
|
self.in_channels, |
|
self.channels, |
|
3, |
|
dilation=dilation, |
|
padding=dilation, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
|
|
|
|
@HEADS.register_module() |
|
class DepthwiseSeparableASPPHead(ASPPHead): |
|
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image |
|
Segmentation. |
|
|
|
This head is the implementation of `DeepLabV3+ |
|
<https://arxiv.org/abs/1802.02611>`_. |
|
|
|
Args: |
|
c1_in_channels (int): The input channels of c1 decoder. If is 0, |
|
the no decoder will be used. |
|
c1_channels (int): The intermediate channels of c1 decoder. |
|
""" |
|
|
|
def __init__(self, c1_in_channels, c1_channels, **kwargs): |
|
super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) |
|
assert c1_in_channels >= 0 |
|
self.aspp_modules = DepthwiseSeparableASPPModule( |
|
dilations=self.dilations, |
|
in_channels=self.in_channels, |
|
channels=self.channels, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
if c1_in_channels > 0: |
|
self.c1_bottleneck = ConvModule( |
|
c1_in_channels, |
|
c1_channels, |
|
1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
else: |
|
self.c1_bottleneck = None |
|
self.sep_bottleneck = nn.Sequential( |
|
DepthwiseSeparableConvModule( |
|
self.channels + c1_channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg), |
|
DepthwiseSeparableConvModule( |
|
self.channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
x = self._transform_inputs(inputs) |
|
aspp_outs = [ |
|
resize( |
|
self.image_pool(x), |
|
size=x.size()[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
] |
|
aspp_outs.extend(self.aspp_modules(x)) |
|
aspp_outs = torch.cat(aspp_outs, dim=1) |
|
output = self.bottleneck(aspp_outs) |
|
if self.c1_bottleneck is not None: |
|
c1_output = self.c1_bottleneck(inputs[0]) |
|
output = resize( |
|
input=output, |
|
size=c1_output.shape[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
output = torch.cat([output, c1_output], dim=1) |
|
output = self.sep_bottleneck(output) |
|
output = self.cls_seg(output) |
|
return output |
|
|