|
import torch |
|
from mmcv.cnn import NonLocal2d |
|
|
|
from ..builder import HEADS |
|
from .fcn_head import FCNHead |
|
|
|
|
|
@HEADS.register_module() |
|
class NLHead(FCNHead): |
|
"""Non-local Neural Networks. |
|
|
|
This head is the implementation of `NLNet |
|
<https://arxiv.org/abs/1711.07971>`_. |
|
|
|
Args: |
|
reduction (int): Reduction factor of projection transform. Default: 2. |
|
use_scale (bool): Whether to scale pairwise_weight by |
|
sqrt(1/inter_channels). Default: True. |
|
mode (str): The nonlocal mode. Options are 'embedded_gaussian', |
|
'dot_product'. Default: 'embedded_gaussian.'. |
|
""" |
|
|
|
def __init__(self, |
|
reduction=2, |
|
use_scale=True, |
|
mode='embedded_gaussian', |
|
**kwargs): |
|
super(NLHead, self).__init__(num_convs=2, **kwargs) |
|
self.reduction = reduction |
|
self.use_scale = use_scale |
|
self.mode = mode |
|
self.nl_block = NonLocal2d( |
|
in_channels=self.channels, |
|
reduction=self.reduction, |
|
use_scale=self.use_scale, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
mode=self.mode) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
x = self._transform_inputs(inputs) |
|
output = self.convs[0](x) |
|
output = self.nl_block(output) |
|
output = self.convs[1](output) |
|
if self.concat_input: |
|
output = self.conv_cat(torch.cat([x, output], dim=1)) |
|
output = self.cls_seg(output) |
|
return output |
|
|