|
import torch.nn as nn |
|
import torch.utils.checkpoint as cp |
|
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, |
|
build_norm_layer, constant_init, kaiming_init) |
|
from mmcv.runner import load_checkpoint |
|
from mmcv.utils.parrots_wrapper import _BatchNorm |
|
|
|
from mmseg.utils import get_root_logger |
|
from ..builder import BACKBONES |
|
from ..utils import UpConvBlock |
|
|
|
|
|
class BasicConvBlock(nn.Module): |
|
"""Basic convolutional block for UNet. |
|
|
|
This module consists of several plain convolutional layers. |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
num_convs (int): Number of convolutional layers. Default: 2. |
|
stride (int): Whether use stride convolution to downsample |
|
the input feature map. If stride=2, it only uses stride convolution |
|
in the first convolutional layer to downsample the input feature |
|
map. Options are 1 or 2. Default: 1. |
|
dilation (int): Whether use dilated convolution to expand the |
|
receptive field. Set dilation rate of each convolutional layer and |
|
the dilation rate of the first convolutional layer is always 1. |
|
Default: 1. |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. Default: False. |
|
conv_cfg (dict | None): Config dict for convolution layer. |
|
Default: None. |
|
norm_cfg (dict | None): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
act_cfg (dict | None): Config dict for activation layer in ConvModule. |
|
Default: dict(type='ReLU'). |
|
dcn (bool): Use deformable convoluton in convolutional layer or not. |
|
Default: None. |
|
plugins (dict): plugins for convolutional layers. Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
num_convs=2, |
|
stride=1, |
|
dilation=1, |
|
with_cp=False, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
dcn=None, |
|
plugins=None): |
|
super(BasicConvBlock, self).__init__() |
|
assert dcn is None, 'Not implemented yet.' |
|
assert plugins is None, 'Not implemented yet.' |
|
|
|
self.with_cp = with_cp |
|
convs = [] |
|
for i in range(num_convs): |
|
convs.append( |
|
ConvModule( |
|
in_channels=in_channels if i == 0 else out_channels, |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=stride if i == 0 else 1, |
|
dilation=1 if i == 0 else dilation, |
|
padding=1 if i == 0 else dilation, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg)) |
|
|
|
self.convs = nn.Sequential(*convs) |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
|
|
if self.with_cp and x.requires_grad: |
|
out = cp.checkpoint(self.convs, x) |
|
else: |
|
out = self.convs(x) |
|
return out |
|
|
|
|
|
@UPSAMPLE_LAYERS.register_module() |
|
class DeconvModule(nn.Module): |
|
"""Deconvolution upsample module in decoder for UNet (2X upsample). |
|
|
|
This module uses deconvolution to upsample feature map in the decoder |
|
of UNet. |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. Default: False. |
|
norm_cfg (dict | None): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
act_cfg (dict | None): Config dict for activation layer in ConvModule. |
|
Default: dict(type='ReLU'). |
|
kernel_size (int): Kernel size of the convolutional layer. Default: 4. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
with_cp=False, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
*, |
|
kernel_size=4, |
|
scale_factor=2): |
|
super(DeconvModule, self).__init__() |
|
|
|
assert (kernel_size - scale_factor >= 0) and\ |
|
(kernel_size - scale_factor) % 2 == 0,\ |
|
f'kernel_size should be greater than or equal to scale_factor '\ |
|
f'and (kernel_size - scale_factor) should be even numbers, '\ |
|
f'while the kernel size is {kernel_size} and scale_factor is '\ |
|
f'{scale_factor}.' |
|
|
|
stride = scale_factor |
|
padding = (kernel_size - scale_factor) // 2 |
|
self.with_cp = with_cp |
|
deconv = nn.ConvTranspose2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding) |
|
|
|
norm_name, norm = build_norm_layer(norm_cfg, out_channels) |
|
activate = build_activation_layer(act_cfg) |
|
self.deconv_upsamping = nn.Sequential(deconv, norm, activate) |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
|
|
if self.with_cp and x.requires_grad: |
|
out = cp.checkpoint(self.deconv_upsamping, x) |
|
else: |
|
out = self.deconv_upsamping(x) |
|
return out |
|
|
|
|
|
@UPSAMPLE_LAYERS.register_module() |
|
class InterpConv(nn.Module): |
|
"""Interpolation upsample module in decoder for UNet. |
|
|
|
This module uses interpolation to upsample feature map in the decoder |
|
of UNet. It consists of one interpolation upsample layer and one |
|
convolutional layer. It can be one interpolation upsample layer followed |
|
by one convolutional layer (conv_first=False) or one convolutional layer |
|
followed by one interpolation upsample layer (conv_first=True). |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. Default: False. |
|
norm_cfg (dict | None): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
act_cfg (dict | None): Config dict for activation layer in ConvModule. |
|
Default: dict(type='ReLU'). |
|
conv_cfg (dict | None): Config dict for convolution layer. |
|
Default: None. |
|
conv_first (bool): Whether convolutional layer or interpolation |
|
upsample layer first. Default: False. It means interpolation |
|
upsample layer followed by one convolutional layer. |
|
kernel_size (int): Kernel size of the convolutional layer. Default: 1. |
|
stride (int): Stride of the convolutional layer. Default: 1. |
|
padding (int): Padding of the convolutional layer. Default: 1. |
|
upsampe_cfg (dict): Interpolation config of the upsample layer. |
|
Default: dict( |
|
scale_factor=2, mode='bilinear', align_corners=False). |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
with_cp=False, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
*, |
|
conv_cfg=None, |
|
conv_first=False, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
upsampe_cfg=dict( |
|
scale_factor=2, mode='bilinear', align_corners=False)): |
|
super(InterpConv, self).__init__() |
|
|
|
self.with_cp = with_cp |
|
conv = ConvModule( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
upsample = nn.Upsample(**upsampe_cfg) |
|
if conv_first: |
|
self.interp_upsample = nn.Sequential(conv, upsample) |
|
else: |
|
self.interp_upsample = nn.Sequential(upsample, conv) |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
|
|
if self.with_cp and x.requires_grad: |
|
out = cp.checkpoint(self.interp_upsample, x) |
|
else: |
|
out = self.interp_upsample(x) |
|
return out |
|
|
|
|
|
@BACKBONES.register_module() |
|
class UNet(nn.Module): |
|
"""UNet backbone. |
|
U-Net: Convolutional Networks for Biomedical Image Segmentation. |
|
https://arxiv.org/pdf/1505.04597.pdf |
|
|
|
Args: |
|
in_channels (int): Number of input image channels. Default" 3. |
|
base_channels (int): Number of base channels of each stage. |
|
The output channels of the first stage. Default: 64. |
|
num_stages (int): Number of stages in encoder, normally 5. Default: 5. |
|
strides (Sequence[int 1 | 2]): Strides of each stage in encoder. |
|
len(strides) is equal to num_stages. Normally the stride of the |
|
first stage in encoder is 1. If strides[i]=2, it uses stride |
|
convolution to downsample in the correspondance encoder stage. |
|
Default: (1, 1, 1, 1, 1). |
|
enc_num_convs (Sequence[int]): Number of convolutional layers in the |
|
convolution block of the correspondance encoder stage. |
|
Default: (2, 2, 2, 2, 2). |
|
dec_num_convs (Sequence[int]): Number of convolutional layers in the |
|
convolution block of the correspondance decoder stage. |
|
Default: (2, 2, 2, 2). |
|
downsamples (Sequence[int]): Whether use MaxPool to downsample the |
|
feature map after the first stage of encoder |
|
(stages: [1, num_stages)). If the correspondance encoder stage use |
|
stride convolution (strides[i]=2), it will never use MaxPool to |
|
downsample, even downsamples[i-1]=True. |
|
Default: (True, True, True, True). |
|
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. |
|
Default: (1, 1, 1, 1, 1). |
|
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. |
|
Default: (1, 1, 1, 1). |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. Default: False. |
|
conv_cfg (dict | None): Config dict for convolution layer. |
|
Default: None. |
|
norm_cfg (dict | None): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
act_cfg (dict | None): Config dict for activation layer in ConvModule. |
|
Default: dict(type='ReLU'). |
|
upsample_cfg (dict): The upsample config of the upsample module in |
|
decoder. Default: dict(type='InterpConv'). |
|
norm_eval (bool): Whether to set norm layers to eval mode, namely, |
|
freeze running stats (mean and var). Note: Effect on Batch Norm |
|
and its variants only. Default: False. |
|
dcn (bool): Use deformable convoluton in convolutional layer or not. |
|
Default: None. |
|
plugins (dict): plugins for convolutional layers. Default: None. |
|
|
|
Notice: |
|
The input image size should be devisible by the whole downsample rate |
|
of the encoder. More detail of the whole downsample rate can be found |
|
in UNet._check_input_devisible. |
|
|
|
""" |
|
|
|
def __init__(self, |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1), |
|
with_cp=False, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
upsample_cfg=dict(type='InterpConv'), |
|
norm_eval=False, |
|
dcn=None, |
|
plugins=None): |
|
super(UNet, self).__init__() |
|
assert dcn is None, 'Not implemented yet.' |
|
assert plugins is None, 'Not implemented yet.' |
|
assert len(strides) == num_stages, \ |
|
'The length of strides should be equal to num_stages, '\ |
|
f'while the strides is {strides}, the length of '\ |
|
f'strides is {len(strides)}, and the num_stages is '\ |
|
f'{num_stages}.' |
|
assert len(enc_num_convs) == num_stages, \ |
|
'The length of enc_num_convs should be equal to num_stages, '\ |
|
f'while the enc_num_convs is {enc_num_convs}, the length of '\ |
|
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ |
|
f'{num_stages}.' |
|
assert len(dec_num_convs) == (num_stages-1), \ |
|
'The length of dec_num_convs should be equal to (num_stages-1), '\ |
|
f'while the dec_num_convs is {dec_num_convs}, the length of '\ |
|
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ |
|
f'{num_stages}.' |
|
assert len(downsamples) == (num_stages-1), \ |
|
'The length of downsamples should be equal to (num_stages-1), '\ |
|
f'while the downsamples is {downsamples}, the length of '\ |
|
f'downsamples is {len(downsamples)}, and the num_stages is '\ |
|
f'{num_stages}.' |
|
assert len(enc_dilations) == num_stages, \ |
|
'The length of enc_dilations should be equal to num_stages, '\ |
|
f'while the enc_dilations is {enc_dilations}, the length of '\ |
|
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ |
|
f'{num_stages}.' |
|
assert len(dec_dilations) == (num_stages-1), \ |
|
'The length of dec_dilations should be equal to (num_stages-1), '\ |
|
f'while the dec_dilations is {dec_dilations}, the length of '\ |
|
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ |
|
f'{num_stages}.' |
|
self.num_stages = num_stages |
|
self.strides = strides |
|
self.downsamples = downsamples |
|
self.norm_eval = norm_eval |
|
|
|
self.encoder = nn.ModuleList() |
|
self.decoder = nn.ModuleList() |
|
|
|
for i in range(num_stages): |
|
enc_conv_block = [] |
|
if i != 0: |
|
if strides[i] == 1 and downsamples[i - 1]: |
|
enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) |
|
upsample = (strides[i] != 1 or downsamples[i - 1]) |
|
self.decoder.append( |
|
UpConvBlock( |
|
conv_block=BasicConvBlock, |
|
in_channels=base_channels * 2**i, |
|
skip_channels=base_channels * 2**(i - 1), |
|
out_channels=base_channels * 2**(i - 1), |
|
num_convs=dec_num_convs[i - 1], |
|
stride=1, |
|
dilation=dec_dilations[i - 1], |
|
with_cp=with_cp, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
upsample_cfg=upsample_cfg if upsample else None, |
|
dcn=None, |
|
plugins=None)) |
|
|
|
enc_conv_block.append( |
|
BasicConvBlock( |
|
in_channels=in_channels, |
|
out_channels=base_channels * 2**i, |
|
num_convs=enc_num_convs[i], |
|
stride=strides[i], |
|
dilation=enc_dilations[i], |
|
with_cp=with_cp, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
dcn=None, |
|
plugins=None)) |
|
self.encoder.append((nn.Sequential(*enc_conv_block))) |
|
in_channels = base_channels * 2**i |
|
|
|
def forward(self, x): |
|
self._check_input_devisible(x) |
|
enc_outs = [] |
|
for enc in self.encoder: |
|
x = enc(x) |
|
enc_outs.append(x) |
|
dec_outs = [x] |
|
for i in reversed(range(len(self.decoder))): |
|
x = self.decoder[i](enc_outs[i], x) |
|
dec_outs.append(x) |
|
|
|
return dec_outs |
|
|
|
def train(self, mode=True): |
|
"""Convert the model into training mode while keep normalization layer |
|
freezed.""" |
|
super(UNet, self).train(mode) |
|
if mode and self.norm_eval: |
|
for m in self.modules(): |
|
|
|
if isinstance(m, _BatchNorm): |
|
m.eval() |
|
|
|
def _check_input_devisible(self, x): |
|
h, w = x.shape[-2:] |
|
whole_downsample_rate = 1 |
|
for i in range(1, self.num_stages): |
|
if self.strides[i] == 2 or self.downsamples[i - 1]: |
|
whole_downsample_rate *= 2 |
|
assert (h % whole_downsample_rate == 0) \ |
|
and (w % whole_downsample_rate == 0),\ |
|
f'The input image size {(h, w)} should be devisible by the whole '\ |
|
f'downsample rate {whole_downsample_rate}, when num_stages is '\ |
|
f'{self.num_stages}, strides is {self.strides}, and downsamples '\ |
|
f'is {self.downsamples}.' |
|
|
|
def init_weights(self, pretrained=None): |
|
"""Initialize the weights in backbone. |
|
|
|
Args: |
|
pretrained (str, optional): Path to pre-trained weights. |
|
Defaults to None. |
|
""" |
|
if isinstance(pretrained, str): |
|
logger = get_root_logger() |
|
load_checkpoint(self, pretrained, strict=False, logger=logger) |
|
elif pretrained is None: |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
kaiming_init(m) |
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): |
|
constant_init(m, 1) |
|
else: |
|
raise TypeError('pretrained must be a str or None') |
|
|