Spaces:
Paused
Paused
# -*- coding: utf-8 -*- | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
import fvcore.nn.weight_init as weight_init | |
from torch import nn | |
from .batch_norm import FrozenBatchNorm2d, get_norm | |
from .wrappers import Conv2d | |
""" | |
CNN building blocks. | |
""" | |
class CNNBlockBase(nn.Module): | |
""" | |
A CNN block is assumed to have input channels, output channels and a stride. | |
The input and output of `forward()` method must be NCHW tensors. | |
The method can perform arbitrary computation but must match the given | |
channels and stride specification. | |
Attribute: | |
in_channels (int): | |
out_channels (int): | |
stride (int): | |
""" | |
def __init__(self, in_channels, out_channels, stride): | |
""" | |
The `__init__` method of any subclass should also contain these arguments. | |
Args: | |
in_channels (int): | |
out_channels (int): | |
stride (int): | |
""" | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.stride = stride | |
def freeze(self): | |
""" | |
Make this block not trainable. | |
This method sets all parameters to `requires_grad=False`, | |
and convert all BatchNorm layers to FrozenBatchNorm | |
Returns: | |
the block itself | |
""" | |
for p in self.parameters(): | |
p.requires_grad = False | |
FrozenBatchNorm2d.convert_frozen_batchnorm(self) | |
return self | |
class DepthwiseSeparableConv2d(nn.Module): | |
""" | |
A kxk depthwise convolution + a 1x1 convolution. | |
In :paper:`xception`, norm & activation are applied on the second conv. | |
:paper:`mobilenet` uses norm & activation on both convs. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
dilation=1, | |
*, | |
norm1=None, | |
activation1=None, | |
norm2=None, | |
activation2=None, | |
): | |
""" | |
Args: | |
norm1, norm2 (str or callable): normalization for the two conv layers. | |
activation1, activation2 (callable(Tensor) -> Tensor): activation | |
function for the two conv layers. | |
""" | |
super().__init__() | |
self.depthwise = Conv2d( | |
in_channels, | |
in_channels, | |
kernel_size=kernel_size, | |
padding=padding, | |
dilation=dilation, | |
groups=in_channels, | |
bias=not norm1, | |
norm=get_norm(norm1, in_channels), | |
activation=activation1, | |
) | |
self.pointwise = Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
bias=not norm2, | |
norm=get_norm(norm2, out_channels), | |
activation=activation2, | |
) | |
# default initialization | |
weight_init.c2_msra_fill(self.depthwise) | |
weight_init.c2_msra_fill(self.pointwise) | |
def forward(self, x): | |
return self.pointwise(self.depthwise(x)) | |