|
import math |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule |
|
|
|
from ..builder import HEADS |
|
from .decode_head import BaseDecodeHead |
|
|
|
|
|
def reduce_mean(tensor): |
|
"""Reduce mean when distributed training.""" |
|
if not (dist.is_available() and dist.is_initialized()): |
|
return tensor |
|
tensor = tensor.clone() |
|
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) |
|
return tensor |
|
|
|
|
|
class EMAModule(nn.Module): |
|
"""Expectation Maximization Attention Module used in EMANet. |
|
|
|
Args: |
|
channels (int): Channels of the whole module. |
|
num_bases (int): Number of bases. |
|
num_stages (int): Number of the EM iterations. |
|
""" |
|
|
|
def __init__(self, channels, num_bases, num_stages, momentum): |
|
super(EMAModule, self).__init__() |
|
assert num_stages >= 1, 'num_stages must be at least 1!' |
|
self.num_bases = num_bases |
|
self.num_stages = num_stages |
|
self.momentum = momentum |
|
|
|
bases = torch.zeros(1, channels, self.num_bases) |
|
bases.normal_(0, math.sqrt(2. / self.num_bases)) |
|
|
|
bases = F.normalize(bases, dim=1, p=2) |
|
self.register_buffer('bases', bases) |
|
|
|
def forward(self, feats): |
|
"""Forward function.""" |
|
batch_size, channels, height, width = feats.size() |
|
|
|
feats = feats.view(batch_size, channels, height * width) |
|
|
|
bases = self.bases.repeat(batch_size, 1, 1) |
|
|
|
with torch.no_grad(): |
|
for i in range(self.num_stages): |
|
|
|
attention = torch.einsum('bcn,bck->bnk', feats, bases) |
|
attention = F.softmax(attention, dim=2) |
|
|
|
attention_normed = F.normalize(attention, dim=1, p=1) |
|
|
|
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) |
|
|
|
bases = F.normalize(bases, dim=1, p=2) |
|
|
|
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) |
|
feats_recon = feats_recon.view(batch_size, channels, height, width) |
|
|
|
if self.training: |
|
bases = bases.mean(dim=0, keepdim=True) |
|
bases = reduce_mean(bases) |
|
|
|
bases = F.normalize(bases, dim=1, p=2) |
|
self.bases = (1 - |
|
self.momentum) * self.bases + self.momentum * bases |
|
|
|
return feats_recon |
|
|
|
|
|
@HEADS.register_module() |
|
class EMAHead(BaseDecodeHead): |
|
"""Expectation Maximization Attention Networks for Semantic Segmentation. |
|
|
|
This head is the implementation of `EMANet |
|
<https://arxiv.org/abs/1907.13426>`_. |
|
|
|
Args: |
|
ema_channels (int): EMA module channels |
|
num_bases (int): Number of bases. |
|
num_stages (int): Number of the EM iterations. |
|
concat_input (bool): Whether concat the input and output of convs |
|
before classification layer. Default: True |
|
momentum (float): Momentum to update the base. Default: 0.1. |
|
""" |
|
|
|
def __init__(self, |
|
ema_channels, |
|
num_bases, |
|
num_stages, |
|
concat_input=True, |
|
momentum=0.1, |
|
**kwargs): |
|
super(EMAHead, self).__init__(**kwargs) |
|
self.ema_channels = ema_channels |
|
self.num_bases = num_bases |
|
self.num_stages = num_stages |
|
self.concat_input = concat_input |
|
self.momentum = momentum |
|
self.ema_module = EMAModule(self.ema_channels, self.num_bases, |
|
self.num_stages, self.momentum) |
|
|
|
self.ema_in_conv = ConvModule( |
|
self.in_channels, |
|
self.ema_channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
|
|
self.ema_mid_conv = ConvModule( |
|
self.ema_channels, |
|
self.ema_channels, |
|
1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=None, |
|
act_cfg=None) |
|
for param in self.ema_mid_conv.parameters(): |
|
param.requires_grad = False |
|
|
|
self.ema_out_conv = ConvModule( |
|
self.ema_channels, |
|
self.ema_channels, |
|
1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=None) |
|
self.bottleneck = ConvModule( |
|
self.ema_channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
if self.concat_input: |
|
self.conv_cat = ConvModule( |
|
self.in_channels + self.channels, |
|
self.channels, |
|
kernel_size=3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
x = self._transform_inputs(inputs) |
|
feats = self.ema_in_conv(x) |
|
identity = feats |
|
feats = self.ema_mid_conv(feats) |
|
recon = self.ema_module(feats) |
|
recon = F.relu(recon, inplace=True) |
|
recon = self.ema_out_conv(recon) |
|
output = F.relu(identity + recon, inplace=True) |
|
output = self.bottleneck(output) |
|
if self.concat_input: |
|
output = self.conv_cat(torch.cat([x, output], dim=1)) |
|
output = self.cls_seg(output) |
|
return output |
|
|