|
""" EfficientNet, MobileNetV3, etc Blocks |
|
|
|
Hacked together by / Copyright 2019, Ross Wightman |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer |
|
|
|
__all__ = [ |
|
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] |
|
|
|
|
|
def num_groups(group_size, channels): |
|
if not group_size: |
|
return 1 |
|
else: |
|
|
|
assert channels % group_size == 0 |
|
return channels // group_size |
|
|
|
|
|
class SqueezeExcite(nn.Module): |
|
""" Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family |
|
|
|
Args: |
|
in_chs (int): input channels to layer |
|
rd_ratio (float): ratio of squeeze reduction |
|
act_layer (nn.Module): activation layer of containing block |
|
gate_layer (Callable): attention gate function |
|
force_act_layer (nn.Module): override block's activation fn if this is set/bound |
|
rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs |
|
""" |
|
|
|
def __init__( |
|
self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU, |
|
gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None): |
|
super(SqueezeExcite, self).__init__() |
|
if rd_channels is None: |
|
rd_round_fn = rd_round_fn or round |
|
rd_channels = rd_round_fn(in_chs * rd_ratio) |
|
act_layer = force_act_layer or act_layer |
|
self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True) |
|
self.act1 = create_act_layer(act_layer, inplace=True) |
|
self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True) |
|
self.gate = create_act_layer(gate_layer) |
|
|
|
def forward(self, x): |
|
x_se = x.mean((2, 3), keepdim=True) |
|
x_se = self.conv_reduce(x_se) |
|
x_se = self.act1(x_se) |
|
x_se = self.conv_expand(x_se) |
|
return x * self.gate(x_se) |
|
|
|
|
|
class ConvBnAct(nn.Module): |
|
""" Conv + Norm Layer + Activation w/ optional skip connection |
|
""" |
|
def __init__( |
|
self, in_chs, out_chs, kernel_size, stride=1, dilation=1, group_size=0, pad_type='', |
|
skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.): |
|
super(ConvBnAct, self).__init__() |
|
norm_act_layer = get_norm_act_layer(norm_layer, act_layer) |
|
groups = num_groups(group_size, in_chs) |
|
self.has_skip = skip and stride == 1 and in_chs == out_chs |
|
|
|
self.conv = create_conv2d( |
|
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type) |
|
self.bn1 = norm_act_layer(out_chs, inplace=True) |
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() |
|
|
|
def feature_info(self, location): |
|
if location == 'expansion': |
|
return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels) |
|
else: |
|
return dict(module='', num_chs=self.conv.out_channels) |
|
|
|
def forward(self, x): |
|
shortcut = x |
|
x = self.conv(x) |
|
x = self.bn1(x) |
|
if self.has_skip: |
|
x = self.drop_path(x) + shortcut |
|
return x |
|
|
|
|
|
class DepthwiseSeparableConv(nn.Module): |
|
""" DepthwiseSeparable block |
|
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion |
|
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv. |
|
""" |
|
def __init__( |
|
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', |
|
noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, |
|
se_layer=None, drop_path_rate=0.): |
|
super(DepthwiseSeparableConv, self).__init__() |
|
norm_act_layer = get_norm_act_layer(norm_layer, act_layer) |
|
groups = num_groups(group_size, in_chs) |
|
self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip |
|
self.has_pw_act = pw_act |
|
|
|
self.conv_dw = create_conv2d( |
|
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups) |
|
self.bn1 = norm_act_layer(in_chs, inplace=True) |
|
|
|
|
|
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() |
|
|
|
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) |
|
self.bn2 = norm_act_layer(out_chs, inplace=True, apply_act=self.has_pw_act) |
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() |
|
|
|
def feature_info(self, location): |
|
if location == 'expansion': |
|
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) |
|
else: |
|
return dict(module='', num_chs=self.conv_pw.out_channels) |
|
|
|
def forward(self, x): |
|
shortcut = x |
|
x = self.conv_dw(x) |
|
x = self.bn1(x) |
|
x = self.se(x) |
|
x = self.conv_pw(x) |
|
x = self.bn2(x) |
|
if self.has_skip: |
|
x = self.drop_path(x) + shortcut |
|
return x |
|
|
|
|
|
class InvertedResidual(nn.Module): |
|
""" Inverted residual block w/ optional SE |
|
|
|
Originally used in MobileNet-V2 - https://arxiv.org/abs/1801.04381v4, this layer is often |
|
referred to as 'MBConv' for (Mobile inverted bottleneck conv) and is also used in |
|
* MNasNet - https://arxiv.org/abs/1807.11626 |
|
* EfficientNet - https://arxiv.org/abs/1905.11946 |
|
* MobileNet-V3 - https://arxiv.org/abs/1905.02244 |
|
""" |
|
|
|
def __init__( |
|
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', |
|
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, |
|
norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): |
|
super(InvertedResidual, self).__init__() |
|
norm_act_layer = get_norm_act_layer(norm_layer, act_layer) |
|
conv_kwargs = conv_kwargs or {} |
|
mid_chs = make_divisible(in_chs * exp_ratio) |
|
groups = num_groups(group_size, mid_chs) |
|
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip |
|
|
|
|
|
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) |
|
self.bn1 = norm_act_layer(mid_chs, inplace=True) |
|
|
|
|
|
self.conv_dw = create_conv2d( |
|
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, |
|
groups=groups, padding=pad_type, **conv_kwargs) |
|
self.bn2 = norm_act_layer(mid_chs, inplace=True) |
|
|
|
|
|
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() |
|
|
|
|
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) |
|
self.bn3 = norm_act_layer(out_chs, apply_act=False) |
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() |
|
|
|
def feature_info(self, location): |
|
if location == 'expansion': |
|
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) |
|
else: |
|
return dict(module='', num_chs=self.conv_pwl.out_channels) |
|
|
|
def forward(self, x): |
|
shortcut = x |
|
x = self.conv_pw(x) |
|
x = self.bn1(x) |
|
x = self.conv_dw(x) |
|
x = self.bn2(x) |
|
x = self.se(x) |
|
x = self.conv_pwl(x) |
|
x = self.bn3(x) |
|
if self.has_skip: |
|
x = self.drop_path(x) + shortcut |
|
return x |
|
|
|
|
|
class CondConvResidual(InvertedResidual): |
|
""" Inverted residual block w/ CondConv routing""" |
|
|
|
def __init__( |
|
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', |
|
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, |
|
norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): |
|
|
|
self.num_experts = num_experts |
|
conv_kwargs = dict(num_experts=self.num_experts) |
|
|
|
super(CondConvResidual, self).__init__( |
|
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, group_size=group_size, |
|
pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, |
|
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs, |
|
drop_path_rate=drop_path_rate) |
|
|
|
self.routing_fn = nn.Linear(in_chs, self.num_experts) |
|
|
|
def forward(self, x): |
|
shortcut = x |
|
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) |
|
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) |
|
x = self.conv_pw(x, routing_weights) |
|
x = self.bn1(x) |
|
x = self.conv_dw(x, routing_weights) |
|
x = self.bn2(x) |
|
x = self.se(x) |
|
x = self.conv_pwl(x, routing_weights) |
|
x = self.bn3(x) |
|
if self.has_skip: |
|
x = self.drop_path(x) + shortcut |
|
return x |
|
|
|
|
|
class EdgeResidual(nn.Module): |
|
""" Residual block with expansion convolution followed by pointwise-linear w/ stride |
|
|
|
Originally introduced in `EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML` |
|
- https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html |
|
|
|
This layer is also called FusedMBConv in the MobileDet, EfficientNet-X, and EfficientNet-V2 papers |
|
* MobileDet - https://arxiv.org/abs/2004.14525 |
|
* EfficientNet-X - https://arxiv.org/abs/2102.05610 |
|
* EfficientNet-V2 - https://arxiv.org/abs/2104.00298 |
|
""" |
|
|
|
def __init__( |
|
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, group_size=0, pad_type='', |
|
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU, |
|
norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): |
|
super(EdgeResidual, self).__init__() |
|
norm_act_layer = get_norm_act_layer(norm_layer, act_layer) |
|
if force_in_chs > 0: |
|
mid_chs = make_divisible(force_in_chs * exp_ratio) |
|
else: |
|
mid_chs = make_divisible(in_chs * exp_ratio) |
|
groups = num_groups(group_size, in_chs) |
|
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip |
|
|
|
|
|
self.conv_exp = create_conv2d( |
|
in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type) |
|
self.bn1 = norm_act_layer(mid_chs, inplace=True) |
|
|
|
|
|
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() |
|
|
|
|
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) |
|
self.bn2 = norm_act_layer(out_chs, apply_act=False) |
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() |
|
|
|
def feature_info(self, location): |
|
if location == 'expansion': |
|
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) |
|
else: |
|
return dict(module='', num_chs=self.conv_pwl.out_channels) |
|
|
|
def forward(self, x): |
|
shortcut = x |
|
x = self.conv_exp(x) |
|
x = self.bn1(x) |
|
x = self.se(x) |
|
x = self.conv_pwl(x) |
|
x = self.bn2(x) |
|
if self.has_skip: |
|
x = self.drop_path(x) + shortcut |
|
return x |
|
|