""" FBNet model builder """ from __future__ import absolute_import, division, print_function, unicode_literals import copy import logging import math from collections import OrderedDict import torch import torch.nn as nn from maskrcnn_benchmark.layers import ( BatchNorm2d, Conv2d, FrozenBatchNorm2d, interpolate, ) from maskrcnn_benchmark.layers.misc import _NewEmptyTensorOp logger = logging.getLogger(__name__) def _py2_round(x): return math.floor(x + 0.5) if x >= 0.0 else math.ceil(x - 0.5) def _get_divisible_by(num, divisible_by, min_val): ret = int(num) if divisible_by > 0 and num % divisible_by != 0: ret = int((_py2_round(num / divisible_by) or min_val) * divisible_by) return ret PRIMITIVES = { "skip": lambda C_in, C_out, expansion, stride, **kwargs: Identity( C_in, C_out, stride ), "ir_k3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, expansion, stride, **kwargs ), "ir_k5": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, expansion, stride, kernel=5, **kwargs ), "ir_k7": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, expansion, stride, kernel=7, **kwargs ), "ir_k1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, expansion, stride, kernel=1, **kwargs ), "shuffle": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, expansion, stride, shuffle_type="mid", pw_group=4, **kwargs ), "basic_block": lambda C_in, C_out, expansion, stride, **kwargs: CascadeConv3x3( C_in, C_out, stride ), "shift_5x5": lambda C_in, C_out, expansion, stride, **kwargs: ShiftBlock5x5( C_in, C_out, expansion, stride ), # layer search 2 "ir_k3_e1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=3, **kwargs ), "ir_k3_e3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 3, stride, kernel=3, **kwargs ), "ir_k3_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 6, stride, kernel=3, **kwargs ), "ir_k3_s4": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 4, stride, kernel=3, shuffle_type="mid", pw_group=4, **kwargs ), "ir_k5_e1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=5, **kwargs ), "ir_k5_e3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 3, stride, kernel=5, **kwargs ), "ir_k5_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 6, stride, kernel=5, **kwargs ), "ir_k5_s4": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 4, stride, kernel=5, shuffle_type="mid", pw_group=4, **kwargs ), # layer search se "ir_k3_e1_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=3, se=True, **kwargs ), "ir_k3_e3_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 3, stride, kernel=3, se=True, **kwargs ), "ir_k3_e6_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 6, stride, kernel=3, se=True, **kwargs ), "ir_k3_s4_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 4, stride, kernel=3, shuffle_type="mid", pw_group=4, se=True, **kwargs ), "ir_k5_e1_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=5, se=True, **kwargs ), "ir_k5_e3_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 3, stride, kernel=5, se=True, **kwargs ), "ir_k5_e6_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 6, stride, kernel=5, se=True, **kwargs ), "ir_k5_s4_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 4, stride, kernel=5, shuffle_type="mid", pw_group=4, se=True, **kwargs ), # layer search 3 (in addition to layer search 2) "ir_k3_s2": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=3, shuffle_type="mid", pw_group=2, **kwargs ), "ir_k5_s2": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=5, shuffle_type="mid", pw_group=2, **kwargs ), "ir_k3_s2_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=3, shuffle_type="mid", pw_group=2, se=True, **kwargs ), "ir_k5_s2_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=5, shuffle_type="mid", pw_group=2, se=True, **kwargs ), # layer search 4 (in addition to layer search 3) "ir_k3_sep": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, expansion, stride, kernel=3, cdw=True, **kwargs ), "ir_k33_e1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=3, cdw=True, **kwargs ), "ir_k33_e3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 3, stride, kernel=3, cdw=True, **kwargs ), "ir_k33_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 6, stride, kernel=3, cdw=True, **kwargs ), # layer search 5 (in addition to layer search 4) "ir_k7_e1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=7, **kwargs ), "ir_k7_e3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 3, stride, kernel=7, **kwargs ), "ir_k7_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 6, stride, kernel=7, **kwargs ), "ir_k7_sep": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, expansion, stride, kernel=7, cdw=True, **kwargs ), "ir_k7_sep_e1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 1, stride, kernel=7, cdw=True, **kwargs ), "ir_k7_sep_e3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 3, stride, kernel=7, cdw=True, **kwargs ), "ir_k7_sep_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( C_in, C_out, 6, stride, kernel=7, cdw=True, **kwargs ), } class Identity(nn.Module): def __init__(self, C_in, C_out, stride): super(Identity, self).__init__() self.conv = ( ConvBNRelu( C_in, C_out, kernel=1, stride=stride, pad=0, no_bias=1, use_relu="relu", bn_type="bn", ) if C_in != C_out or stride != 1 else None ) def forward(self, x): if self.conv: out = self.conv(x) else: out = x return out class CascadeConv3x3(nn.Sequential): def __init__(self, C_in, C_out, stride): assert stride in [1, 2] ops = [ Conv2d(C_in, C_in, 3, stride, 1, bias=False), BatchNorm2d(C_in), nn.ReLU(inplace=True), Conv2d(C_in, C_out, 3, 1, 1, bias=False), BatchNorm2d(C_out), ] super(CascadeConv3x3, self).__init__(*ops) self.res_connect = (stride == 1) and (C_in == C_out) def forward(self, x): y = super(CascadeConv3x3, self).forward(x) if self.res_connect: y += x return y class Shift(nn.Module): def __init__(self, C, kernel_size, stride, padding): super(Shift, self).__init__() self.C = C kernel = torch.zeros((C, 1, kernel_size, kernel_size), dtype=torch.float32) ch_idx = 0 assert stride in [1, 2] self.stride = stride self.padding = padding self.kernel_size = kernel_size self.dilation = 1 hks = kernel_size // 2 ksq = kernel_size ** 2 for i in range(kernel_size): for j in range(kernel_size): if i == hks and j == hks: num_ch = C // ksq + C % ksq else: num_ch = C // ksq kernel[ch_idx : ch_idx + num_ch, 0, i, j] = 1 ch_idx += num_ch self.register_parameter("bias", None) self.kernel = nn.Parameter(kernel, requires_grad=False) def forward(self, x): if x.numel() > 0: return nn.functional.conv2d( x, self.kernel, self.bias, (self.stride, self.stride), (self.padding, self.padding), self.dilation, self.C, # groups ) output_shape = [ (i + 2 * p - (di * (k - 1) + 1)) // d + 1 for i, p, di, k, d in zip( x.shape[-2:], (self.padding, self.dilation), (self.dilation, self.dilation), (self.kernel_size, self.kernel_size), (self.stride, self.stride), ) ] output_shape = [x.shape[0], self.C] + output_shape return _NewEmptyTensorOp.apply(x, output_shape) class ShiftBlock5x5(nn.Sequential): def __init__(self, C_in, C_out, expansion, stride): assert stride in [1, 2] self.res_connect = (stride == 1) and (C_in == C_out) C_mid = _get_divisible_by(C_in * expansion, 8, 8) ops = [ # pw Conv2d(C_in, C_mid, 1, 1, 0, bias=False), BatchNorm2d(C_mid), nn.ReLU(inplace=True), # shift Shift(C_mid, 5, stride, 2), # pw-linear Conv2d(C_mid, C_out, 1, 1, 0, bias=False), BatchNorm2d(C_out), ] super(ShiftBlock5x5, self).__init__(*ops) def forward(self, x): y = super(ShiftBlock5x5, self).forward(x) if self.res_connect: y += x return y class ChannelShuffle(nn.Module): def __init__(self, groups): super(ChannelShuffle, self).__init__() self.groups = groups def forward(self, x): """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" N, C, H, W = x.size() g = self.groups assert C % g == 0, "Incompatible group size {} for input channel {}".format( g, C ) return ( x.view(N, g, int(C / g), H, W) .permute(0, 2, 1, 3, 4) .contiguous() .view(N, C, H, W) ) class ConvBNRelu(nn.Sequential): def __init__( self, input_depth, output_depth, kernel, stride, pad, no_bias, use_relu, bn_type, group=1, *args, **kwargs ): super(ConvBNRelu, self).__init__() assert use_relu in ["relu", None] if isinstance(bn_type, (list, tuple)): assert len(bn_type) == 2 assert bn_type[0] == "gn" gn_group = bn_type[1] bn_type = bn_type[0] assert bn_type in ["bn", "af", "gn", None] assert stride in [1, 2, 4] op = Conv2d( input_depth, output_depth, kernel_size=kernel, stride=stride, padding=pad, bias=not no_bias, groups=group, *args, **kwargs ) nn.init.kaiming_normal_(op.weight, mode="fan_out", nonlinearity="relu") if op.bias is not None: nn.init.constant_(op.bias, 0.0) self.add_module("conv", op) if bn_type == "bn": bn_op = BatchNorm2d(output_depth) elif bn_type == "gn": bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=output_depth) elif bn_type == "af": bn_op = FrozenBatchNorm2d(output_depth) if bn_type is not None: self.add_module("bn", bn_op) if use_relu == "relu": self.add_module("relu", nn.ReLU(inplace=True)) class SEModule(nn.Module): reduction = 4 def __init__(self, C): super(SEModule, self).__init__() mid = max(C // self.reduction, 8) conv1 = Conv2d(C, mid, 1, 1, 0) conv2 = Conv2d(mid, C, 1, 1, 0) self.op = nn.Sequential( nn.AdaptiveAvgPool2d(1), conv1, nn.ReLU(inplace=True), conv2, nn.Sigmoid() ) def forward(self, x): return x * self.op(x) class Upsample(nn.Module): def __init__(self, scale_factor, mode, align_corners=None): super(Upsample, self).__init__() self.scale = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): return interpolate( x, scale_factor=self.scale, mode=self.mode, align_corners=self.align_corners ) def _get_upsample_op(stride): assert ( stride in [1, 2, 4] or stride in [-1, -2, -4] or (isinstance(stride, tuple) and all(x in [-1, -2, -4] for x in stride)) ) scales = stride ret = None if isinstance(stride, tuple) or stride < 0: scales = [-x for x in stride] if isinstance(stride, tuple) else -stride stride = 1 ret = Upsample(scale_factor=scales, mode="nearest", align_corners=None) return ret, stride class IRFBlock(nn.Module): def __init__( self, input_depth, output_depth, expansion, stride, bn_type="bn", kernel=3, width_divisor=1, shuffle_type=None, pw_group=1, se=False, cdw=False, dw_skip_bn=False, dw_skip_relu=False, ): super(IRFBlock, self).__init__() assert kernel in [1, 3, 5, 7], kernel self.use_res_connect = stride == 1 and input_depth == output_depth self.output_depth = output_depth mid_depth = int(input_depth * expansion) mid_depth = _get_divisible_by(mid_depth, width_divisor, width_divisor) # pw self.pw = ConvBNRelu( input_depth, mid_depth, kernel=1, stride=1, pad=0, no_bias=1, use_relu="relu", bn_type=bn_type, group=pw_group, ) # negative stride to do upsampling self.upscale, stride = _get_upsample_op(stride) # dw if kernel == 1: self.dw = nn.Sequential() elif cdw: dw1 = ConvBNRelu( mid_depth, mid_depth, kernel=kernel, stride=stride, pad=(kernel // 2), group=mid_depth, no_bias=1, use_relu="relu", bn_type=bn_type, ) dw2 = ConvBNRelu( mid_depth, mid_depth, kernel=kernel, stride=1, pad=(kernel // 2), group=mid_depth, no_bias=1, use_relu="relu" if not dw_skip_relu else None, bn_type=bn_type if not dw_skip_bn else None, ) self.dw = nn.Sequential(OrderedDict([("dw1", dw1), ("dw2", dw2)])) else: self.dw = ConvBNRelu( mid_depth, mid_depth, kernel=kernel, stride=stride, pad=(kernel // 2), group=mid_depth, no_bias=1, use_relu="relu" if not dw_skip_relu else None, bn_type=bn_type if not dw_skip_bn else None, ) # pw-linear self.pwl = ConvBNRelu( mid_depth, output_depth, kernel=1, stride=1, pad=0, no_bias=1, use_relu=None, bn_type=bn_type, group=pw_group, ) self.shuffle_type = shuffle_type if shuffle_type is not None: self.shuffle = ChannelShuffle(pw_group) self.se4 = SEModule(output_depth) if se else nn.Sequential() self.output_depth = output_depth def forward(self, x): y = self.pw(x) if self.shuffle_type == "mid": y = self.shuffle(y) if self.upscale is not None: y = self.upscale(y) y = self.dw(y) y = self.pwl(y) if self.use_res_connect: y += x y = self.se4(y) return y def _expand_block_cfg(block_cfg): assert isinstance(block_cfg, list) ret = [] for idx in range(block_cfg[2]): cur = copy.deepcopy(block_cfg) cur[2] = 1 cur[3] = 1 if idx >= 1 else cur[3] ret.append(cur) return ret def expand_stage_cfg(stage_cfg): """ For a single stage """ assert isinstance(stage_cfg, list) ret = [] for x in stage_cfg: ret += _expand_block_cfg(x) return ret def expand_stages_cfg(stage_cfgs): """ For a list of stages """ assert isinstance(stage_cfgs, list) ret = [] for x in stage_cfgs: ret.append(expand_stage_cfg(x)) return ret def _block_cfgs_to_list(block_cfgs): assert isinstance(block_cfgs, list) ret = [] for stage_idx, stage in enumerate(block_cfgs): stage = expand_stage_cfg(stage) for block_idx, block in enumerate(stage): cur = {"stage_idx": stage_idx, "block_idx": block_idx, "block": block} ret.append(cur) return ret def _add_to_arch(arch, info, name): """ arch = [{block_0}, {block_1}, ...] info = [ # stage 0 [ block0_info, block1_info, ... ], ... ] convert to: arch = [ { block_0, name: block0_info, }, { block_1, name: block1_info, }, ... ] """ assert isinstance(arch, list) and all(isinstance(x, dict) for x in arch) assert isinstance(info, list) and all(isinstance(x, list) for x in info) idx = 0 for stage_idx, stage in enumerate(info): for block_idx, block in enumerate(stage): assert ( arch[idx]["stage_idx"] == stage_idx and arch[idx]["block_idx"] == block_idx ), "Index ({}, {}) does not match for block {}".format( stage_idx, block_idx, arch[idx] ) assert name not in arch[idx] arch[idx][name] = block idx += 1 def unify_arch_def(arch_def): """ unify the arch_def to: { ..., "arch": [ { "stage_idx": idx, "block_idx": idx, ... }, {}, ... ] } """ ret = copy.deepcopy(arch_def) assert "block_cfg" in arch_def and "stages" in arch_def["block_cfg"] assert "stages" not in ret # copy 'first', 'last' etc. inside arch_def['block_cfg'] to ret ret.update({x: arch_def["block_cfg"][x] for x in arch_def["block_cfg"]}) ret["stages"] = _block_cfgs_to_list(arch_def["block_cfg"]["stages"]) del ret["block_cfg"] assert "block_op_type" in arch_def _add_to_arch(ret["stages"], arch_def["block_op_type"], "block_op_type") del ret["block_op_type"] return ret def get_num_stages(arch_def): ret = 0 for x in arch_def["stages"]: ret = max(x["stage_idx"], ret) ret = ret + 1 return ret def get_blocks(arch_def, stage_indices=None, block_indices=None): ret = copy.deepcopy(arch_def) ret["stages"] = [] for block in arch_def["stages"]: keep = True if stage_indices not in (None, []) and block["stage_idx"] not in stage_indices: keep = False if block_indices not in (None, []) and block["block_idx"] not in block_indices: keep = False if keep: ret["stages"].append(block) return ret class FBNetBuilder(object): def __init__( self, width_ratio, bn_type="bn", width_divisor=1, dw_skip_bn=False, dw_skip_relu=False, ): self.width_ratio = width_ratio self.last_depth = -1 self.bn_type = bn_type self.width_divisor = width_divisor self.dw_skip_bn = dw_skip_bn self.dw_skip_relu = dw_skip_relu def add_first(self, stage_info, dim_in=3, pad=True): # stage_info: [c, s, kernel] assert len(stage_info) >= 2 channel = stage_info[0] stride = stage_info[1] out_depth = self._get_divisible_width(int(channel * self.width_ratio)) kernel = 3 if len(stage_info) > 2: kernel = stage_info[2] out = ConvBNRelu( dim_in, out_depth, kernel=kernel, stride=stride, pad=kernel // 2 if pad else 0, no_bias=1, use_relu="relu", bn_type=self.bn_type, ) self.last_depth = out_depth return out def add_blocks(self, blocks): """ blocks: [{}, {}, ...] """ assert isinstance(blocks, list) and all( isinstance(x, dict) for x in blocks ), blocks modules = OrderedDict() for block in blocks: stage_idx = block["stage_idx"] block_idx = block["block_idx"] block_op_type = block["block_op_type"] tcns = block["block"] n = tcns[2] assert n == 1 nnblock = self.add_ir_block(tcns, [block_op_type]) nn_name = "xif{}_{}".format(stage_idx, block_idx) assert nn_name not in modules modules[nn_name] = nnblock ret = nn.Sequential(modules) return ret def add_last(self, stage_info): """ skip last layer if channel_scale == 0 use the same output channel if channel_scale < 0 """ assert len(stage_info) == 2 channels = stage_info[0] channel_scale = stage_info[1] if channel_scale == 0.0: return nn.Sequential() if channel_scale > 0: last_channel = ( int(channels * self.width_ratio) if self.width_ratio > 1.0 else channels ) last_channel = int(last_channel * channel_scale) else: last_channel = int(self.last_depth * (-channel_scale)) last_channel = self._get_divisible_width(last_channel) if last_channel == 0: return nn.Sequential() dim_in = self.last_depth ret = ConvBNRelu( dim_in, last_channel, kernel=1, stride=1, pad=0, no_bias=1, use_relu="relu", bn_type=self.bn_type, ) self.last_depth = last_channel return ret # def add_final_pool(self, model, blob_in, kernel_size): # ret = model.AveragePool(blob_in, "final_avg", kernel=kernel_size, stride=1) # return ret def _add_ir_block( self, dim_in, dim_out, stride, expand_ratio, block_op_type, **kwargs ): ret = PRIMITIVES[block_op_type]( dim_in, dim_out, expansion=expand_ratio, stride=stride, bn_type=self.bn_type, width_divisor=self.width_divisor, dw_skip_bn=self.dw_skip_bn, dw_skip_relu=self.dw_skip_relu, **kwargs ) return ret, ret.output_depth def add_ir_block(self, tcns, block_op_types, **kwargs): t, c, n, s = tcns assert n == 1 out_depth = self._get_divisible_width(int(c * self.width_ratio)) dim_in = self.last_depth op, ret_depth = self._add_ir_block( dim_in, out_depth, stride=s, expand_ratio=t, block_op_type=block_op_types[0], **kwargs ) self.last_depth = ret_depth return op def _get_divisible_width(self, width): ret = _get_divisible_by(int(width), self.width_divisor, self.width_divisor) return ret