#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ResNe(X)t 3D stem helper.""" import torch.nn as nn class VideoModelStem(nn.Module): """ Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool on input data tensor for one or multiple pathways. """ def __init__( self, dim_in, dim_out, kernel, stride, padding, inplace_relu=True, eps=1e-5, bn_mmt=0.1, norm_module=nn.BatchNorm3d, ): """ The `__init__` method of any subclass should also contain these arguments. List size of 1 for single pathway models (C2D, I3D, Slow and etc), list size of 2 for two pathway models (SlowFast). Args: dim_in (list): the list of channel dimensions of the inputs. dim_out (list): the output dimension of the convolution in the stem layer. kernel (list): the kernels' size of the convolutions in the stem layers. Temporal kernel size, height kernel size, width kernel size in order. stride (list): the stride sizes of the convolutions in the stem layer. Temporal kernel stride, height kernel size, width kernel size in order. padding (list): the paddings' sizes of the convolutions in the stem layer. Temporal padding size, height padding size, width padding size in order. inplace_relu (bool): calculate the relu on the original input without allocating new memory. eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. """ super(VideoModelStem, self).__init__() assert ( len( { len(dim_in), len(dim_out), len(kernel), len(stride), len(padding), } ) == 1 ), "Input pathway dimensions are not consistent." self.num_pathways = len(dim_in) self.kernel = kernel self.stride = stride self.padding = padding self.inplace_relu = inplace_relu self.eps = eps self.bn_mmt = bn_mmt # Construct the stem layer. self._construct_stem(dim_in, dim_out, norm_module) def _construct_stem(self, dim_in, dim_out, norm_module): for pathway in range(len(dim_in)): stem = ResNetBasicStem( dim_in[pathway], dim_out[pathway], self.kernel[pathway], self.stride[pathway], self.padding[pathway], self.inplace_relu, self.eps, self.bn_mmt, norm_module, ) self.add_module("pathway{}_stem".format(pathway), stem) def forward(self, x): assert ( len(x) == self.num_pathways ), "Input tensor does not contain {} pathway".format(self.num_pathways) for pathway in range(len(x)): m = getattr(self, "pathway{}_stem".format(pathway)) x[pathway] = m(x[pathway]) return x class ResNetBasicStem(nn.Module): """ ResNe(X)t 3D stem module. Performs spatiotemporal Convolution, BN, and Relu following by a spatiotemporal pooling. """ def __init__( self, dim_in, dim_out, kernel, stride, padding, inplace_relu=True, eps=1e-5, bn_mmt=0.1, norm_module=nn.BatchNorm3d, ): """ The `__init__` method of any subclass should also contain these arguments. Args: dim_in (int): the channel dimension of the input. Normally 3 is used for rgb input, and 2 or 3 is used for optical flow input. dim_out (int): the output dimension of the convolution in the stem layer. kernel (list): the kernel size of the convolution in the stem layer. temporal kernel size, height kernel size, width kernel size in order. stride (list): the stride size of the convolution in the stem layer. temporal kernel stride, height kernel size, width kernel size in order. padding (int): the padding size of the convolution in the stem layer, temporal padding size, height padding size, width padding size in order. inplace_relu (bool): calculate the relu on the original input without allocating new memory. eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. """ super(ResNetBasicStem, self).__init__() self.kernel = kernel self.stride = stride self.padding = padding self.inplace_relu = inplace_relu self.eps = eps self.bn_mmt = bn_mmt # Construct the stem layer. self._construct_stem(dim_in, dim_out, norm_module) def _construct_stem(self, dim_in, dim_out, norm_module): self.conv = nn.Conv3d( dim_in, dim_out, self.kernel, stride=self.stride, padding=self.padding, bias=False, ) self.bn = norm_module( num_features=dim_out, eps=self.eps, momentum=self.bn_mmt ) self.relu = nn.ReLU(self.inplace_relu) self.pool_layer = nn.MaxPool3d( kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1] ) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.pool_layer(x) return x