Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""ResNe(X)t 3D stem helper.""" | |
import torch | |
import torch.nn as nn | |
def get_stem_func(name): | |
""" | |
Retrieves the stem module by name. | |
""" | |
trans_funcs = {"x3d_stem": X3DStem, "basic_stem": ResNetBasicStem} | |
assert ( | |
name in trans_funcs.keys() | |
), "Transformation function '{}' not supported".format(name) | |
return trans_funcs[name] | |
class VideoModelStem(nn.Module): | |
""" | |
Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool | |
on input data tensor for one or multiple pathways. | |
""" | |
def __init__( | |
self, | |
dim_in, | |
dim_out, | |
kernel, | |
stride, | |
padding, | |
inplace_relu=True, | |
eps=1e-5, | |
bn_mmt=0.1, | |
norm_module=nn.BatchNorm3d, | |
stem_func_name="basic_stem", | |
): | |
""" | |
The `__init__` method of any subclass should also contain these | |
arguments. List size of 1 for single pathway models (C2D, I3D, Slow | |
and etc), list size of 2 for two pathway models (SlowFast). | |
Args: | |
dim_in (list): the list of channel dimensions of the inputs. | |
dim_out (list): the output dimension of the convolution in the stem | |
layer. | |
kernel (list): the kernels' size of the convolutions in the stem | |
layers. Temporal kernel size, height kernel size, width kernel | |
size in order. | |
stride (list): the stride sizes of the convolutions in the stem | |
layer. Temporal kernel stride, height kernel size, width kernel | |
size in order. | |
padding (list): the paddings' sizes of the convolutions in the stem | |
layer. Temporal padding size, height padding size, width padding | |
size in order. | |
inplace_relu (bool): calculate the relu on the original input | |
without allocating new memory. | |
eps (float): epsilon for batch norm. | |
bn_mmt (float): momentum for batch norm. Noted that BN momentum in | |
PyTorch = 1 - BN momentum in Caffe2. | |
norm_module (nn.Module): nn.Module for the normalization layer. The | |
default is nn.BatchNorm3d. | |
stem_func_name (string): name of the the stem function applied on | |
input to the network. | |
""" | |
super(VideoModelStem, self).__init__() | |
assert ( | |
len( | |
{ | |
len(dim_in), | |
len(dim_out), | |
len(kernel), | |
len(stride), | |
len(padding), | |
} | |
) | |
== 1 | |
), "Input pathway dimensions are not consistent. {} {} {} {} {}".format( | |
len(dim_in), | |
len(dim_out), | |
len(kernel), | |
len(stride), | |
len(padding), | |
) | |
self.num_pathways = len(dim_in) | |
self.kernel = kernel | |
self.stride = stride | |
self.padding = padding | |
self.inplace_relu = inplace_relu | |
self.eps = eps | |
self.bn_mmt = bn_mmt | |
# Construct the stem layer. | |
self._construct_stem(dim_in, dim_out, norm_module, stem_func_name) | |
def _construct_stem(self, dim_in, dim_out, norm_module, stem_func_name): | |
trans_func = get_stem_func(stem_func_name) | |
for pathway in range(len(dim_in)): | |
stem = trans_func( | |
dim_in[pathway], | |
dim_out[pathway], | |
self.kernel[pathway], | |
self.stride[pathway], | |
self.padding[pathway], | |
self.inplace_relu, | |
self.eps, | |
self.bn_mmt, | |
norm_module, | |
) | |
self.add_module("pathway{}_stem".format(pathway), stem) | |
def forward(self, x): | |
assert ( | |
len(x) == self.num_pathways | |
), "Input tensor does not contain {} pathway".format(self.num_pathways) | |
# use a new list, don't modify in-place the x list, which is bad for activation checkpointing. | |
y = [] | |
for pathway in range(len(x)): | |
m = getattr(self, "pathway{}_stem".format(pathway)) | |
y.append(m(x[pathway])) | |
return y | |
class ResNetBasicStem(nn.Module): | |
""" | |
ResNe(X)t 3D stem module. | |
Performs spatiotemporal Convolution, BN, and Relu following by a | |
spatiotemporal pooling. | |
""" | |
def __init__( | |
self, | |
dim_in, | |
dim_out, | |
kernel, | |
stride, | |
padding, | |
inplace_relu=True, | |
eps=1e-5, | |
bn_mmt=0.1, | |
norm_module=nn.BatchNorm3d, | |
): | |
""" | |
The `__init__` method of any subclass should also contain these arguments. | |
Args: | |
dim_in (int): the channel dimension of the input. Normally 3 is used | |
for rgb input, and 2 or 3 is used for optical flow input. | |
dim_out (int): the output dimension of the convolution in the stem | |
layer. | |
kernel (list): the kernel size of the convolution in the stem layer. | |
temporal kernel size, height kernel size, width kernel size in | |
order. | |
stride (list): the stride size of the convolution in the stem layer. | |
temporal kernel stride, height kernel size, width kernel size in | |
order. | |
padding (int): the padding size of the convolution in the stem | |
layer, temporal padding size, height padding size, width | |
padding size in order. | |
inplace_relu (bool): calculate the relu on the original input | |
without allocating new memory. | |
eps (float): epsilon for batch norm. | |
bn_mmt (float): momentum for batch norm. Noted that BN momentum in | |
PyTorch = 1 - BN momentum in Caffe2. | |
norm_module (nn.Module): nn.Module for the normalization layer. The | |
default is nn.BatchNorm3d. | |
""" | |
super(ResNetBasicStem, self).__init__() | |
self.kernel = kernel | |
self.stride = stride | |
self.padding = padding | |
self.inplace_relu = inplace_relu | |
self.eps = eps | |
self.bn_mmt = bn_mmt | |
# Construct the stem layer. | |
self._construct_stem(dim_in, dim_out, norm_module) | |
def _construct_stem(self, dim_in, dim_out, norm_module): | |
self.conv = nn.Conv3d( | |
dim_in, | |
dim_out, | |
self.kernel, | |
stride=self.stride, | |
padding=self.padding, | |
bias=False, | |
) | |
self.bn = norm_module( | |
num_features=dim_out, eps=self.eps, momentum=self.bn_mmt | |
) | |
self.relu = nn.ReLU(self.inplace_relu) | |
self.pool_layer = nn.MaxPool3d( | |
kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1] | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
x = self.relu(x) | |
x = self.pool_layer(x) | |
return x | |
class X3DStem(nn.Module): | |
""" | |
X3D's 3D stem module. | |
Performs a spatial followed by a depthwise temporal Convolution, BN, and Relu following by a | |
spatiotemporal pooling. | |
""" | |
def __init__( | |
self, | |
dim_in, | |
dim_out, | |
kernel, | |
stride, | |
padding, | |
inplace_relu=True, | |
eps=1e-5, | |
bn_mmt=0.1, | |
norm_module=nn.BatchNorm3d, | |
): | |
""" | |
The `__init__` method of any subclass should also contain these arguments. | |
Args: | |
dim_in (int): the channel dimension of the input. Normally 3 is used | |
for rgb input, and 2 or 3 is used for optical flow input. | |
dim_out (int): the output dimension of the convolution in the stem | |
layer. | |
kernel (list): the kernel size of the convolution in the stem layer. | |
temporal kernel size, height kernel size, width kernel size in | |
order. | |
stride (list): the stride size of the convolution in the stem layer. | |
temporal kernel stride, height kernel size, width kernel size in | |
order. | |
padding (int): the padding size of the convolution in the stem | |
layer, temporal padding size, height padding size, width | |
padding size in order. | |
inplace_relu (bool): calculate the relu on the original input | |
without allocating new memory. | |
eps (float): epsilon for batch norm. | |
bn_mmt (float): momentum for batch norm. Noted that BN momentum in | |
PyTorch = 1 - BN momentum in Caffe2. | |
norm_module (nn.Module): nn.Module for the normalization layer. The | |
default is nn.BatchNorm3d. | |
""" | |
super(X3DStem, self).__init__() | |
self.kernel = kernel | |
self.stride = stride | |
self.padding = padding | |
self.inplace_relu = inplace_relu | |
self.eps = eps | |
self.bn_mmt = bn_mmt | |
# Construct the stem layer. | |
self._construct_stem(dim_in, dim_out, norm_module) | |
def _construct_stem(self, dim_in, dim_out, norm_module): | |
self.conv_xy = nn.Conv3d( | |
dim_in, | |
dim_out, | |
kernel_size=(1, self.kernel[1], self.kernel[2]), | |
stride=(1, self.stride[1], self.stride[2]), | |
padding=(0, self.padding[1], self.padding[2]), | |
bias=False, | |
) | |
self.conv = nn.Conv3d( | |
dim_out, | |
dim_out, | |
kernel_size=(self.kernel[0], 1, 1), | |
stride=(self.stride[0], 1, 1), | |
padding=(self.padding[0], 0, 0), | |
bias=False, | |
groups=dim_out, | |
) | |
self.bn = norm_module( | |
num_features=dim_out, eps=self.eps, momentum=self.bn_mmt | |
) | |
self.relu = nn.ReLU(self.inplace_relu) | |
def forward(self, x): | |
x = self.conv_xy(x) | |
x = self.conv(x) | |
x = self.bn(x) | |
x = self.relu(x) | |
return x | |
class PatchEmbed(nn.Module): | |
""" | |
PatchEmbed. | |
""" | |
def __init__( | |
self, | |
dim_in=3, | |
dim_out=768, | |
kernel=(1, 16, 16), | |
stride=(1, 4, 4), | |
padding=(1, 7, 7), | |
conv_2d=False, | |
): | |
super().__init__() | |
if conv_2d: | |
conv = nn.Conv2d | |
else: | |
conv = nn.Conv3d | |
self.proj = conv( | |
dim_in, | |
dim_out, | |
kernel_size=kernel, | |
stride=stride, | |
padding=padding, | |
) | |
def forward(self, x, keep_spatial=False): | |
x = self.proj(x) | |
if keep_spatial: | |
return x, x.shape | |
# B C (T) H W -> B (T)HW C | |
return x.flatten(2).transpose(1, 2), x.shape |