|
import torch.nn as nn
|
|
import torch.utils.checkpoint as cp
|
|
from annotator.uniformer.mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
|
build_norm_layer, constant_init, kaiming_init)
|
|
from annotator.uniformer.mmcv.runner import load_checkpoint
|
|
from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
|
|
|
|
from annotator.uniformer.mmseg.utils import get_root_logger
|
|
from ..builder import BACKBONES
|
|
from ..utils import UpConvBlock
|
|
|
|
|
|
class BasicConvBlock(nn.Module):
|
|
"""Basic convolutional block for UNet.
|
|
|
|
This module consists of several plain convolutional layers.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
num_convs (int): Number of convolutional layers. Default: 2.
|
|
stride (int): Whether use stride convolution to downsample
|
|
the input feature map. If stride=2, it only uses stride convolution
|
|
in the first convolutional layer to downsample the input feature
|
|
map. Options are 1 or 2. Default: 1.
|
|
dilation (int): Whether use dilated convolution to expand the
|
|
receptive field. Set dilation rate of each convolutional layer and
|
|
the dilation rate of the first convolutional layer is always 1.
|
|
Default: 1.
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Default: False.
|
|
conv_cfg (dict | None): Config dict for convolution layer.
|
|
Default: None.
|
|
norm_cfg (dict | None): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
|
Default: dict(type='ReLU').
|
|
dcn (bool): Use deformable convolution in convolutional layer or not.
|
|
Default: None.
|
|
plugins (dict): plugins for convolutional layers. Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
num_convs=2,
|
|
stride=1,
|
|
dilation=1,
|
|
with_cp=False,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'),
|
|
dcn=None,
|
|
plugins=None):
|
|
super(BasicConvBlock, self).__init__()
|
|
assert dcn is None, 'Not implemented yet.'
|
|
assert plugins is None, 'Not implemented yet.'
|
|
|
|
self.with_cp = with_cp
|
|
convs = []
|
|
for i in range(num_convs):
|
|
convs.append(
|
|
ConvModule(
|
|
in_channels=in_channels if i == 0 else out_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
stride=stride if i == 0 else 1,
|
|
dilation=1 if i == 0 else dilation,
|
|
padding=1 if i == 0 else dilation,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg))
|
|
|
|
self.convs = nn.Sequential(*convs)
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
|
|
if self.with_cp and x.requires_grad:
|
|
out = cp.checkpoint(self.convs, x)
|
|
else:
|
|
out = self.convs(x)
|
|
return out
|
|
|
|
|
|
@UPSAMPLE_LAYERS.register_module()
|
|
class DeconvModule(nn.Module):
|
|
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
|
|
|
This module uses deconvolution to upsample feature map in the decoder
|
|
of UNet.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Default: False.
|
|
norm_cfg (dict | None): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
|
Default: dict(type='ReLU').
|
|
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
with_cp=False,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'),
|
|
*,
|
|
kernel_size=4,
|
|
scale_factor=2):
|
|
super(DeconvModule, self).__init__()
|
|
|
|
assert (kernel_size - scale_factor >= 0) and\
|
|
(kernel_size - scale_factor) % 2 == 0,\
|
|
f'kernel_size should be greater than or equal to scale_factor '\
|
|
f'and (kernel_size - scale_factor) should be even numbers, '\
|
|
f'while the kernel size is {kernel_size} and scale_factor is '\
|
|
f'{scale_factor}.'
|
|
|
|
stride = scale_factor
|
|
padding = (kernel_size - scale_factor) // 2
|
|
self.with_cp = with_cp
|
|
deconv = nn.ConvTranspose2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding)
|
|
|
|
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
|
|
activate = build_activation_layer(act_cfg)
|
|
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
|
|
if self.with_cp and x.requires_grad:
|
|
out = cp.checkpoint(self.deconv_upsamping, x)
|
|
else:
|
|
out = self.deconv_upsamping(x)
|
|
return out
|
|
|
|
|
|
@UPSAMPLE_LAYERS.register_module()
|
|
class InterpConv(nn.Module):
|
|
"""Interpolation upsample module in decoder for UNet.
|
|
|
|
This module uses interpolation to upsample feature map in the decoder
|
|
of UNet. It consists of one interpolation upsample layer and one
|
|
convolutional layer. It can be one interpolation upsample layer followed
|
|
by one convolutional layer (conv_first=False) or one convolutional layer
|
|
followed by one interpolation upsample layer (conv_first=True).
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Default: False.
|
|
norm_cfg (dict | None): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
|
Default: dict(type='ReLU').
|
|
conv_cfg (dict | None): Config dict for convolution layer.
|
|
Default: None.
|
|
conv_first (bool): Whether convolutional layer or interpolation
|
|
upsample layer first. Default: False. It means interpolation
|
|
upsample layer followed by one convolutional layer.
|
|
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
|
stride (int): Stride of the convolutional layer. Default: 1.
|
|
padding (int): Padding of the convolutional layer. Default: 1.
|
|
upsample_cfg (dict): Interpolation config of the upsample layer.
|
|
Default: dict(
|
|
scale_factor=2, mode='bilinear', align_corners=False).
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
with_cp=False,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'),
|
|
*,
|
|
conv_cfg=None,
|
|
conv_first=False,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
upsample_cfg=dict(
|
|
scale_factor=2, mode='bilinear', align_corners=False)):
|
|
super(InterpConv, self).__init__()
|
|
|
|
self.with_cp = with_cp
|
|
conv = ConvModule(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
upsample = nn.Upsample(**upsample_cfg)
|
|
if conv_first:
|
|
self.interp_upsample = nn.Sequential(conv, upsample)
|
|
else:
|
|
self.interp_upsample = nn.Sequential(upsample, conv)
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
|
|
if self.with_cp and x.requires_grad:
|
|
out = cp.checkpoint(self.interp_upsample, x)
|
|
else:
|
|
out = self.interp_upsample(x)
|
|
return out
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class UNet(nn.Module):
|
|
"""UNet backbone.
|
|
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
|
https://arxiv.org/pdf/1505.04597.pdf
|
|
|
|
Args:
|
|
in_channels (int): Number of input image channels. Default" 3.
|
|
base_channels (int): Number of base channels of each stage.
|
|
The output channels of the first stage. Default: 64.
|
|
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
|
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
|
len(strides) is equal to num_stages. Normally the stride of the
|
|
first stage in encoder is 1. If strides[i]=2, it uses stride
|
|
convolution to downsample in the correspondence encoder stage.
|
|
Default: (1, 1, 1, 1, 1).
|
|
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
|
convolution block of the correspondence encoder stage.
|
|
Default: (2, 2, 2, 2, 2).
|
|
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
|
convolution block of the correspondence decoder stage.
|
|
Default: (2, 2, 2, 2).
|
|
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
|
feature map after the first stage of encoder
|
|
(stages: [1, num_stages)). If the correspondence encoder stage use
|
|
stride convolution (strides[i]=2), it will never use MaxPool to
|
|
downsample, even downsamples[i-1]=True.
|
|
Default: (True, True, True, True).
|
|
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
|
Default: (1, 1, 1, 1, 1).
|
|
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
|
Default: (1, 1, 1, 1).
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Default: False.
|
|
conv_cfg (dict | None): Config dict for convolution layer.
|
|
Default: None.
|
|
norm_cfg (dict | None): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
|
Default: dict(type='ReLU').
|
|
upsample_cfg (dict): The upsample config of the upsample module in
|
|
decoder. Default: dict(type='InterpConv').
|
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
|
and its variants only. Default: False.
|
|
dcn (bool): Use deformable convolution in convolutional layer or not.
|
|
Default: None.
|
|
plugins (dict): plugins for convolutional layers. Default: None.
|
|
|
|
Notice:
|
|
The input image size should be divisible by the whole downsample rate
|
|
of the encoder. More detail of the whole downsample rate can be found
|
|
in UNet._check_input_divisible.
|
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels=3,
|
|
base_channels=64,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1),
|
|
with_cp=False,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'),
|
|
upsample_cfg=dict(type='InterpConv'),
|
|
norm_eval=False,
|
|
dcn=None,
|
|
plugins=None):
|
|
super(UNet, self).__init__()
|
|
assert dcn is None, 'Not implemented yet.'
|
|
assert plugins is None, 'Not implemented yet.'
|
|
assert len(strides) == num_stages, \
|
|
'The length of strides should be equal to num_stages, '\
|
|
f'while the strides is {strides}, the length of '\
|
|
f'strides is {len(strides)}, and the num_stages is '\
|
|
f'{num_stages}.'
|
|
assert len(enc_num_convs) == num_stages, \
|
|
'The length of enc_num_convs should be equal to num_stages, '\
|
|
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
|
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
|
f'{num_stages}.'
|
|
assert len(dec_num_convs) == (num_stages-1), \
|
|
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
|
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
|
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
|
f'{num_stages}.'
|
|
assert len(downsamples) == (num_stages-1), \
|
|
'The length of downsamples should be equal to (num_stages-1), '\
|
|
f'while the downsamples is {downsamples}, the length of '\
|
|
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
|
f'{num_stages}.'
|
|
assert len(enc_dilations) == num_stages, \
|
|
'The length of enc_dilations should be equal to num_stages, '\
|
|
f'while the enc_dilations is {enc_dilations}, the length of '\
|
|
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
|
f'{num_stages}.'
|
|
assert len(dec_dilations) == (num_stages-1), \
|
|
'The length of dec_dilations should be equal to (num_stages-1), '\
|
|
f'while the dec_dilations is {dec_dilations}, the length of '\
|
|
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
|
f'{num_stages}.'
|
|
self.num_stages = num_stages
|
|
self.strides = strides
|
|
self.downsamples = downsamples
|
|
self.norm_eval = norm_eval
|
|
self.base_channels = base_channels
|
|
|
|
self.encoder = nn.ModuleList()
|
|
self.decoder = nn.ModuleList()
|
|
|
|
for i in range(num_stages):
|
|
enc_conv_block = []
|
|
if i != 0:
|
|
if strides[i] == 1 and downsamples[i - 1]:
|
|
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
|
upsample = (strides[i] != 1 or downsamples[i - 1])
|
|
self.decoder.append(
|
|
UpConvBlock(
|
|
conv_block=BasicConvBlock,
|
|
in_channels=base_channels * 2**i,
|
|
skip_channels=base_channels * 2**(i - 1),
|
|
out_channels=base_channels * 2**(i - 1),
|
|
num_convs=dec_num_convs[i - 1],
|
|
stride=1,
|
|
dilation=dec_dilations[i - 1],
|
|
with_cp=with_cp,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
upsample_cfg=upsample_cfg if upsample else None,
|
|
dcn=None,
|
|
plugins=None))
|
|
|
|
enc_conv_block.append(
|
|
BasicConvBlock(
|
|
in_channels=in_channels,
|
|
out_channels=base_channels * 2**i,
|
|
num_convs=enc_num_convs[i],
|
|
stride=strides[i],
|
|
dilation=enc_dilations[i],
|
|
with_cp=with_cp,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
dcn=None,
|
|
plugins=None))
|
|
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
|
in_channels = base_channels * 2**i
|
|
|
|
def forward(self, x):
|
|
self._check_input_divisible(x)
|
|
enc_outs = []
|
|
for enc in self.encoder:
|
|
x = enc(x)
|
|
enc_outs.append(x)
|
|
dec_outs = [x]
|
|
for i in reversed(range(len(self.decoder))):
|
|
x = self.decoder[i](enc_outs[i], x)
|
|
dec_outs.append(x)
|
|
|
|
return dec_outs
|
|
|
|
def train(self, mode=True):
|
|
"""Convert the model into training mode while keep normalization layer
|
|
freezed."""
|
|
super(UNet, self).train(mode)
|
|
if mode and self.norm_eval:
|
|
for m in self.modules():
|
|
|
|
if isinstance(m, _BatchNorm):
|
|
m.eval()
|
|
|
|
def _check_input_divisible(self, x):
|
|
h, w = x.shape[-2:]
|
|
whole_downsample_rate = 1
|
|
for i in range(1, self.num_stages):
|
|
if self.strides[i] == 2 or self.downsamples[i - 1]:
|
|
whole_downsample_rate *= 2
|
|
assert (h % whole_downsample_rate == 0) \
|
|
and (w % whole_downsample_rate == 0),\
|
|
f'The input image size {(h, w)} should be divisible by the whole '\
|
|
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
|
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
|
f'is {self.downsamples}.'
|
|
|
|
def init_weights(self, pretrained=None):
|
|
"""Initialize the weights in backbone.
|
|
|
|
Args:
|
|
pretrained (str, optional): Path to pre-trained weights.
|
|
Defaults to None.
|
|
"""
|
|
if isinstance(pretrained, str):
|
|
logger = get_root_logger()
|
|
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
|
elif pretrained is None:
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
kaiming_init(m)
|
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
|
constant_init(m, 1)
|
|
else:
|
|
raise TypeError('pretrained must be a str or None')
|
|
|