import torch import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.ops import resize from ..builder import HEADS from .decode_head import BaseDecodeHead class ASPPModule(nn.ModuleList): """Atrous Spatial Pyramid Pooling (ASPP) Module. Args: dilations (tuple[int]): Dilation rate of each layer. in_channels (int): Input channels. channels (int): Channels after modules, before conv_seg. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict): Config of activation layers. """ def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, act_cfg): super(ASPPModule, self).__init__() self.dilations = dilations self.in_channels = in_channels self.channels = channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg for dilation in dilations: self.append( ConvModule( self.in_channels, self.channels, 1 if dilation == 1 else 3, dilation=dilation, padding=0 if dilation == 1 else dilation, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) def forward(self, x): """Forward function.""" aspp_outs = [] for aspp_module in self: aspp_outs.append(aspp_module(x)) return aspp_outs @HEADS.register_module() class ASPPHead(BaseDecodeHead): """Rethinking Atrous Convolution for Semantic Image Segmentation. This head is the implementation of `DeepLabV3 `_. Args: dilations (tuple[int]): Dilation rates for ASPP module. Default: (1, 6, 12, 18). """ def __init__(self, dilations=(1, 6, 12, 18), **kwargs): super(ASPPHead, self).__init__(**kwargs) assert isinstance(dilations, (list, tuple)) self.dilations = dilations self.image_pool = nn.Sequential( nn.AdaptiveAvgPool2d(1), ConvModule( self.in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.aspp_modules = ASPPModule( dilations, self.in_channels, self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.bottleneck = ConvModule( (len(dilations) + 1) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) aspp_outs = [ resize( self.image_pool(x), size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) output = self.bottleneck(aspp_outs) output = self.cls_seg(output) return output