Spaces:
Runtime error
Runtime error
from __future__ import absolute_import, division, print_function, unicode_literals | |
import copy | |
import json | |
import logging | |
from collections import OrderedDict | |
from . import ( | |
fbnet_builder as mbuilder, | |
fbnet_modeldef as modeldef, | |
) | |
import torch.nn as nn | |
from maskrcnn_benchmark.modeling import registry | |
from maskrcnn_benchmark.modeling.rpn import rpn | |
from maskrcnn_benchmark.modeling import poolers | |
logger = logging.getLogger(__name__) | |
def create_builder(cfg): | |
bn_type = cfg.MODEL.FBNET.BN_TYPE | |
if bn_type == "gn": | |
bn_type = (bn_type, cfg.GROUP_NORM.NUM_GROUPS) | |
factor = cfg.MODEL.FBNET.SCALE_FACTOR | |
arch = cfg.MODEL.FBNET.ARCH | |
arch_def = cfg.MODEL.FBNET.ARCH_DEF | |
if len(arch_def) > 0: | |
arch_def = json.loads(arch_def) | |
if arch in modeldef.MODEL_ARCH: | |
if len(arch_def) > 0: | |
assert ( | |
arch_def == modeldef.MODEL_ARCH[arch] | |
), "Two architectures with the same name {},\n{},\n{}".format( | |
arch, arch_def, modeldef.MODEL_ARCH[arch] | |
) | |
arch_def = modeldef.MODEL_ARCH[arch] | |
else: | |
assert arch_def is not None and len(arch_def) > 0 | |
arch_def = mbuilder.unify_arch_def(arch_def) | |
rpn_stride = arch_def.get("rpn_stride", None) | |
if rpn_stride is not None: | |
assert ( | |
cfg.MODEL.RPN.ANCHOR_STRIDE[0] == rpn_stride | |
), "Needs to set cfg.MODEL.RPN.ANCHOR_STRIDE to {}, got {}".format( | |
rpn_stride, cfg.MODEL.RPN.ANCHOR_STRIDE | |
) | |
width_divisor = cfg.MODEL.FBNET.WIDTH_DIVISOR | |
dw_skip_bn = cfg.MODEL.FBNET.DW_CONV_SKIP_BN | |
dw_skip_relu = cfg.MODEL.FBNET.DW_CONV_SKIP_RELU | |
logger.info( | |
"Building fbnet model with arch {} (without scaling):\n{}".format( | |
arch, arch_def | |
) | |
) | |
builder = mbuilder.FBNetBuilder( | |
width_ratio=factor, | |
bn_type=bn_type, | |
width_divisor=width_divisor, | |
dw_skip_bn=dw_skip_bn, | |
dw_skip_relu=dw_skip_relu, | |
) | |
return builder, arch_def | |
def _get_trunk_cfg(arch_def): | |
""" Get all stages except the last one """ | |
num_stages = mbuilder.get_num_stages(arch_def) | |
trunk_stages = arch_def.get("backbone", range(num_stages - 1)) | |
ret = mbuilder.get_blocks(arch_def, stage_indices=trunk_stages) | |
return ret | |
class FBNetTrunk(nn.Module): | |
def __init__( | |
self, builder, arch_def, dim_in, | |
): | |
super(FBNetTrunk, self).__init__() | |
self.first = builder.add_first(arch_def["first"], dim_in=dim_in) | |
trunk_cfg = _get_trunk_cfg(arch_def) | |
self.stages = builder.add_blocks(trunk_cfg["stages"]) | |
# return features for each stage | |
def forward(self, x): | |
y = self.first(x) | |
y = self.stages(y) | |
ret = [y] | |
return ret | |
def add_conv_body(cfg, dim_in=3): | |
builder, arch_def = create_builder(cfg) | |
body = FBNetTrunk(builder, arch_def, dim_in) | |
model = nn.Sequential(OrderedDict([("body", body)])) | |
model.out_channels = builder.last_depth | |
return model | |
def _get_rpn_stage(arch_def, num_blocks): | |
rpn_stage = arch_def.get("rpn") | |
ret = mbuilder.get_blocks(arch_def, stage_indices=rpn_stage) | |
if num_blocks > 0: | |
logger.warn('Use last {} blocks in {} as rpn'.format(num_blocks, ret)) | |
block_count = len(ret["stages"]) | |
assert num_blocks <= block_count, "use block {}, block count {}".format( | |
num_blocks, block_count | |
) | |
blocks = range(block_count - num_blocks, block_count) | |
ret = mbuilder.get_blocks(ret, block_indices=blocks) | |
return ret["stages"] | |
class FBNetRPNHead(nn.Module): | |
def __init__( | |
self, cfg, in_channels, builder, arch_def, | |
): | |
super(FBNetRPNHead, self).__init__() | |
assert in_channels == builder.last_depth | |
rpn_bn_type = cfg.MODEL.FBNET.RPN_BN_TYPE | |
if len(rpn_bn_type) > 0: | |
builder.bn_type = rpn_bn_type | |
use_blocks = cfg.MODEL.FBNET.RPN_HEAD_BLOCKS | |
stages = _get_rpn_stage(arch_def, use_blocks) | |
self.head = builder.add_blocks(stages) | |
self.out_channels = builder.last_depth | |
def forward(self, x): | |
x = [self.head(y) for y in x] | |
return x | |
def add_rpn_head(cfg, in_channels, num_anchors): | |
builder, model_arch = create_builder(cfg) | |
builder.last_depth = in_channels | |
assert in_channels == builder.last_depth | |
# builder.name_prefix = "[rpn]" | |
rpn_feature = FBNetRPNHead(cfg, in_channels, builder, model_arch) | |
rpn_regressor = rpn.RPNHeadConvRegressor( | |
cfg, rpn_feature.out_channels, num_anchors) | |
return nn.Sequential(rpn_feature, rpn_regressor) | |
def _get_head_stage(arch, head_name, blocks): | |
# use default name 'head' if the specific name 'head_name' does not existed | |
if head_name not in arch: | |
head_name = "head" | |
head_stage = arch.get(head_name) | |
ret = mbuilder.get_blocks(arch, stage_indices=head_stage, block_indices=blocks) | |
return ret["stages"] | |
# name mapping for head names in arch def and cfg | |
ARCH_CFG_NAME_MAPPING = { | |
"bbox": "ROI_BOX_HEAD", | |
"kpts": "ROI_KEYPOINT_HEAD", | |
"mask": "ROI_MASK_HEAD", | |
} | |
class FBNetROIHead(nn.Module): | |
def __init__( | |
self, cfg, in_channels, builder, arch_def, | |
head_name, use_blocks, stride_init, last_layer_scale, | |
): | |
super(FBNetROIHead, self).__init__() | |
assert in_channels == builder.last_depth | |
assert isinstance(use_blocks, list) | |
head_cfg_name = ARCH_CFG_NAME_MAPPING[head_name] | |
self.pooler = poolers.make_pooler(cfg, head_cfg_name) | |
stage = _get_head_stage(arch_def, head_name, use_blocks) | |
assert stride_init in [0, 1, 2] | |
if stride_init != 0: | |
stage[0]["block"][3] = stride_init | |
blocks = builder.add_blocks(stage) | |
last_info = copy.deepcopy(arch_def["last"]) | |
last_info[1] = last_layer_scale | |
last = builder.add_last(last_info) | |
self.head = nn.Sequential(OrderedDict([ | |
("blocks", blocks), | |
("last", last) | |
])) | |
self.out_channels = builder.last_depth | |
def forward(self, x, proposals): | |
x = self.pooler(x, proposals) | |
x = self.head(x) | |
return x | |
def add_roi_head(cfg, in_channels): | |
builder, model_arch = create_builder(cfg) | |
builder.last_depth = in_channels | |
# builder.name_prefix = "_[bbox]_" | |
return FBNetROIHead( | |
cfg, in_channels, builder, model_arch, | |
head_name="bbox", | |
use_blocks=cfg.MODEL.FBNET.DET_HEAD_BLOCKS, | |
stride_init=cfg.MODEL.FBNET.DET_HEAD_STRIDE, | |
last_layer_scale=cfg.MODEL.FBNET.DET_HEAD_LAST_SCALE, | |
) | |
def add_roi_head_keypoints(cfg, in_channels): | |
builder, model_arch = create_builder(cfg) | |
builder.last_depth = in_channels | |
# builder.name_prefix = "_[kpts]_" | |
return FBNetROIHead( | |
cfg, in_channels, builder, model_arch, | |
head_name="kpts", | |
use_blocks=cfg.MODEL.FBNET.KPTS_HEAD_BLOCKS, | |
stride_init=cfg.MODEL.FBNET.KPTS_HEAD_STRIDE, | |
last_layer_scale=cfg.MODEL.FBNET.KPTS_HEAD_LAST_SCALE, | |
) | |
def add_roi_head_mask(cfg, in_channels): | |
builder, model_arch = create_builder(cfg) | |
builder.last_depth = in_channels | |
# builder.name_prefix = "_[mask]_" | |
return FBNetROIHead( | |
cfg, in_channels, builder, model_arch, | |
head_name="mask", | |
use_blocks=cfg.MODEL.FBNET.MASK_HEAD_BLOCKS, | |
stride_init=cfg.MODEL.FBNET.MASK_HEAD_STRIDE, | |
last_layer_scale=cfg.MODEL.FBNET.MASK_HEAD_LAST_SCALE, | |
) | |