|
import torch |
|
import torch.nn as nn |
|
from mmcv import is_tuple_of |
|
from mmcv.cnn import ConvModule |
|
|
|
from mmseg.ops import resize |
|
from ..builder import HEADS |
|
from .decode_head import BaseDecodeHead |
|
|
|
|
|
@HEADS.register_module() |
|
class LRASPPHead(BaseDecodeHead): |
|
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. |
|
|
|
This head is the improved implementation of `Searching for MobileNetV3 |
|
<https://ieeexplore.ieee.org/document/9008835>`_. |
|
|
|
Args: |
|
branch_channels (tuple[int]): The number of output channels in every |
|
each branch. Default: (32, 64). |
|
""" |
|
|
|
def __init__(self, branch_channels=(32, 64), **kwargs): |
|
super(LRASPPHead, self).__init__(**kwargs) |
|
if self.input_transform != 'multiple_select': |
|
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' |
|
f'must be \'multiple_select\'. But received ' |
|
f'\'{self.input_transform}\'') |
|
assert is_tuple_of(branch_channels, int) |
|
assert len(branch_channels) == len(self.in_channels) - 1 |
|
self.branch_channels = branch_channels |
|
|
|
self.convs = nn.Sequential() |
|
self.conv_ups = nn.Sequential() |
|
for i in range(len(branch_channels)): |
|
self.convs.add_module( |
|
f'conv{i}', |
|
nn.Conv2d( |
|
self.in_channels[i], branch_channels[i], 1, bias=False)) |
|
self.conv_ups.add_module( |
|
f'conv_up{i}', |
|
ConvModule( |
|
self.channels + branch_channels[i], |
|
self.channels, |
|
1, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg, |
|
bias=False)) |
|
|
|
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) |
|
|
|
self.aspp_conv = ConvModule( |
|
self.in_channels[-1], |
|
self.channels, |
|
1, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg, |
|
bias=False) |
|
self.image_pool = nn.Sequential( |
|
nn.AvgPool2d(kernel_size=49, stride=(16, 20)), |
|
ConvModule( |
|
self.in_channels[2], |
|
self.channels, |
|
1, |
|
act_cfg=dict(type='Sigmoid'), |
|
bias=False)) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
inputs = self._transform_inputs(inputs) |
|
|
|
x = inputs[-1] |
|
|
|
x = self.aspp_conv(x) * resize( |
|
self.image_pool(x), |
|
size=x.size()[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
x = self.conv_up_input(x) |
|
|
|
for i in range(len(self.branch_channels) - 1, -1, -1): |
|
x = resize( |
|
x, |
|
size=inputs[i].size()[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
x = torch.cat([x, self.convs[i](inputs[i])], 1) |
|
x = self.conv_ups[i](x) |
|
|
|
return self.cls_seg(x) |
|
|