Cyril666's picture
First model version
4ea50ff
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import json
import logging
from collections import OrderedDict
from . import (
fbnet_builder as mbuilder,
fbnet_modeldef as modeldef,
)
import torch.nn as nn
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.rpn import rpn
from maskrcnn_benchmark.modeling import poolers
logger = logging.getLogger(__name__)
def create_builder(cfg):
bn_type = cfg.MODEL.FBNET.BN_TYPE
if bn_type == "gn":
bn_type = (bn_type, cfg.GROUP_NORM.NUM_GROUPS)
factor = cfg.MODEL.FBNET.SCALE_FACTOR
arch = cfg.MODEL.FBNET.ARCH
arch_def = cfg.MODEL.FBNET.ARCH_DEF
if len(arch_def) > 0:
arch_def = json.loads(arch_def)
if arch in modeldef.MODEL_ARCH:
if len(arch_def) > 0:
assert (
arch_def == modeldef.MODEL_ARCH[arch]
), "Two architectures with the same name {},\n{},\n{}".format(
arch, arch_def, modeldef.MODEL_ARCH[arch]
)
arch_def = modeldef.MODEL_ARCH[arch]
else:
assert arch_def is not None and len(arch_def) > 0
arch_def = mbuilder.unify_arch_def(arch_def)
rpn_stride = arch_def.get("rpn_stride", None)
if rpn_stride is not None:
assert (
cfg.MODEL.RPN.ANCHOR_STRIDE[0] == rpn_stride
), "Needs to set cfg.MODEL.RPN.ANCHOR_STRIDE to {}, got {}".format(
rpn_stride, cfg.MODEL.RPN.ANCHOR_STRIDE
)
width_divisor = cfg.MODEL.FBNET.WIDTH_DIVISOR
dw_skip_bn = cfg.MODEL.FBNET.DW_CONV_SKIP_BN
dw_skip_relu = cfg.MODEL.FBNET.DW_CONV_SKIP_RELU
logger.info(
"Building fbnet model with arch {} (without scaling):\n{}".format(
arch, arch_def
)
)
builder = mbuilder.FBNetBuilder(
width_ratio=factor,
bn_type=bn_type,
width_divisor=width_divisor,
dw_skip_bn=dw_skip_bn,
dw_skip_relu=dw_skip_relu,
)
return builder, arch_def
def _get_trunk_cfg(arch_def):
""" Get all stages except the last one """
num_stages = mbuilder.get_num_stages(arch_def)
trunk_stages = arch_def.get("backbone", range(num_stages - 1))
ret = mbuilder.get_blocks(arch_def, stage_indices=trunk_stages)
return ret
class FBNetTrunk(nn.Module):
def __init__(
self, builder, arch_def, dim_in,
):
super(FBNetTrunk, self).__init__()
self.first = builder.add_first(arch_def["first"], dim_in=dim_in)
trunk_cfg = _get_trunk_cfg(arch_def)
self.stages = builder.add_blocks(trunk_cfg["stages"])
# return features for each stage
def forward(self, x):
y = self.first(x)
y = self.stages(y)
ret = [y]
return ret
@registry.BACKBONES.register("FBNet")
def add_conv_body(cfg, dim_in=3):
builder, arch_def = create_builder(cfg)
body = FBNetTrunk(builder, arch_def, dim_in)
model = nn.Sequential(OrderedDict([("body", body)]))
model.out_channels = builder.last_depth
return model
def _get_rpn_stage(arch_def, num_blocks):
rpn_stage = arch_def.get("rpn")
ret = mbuilder.get_blocks(arch_def, stage_indices=rpn_stage)
if num_blocks > 0:
logger.warn('Use last {} blocks in {} as rpn'.format(num_blocks, ret))
block_count = len(ret["stages"])
assert num_blocks <= block_count, "use block {}, block count {}".format(
num_blocks, block_count
)
blocks = range(block_count - num_blocks, block_count)
ret = mbuilder.get_blocks(ret, block_indices=blocks)
return ret["stages"]
class FBNetRPNHead(nn.Module):
def __init__(
self, cfg, in_channels, builder, arch_def,
):
super(FBNetRPNHead, self).__init__()
assert in_channels == builder.last_depth
rpn_bn_type = cfg.MODEL.FBNET.RPN_BN_TYPE
if len(rpn_bn_type) > 0:
builder.bn_type = rpn_bn_type
use_blocks = cfg.MODEL.FBNET.RPN_HEAD_BLOCKS
stages = _get_rpn_stage(arch_def, use_blocks)
self.head = builder.add_blocks(stages)
self.out_channels = builder.last_depth
def forward(self, x):
x = [self.head(y) for y in x]
return x
@registry.RPN_HEADS.register("FBNet.rpn_head")
def add_rpn_head(cfg, in_channels, num_anchors):
builder, model_arch = create_builder(cfg)
builder.last_depth = in_channels
assert in_channels == builder.last_depth
# builder.name_prefix = "[rpn]"
rpn_feature = FBNetRPNHead(cfg, in_channels, builder, model_arch)
rpn_regressor = rpn.RPNHeadConvRegressor(
cfg, rpn_feature.out_channels, num_anchors)
return nn.Sequential(rpn_feature, rpn_regressor)
def _get_head_stage(arch, head_name, blocks):
# use default name 'head' if the specific name 'head_name' does not existed
if head_name not in arch:
head_name = "head"
head_stage = arch.get(head_name)
ret = mbuilder.get_blocks(arch, stage_indices=head_stage, block_indices=blocks)
return ret["stages"]
# name mapping for head names in arch def and cfg
ARCH_CFG_NAME_MAPPING = {
"bbox": "ROI_BOX_HEAD",
"kpts": "ROI_KEYPOINT_HEAD",
"mask": "ROI_MASK_HEAD",
}
class FBNetROIHead(nn.Module):
def __init__(
self, cfg, in_channels, builder, arch_def,
head_name, use_blocks, stride_init, last_layer_scale,
):
super(FBNetROIHead, self).__init__()
assert in_channels == builder.last_depth
assert isinstance(use_blocks, list)
head_cfg_name = ARCH_CFG_NAME_MAPPING[head_name]
self.pooler = poolers.make_pooler(cfg, head_cfg_name)
stage = _get_head_stage(arch_def, head_name, use_blocks)
assert stride_init in [0, 1, 2]
if stride_init != 0:
stage[0]["block"][3] = stride_init
blocks = builder.add_blocks(stage)
last_info = copy.deepcopy(arch_def["last"])
last_info[1] = last_layer_scale
last = builder.add_last(last_info)
self.head = nn.Sequential(OrderedDict([
("blocks", blocks),
("last", last)
]))
self.out_channels = builder.last_depth
def forward(self, x, proposals):
x = self.pooler(x, proposals)
x = self.head(x)
return x
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FBNet.roi_head")
def add_roi_head(cfg, in_channels):
builder, model_arch = create_builder(cfg)
builder.last_depth = in_channels
# builder.name_prefix = "_[bbox]_"
return FBNetROIHead(
cfg, in_channels, builder, model_arch,
head_name="bbox",
use_blocks=cfg.MODEL.FBNET.DET_HEAD_BLOCKS,
stride_init=cfg.MODEL.FBNET.DET_HEAD_STRIDE,
last_layer_scale=cfg.MODEL.FBNET.DET_HEAD_LAST_SCALE,
)
@registry.ROI_KEYPOINT_FEATURE_EXTRACTORS.register("FBNet.roi_head_keypoints")
def add_roi_head_keypoints(cfg, in_channels):
builder, model_arch = create_builder(cfg)
builder.last_depth = in_channels
# builder.name_prefix = "_[kpts]_"
return FBNetROIHead(
cfg, in_channels, builder, model_arch,
head_name="kpts",
use_blocks=cfg.MODEL.FBNET.KPTS_HEAD_BLOCKS,
stride_init=cfg.MODEL.FBNET.KPTS_HEAD_STRIDE,
last_layer_scale=cfg.MODEL.FBNET.KPTS_HEAD_LAST_SCALE,
)
@registry.ROI_MASK_FEATURE_EXTRACTORS.register("FBNet.roi_head_mask")
def add_roi_head_mask(cfg, in_channels):
builder, model_arch = create_builder(cfg)
builder.last_depth = in_channels
# builder.name_prefix = "_[mask]_"
return FBNetROIHead(
cfg, in_channels, builder, model_arch,
head_name="mask",
use_blocks=cfg.MODEL.FBNET.MASK_HEAD_BLOCKS,
stride_init=cfg.MODEL.FBNET.MASK_HEAD_STRIDE,
last_layer_scale=cfg.MODEL.FBNET.MASK_HEAD_LAST_SCALE,
)