|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule |
|
from mmcv.runner import BaseModule |
|
|
|
from ..builder import NECKS |
|
|
|
|
|
class Transition(BaseModule): |
|
"""Base class for transition. |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, init_cfg=None): |
|
super().__init__(init_cfg) |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
|
|
def forward(x): |
|
pass |
|
|
|
|
|
class UpInterpolationConv(Transition): |
|
"""A transition used for up-sampling. |
|
|
|
Up-sample the input by interpolation then refines the feature by |
|
a convolution layer. |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
scale_factor (int): Up-sampling factor. Default: 2. |
|
mode (int): Interpolation mode. Default: nearest. |
|
align_corners (bool): Whether align corners when interpolation. |
|
Default: None. |
|
kernel_size (int): Kernel size for the conv. Default: 3. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
scale_factor=2, |
|
mode='nearest', |
|
align_corners=None, |
|
kernel_size=3, |
|
init_cfg=None, |
|
**kwargs): |
|
super().__init__(in_channels, out_channels, init_cfg) |
|
self.mode = mode |
|
self.scale_factor = scale_factor |
|
self.align_corners = align_corners |
|
self.conv = ConvModule( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
padding=(kernel_size - 1) // 2, |
|
**kwargs) |
|
|
|
def forward(self, x): |
|
x = F.interpolate( |
|
x, |
|
scale_factor=self.scale_factor, |
|
mode=self.mode, |
|
align_corners=self.align_corners) |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class LastConv(Transition): |
|
"""A transition used for refining the output of the last stage. |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
num_inputs (int): Number of inputs of the FPN features. |
|
kernel_size (int): Kernel size for the conv. Default: 3. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
num_inputs, |
|
kernel_size=3, |
|
init_cfg=None, |
|
**kwargs): |
|
super().__init__(in_channels, out_channels, init_cfg) |
|
self.num_inputs = num_inputs |
|
self.conv_out = ConvModule( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
padding=(kernel_size - 1) // 2, |
|
**kwargs) |
|
|
|
def forward(self, inputs): |
|
assert len(inputs) == self.num_inputs |
|
return self.conv_out(inputs[-1]) |
|
|
|
|
|
@NECKS.register_module() |
|
class FPG(BaseModule): |
|
"""FPG. |
|
|
|
Implementation of `Feature Pyramid Grids (FPG) |
|
<https://arxiv.org/abs/2004.03580>`_. |
|
This implementation only gives the basic structure stated in the paper. |
|
But users can implement different type of transitions to fully explore the |
|
the potential power of the structure of FPG. |
|
|
|
Args: |
|
in_channels (int): Number of input channels (feature maps of all levels |
|
should have the same channels). |
|
out_channels (int): Number of output channels (used at each scale) |
|
num_outs (int): Number of output scales. |
|
stack_times (int): The number of times the pyramid architecture will |
|
be stacked. |
|
paths (list[str]): Specify the path order of each stack level. |
|
Each element in the list should be either 'bu' (bottom-up) or |
|
'td' (top-down). |
|
inter_channels (int): Number of inter channels. |
|
same_up_trans (dict): Transition that goes down at the same stage. |
|
same_down_trans (dict): Transition that goes up at the same stage. |
|
across_lateral_trans (dict): Across-pathway same-stage |
|
across_down_trans (dict): Across-pathway bottom-up connection. |
|
across_up_trans (dict): Across-pathway top-down connection. |
|
across_skip_trans (dict): Across-pathway skip connection. |
|
output_trans (dict): Transition that trans the output of the |
|
last stage. |
|
start_level (int): Index of the start input backbone level used to |
|
build the feature pyramid. Default: 0. |
|
end_level (int): Index of the end input backbone level (exclusive) to |
|
build the feature pyramid. Default: -1, which means the last level. |
|
add_extra_convs (bool): It decides whether to add conv |
|
layers on top of the original feature maps. Default to False. |
|
If True, its actual mode is specified by `extra_convs_on_inputs`. |
|
norm_cfg (dict): Config dict for normalization layer. Default: None. |
|
init_cfg (dict or list[dict], optional): Initialization config dict. |
|
""" |
|
|
|
transition_types = { |
|
'conv': ConvModule, |
|
'interpolation_conv': UpInterpolationConv, |
|
'last_conv': LastConv, |
|
} |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
num_outs, |
|
stack_times, |
|
paths, |
|
inter_channels=None, |
|
same_down_trans=None, |
|
same_up_trans=dict( |
|
type='conv', kernel_size=3, stride=2, padding=1), |
|
across_lateral_trans=dict(type='conv', kernel_size=1), |
|
across_down_trans=dict(type='conv', kernel_size=3), |
|
across_up_trans=None, |
|
across_skip_trans=dict(type='identity'), |
|
output_trans=dict(type='last_conv', kernel_size=3), |
|
start_level=0, |
|
end_level=-1, |
|
add_extra_convs=False, |
|
norm_cfg=None, |
|
skip_inds=None, |
|
init_cfg=[ |
|
dict(type='Caffe2Xavier', layer='Conv2d'), |
|
dict( |
|
type='Constant', |
|
layer=[ |
|
'_BatchNorm', '_InstanceNorm', 'GroupNorm', |
|
'LayerNorm' |
|
], |
|
val=1.0) |
|
]): |
|
super(FPG, self).__init__(init_cfg) |
|
assert isinstance(in_channels, list) |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.num_ins = len(in_channels) |
|
self.num_outs = num_outs |
|
if inter_channels is None: |
|
self.inter_channels = [out_channels for _ in range(num_outs)] |
|
elif isinstance(inter_channels, int): |
|
self.inter_channels = [inter_channels for _ in range(num_outs)] |
|
else: |
|
assert isinstance(inter_channels, list) |
|
assert len(inter_channels) == num_outs |
|
self.inter_channels = inter_channels |
|
self.stack_times = stack_times |
|
self.paths = paths |
|
assert isinstance(paths, list) and len(paths) == stack_times |
|
for d in paths: |
|
assert d in ('bu', 'td') |
|
|
|
self.same_down_trans = same_down_trans |
|
self.same_up_trans = same_up_trans |
|
self.across_lateral_trans = across_lateral_trans |
|
self.across_down_trans = across_down_trans |
|
self.across_up_trans = across_up_trans |
|
self.output_trans = output_trans |
|
self.across_skip_trans = across_skip_trans |
|
|
|
self.with_bias = norm_cfg is None |
|
|
|
if self.across_skip_trans is not None: |
|
skip_inds is not None |
|
self.skip_inds = skip_inds |
|
assert len(self.skip_inds[0]) <= self.stack_times |
|
|
|
if end_level == -1: |
|
self.backbone_end_level = self.num_ins |
|
assert num_outs >= self.num_ins - start_level |
|
else: |
|
|
|
self.backbone_end_level = end_level |
|
assert end_level <= len(in_channels) |
|
assert num_outs == end_level - start_level |
|
self.start_level = start_level |
|
self.end_level = end_level |
|
self.add_extra_convs = add_extra_convs |
|
|
|
|
|
self.lateral_convs = nn.ModuleList() |
|
for i in range(self.start_level, self.backbone_end_level): |
|
l_conv = nn.Conv2d(self.in_channels[i], |
|
self.inter_channels[i - self.start_level], 1) |
|
self.lateral_convs.append(l_conv) |
|
|
|
extra_levels = num_outs - self.backbone_end_level + self.start_level |
|
self.extra_downsamples = nn.ModuleList() |
|
for i in range(extra_levels): |
|
if self.add_extra_convs: |
|
fpn_idx = self.backbone_end_level - self.start_level + i |
|
extra_conv = nn.Conv2d( |
|
self.inter_channels[fpn_idx - 1], |
|
self.inter_channels[fpn_idx], |
|
3, |
|
stride=2, |
|
padding=1) |
|
self.extra_downsamples.append(extra_conv) |
|
else: |
|
self.extra_downsamples.append(nn.MaxPool2d(1, stride=2)) |
|
|
|
self.fpn_transitions = nn.ModuleList() |
|
for s in range(self.stack_times): |
|
stage_trans = nn.ModuleList() |
|
for i in range(self.num_outs): |
|
|
|
trans = nn.ModuleDict() |
|
if s in self.skip_inds[i]: |
|
stage_trans.append(trans) |
|
continue |
|
|
|
if i == 0 or self.same_up_trans is None: |
|
same_up_trans = None |
|
else: |
|
same_up_trans = self.build_trans( |
|
self.same_up_trans, self.inter_channels[i - 1], |
|
self.inter_channels[i]) |
|
trans['same_up'] = same_up_trans |
|
|
|
if i == self.num_outs - 1 or self.same_down_trans is None: |
|
same_down_trans = None |
|
else: |
|
same_down_trans = self.build_trans( |
|
self.same_down_trans, self.inter_channels[i + 1], |
|
self.inter_channels[i]) |
|
trans['same_down'] = same_down_trans |
|
|
|
across_lateral_trans = self.build_trans( |
|
self.across_lateral_trans, self.inter_channels[i], |
|
self.inter_channels[i]) |
|
trans['across_lateral'] = across_lateral_trans |
|
|
|
if i == self.num_outs - 1 or self.across_down_trans is None: |
|
across_down_trans = None |
|
else: |
|
across_down_trans = self.build_trans( |
|
self.across_down_trans, self.inter_channels[i + 1], |
|
self.inter_channels[i]) |
|
trans['across_down'] = across_down_trans |
|
|
|
if i == 0 or self.across_up_trans is None: |
|
across_up_trans = None |
|
else: |
|
across_up_trans = self.build_trans( |
|
self.across_up_trans, self.inter_channels[i - 1], |
|
self.inter_channels[i]) |
|
trans['across_up'] = across_up_trans |
|
if self.across_skip_trans is None: |
|
across_skip_trans = None |
|
else: |
|
across_skip_trans = self.build_trans( |
|
self.across_skip_trans, self.inter_channels[i - 1], |
|
self.inter_channels[i]) |
|
trans['across_skip'] = across_skip_trans |
|
|
|
stage_trans.append(trans) |
|
self.fpn_transitions.append(stage_trans) |
|
|
|
self.output_transition = nn.ModuleList() |
|
for i in range(self.num_outs): |
|
trans = self.build_trans( |
|
self.output_trans, |
|
self.inter_channels[i], |
|
self.out_channels, |
|
num_inputs=self.stack_times + 1) |
|
self.output_transition.append(trans) |
|
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
def build_trans(self, cfg, in_channels, out_channels, **extra_args): |
|
cfg_ = cfg.copy() |
|
trans_type = cfg_.pop('type') |
|
trans_cls = self.transition_types[trans_type] |
|
return trans_cls(in_channels, out_channels, **cfg_, **extra_args) |
|
|
|
def fuse(self, fuse_dict): |
|
out = None |
|
for item in fuse_dict.values(): |
|
if item is not None: |
|
if out is None: |
|
out = item |
|
else: |
|
out = out + item |
|
return out |
|
|
|
def forward(self, inputs): |
|
assert len(inputs) == len(self.in_channels) |
|
|
|
|
|
feats = [ |
|
lateral_conv(inputs[i + self.start_level]) |
|
for i, lateral_conv in enumerate(self.lateral_convs) |
|
] |
|
for downsample in self.extra_downsamples: |
|
feats.append(downsample(feats[-1])) |
|
|
|
outs = [feats] |
|
|
|
for i in range(self.stack_times): |
|
current_outs = outs[-1] |
|
next_outs = [] |
|
direction = self.paths[i] |
|
for j in range(self.num_outs): |
|
if i in self.skip_inds[j]: |
|
next_outs.append(outs[-1][j]) |
|
continue |
|
|
|
if direction == 'td': |
|
lvl = self.num_outs - j - 1 |
|
else: |
|
lvl = j |
|
|
|
if direction == 'td': |
|
same_trans = self.fpn_transitions[i][lvl]['same_down'] |
|
else: |
|
same_trans = self.fpn_transitions[i][lvl]['same_up'] |
|
across_lateral_trans = self.fpn_transitions[i][lvl][ |
|
'across_lateral'] |
|
across_down_trans = self.fpn_transitions[i][lvl]['across_down'] |
|
across_up_trans = self.fpn_transitions[i][lvl]['across_up'] |
|
across_skip_trans = self.fpn_transitions[i][lvl]['across_skip'] |
|
|
|
to_fuse = dict( |
|
same=None, lateral=None, across_up=None, across_down=None) |
|
|
|
if same_trans is not None: |
|
to_fuse['same'] = same_trans(next_outs[-1]) |
|
|
|
if across_lateral_trans is not None: |
|
to_fuse['lateral'] = across_lateral_trans( |
|
current_outs[lvl]) |
|
|
|
if lvl > 0 and across_up_trans is not None: |
|
to_fuse['across_up'] = across_up_trans(current_outs[lvl - |
|
1]) |
|
|
|
if (lvl < self.num_outs - 1 and across_down_trans is not None): |
|
to_fuse['across_down'] = across_down_trans( |
|
current_outs[lvl + 1]) |
|
if across_skip_trans is not None: |
|
to_fuse['across_skip'] = across_skip_trans(outs[0][lvl]) |
|
x = self.fuse(to_fuse) |
|
next_outs.append(x) |
|
|
|
if direction == 'td': |
|
outs.append(next_outs[::-1]) |
|
else: |
|
outs.append(next_outs) |
|
|
|
|
|
final_outs = [] |
|
for i in range(self.num_outs): |
|
lvl_out_list = [] |
|
for s in range(len(outs)): |
|
lvl_out_list.append(outs[s][i]) |
|
lvl_out = self.output_transition[i](lvl_out_list) |
|
final_outs.append(lvl_out) |
|
|
|
return final_outs |
|
|