ianpan's picture
Initial commit
231edce
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""ResNe(X)t 3D stem helper."""
import torch
import torch.nn as nn
def get_stem_func(name):
"""
Retrieves the stem module by name.
"""
trans_funcs = {"x3d_stem": X3DStem, "basic_stem": ResNetBasicStem}
assert (
name in trans_funcs.keys()
), "Transformation function '{}' not supported".format(name)
return trans_funcs[name]
class VideoModelStem(nn.Module):
"""
Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool
on input data tensor for one or multiple pathways.
"""
def __init__(
self,
dim_in,
dim_out,
kernel,
stride,
padding,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
norm_module=nn.BatchNorm3d,
stem_func_name="basic_stem",
):
"""
The `__init__` method of any subclass should also contain these
arguments. List size of 1 for single pathway models (C2D, I3D, Slow
and etc), list size of 2 for two pathway models (SlowFast).
Args:
dim_in (list): the list of channel dimensions of the inputs.
dim_out (list): the output dimension of the convolution in the stem
layer.
kernel (list): the kernels' size of the convolutions in the stem
layers. Temporal kernel size, height kernel size, width kernel
size in order.
stride (list): the stride sizes of the convolutions in the stem
layer. Temporal kernel stride, height kernel size, width kernel
size in order.
padding (list): the paddings' sizes of the convolutions in the stem
layer. Temporal padding size, height padding size, width padding
size in order.
inplace_relu (bool): calculate the relu on the original input
without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
stem_func_name (string): name of the the stem function applied on
input to the network.
"""
super(VideoModelStem, self).__init__()
assert (
len(
{
len(dim_in),
len(dim_out),
len(kernel),
len(stride),
len(padding),
}
)
== 1
), "Input pathway dimensions are not consistent. {} {} {} {} {}".format(
len(dim_in),
len(dim_out),
len(kernel),
len(stride),
len(padding),
)
self.num_pathways = len(dim_in)
self.kernel = kernel
self.stride = stride
self.padding = padding
self.inplace_relu = inplace_relu
self.eps = eps
self.bn_mmt = bn_mmt
# Construct the stem layer.
self._construct_stem(dim_in, dim_out, norm_module, stem_func_name)
def _construct_stem(self, dim_in, dim_out, norm_module, stem_func_name):
trans_func = get_stem_func(stem_func_name)
for pathway in range(len(dim_in)):
stem = trans_func(
dim_in[pathway],
dim_out[pathway],
self.kernel[pathway],
self.stride[pathway],
self.padding[pathway],
self.inplace_relu,
self.eps,
self.bn_mmt,
norm_module,
)
self.add_module("pathway{}_stem".format(pathway), stem)
def forward(self, x):
assert (
len(x) == self.num_pathways
), "Input tensor does not contain {} pathway".format(self.num_pathways)
# use a new list, don't modify in-place the x list, which is bad for activation checkpointing.
y = []
for pathway in range(len(x)):
m = getattr(self, "pathway{}_stem".format(pathway))
y.append(m(x[pathway]))
return y
class ResNetBasicStem(nn.Module):
"""
ResNe(X)t 3D stem module.
Performs spatiotemporal Convolution, BN, and Relu following by a
spatiotemporal pooling.
"""
def __init__(
self,
dim_in,
dim_out,
kernel,
stride,
padding,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
norm_module=nn.BatchNorm3d,
):
"""
The `__init__` method of any subclass should also contain these arguments.
Args:
dim_in (int): the channel dimension of the input. Normally 3 is used
for rgb input, and 2 or 3 is used for optical flow input.
dim_out (int): the output dimension of the convolution in the stem
layer.
kernel (list): the kernel size of the convolution in the stem layer.
temporal kernel size, height kernel size, width kernel size in
order.
stride (list): the stride size of the convolution in the stem layer.
temporal kernel stride, height kernel size, width kernel size in
order.
padding (int): the padding size of the convolution in the stem
layer, temporal padding size, height padding size, width
padding size in order.
inplace_relu (bool): calculate the relu on the original input
without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super(ResNetBasicStem, self).__init__()
self.kernel = kernel
self.stride = stride
self.padding = padding
self.inplace_relu = inplace_relu
self.eps = eps
self.bn_mmt = bn_mmt
# Construct the stem layer.
self._construct_stem(dim_in, dim_out, norm_module)
def _construct_stem(self, dim_in, dim_out, norm_module):
self.conv = nn.Conv3d(
dim_in,
dim_out,
self.kernel,
stride=self.stride,
padding=self.padding,
bias=False,
)
self.bn = norm_module(
num_features=dim_out, eps=self.eps, momentum=self.bn_mmt
)
self.relu = nn.ReLU(self.inplace_relu)
self.pool_layer = nn.MaxPool3d(
kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1]
)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.pool_layer(x)
return x
class X3DStem(nn.Module):
"""
X3D's 3D stem module.
Performs a spatial followed by a depthwise temporal Convolution, BN, and Relu following by a
spatiotemporal pooling.
"""
def __init__(
self,
dim_in,
dim_out,
kernel,
stride,
padding,
inplace_relu=True,
eps=1e-5,
bn_mmt=0.1,
norm_module=nn.BatchNorm3d,
):
"""
The `__init__` method of any subclass should also contain these arguments.
Args:
dim_in (int): the channel dimension of the input. Normally 3 is used
for rgb input, and 2 or 3 is used for optical flow input.
dim_out (int): the output dimension of the convolution in the stem
layer.
kernel (list): the kernel size of the convolution in the stem layer.
temporal kernel size, height kernel size, width kernel size in
order.
stride (list): the stride size of the convolution in the stem layer.
temporal kernel stride, height kernel size, width kernel size in
order.
padding (int): the padding size of the convolution in the stem
layer, temporal padding size, height padding size, width
padding size in order.
inplace_relu (bool): calculate the relu on the original input
without allocating new memory.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super(X3DStem, self).__init__()
self.kernel = kernel
self.stride = stride
self.padding = padding
self.inplace_relu = inplace_relu
self.eps = eps
self.bn_mmt = bn_mmt
# Construct the stem layer.
self._construct_stem(dim_in, dim_out, norm_module)
def _construct_stem(self, dim_in, dim_out, norm_module):
self.conv_xy = nn.Conv3d(
dim_in,
dim_out,
kernel_size=(1, self.kernel[1], self.kernel[2]),
stride=(1, self.stride[1], self.stride[2]),
padding=(0, self.padding[1], self.padding[2]),
bias=False,
)
self.conv = nn.Conv3d(
dim_out,
dim_out,
kernel_size=(self.kernel[0], 1, 1),
stride=(self.stride[0], 1, 1),
padding=(self.padding[0], 0, 0),
bias=False,
groups=dim_out,
)
self.bn = norm_module(
num_features=dim_out, eps=self.eps, momentum=self.bn_mmt
)
self.relu = nn.ReLU(self.inplace_relu)
def forward(self, x):
x = self.conv_xy(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class PatchEmbed(nn.Module):
"""
PatchEmbed.
"""
def __init__(
self,
dim_in=3,
dim_out=768,
kernel=(1, 16, 16),
stride=(1, 4, 4),
padding=(1, 7, 7),
conv_2d=False,
):
super().__init__()
if conv_2d:
conv = nn.Conv2d
else:
conv = nn.Conv3d
self.proj = conv(
dim_in,
dim_out,
kernel_size=kernel,
stride=stride,
padding=padding,
)
def forward(self, x, keep_spatial=False):
x = self.proj(x)
if keep_spatial:
return x, x.shape
# B C (T) H W -> B (T)HW C
return x.flatten(2).transpose(1, 2), x.shape