|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule, normal_init |
|
from mmseg.ops import resize |
|
|
|
|
|
class BaseDecodeHead(nn.Module): |
|
"""Base class for BaseDecodeHead. |
|
|
|
Args: |
|
in_channels (int|Sequence[int]): Input channels. |
|
channels (int): Channels after modules, before conv_seg. |
|
num_classes (int): Number of classes. |
|
dropout_ratio (float): Ratio of dropout layer. Default: 0.1. |
|
conv_cfg (dict|None): Config of conv layers. Default: None. |
|
norm_cfg (dict|None): Config of norm layers. Default: None. |
|
act_cfg (dict): Config of activation layers. |
|
Default: dict(type='ReLU') |
|
in_index (int|Sequence[int]): Input feature index. Default: -1 |
|
input_transform (str|None): Transformation type of input features. |
|
Options: 'resize_concat', 'multiple_select', None. |
|
'resize_concat': Multiple feature maps will be resize to the |
|
same size as first one and than concat together. |
|
Usually used in FCN head of HRNet. |
|
'multiple_select': Multiple feature maps will be bundle into |
|
a list and passed into decode head. |
|
None: Only one select feature map is allowed. |
|
Default: None. |
|
loss_decode (dict): Config of decode loss. |
|
Default: dict(type='CrossEntropyLoss'). |
|
ignore_index (int | None): The label index to be ignored. When using |
|
masked BCE loss, ignore_index should be set to None. Default: 255 |
|
sampler (dict|None): The config of segmentation map sampler. |
|
Default: None. |
|
align_corners (bool): align_corners argument of F.interpolate. |
|
Default: False. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
channels, |
|
*, |
|
num_classes, |
|
dropout_ratio=0.1, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
in_index=-1, |
|
input_transform=None, |
|
ignore_index=255, |
|
align_corners=False): |
|
super(BaseDecodeHead, self).__init__() |
|
self._init_inputs(in_channels, in_index, input_transform) |
|
self.channels = channels |
|
self.num_classes = num_classes |
|
self.dropout_ratio = dropout_ratio |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.act_cfg = act_cfg |
|
self.in_index = in_index |
|
|
|
self.ignore_index = ignore_index |
|
self.align_corners = align_corners |
|
|
|
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) |
|
if dropout_ratio > 0: |
|
self.dropout = nn.Dropout2d(dropout_ratio) |
|
else: |
|
self.dropout = None |
|
|
|
def extra_repr(self): |
|
"""Extra repr.""" |
|
s = f'input_transform={self.input_transform}, ' \ |
|
f'ignore_index={self.ignore_index}, ' \ |
|
f'align_corners={self.align_corners}' |
|
return s |
|
|
|
def _init_inputs(self, in_channels, in_index, input_transform): |
|
"""Check and initialize input transforms. |
|
|
|
The in_channels, in_index and input_transform must match. |
|
Specifically, when input_transform is None, only single feature map |
|
will be selected. So in_channels and in_index must be of type int. |
|
When input_transform |
|
|
|
Args: |
|
in_channels (int|Sequence[int]): Input channels. |
|
in_index (int|Sequence[int]): Input feature index. |
|
input_transform (str|None): Transformation type of input features. |
|
Options: 'resize_concat', 'multiple_select', None. |
|
'resize_concat': Multiple feature maps will be resize to the |
|
same size as first one and than concat together. |
|
Usually used in FCN head of HRNet. |
|
'multiple_select': Multiple feature maps will be bundle into |
|
a list and passed into decode head. |
|
None: Only one select feature map is allowed. |
|
""" |
|
|
|
if input_transform is not None: |
|
assert input_transform in ['resize_concat', 'multiple_select'] |
|
self.input_transform = input_transform |
|
self.in_index = in_index |
|
if input_transform is not None: |
|
assert isinstance(in_channels, (list, tuple)) |
|
assert isinstance(in_index, (list, tuple)) |
|
assert len(in_channels) == len(in_index) |
|
if input_transform == 'resize_concat': |
|
self.in_channels = sum(in_channels) |
|
else: |
|
self.in_channels = in_channels |
|
else: |
|
assert isinstance(in_channels, int) |
|
assert isinstance(in_index, int) |
|
self.in_channels = in_channels |
|
|
|
def init_weights(self): |
|
"""Initialize weights of classification layer.""" |
|
normal_init(self.conv_seg, mean=0, std=0.01) |
|
|
|
def _transform_inputs(self, inputs): |
|
"""Transform inputs for decoder. |
|
|
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
|
|
Returns: |
|
Tensor: The transformed inputs |
|
""" |
|
|
|
if self.input_transform == 'resize_concat': |
|
inputs = [inputs[i] for i in self.in_index] |
|
upsampled_inputs = [ |
|
resize( |
|
input=x, |
|
size=inputs[0].shape[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) for x in inputs |
|
] |
|
inputs = torch.cat(upsampled_inputs, dim=1) |
|
elif self.input_transform == 'multiple_select': |
|
inputs = [inputs[i] for i in self.in_index] |
|
else: |
|
inputs = inputs[self.in_index] |
|
|
|
return inputs |
|
|
|
def forward(self, inputs): |
|
"""Placeholder of forward function.""" |
|
pass |
|
|
|
def cls_seg(self, feat): |
|
"""Classify each pixel.""" |
|
if self.dropout is not None: |
|
feat = self.dropout(feat) |
|
output = self.conv_seg(feat) |
|
return output |
|
|
|
|
|
class FCNHead(BaseDecodeHead): |
|
"""Fully Convolution Networks for Semantic Segmentation. |
|
|
|
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_. |
|
|
|
Args: |
|
num_convs (int): Number of convs in the head. Default: 2. |
|
kernel_size (int): The kernel size for convs in the head. Default: 3. |
|
concat_input (bool): Whether concat the input and output of convs |
|
before classification layer. |
|
""" |
|
|
|
def __init__(self, |
|
num_convs=2, |
|
kernel_size=3, |
|
concat_input=True, |
|
**kwargs): |
|
assert num_convs >= 0 |
|
self.num_convs = num_convs |
|
self.concat_input = concat_input |
|
self.kernel_size = kernel_size |
|
super(FCNHead, self).__init__(**kwargs) |
|
if num_convs == 0: |
|
assert self.in_channels == self.channels |
|
|
|
convs = [] |
|
convs.append( |
|
ConvModule( |
|
self.in_channels, |
|
self.channels, |
|
kernel_size=kernel_size, |
|
padding=kernel_size // 2, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
for i in range(num_convs - 1): |
|
convs.append( |
|
ConvModule( |
|
self.channels, |
|
self.channels, |
|
kernel_size=kernel_size, |
|
padding=kernel_size // 2, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
if num_convs == 0: |
|
self.convs = nn.Identity() |
|
else: |
|
self.convs = nn.Sequential(*convs) |
|
if self.concat_input: |
|
self.conv_cat = ConvModule( |
|
self.in_channels + self.channels, |
|
self.channels, |
|
kernel_size=kernel_size, |
|
padding=kernel_size // 2, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
x = self._transform_inputs(inputs) |
|
output = self.convs(x) |
|
if self.concat_input: |
|
output = self.conv_cat(torch.cat([x, output], dim=1)) |
|
output = self.cls_seg(output) |
|
return output |
|
|
|
|
|
class MultiHeadFCNHead(nn.Module): |
|
"""Fully Convolution Networks for Semantic Segmentation. |
|
|
|
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_. |
|
|
|
Args: |
|
num_convs (int): Number of convs in the head. Default: 2. |
|
kernel_size (int): The kernel size for convs in the head. Default: 3. |
|
concat_input (bool): Whether concat the input and output of convs |
|
before classification layer. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
channels, |
|
*, |
|
num_classes, |
|
dropout_ratio=0.1, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
in_index=-1, |
|
input_transform=None, |
|
ignore_index=255, |
|
align_corners=False, |
|
num_convs=2, |
|
kernel_size=3, |
|
concat_input=True, |
|
num_head=18, |
|
**kwargs): |
|
super(MultiHeadFCNHead, self).__init__() |
|
assert num_convs >= 0 |
|
self.num_convs = num_convs |
|
self.concat_input = concat_input |
|
self.kernel_size = kernel_size |
|
self._init_inputs(in_channels, in_index, input_transform) |
|
self.channels = channels |
|
self.num_classes = num_classes |
|
self.dropout_ratio = dropout_ratio |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.act_cfg = act_cfg |
|
self.in_index = in_index |
|
self.num_head = num_head |
|
|
|
self.ignore_index = ignore_index |
|
self.align_corners = align_corners |
|
|
|
if dropout_ratio > 0: |
|
self.dropout = nn.Dropout2d(dropout_ratio) |
|
|
|
conv_seg_head_list = [] |
|
for _ in range(self.num_head): |
|
conv_seg_head_list.append( |
|
nn.Conv2d(channels, num_classes, kernel_size=1)) |
|
|
|
self.conv_seg_head_list = nn.ModuleList(conv_seg_head_list) |
|
|
|
self.init_weights() |
|
|
|
if num_convs == 0: |
|
assert self.in_channels == self.channels |
|
|
|
convs_list = [] |
|
conv_cat_list = [] |
|
|
|
for _ in range(self.num_head): |
|
convs = [] |
|
convs.append( |
|
ConvModule( |
|
self.in_channels, |
|
self.channels, |
|
kernel_size=kernel_size, |
|
padding=kernel_size // 2, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
for _ in range(num_convs - 1): |
|
convs.append( |
|
ConvModule( |
|
self.channels, |
|
self.channels, |
|
kernel_size=kernel_size, |
|
padding=kernel_size // 2, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
if num_convs == 0: |
|
convs_list.append(nn.Identity()) |
|
else: |
|
convs_list.append(nn.Sequential(*convs)) |
|
if self.concat_input: |
|
conv_cat_list.append( |
|
ConvModule( |
|
self.in_channels + self.channels, |
|
self.channels, |
|
kernel_size=kernel_size, |
|
padding=kernel_size // 2, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
|
|
self.convs_list = nn.ModuleList(convs_list) |
|
self.conv_cat_list = nn.ModuleList(conv_cat_list) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
x = self._transform_inputs(inputs) |
|
|
|
output_list = [] |
|
for head_idx in range(self.num_head): |
|
output = self.convs_list[head_idx](x) |
|
if self.concat_input: |
|
output = self.conv_cat_list[head_idx]( |
|
torch.cat([x, output], dim=1)) |
|
if self.dropout is not None: |
|
output = self.dropout(output) |
|
output = self.conv_seg_head_list[head_idx](output) |
|
output_list.append(output) |
|
|
|
return output_list |
|
|
|
def _init_inputs(self, in_channels, in_index, input_transform): |
|
"""Check and initialize input transforms. |
|
|
|
The in_channels, in_index and input_transform must match. |
|
Specifically, when input_transform is None, only single feature map |
|
will be selected. So in_channels and in_index must be of type int. |
|
When input_transform |
|
|
|
Args: |
|
in_channels (int|Sequence[int]): Input channels. |
|
in_index (int|Sequence[int]): Input feature index. |
|
input_transform (str|None): Transformation type of input features. |
|
Options: 'resize_concat', 'multiple_select', None. |
|
'resize_concat': Multiple feature maps will be resize to the |
|
same size as first one and than concat together. |
|
Usually used in FCN head of HRNet. |
|
'multiple_select': Multiple feature maps will be bundle into |
|
a list and passed into decode head. |
|
None: Only one select feature map is allowed. |
|
""" |
|
|
|
if input_transform is not None: |
|
assert input_transform in ['resize_concat', 'multiple_select'] |
|
self.input_transform = input_transform |
|
self.in_index = in_index |
|
if input_transform is not None: |
|
assert isinstance(in_channels, (list, tuple)) |
|
assert isinstance(in_index, (list, tuple)) |
|
assert len(in_channels) == len(in_index) |
|
if input_transform == 'resize_concat': |
|
self.in_channels = sum(in_channels) |
|
else: |
|
self.in_channels = in_channels |
|
else: |
|
assert isinstance(in_channels, int) |
|
assert isinstance(in_index, int) |
|
self.in_channels = in_channels |
|
|
|
def init_weights(self): |
|
"""Initialize weights of classification layer.""" |
|
for conv_seg_head in self.conv_seg_head_list: |
|
normal_init(conv_seg_head, mean=0, std=0.01) |
|
|
|
def _transform_inputs(self, inputs): |
|
"""Transform inputs for decoder. |
|
|
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
|
|
Returns: |
|
Tensor: The transformed inputs |
|
""" |
|
|
|
if self.input_transform == 'resize_concat': |
|
inputs = [inputs[i] for i in self.in_index] |
|
upsampled_inputs = [ |
|
resize( |
|
input=x, |
|
size=inputs[0].shape[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) for x in inputs |
|
] |
|
inputs = torch.cat(upsampled_inputs, dim=1) |
|
elif self.input_transform == 'multiple_select': |
|
inputs = [inputs[i] for i in self.in_index] |
|
else: |
|
inputs = inputs[self.in_index] |
|
|
|
return inputs |
|
|