|
|
|
|
|
|
|
import fvcore.nn.weight_init as weight_init |
|
from torch import nn |
|
|
|
from .batch_norm import FrozenBatchNorm2d, get_norm |
|
from .wrappers import Conv2d |
|
|
|
|
|
""" |
|
CNN building blocks. |
|
""" |
|
|
|
|
|
class CNNBlockBase(nn.Module): |
|
""" |
|
A CNN block is assumed to have input channels, output channels and a stride. |
|
The input and output of `forward()` method must be NCHW tensors. |
|
The method can perform arbitrary computation but must match the given |
|
channels and stride specification. |
|
|
|
Attribute: |
|
in_channels (int): |
|
out_channels (int): |
|
stride (int): |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, stride): |
|
""" |
|
The `__init__` method of any subclass should also contain these arguments. |
|
|
|
Args: |
|
in_channels (int): |
|
out_channels (int): |
|
stride (int): |
|
""" |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.stride = stride |
|
|
|
def freeze(self): |
|
""" |
|
Make this block not trainable. |
|
This method sets all parameters to `requires_grad=False`, |
|
and convert all BatchNorm layers to FrozenBatchNorm |
|
|
|
Returns: |
|
the block itself |
|
""" |
|
for p in self.parameters(): |
|
p.requires_grad = False |
|
FrozenBatchNorm2d.convert_frozen_batchnorm(self) |
|
return self |
|
|
|
|
|
class DepthwiseSeparableConv2d(nn.Module): |
|
""" |
|
A kxk depthwise convolution + a 1x1 convolution. |
|
|
|
In :paper:`xception`, norm & activation are applied on the second conv. |
|
:paper:`mobilenet` uses norm & activation on both convs. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
padding=1, |
|
dilation=1, |
|
*, |
|
norm1=None, |
|
activation1=None, |
|
norm2=None, |
|
activation2=None, |
|
): |
|
""" |
|
Args: |
|
norm1, norm2 (str or callable): normalization for the two conv layers. |
|
activation1, activation2 (callable(Tensor) -> Tensor): activation |
|
function for the two conv layers. |
|
""" |
|
super().__init__() |
|
self.depthwise = Conv2d( |
|
in_channels, |
|
in_channels, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=in_channels, |
|
bias=not norm1, |
|
norm=get_norm(norm1, in_channels), |
|
activation=activation1, |
|
) |
|
self.pointwise = Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
bias=not norm2, |
|
norm=get_norm(norm2, out_channels), |
|
activation=activation2, |
|
) |
|
|
|
|
|
weight_init.c2_msra_fill(self.depthwise) |
|
weight_init.c2_msra_fill(self.pointwise) |
|
|
|
def forward(self, x): |
|
return self.pointwise(self.depthwise(x)) |
|
|