rlawjdghek's picture
detectron2
a9a0ec2
raw
history blame
7.85 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn.functional as F
from detectron2.layers import Conv2d, FrozenBatchNorm2d, get_norm
from detectron2.modeling import BACKBONE_REGISTRY, ResNet, ResNetBlockBase
from detectron2.modeling.backbone.resnet import BasicStem, BottleneckBlock, DeformBottleneckBlock
from .trident_conv import TridentConv
__all__ = ["TridentBottleneckBlock", "make_trident_stage", "build_trident_resnet_backbone"]
class TridentBottleneckBlock(ResNetBlockBase):
def __init__(
self,
in_channels,
out_channels,
*,
bottleneck_channels,
stride=1,
num_groups=1,
norm="BN",
stride_in_1x1=False,
num_branch=3,
dilations=(1, 2, 3),
concat_output=False,
test_branch_idx=-1,
):
"""
Args:
num_branch (int): the number of branches in TridentNet.
dilations (tuple): the dilations of multiple branches in TridentNet.
concat_output (bool): if concatenate outputs of multiple branches in TridentNet.
Use 'True' for the last trident block.
"""
super().__init__(in_channels, out_channels, stride)
assert num_branch == len(dilations)
self.num_branch = num_branch
self.concat_output = concat_output
self.test_branch_idx = test_branch_idx
if in_channels != out_channels:
self.shortcut = Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=get_norm(norm, out_channels),
)
else:
self.shortcut = None
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
self.conv1 = Conv2d(
in_channels,
bottleneck_channels,
kernel_size=1,
stride=stride_1x1,
bias=False,
norm=get_norm(norm, bottleneck_channels),
)
self.conv2 = TridentConv(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
paddings=dilations,
bias=False,
groups=num_groups,
dilations=dilations,
num_branch=num_branch,
test_branch_idx=test_branch_idx,
norm=get_norm(norm, bottleneck_channels),
)
self.conv3 = Conv2d(
bottleneck_channels,
out_channels,
kernel_size=1,
bias=False,
norm=get_norm(norm, out_channels),
)
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
if layer is not None: # shortcut can be None
weight_init.c2_msra_fill(layer)
def forward(self, x):
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
if not isinstance(x, list):
x = [x] * num_branch
out = [self.conv1(b) for b in x]
out = [F.relu_(b) for b in out]
out = self.conv2(out)
out = [F.relu_(b) for b in out]
out = [self.conv3(b) for b in out]
if self.shortcut is not None:
shortcut = [self.shortcut(b) for b in x]
else:
shortcut = x
out = [out_b + shortcut_b for out_b, shortcut_b in zip(out, shortcut)]
out = [F.relu_(b) for b in out]
if self.concat_output:
out = torch.cat(out)
return out
def make_trident_stage(block_class, num_blocks, **kwargs):
"""
Create a resnet stage by creating many blocks for TridentNet.
"""
concat_output = [False] * (num_blocks - 1) + [True]
kwargs["concat_output_per_block"] = concat_output
return ResNet.make_stage(block_class, num_blocks, **kwargs)
@BACKBONE_REGISTRY.register()
def build_trident_resnet_backbone(cfg, input_shape):
"""
Create a ResNet instance from config for TridentNet.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# need registration of new blocks/stems?
norm = cfg.MODEL.RESNETS.NORM
stem = BasicStem(
in_channels=input_shape.channels,
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
norm=norm,
)
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
if freeze_at >= 1:
for p in stem.parameters():
p.requires_grad = False
stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem)
# fmt: off
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
depth = cfg.MODEL.RESNETS.DEPTH
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
bottleneck_channels = num_groups * width_per_group
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
branch_dilations = cfg.MODEL.TRIDENT.BRANCH_DILATIONS
trident_stage = cfg.MODEL.TRIDENT.TRIDENT_STAGE
test_branch_idx = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX
# fmt: on
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
stages = []
res_stage_idx = {"res2": 2, "res3": 3, "res4": 4, "res5": 5}
out_stage_idx = [res_stage_idx[f] for f in out_features]
trident_stage_idx = res_stage_idx[trident_stage]
max_stage_idx = max(out_stage_idx)
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
dilation = res5_dilation if stage_idx == 5 else 1
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
stage_kargs = {
"num_blocks": num_blocks_per_stage[idx],
"stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
"in_channels": in_channels,
"bottleneck_channels": bottleneck_channels,
"out_channels": out_channels,
"num_groups": num_groups,
"norm": norm,
"stride_in_1x1": stride_in_1x1,
"dilation": dilation,
}
if stage_idx == trident_stage_idx:
assert not deform_on_per_stage[
idx
], "Not support deformable conv in Trident blocks yet."
stage_kargs["block_class"] = TridentBottleneckBlock
stage_kargs["num_branch"] = num_branch
stage_kargs["dilations"] = branch_dilations
stage_kargs["test_branch_idx"] = test_branch_idx
stage_kargs.pop("dilation")
elif deform_on_per_stage[idx]:
stage_kargs["block_class"] = DeformBottleneckBlock
stage_kargs["deform_modulated"] = deform_modulated
stage_kargs["deform_num_groups"] = deform_num_groups
else:
stage_kargs["block_class"] = BottleneckBlock
blocks = (
make_trident_stage(**stage_kargs)
if stage_idx == trident_stage_idx
else ResNet.make_stage(**stage_kargs)
)
in_channels = out_channels
out_channels *= 2
bottleneck_channels *= 2
if freeze_at >= stage_idx:
for block in blocks:
block.freeze()
stages.append(blocks)
return ResNet(stem, stages, out_features=out_features)