Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Contains common building blocks for MoViNets. | |
Reference: https://arxiv.org/pdf/2103.11511.pdf | |
""" | |
from typing import Any, Mapping, Optional, Sequence, Tuple, Union | |
import tensorflow as tf, tf_keras | |
from official.modeling import tf_utils | |
from official.vision.modeling.layers import nn_layers | |
# Default kernel weight decay that may be overridden | |
KERNEL_WEIGHT_DECAY = 1.5e-5 | |
def normalize_tuple(value: Union[int, Tuple[int, ...]], size: int, name: str): | |
"""Transforms a single integer or iterable of integers into an integer tuple. | |
Arguments: | |
value: The value to validate and convert. Could an int, or any iterable of | |
ints. | |
size: The size of the tuple to be returned. | |
name: The name of the argument being validated, e.g. "strides" or | |
"kernel_size". This is only used to format error messages. | |
Returns: | |
A tuple of `size` integers. | |
Raises: | |
ValueError: If something else than an int/long or iterable thereof was | |
passed. | |
""" | |
if isinstance(value, int): | |
return (value,) * size | |
else: | |
try: | |
value_tuple = tuple(value) | |
except TypeError: | |
raise ValueError('The `' + name + '` argument must be a tuple of ' + | |
str(size) + ' integers. Received: ' + str(value)) | |
if len(value_tuple) != size: | |
raise ValueError('The `' + name + '` argument must be a tuple of ' + | |
str(size) + ' integers. Received: ' + str(value)) | |
for single_value in value_tuple: | |
try: | |
int(single_value) | |
except (ValueError, TypeError): | |
raise ValueError('The `' + name + '` argument must be a tuple of ' + | |
str(size) + ' integers. Received: ' + str(value) + ' ' | |
'including element ' + str(single_value) + ' of type' + | |
' ' + str(type(single_value))) | |
return value_tuple | |
class Squeeze3D(tf_keras.layers.Layer): | |
"""Squeeze3D layer to remove singular dimensions.""" | |
def call(self, inputs): | |
"""Calls the layer with the given inputs.""" | |
return tf.squeeze(inputs, axis=(1, 2, 3)) | |
class MobileConv2D(tf_keras.layers.Layer): | |
"""Conv2D layer with extra options to support mobile devices. | |
Reshapes 5D video tensor inputs to 4D, allowing Conv2D to run across | |
dimensions (2, 3) or (3, 4). Reshapes tensors back to 5D when returning the | |
output. | |
""" | |
def __init__( | |
self, | |
filters: int, | |
kernel_size: Union[int, Sequence[int]], | |
strides: Union[int, Sequence[int]] = (1, 1), | |
padding: str = 'valid', | |
data_format: Optional[str] = None, | |
dilation_rate: Union[int, Sequence[int]] = (1, 1), | |
groups: int = 1, | |
use_bias: bool = True, | |
kernel_initializer: str = 'glorot_uniform', | |
bias_initializer: str = 'zeros', | |
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, | |
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, | |
activity_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, | |
kernel_constraint: Optional[tf_keras.constraints.Constraint] = None, | |
bias_constraint: Optional[tf_keras.constraints.Constraint] = None, | |
use_depthwise: bool = False, | |
use_temporal: bool = False, | |
use_buffered_input: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras | |
batch_norm_op: Optional[Any] = None, | |
activation_op: Optional[Any] = None, | |
**kwargs): # pylint: disable=g-doc-args | |
"""Initializes mobile conv2d. | |
For the majority of arguments, see tf_keras.layers.Conv2D. | |
Args: | |
use_depthwise: if True, use DepthwiseConv2D instead of Conv2D | |
use_temporal: if True, apply Conv2D starting from the temporal dimension | |
instead of the spatial dimensions. | |
use_buffered_input: if True, the input is expected to be padded | |
beforehand. In effect, calling this layer will use 'valid' padding on | |
the temporal dimension to simulate 'causal' padding. | |
batch_norm_op: A callable object of batch norm layer. If None, no batch | |
norm will be applied after the convolution. | |
activation_op: A callabel object of activation layer. If None, no | |
activation will be applied after the convolution. | |
**kwargs: keyword arguments to be passed to this layer. | |
Returns: | |
A output tensor of the MobileConv2D operation. | |
""" | |
super(MobileConv2D, self).__init__(**kwargs) | |
self._filters = filters | |
self._kernel_size = kernel_size | |
self._strides = strides | |
self._padding = padding | |
self._data_format = data_format | |
self._dilation_rate = dilation_rate | |
self._groups = groups | |
self._use_bias = use_bias | |
self._kernel_initializer = kernel_initializer | |
self._bias_initializer = bias_initializer | |
self._kernel_regularizer = kernel_regularizer | |
self._bias_regularizer = bias_regularizer | |
self._activity_regularizer = activity_regularizer | |
self._kernel_constraint = kernel_constraint | |
self._bias_constraint = bias_constraint | |
self._use_depthwise = use_depthwise | |
self._use_temporal = use_temporal | |
self._use_buffered_input = use_buffered_input | |
self._batch_norm_op = batch_norm_op | |
self._activation_op = activation_op | |
kernel_size = normalize_tuple(kernel_size, 2, 'kernel_size') | |
if self._use_temporal and kernel_size[1] > 1: | |
raise ValueError('Temporal conv with spatial kernel is not supported.') | |
if use_depthwise: | |
self._conv = nn_layers.DepthwiseConv2D( | |
kernel_size=kernel_size, | |
strides=strides, | |
padding=padding, | |
depth_multiplier=1, | |
data_format=data_format, | |
dilation_rate=dilation_rate, | |
use_bias=use_bias, | |
depthwise_initializer=kernel_initializer, | |
bias_initializer=bias_initializer, | |
depthwise_regularizer=kernel_regularizer, | |
bias_regularizer=bias_regularizer, | |
activity_regularizer=activity_regularizer, | |
depthwise_constraint=kernel_constraint, | |
bias_constraint=bias_constraint, | |
use_buffered_input=use_buffered_input) | |
else: | |
self._conv = nn_layers.Conv2D( | |
filters=filters, | |
kernel_size=kernel_size, | |
strides=strides, | |
padding=padding, | |
data_format=data_format, | |
dilation_rate=dilation_rate, | |
groups=groups, | |
use_bias=use_bias, | |
kernel_initializer=kernel_initializer, | |
bias_initializer=bias_initializer, | |
kernel_regularizer=kernel_regularizer, | |
bias_regularizer=bias_regularizer, | |
activity_regularizer=activity_regularizer, | |
kernel_constraint=kernel_constraint, | |
bias_constraint=bias_constraint, | |
use_buffered_input=use_buffered_input) | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = { | |
'filters': self._filters, | |
'kernel_size': self._kernel_size, | |
'strides': self._strides, | |
'padding': self._padding, | |
'data_format': self._data_format, | |
'dilation_rate': self._dilation_rate, | |
'groups': self._groups, | |
'use_bias': self._use_bias, | |
'kernel_initializer': self._kernel_initializer, | |
'bias_initializer': self._bias_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
'bias_regularizer': self._bias_regularizer, | |
'activity_regularizer': self._activity_regularizer, | |
'kernel_constraint': self._kernel_constraint, | |
'bias_constraint': self._bias_constraint, | |
'use_depthwise': self._use_depthwise, | |
'use_temporal': self._use_temporal, | |
'use_buffered_input': self._use_buffered_input, | |
} | |
base_config = super(MobileConv2D, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs): | |
"""Calls the layer with the given inputs.""" | |
if self._use_temporal: | |
input_shape = [ | |
tf.shape(inputs)[0], | |
tf.shape(inputs)[1], | |
tf.shape(inputs)[2] * tf.shape(inputs)[3], | |
inputs.shape[4]] | |
else: | |
input_shape = [ | |
tf.shape(inputs)[0] * tf.shape(inputs)[1], | |
tf.shape(inputs)[2], | |
tf.shape(inputs)[3], | |
inputs.shape[4]] | |
x = tf.reshape(inputs, input_shape) | |
x = self._conv(x) | |
if self._batch_norm_op is not None: | |
x = self._batch_norm_op(x) | |
if self._activation_op is not None: | |
x = self._activation_op(x) | |
if self._use_temporal: | |
output_shape = [ | |
tf.shape(x)[0], | |
tf.shape(x)[1], | |
tf.shape(inputs)[2], | |
tf.shape(inputs)[3], | |
x.shape[3]] | |
else: | |
output_shape = [ | |
tf.shape(inputs)[0], | |
tf.shape(inputs)[1], | |
tf.shape(x)[1], | |
tf.shape(x)[2], | |
x.shape[3]] | |
x = tf.reshape(x, output_shape) | |
return x | |
class ConvBlock(tf_keras.layers.Layer): | |
"""A Conv followed by optional BatchNorm and Activation.""" | |
def __init__( | |
self, | |
filters: int, | |
kernel_size: Union[int, Sequence[int]], | |
strides: Union[int, Sequence[int]] = 1, | |
depthwise: bool = False, | |
causal: bool = False, | |
use_bias: bool = False, | |
kernel_initializer: tf_keras.initializers.Initializer = 'HeNormal', | |
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = | |
tf_keras.regularizers.L2(KERNEL_WEIGHT_DECAY), | |
use_batch_norm: bool = True, | |
batch_norm_layer: tf_keras.layers.Layer = | |
tf_keras.layers.BatchNormalization, | |
batch_norm_momentum: float = 0.99, | |
batch_norm_epsilon: float = 1e-3, | |
use_sync_bn: bool = False, | |
activation: Optional[Any] = None, | |
conv_type: str = '3d', | |
use_buffered_input: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras | |
**kwargs): | |
"""Initializes a conv block. | |
Args: | |
filters: filters for the conv operation. | |
kernel_size: kernel size for the conv operation. | |
strides: strides for the conv operation. | |
depthwise: if True, use DepthwiseConv2D instead of Conv2D | |
causal: if True, use causal mode for the conv operation. | |
use_bias: use bias for the conv operation. | |
kernel_initializer: kernel initializer for the conv operation. | |
kernel_regularizer: kernel regularizer for the conv operation. | |
use_batch_norm: if True, apply batch norm after the conv operation. | |
batch_norm_layer: class to use for batch norm, if applied. | |
batch_norm_momentum: momentum of the batch norm operation, if applied. | |
batch_norm_epsilon: epsilon of the batch norm operation, if applied. | |
use_sync_bn: if True, use synchronized batch normalization. | |
activation: activation after the conv and batch norm operations. | |
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D | |
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their | |
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but | |
uses two sequential 3D ops instead. | |
use_buffered_input: if True, the input is expected to be padded | |
beforehand. In effect, calling this layer will use 'valid' padding on | |
the temporal dimension to simulate 'causal' padding. | |
**kwargs: keyword arguments to be passed to this layer. | |
Returns: | |
A output tensor of the ConvBlock operation. | |
""" | |
super(ConvBlock, self).__init__(**kwargs) | |
kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size') | |
strides = normalize_tuple(strides, 3, 'strides') | |
self._filters = filters | |
self._kernel_size = kernel_size | |
self._strides = strides | |
self._depthwise = depthwise | |
self._causal = causal | |
self._use_bias = use_bias | |
self._kernel_initializer = kernel_initializer | |
self._kernel_regularizer = kernel_regularizer | |
self._use_batch_norm = use_batch_norm | |
self._batch_norm_layer = batch_norm_layer | |
self._batch_norm_momentum = batch_norm_momentum | |
self._batch_norm_epsilon = batch_norm_epsilon | |
self._use_sync_bn = use_sync_bn | |
self._activation = activation | |
self._conv_type = conv_type | |
self._use_buffered_input = use_buffered_input | |
if activation is not None: | |
self._activation_layer = tf_utils.get_activation( | |
activation, use_keras_layer=True) | |
else: | |
self._activation_layer = None | |
self._groups = None | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = { | |
'filters': self._filters, | |
'kernel_size': self._kernel_size, | |
'strides': self._strides, | |
'depthwise': self._depthwise, | |
'causal': self._causal, | |
'use_bias': self._use_bias, | |
'kernel_initializer': self._kernel_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
'use_batch_norm': self._use_batch_norm, | |
'batch_norm_momentum': self._batch_norm_momentum, | |
'batch_norm_epsilon': self._batch_norm_epsilon, | |
'use_sync_bn': self._use_sync_bn, | |
'activation': self._activation, | |
'conv_type': self._conv_type, | |
'use_buffered_input': self._use_buffered_input, | |
} | |
base_config = super(ConvBlock, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def build(self, input_shape): | |
"""Builds the layer with the given input shape.""" | |
padding = 'causal' if self._causal else 'same' | |
self._groups = input_shape[-1] if self._depthwise else 1 | |
self._batch_norm = None | |
self._batch_norm_temporal = None | |
if self._use_batch_norm: | |
self._batch_norm = self._batch_norm_layer( | |
momentum=self._batch_norm_momentum, | |
epsilon=self._batch_norm_epsilon, | |
synchronized=self._use_sync_bn, | |
name='bn') | |
if self._conv_type != '3d' and self._kernel_size[0] > 1: | |
self._batch_norm_temporal = self._batch_norm_layer( | |
momentum=self._batch_norm_momentum, | |
epsilon=self._batch_norm_epsilon, | |
synchronized=self._use_sync_bn, | |
name='bn_temporal') | |
self._conv_temporal = None | |
if self._conv_type == '3d_2plus1d' and self._kernel_size[0] > 1: | |
self._conv = nn_layers.Conv3D( | |
self._filters, | |
(1, self._kernel_size[1], self._kernel_size[2]), | |
strides=(1, self._strides[1], self._strides[2]), | |
padding='same', | |
groups=self._groups, | |
use_bias=self._use_bias, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
use_buffered_input=False, | |
name='conv3d') | |
self._conv_temporal = nn_layers.Conv3D( | |
self._filters, | |
(self._kernel_size[0], 1, 1), | |
strides=(self._strides[0], 1, 1), | |
padding=padding, | |
groups=self._groups, | |
use_bias=self._use_bias, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
use_buffered_input=self._use_buffered_input, | |
name='conv3d_temporal') | |
elif self._conv_type == '2plus1d': | |
self._conv = MobileConv2D( | |
self._filters, | |
(self._kernel_size[1], self._kernel_size[2]), | |
strides=(self._strides[1], self._strides[2]), | |
padding='same', | |
use_depthwise=self._depthwise, | |
groups=self._groups, | |
use_bias=self._use_bias, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
use_buffered_input=False, | |
batch_norm_op=self._batch_norm, | |
activation_op=self._activation_layer, | |
name='conv2d') | |
if self._kernel_size[0] > 1: | |
self._conv_temporal = MobileConv2D( | |
self._filters, | |
(self._kernel_size[0], 1), | |
strides=(self._strides[0], 1), | |
padding=padding, | |
use_temporal=True, | |
use_depthwise=self._depthwise, | |
groups=self._groups, | |
use_bias=self._use_bias, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
use_buffered_input=self._use_buffered_input, | |
batch_norm_op=self._batch_norm_temporal, | |
activation_op=self._activation_layer, | |
name='conv2d_temporal') | |
else: | |
self._conv = nn_layers.Conv3D( | |
self._filters, | |
self._kernel_size, | |
strides=self._strides, | |
padding=padding, | |
groups=self._groups, | |
use_bias=self._use_bias, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
use_buffered_input=self._use_buffered_input, | |
name='conv3d') | |
super(ConvBlock, self).build(input_shape) | |
def call(self, inputs): | |
"""Calls the layer with the given inputs.""" | |
x = inputs | |
# bn_op and activation_op are folded into the '2plus1d' conv layer so that | |
# we do not explicitly call them here. | |
# TODO(lzyuan): clean the conv layers api once the models are re-trained. | |
x = self._conv(x) | |
if self._batch_norm is not None and self._conv_type != '2plus1d': | |
x = self._batch_norm(x) | |
if self._activation_layer is not None and self._conv_type != '2plus1d': | |
x = self._activation_layer(x) | |
if self._conv_temporal is not None: | |
x = self._conv_temporal(x) | |
if self._batch_norm_temporal is not None and self._conv_type != '2plus1d': | |
x = self._batch_norm_temporal(x) | |
if self._activation_layer is not None and self._conv_type != '2plus1d': | |
x = self._activation_layer(x) | |
return x | |
class StreamBuffer(tf_keras.layers.Layer): | |
"""Stream buffer wrapper which caches activations of previous frames.""" | |
def __init__(self, | |
buffer_size: int, | |
state_prefix: Optional[str] = None, | |
**kwargs): | |
"""Initializes a stream buffer. | |
Args: | |
buffer_size: the number of input frames to cache. | |
state_prefix: a prefix string to identify states. | |
**kwargs: keyword arguments to be passed to this layer. | |
Returns: | |
A output tensor of the StreamBuffer operation. | |
""" | |
super(StreamBuffer, self).__init__(**kwargs) | |
state_prefix = state_prefix if state_prefix is not None else '' | |
self._state_prefix = state_prefix | |
self._state_name = f'{state_prefix}_stream_buffer' | |
self._buffer_size = buffer_size | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = { | |
'buffer_size': self._buffer_size, | |
'state_prefix': self._state_prefix, | |
} | |
base_config = super(StreamBuffer, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call( | |
self, | |
inputs: tf.Tensor, | |
states: Optional[nn_layers.States] = None, | |
) -> Tuple[Any, nn_layers.States]: | |
"""Calls the layer with the given inputs. | |
Args: | |
inputs: the input tensor. | |
states: a dict of states such that, if any of the keys match for this | |
layer, will overwrite the contents of the buffer(s). | |
Expected keys include `state_prefix + '_stream_buffer'`. | |
Returns: | |
the output tensor and states | |
""" | |
states = dict(states) if states is not None else {} | |
buffer = states.get(self._state_name, None) | |
# Create the buffer if it does not exist in the states. | |
# Output buffer shape: | |
# [batch_size, buffer_size, input_height, input_width, num_channels] | |
if buffer is None: | |
shape = tf.shape(inputs) | |
buffer = tf.zeros( | |
[shape[0], self._buffer_size, shape[2], shape[3], shape[4]], | |
dtype=inputs.dtype) | |
# tf.pad has limited support for tf lite, so use tf.concat instead. | |
full_inputs = tf.concat([buffer, inputs], axis=1) | |
# Cache the last b frames of the input where b is the buffer size and f | |
# is the number of input frames. If b > f, then we will cache the last b - f | |
# frames from the previous buffer concatenated with the current f input | |
# frames. | |
new_buffer = full_inputs[:, -self._buffer_size:] | |
states[self._state_name] = new_buffer | |
return full_inputs, states | |
class StreamConvBlock(ConvBlock): | |
"""ConvBlock with StreamBuffer.""" | |
def __init__( | |
self, | |
filters: int, | |
kernel_size: Union[int, Sequence[int]], | |
strides: Union[int, Sequence[int]] = 1, | |
depthwise: bool = False, | |
causal: bool = False, | |
use_bias: bool = False, | |
kernel_initializer: tf_keras.initializers.Initializer = 'HeNormal', | |
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = tf.keras | |
.regularizers.L2(KERNEL_WEIGHT_DECAY), | |
use_batch_norm: bool = True, | |
batch_norm_layer: tf_keras.layers.Layer = | |
tf_keras.layers.BatchNormalization, | |
batch_norm_momentum: float = 0.99, | |
batch_norm_epsilon: float = 1e-3, | |
use_sync_bn: bool = False, | |
activation: Optional[Any] = None, | |
conv_type: str = '3d', | |
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras | |
**kwargs): | |
"""Initializes a stream conv block. | |
Args: | |
filters: filters for the conv operation. | |
kernel_size: kernel size for the conv operation. | |
strides: strides for the conv operation. | |
depthwise: if True, use DepthwiseConv2D instead of Conv2D | |
causal: if True, use causal mode for the conv operation. | |
use_bias: use bias for the conv operation. | |
kernel_initializer: kernel initializer for the conv operation. | |
kernel_regularizer: kernel regularizer for the conv operation. | |
use_batch_norm: if True, apply batch norm after the conv operation. | |
batch_norm_layer: class to use for batch norm, if applied. | |
batch_norm_momentum: momentum of the batch norm operation, if applied. | |
batch_norm_epsilon: epsilon of the batch norm operation, if applied. | |
use_sync_bn: if True, use synchronized batch normalization. | |
activation: activation after the conv and batch norm operations. | |
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D | |
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their | |
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but | |
uses two sequential 3D ops instead. | |
state_prefix: a prefix string to identify states. | |
**kwargs: keyword arguments to be passed to this layer. | |
Returns: | |
A output tensor of the StreamConvBlock operation. | |
""" | |
kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size') | |
buffer_size = kernel_size[0] - 1 | |
use_buffer = buffer_size > 0 and causal | |
self._state_prefix = state_prefix | |
super(StreamConvBlock, self).__init__( | |
filters, | |
kernel_size, | |
strides=strides, | |
depthwise=depthwise, | |
causal=causal, | |
use_bias=use_bias, | |
kernel_initializer=kernel_initializer, | |
kernel_regularizer=kernel_regularizer, | |
use_batch_norm=use_batch_norm, | |
batch_norm_layer=batch_norm_layer, | |
batch_norm_momentum=batch_norm_momentum, | |
batch_norm_epsilon=batch_norm_epsilon, | |
use_sync_bn=use_sync_bn, | |
activation=activation, | |
conv_type=conv_type, | |
use_buffered_input=use_buffer, | |
**kwargs) | |
self._stream_buffer = None | |
if use_buffer: | |
self._stream_buffer = StreamBuffer( | |
buffer_size=buffer_size, state_prefix=state_prefix) | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = {'state_prefix': self._state_prefix} | |
base_config = super(StreamConvBlock, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, | |
inputs: tf.Tensor, | |
states: Optional[nn_layers.States] = None | |
) -> Tuple[tf.Tensor, nn_layers.States]: | |
"""Calls the layer with the given inputs. | |
Args: | |
inputs: the input tensor. | |
states: a dict of states such that, if any of the keys match for this | |
layer, will overwrite the contents of the buffer(s). | |
Returns: | |
the output tensor and states | |
""" | |
states = dict(states) if states is not None else {} | |
x = inputs | |
# If we have no separate temporal conv, use the buffer before the 3D conv. | |
if self._conv_temporal is None and self._stream_buffer is not None: | |
x, states = self._stream_buffer(x, states=states) | |
# bn_op and activation_op are folded into the '2plus1d' conv layer so that | |
# we do not explicitly call them here. | |
# TODO(lzyuan): clean the conv layers api once the models are re-trained. | |
x = self._conv(x) | |
if self._batch_norm is not None and self._conv_type != '2plus1d': | |
x = self._batch_norm(x) | |
if self._activation_layer is not None and self._conv_type != '2plus1d': | |
x = self._activation_layer(x) | |
if self._conv_temporal is not None: | |
if self._stream_buffer is not None: | |
# If we have a separate temporal conv, use the buffer before the | |
# 1D conv instead (otherwise, we may waste computation on the 2D conv). | |
x, states = self._stream_buffer(x, states=states) | |
x = self._conv_temporal(x) | |
if self._batch_norm_temporal is not None and self._conv_type != '2plus1d': | |
x = self._batch_norm_temporal(x) | |
if self._activation_layer is not None and self._conv_type != '2plus1d': | |
x = self._activation_layer(x) | |
return x, states | |
class StreamSqueezeExcitation(tf_keras.layers.Layer): | |
"""Squeeze and excitation layer with causal mode. | |
Reference: https://arxiv.org/pdf/1709.01507.pdf | |
""" | |
def __init__( | |
self, | |
hidden_filters: int, | |
se_type: str = '3d', | |
activation: nn_layers.Activation = 'swish', | |
gating_activation: nn_layers.Activation = 'sigmoid', | |
causal: bool = False, | |
conv_type: str = '3d', | |
kernel_initializer: tf_keras.initializers.Initializer = 'HeNormal', | |
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = tf.keras | |
.regularizers.L2(KERNEL_WEIGHT_DECAY), | |
use_positional_encoding: bool = False, | |
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras | |
**kwargs): | |
"""Implementation for squeeze and excitation. | |
Args: | |
hidden_filters: The hidden filters of squeeze excite. | |
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D | |
spatiotemporal global average pooling for squeeze excitation. '2d' | |
uses 2D spatial global average pooling on each frame. '2plus3d' | |
concatenates both 3D and 2D global average pooling. | |
activation: name of the activation function. | |
gating_activation: name of the activation function for gating. | |
causal: if True, use causal mode in the global average pool. | |
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D | |
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their | |
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but | |
uses two sequential 3D ops instead. | |
kernel_initializer: kernel initializer for the conv operations. | |
kernel_regularizer: kernel regularizer for the conv operation. | |
use_positional_encoding: add a positional encoding after the (cumulative) | |
global average pooling layer. | |
state_prefix: a prefix string to identify states. | |
**kwargs: keyword arguments to be passed to this layer. | |
""" | |
super(StreamSqueezeExcitation, self).__init__(**kwargs) | |
self._hidden_filters = hidden_filters | |
self._se_type = se_type | |
self._activation = activation | |
self._gating_activation = gating_activation | |
self._causal = causal | |
self._conv_type = conv_type | |
self._kernel_initializer = kernel_initializer | |
self._kernel_regularizer = kernel_regularizer | |
self._use_positional_encoding = use_positional_encoding | |
self._state_prefix = state_prefix | |
self._spatiotemporal_pool = nn_layers.GlobalAveragePool3D( | |
keepdims=True, causal=causal, state_prefix=state_prefix) | |
self._spatial_pool = nn_layers.SpatialAveragePool3D(keepdims=True) | |
self._pos_encoding = None | |
if use_positional_encoding: | |
self._pos_encoding = nn_layers.PositionalEncoding( | |
initializer='zeros', state_prefix=state_prefix) | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = { | |
'hidden_filters': self._hidden_filters, | |
'se_type': self._se_type, | |
'activation': self._activation, | |
'gating_activation': self._gating_activation, | |
'causal': self._causal, | |
'conv_type': self._conv_type, | |
'kernel_initializer': self._kernel_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
'use_positional_encoding': self._use_positional_encoding, | |
'state_prefix': self._state_prefix, | |
} | |
base_config = super(StreamSqueezeExcitation, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def build(self, input_shape): | |
"""Builds the layer with the given input shape.""" | |
self._se_reduce = ConvBlock( | |
filters=self._hidden_filters, | |
kernel_size=1, | |
causal=self._causal, | |
use_bias=True, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
use_batch_norm=False, | |
activation=self._activation, | |
conv_type=self._conv_type, | |
name='se_reduce') | |
self._se_expand = ConvBlock( | |
filters=input_shape[-1], | |
kernel_size=1, | |
causal=self._causal, | |
use_bias=True, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
use_batch_norm=False, | |
activation=self._gating_activation, | |
conv_type=self._conv_type, | |
name='se_expand') | |
super(StreamSqueezeExcitation, self).build(input_shape) | |
def call(self, | |
inputs: tf.Tensor, | |
states: Optional[nn_layers.States] = None | |
) -> Tuple[tf.Tensor, nn_layers.States]: | |
"""Calls the layer with the given inputs. | |
Args: | |
inputs: the input tensor. | |
states: a dict of states such that, if any of the keys match for this | |
layer, will overwrite the contents of the buffer(s). | |
Returns: | |
the output tensor and states | |
""" | |
states = dict(states) if states is not None else {} | |
if self._se_type == '3d': | |
x, states = self._spatiotemporal_pool( | |
inputs, states=states, output_states=True) | |
elif self._se_type == '2d': | |
x = self._spatial_pool(inputs) | |
elif self._se_type == '2plus3d': | |
x_space = self._spatial_pool(inputs) | |
x, states = self._spatiotemporal_pool( | |
x_space, states=states, output_states=True) | |
if not self._causal: | |
x = tf.tile(x, [1, tf.shape(inputs)[1], 1, 1, 1]) | |
x = tf.concat([x, x_space], axis=-1) | |
else: | |
raise ValueError('Unknown Squeeze Excitation type {}'.format( | |
self._se_type)) | |
if self._pos_encoding is not None: | |
x, states = self._pos_encoding(x, states=states) | |
x = self._se_reduce(x) | |
x = self._se_expand(x) | |
return x * inputs, states | |
class MobileBottleneck(tf_keras.layers.Layer): | |
"""A depthwise inverted bottleneck block. | |
Uses dependency injection to allow flexible definition of different layers | |
within this block. | |
""" | |
def __init__(self, | |
expansion_layer: tf_keras.layers.Layer, | |
feature_layer: tf_keras.layers.Layer, | |
projection_layer: tf_keras.layers.Layer, | |
attention_layer: Optional[tf_keras.layers.Layer] = None, | |
skip_layer: Optional[tf_keras.layers.Layer] = None, | |
stochastic_depth_drop_rate: Optional[float] = None, | |
**kwargs): | |
"""Implementation for mobile bottleneck. | |
Args: | |
expansion_layer: initial layer used for pointwise expansion. | |
feature_layer: main layer used for computing 3D features. | |
projection_layer: layer used for pointwise projection. | |
attention_layer: optional layer used for attention-like operations (e.g., | |
squeeze excite). | |
skip_layer: optional skip layer used to project the input before summing | |
with the output for the residual connection. | |
stochastic_depth_drop_rate: optional drop rate for stochastic depth. | |
**kwargs: keyword arguments to be passed to this layer. | |
""" | |
super(MobileBottleneck, self).__init__(**kwargs) | |
self._projection_layer = projection_layer | |
self._attention_layer = attention_layer | |
self._skip_layer = skip_layer | |
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate | |
self._identity = tf_keras.layers.Activation(tf.identity) | |
self._rezero = nn_layers.Scale(initializer='zeros', name='rezero') | |
if stochastic_depth_drop_rate: | |
self._stochastic_depth = nn_layers.StochasticDepth( | |
stochastic_depth_drop_rate, name='stochastic_depth') | |
else: | |
self._stochastic_depth = None | |
self._feature_layer = feature_layer | |
self._expansion_layer = expansion_layer | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = { | |
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate, | |
} | |
base_config = super(MobileBottleneck, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, | |
inputs: tf.Tensor, | |
states: Optional[nn_layers.States] = None | |
) -> Tuple[tf.Tensor, nn_layers.States]: | |
"""Calls the layer with the given inputs. | |
Args: | |
inputs: the input tensor. | |
states: a dict of states such that, if any of the keys match for this | |
layer, will overwrite the contents of the buffer(s). | |
Returns: | |
the output tensor and states | |
""" | |
states = dict(states) if states is not None else {} | |
x = self._expansion_layer(inputs) | |
x, states = self._feature_layer(x, states=states) | |
if self._attention_layer is not None: | |
x, states = self._attention_layer(x, states=states) | |
x = self._projection_layer(x) | |
# Add identity so that the ops are ordered as written. This is useful for, | |
# e.g., quantization. | |
x = self._identity(x) | |
x = self._rezero(x) | |
if self._stochastic_depth is not None: | |
x = self._stochastic_depth(x) | |
if self._skip_layer is not None: | |
skip = self._skip_layer(inputs) | |
else: | |
skip = inputs | |
return x + skip, states | |
class SkipBlock(tf_keras.layers.Layer): | |
"""Skip block for bottleneck blocks.""" | |
def __init__( | |
self, | |
out_filters: int, | |
downsample: bool = False, | |
conv_type: str = '3d', | |
kernel_initializer: tf_keras.initializers.Initializer = 'HeNormal', | |
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = | |
tf_keras.regularizers.L2(KERNEL_WEIGHT_DECAY), | |
batch_norm_layer: tf_keras.layers.Layer = | |
tf_keras.layers.BatchNormalization, | |
batch_norm_momentum: float = 0.99, | |
batch_norm_epsilon: float = 1e-3, # pytype: disable=annotation-type-mismatch # typed-keras | |
use_sync_bn: bool = False, | |
**kwargs): | |
"""Implementation for skip block. | |
Args: | |
out_filters: the number of projected output filters. | |
downsample: if True, downsamples the input by a factor of 2 by applying | |
average pooling with a 3x3 kernel size on the spatial dimensions. | |
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D | |
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their | |
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but | |
uses two sequential 3D ops instead. | |
kernel_initializer: kernel initializer for the conv operations. | |
kernel_regularizer: kernel regularizer for the conv projection. | |
batch_norm_layer: class to use for batch norm. | |
batch_norm_momentum: momentum of the batch norm operation. | |
batch_norm_epsilon: epsilon of the batch norm operation. | |
use_sync_bn: if True, use synchronized batch normalization. | |
**kwargs: keyword arguments to be passed to this layer. | |
""" | |
super(SkipBlock, self).__init__(**kwargs) | |
self._out_filters = out_filters | |
self._downsample = downsample | |
self._conv_type = conv_type | |
self._kernel_initializer = kernel_initializer | |
self._kernel_regularizer = kernel_regularizer | |
self._batch_norm_layer = batch_norm_layer | |
self._batch_norm_momentum = batch_norm_momentum | |
self._batch_norm_epsilon = batch_norm_epsilon | |
self._use_sync_bn = use_sync_bn | |
self._projection = ConvBlock( | |
filters=self._out_filters, | |
kernel_size=1, | |
conv_type=conv_type, | |
kernel_initializer=kernel_initializer, | |
kernel_regularizer=kernel_regularizer, | |
use_batch_norm=True, | |
batch_norm_layer=self._batch_norm_layer, | |
batch_norm_momentum=self._batch_norm_momentum, | |
batch_norm_epsilon=self._batch_norm_epsilon, | |
use_sync_bn=self._use_sync_bn, | |
name='skip_project') | |
if downsample: | |
if self._conv_type == '2plus1d': | |
self._pool = tf_keras.layers.AveragePooling2D( | |
pool_size=(3, 3), | |
strides=(2, 2), | |
padding='same', | |
name='skip_pool') | |
else: | |
self._pool = tf_keras.layers.AveragePooling3D( | |
pool_size=(1, 3, 3), | |
strides=(1, 2, 2), | |
padding='same', | |
name='skip_pool') | |
else: | |
self._pool = None | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = { | |
'out_filters': self._out_filters, | |
'downsample': self._downsample, | |
'conv_type': self._conv_type, | |
'kernel_initializer': self._kernel_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
'batch_norm_momentum': self._batch_norm_momentum, | |
'batch_norm_epsilon': self._batch_norm_epsilon, | |
'use_sync_bn': self._use_sync_bn | |
} | |
base_config = super(SkipBlock, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs): | |
"""Calls the layer with the given inputs.""" | |
x = inputs | |
if self._pool is not None: | |
if self._conv_type == '2plus1d': | |
x = tf.reshape(x, [-1, tf.shape(x)[2], tf.shape(x)[3], x.shape[4]]) | |
x = self._pool(x) | |
if self._conv_type == '2plus1d': | |
x = tf.reshape( | |
x, | |
[tf.shape(inputs)[0], -1, tf.shape(x)[1], | |
tf.shape(x)[2], x.shape[3]]) | |
return self._projection(x) | |
class MovinetBlock(tf_keras.layers.Layer): | |
"""A basic block for MoViNets. | |
Applies a mobile inverted bottleneck with pointwise expansion, 3D depthwise | |
convolution, 3D squeeze excite, pointwise projection, and residual connection. | |
""" | |
def __init__( | |
self, | |
out_filters: int, | |
expand_filters: int, | |
kernel_size: Union[int, Sequence[int]] = (3, 3, 3), | |
strides: Union[int, Sequence[int]] = (1, 1, 1), | |
causal: bool = False, | |
activation: nn_layers.Activation = 'swish', | |
gating_activation: nn_layers.Activation = 'sigmoid', | |
se_ratio: float = 0.25, | |
stochastic_depth_drop_rate: float = 0., | |
conv_type: str = '3d', | |
se_type: str = '3d', | |
use_positional_encoding: bool = False, | |
kernel_initializer: tf_keras.initializers.Initializer = 'HeNormal', | |
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = tf.keras | |
.regularizers.L2(KERNEL_WEIGHT_DECAY), | |
batch_norm_layer: tf_keras.layers.Layer = | |
tf_keras.layers.BatchNormalization, | |
batch_norm_momentum: float = 0.99, | |
batch_norm_epsilon: float = 1e-3, | |
use_sync_bn: bool = False, | |
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras | |
**kwargs): | |
"""Implementation for MoViNet block. | |
Args: | |
out_filters: number of output filters for the final projection. | |
expand_filters: number of expansion filters after the input. | |
kernel_size: kernel size of the main depthwise convolution. | |
strides: strides of the main depthwise convolution. | |
causal: if True, run the temporal convolutions in causal mode. | |
activation: activation to use across all conv operations. | |
gating_activation: gating activation to use in squeeze excitation layers. | |
se_ratio: squeeze excite filters ratio. | |
stochastic_depth_drop_rate: optional drop rate for stochastic depth. | |
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D | |
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their | |
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but | |
uses two sequential 3D ops instead. | |
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D | |
spatiotemporal global average pooling for squeeze excitation. '2d' | |
uses 2D spatial global average pooling on each frame. '2plus3d' | |
concatenates both 3D and 2D global average pooling. | |
use_positional_encoding: add a positional encoding after the (cumulative) | |
global average pooling layer in the squeeze excite layer. | |
kernel_initializer: kernel initializer for the conv operations. | |
kernel_regularizer: kernel regularizer for the conv operations. | |
batch_norm_layer: class to use for batch norm. | |
batch_norm_momentum: momentum of the batch norm operation. | |
batch_norm_epsilon: epsilon of the batch norm operation. | |
use_sync_bn: if True, use synchronized batch normalization. | |
state_prefix: a prefix string to identify states. | |
**kwargs: keyword arguments to be passed to this layer. | |
""" | |
super(MovinetBlock, self).__init__(**kwargs) | |
self._kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size') | |
self._strides = normalize_tuple(strides, 3, 'strides') | |
# Use a multiplier of 2 if concatenating multiple features | |
se_multiplier = 2 if se_type == '2plus3d' else 1 | |
se_hidden_filters = nn_layers.make_divisible( | |
se_ratio * expand_filters * se_multiplier, divisor=8) | |
self._out_filters = out_filters | |
self._expand_filters = expand_filters | |
self._causal = causal | |
self._activation = activation | |
self._gating_activation = gating_activation | |
self._se_ratio = se_ratio | |
self._downsample = any(s > 1 for s in self._strides) | |
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate | |
self._conv_type = conv_type | |
self._se_type = se_type | |
self._use_positional_encoding = use_positional_encoding | |
self._kernel_initializer = kernel_initializer | |
self._kernel_regularizer = kernel_regularizer | |
self._batch_norm_layer = batch_norm_layer | |
self._batch_norm_momentum = batch_norm_momentum | |
self._batch_norm_epsilon = batch_norm_epsilon | |
self._use_sync_bn = use_sync_bn | |
self._state_prefix = state_prefix | |
self._expansion = ConvBlock( | |
expand_filters, | |
(1, 1, 1), | |
activation=activation, | |
conv_type=conv_type, | |
kernel_initializer=kernel_initializer, | |
kernel_regularizer=kernel_regularizer, | |
use_batch_norm=True, | |
batch_norm_layer=self._batch_norm_layer, | |
batch_norm_momentum=self._batch_norm_momentum, | |
batch_norm_epsilon=self._batch_norm_epsilon, | |
use_sync_bn=self._use_sync_bn, | |
name='expansion') | |
self._feature = StreamConvBlock( | |
expand_filters, | |
self._kernel_size, | |
strides=self._strides, | |
depthwise=True, | |
causal=self._causal, | |
activation=activation, | |
conv_type=conv_type, | |
kernel_initializer=kernel_initializer, | |
kernel_regularizer=kernel_regularizer, | |
use_batch_norm=True, | |
batch_norm_layer=self._batch_norm_layer, | |
batch_norm_momentum=self._batch_norm_momentum, | |
batch_norm_epsilon=self._batch_norm_epsilon, | |
use_sync_bn=self._use_sync_bn, | |
state_prefix=state_prefix, | |
name='feature') | |
self._projection = ConvBlock( | |
out_filters, | |
(1, 1, 1), | |
activation=None, | |
conv_type=conv_type, | |
kernel_initializer=kernel_initializer, | |
kernel_regularizer=kernel_regularizer, | |
use_batch_norm=True, | |
batch_norm_layer=self._batch_norm_layer, | |
batch_norm_momentum=self._batch_norm_momentum, | |
batch_norm_epsilon=self._batch_norm_epsilon, | |
use_sync_bn=self._use_sync_bn, | |
name='projection') | |
self._attention = None | |
if se_type != 'none': | |
self._attention = StreamSqueezeExcitation( | |
se_hidden_filters, | |
se_type=se_type, | |
activation=activation, | |
gating_activation=gating_activation, | |
causal=self._causal, | |
conv_type=conv_type, | |
use_positional_encoding=use_positional_encoding, | |
kernel_initializer=kernel_initializer, | |
kernel_regularizer=kernel_regularizer, | |
state_prefix=state_prefix, | |
name='se') | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = { | |
'out_filters': self._out_filters, | |
'expand_filters': self._expand_filters, | |
'kernel_size': self._kernel_size, | |
'strides': self._strides, | |
'causal': self._causal, | |
'activation': self._activation, | |
'gating_activation': self._gating_activation, | |
'se_ratio': self._se_ratio, | |
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate, | |
'conv_type': self._conv_type, | |
'se_type': self._se_type, | |
'use_positional_encoding': self._use_positional_encoding, | |
'kernel_initializer': self._kernel_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
'batch_norm_momentum': self._batch_norm_momentum, | |
'batch_norm_epsilon': self._batch_norm_epsilon, | |
'use_sync_bn': self._use_sync_bn, | |
'state_prefix': self._state_prefix, | |
} | |
base_config = super(MovinetBlock, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def build(self, input_shape): | |
"""Builds the layer with the given input shape.""" | |
if input_shape[-1] == self._out_filters and not self._downsample: | |
self._skip = None | |
else: | |
self._skip = SkipBlock( | |
self._out_filters, | |
downsample=self._downsample, | |
conv_type=self._conv_type, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
name='skip') | |
self._mobile_bottleneck = MobileBottleneck( | |
self._expansion, | |
self._feature, | |
self._projection, | |
attention_layer=self._attention, | |
skip_layer=self._skip, | |
stochastic_depth_drop_rate=self._stochastic_depth_drop_rate, | |
name='bneck') | |
super(MovinetBlock, self).build(input_shape) | |
def call(self, | |
inputs: tf.Tensor, | |
states: Optional[nn_layers.States] = None | |
) -> Tuple[tf.Tensor, nn_layers.States]: | |
"""Calls the layer with the given inputs. | |
Args: | |
inputs: the input tensor. | |
states: a dict of states such that, if any of the keys match for this | |
layer, will overwrite the contents of the buffer(s). | |
Returns: | |
the output tensor and states | |
""" | |
states = dict(states) if states is not None else {} | |
return self._mobile_bottleneck(inputs, states=states) | |
class Stem(tf_keras.layers.Layer): | |
"""Stem layer for video networks. | |
Applies an initial convolution block operation. | |
""" | |
def __init__( | |
self, | |
out_filters: int, | |
kernel_size: Union[int, Sequence[int]], | |
strides: Union[int, Sequence[int]] = (1, 1, 1), | |
causal: bool = False, | |
conv_type: str = '3d', | |
activation: nn_layers.Activation = 'swish', | |
kernel_initializer: tf_keras.initializers.Initializer = 'HeNormal', | |
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = tf.keras | |
.regularizers.L2(KERNEL_WEIGHT_DECAY), | |
batch_norm_layer: tf_keras.layers.Layer = | |
tf_keras.layers.BatchNormalization, | |
batch_norm_momentum: float = 0.99, | |
batch_norm_epsilon: float = 1e-3, | |
use_sync_bn: bool = False, | |
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras | |
**kwargs): | |
"""Implementation for video model stem. | |
Args: | |
out_filters: number of output filters. | |
kernel_size: kernel size of the convolution. | |
strides: strides of the convolution. | |
causal: if True, run the temporal convolutions in causal mode. | |
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D | |
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their | |
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but | |
uses two sequential 3D ops instead. | |
activation: the input activation name. | |
kernel_initializer: kernel initializer for the conv operations. | |
kernel_regularizer: kernel regularizer for the conv operations. | |
batch_norm_layer: class to use for batch norm. | |
batch_norm_momentum: momentum of the batch norm operation. | |
batch_norm_epsilon: epsilon of the batch norm operation. | |
use_sync_bn: if True, use synchronized batch normalization. | |
state_prefix: a prefix string to identify states. | |
**kwargs: keyword arguments to be passed to this layer. | |
""" | |
super(Stem, self).__init__(**kwargs) | |
self._out_filters = out_filters | |
self._kernel_size = normalize_tuple(kernel_size, 3, 'kernel_size') | |
self._strides = normalize_tuple(strides, 3, 'strides') | |
self._causal = causal | |
self._conv_type = conv_type | |
self._activation = activation | |
self._kernel_initializer = kernel_initializer | |
self._kernel_regularizer = kernel_regularizer | |
self._batch_norm_layer = batch_norm_layer | |
self._batch_norm_momentum = batch_norm_momentum | |
self._batch_norm_epsilon = batch_norm_epsilon | |
self._use_sync_bn = use_sync_bn | |
self._state_prefix = state_prefix | |
self._stem = StreamConvBlock( | |
filters=self._out_filters, | |
kernel_size=self._kernel_size, | |
strides=self._strides, | |
causal=self._causal, | |
activation=self._activation, | |
conv_type=self._conv_type, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
use_batch_norm=True, | |
batch_norm_layer=self._batch_norm_layer, | |
batch_norm_momentum=self._batch_norm_momentum, | |
batch_norm_epsilon=self._batch_norm_epsilon, | |
use_sync_bn=self._use_sync_bn, | |
state_prefix=self._state_prefix, | |
name='stem') | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = { | |
'out_filters': self._out_filters, | |
'kernel_size': self._kernel_size, | |
'strides': self._strides, | |
'causal': self._causal, | |
'activation': self._activation, | |
'conv_type': self._conv_type, | |
'kernel_initializer': self._kernel_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
'batch_norm_momentum': self._batch_norm_momentum, | |
'batch_norm_epsilon': self._batch_norm_epsilon, | |
'use_sync_bn': self._use_sync_bn, | |
'state_prefix': self._state_prefix, | |
} | |
base_config = super(Stem, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, | |
inputs: tf.Tensor, | |
states: Optional[nn_layers.States] = None | |
) -> Tuple[tf.Tensor, nn_layers.States]: | |
"""Calls the layer with the given inputs. | |
Args: | |
inputs: the input tensor. | |
states: a dict of states such that, if any of the keys match for this | |
layer, will overwrite the contents of the buffer(s). | |
Returns: | |
the output tensor and states | |
""" | |
states = dict(states) if states is not None else {} | |
return self._stem(inputs, states=states) | |
class Head(tf_keras.layers.Layer): | |
"""Head layer for video networks. | |
Applies pointwise projection and global pooling. | |
""" | |
def __init__( | |
self, | |
project_filters: int, | |
conv_type: str = '3d', | |
activation: nn_layers.Activation = 'swish', | |
kernel_initializer: tf_keras.initializers.Initializer = 'HeNormal', | |
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = tf.keras | |
.regularizers.L2(KERNEL_WEIGHT_DECAY), | |
batch_norm_layer: tf_keras.layers.Layer = | |
tf_keras.layers.BatchNormalization, | |
batch_norm_momentum: float = 0.99, | |
batch_norm_epsilon: float = 1e-3, | |
use_sync_bn: bool = False, | |
average_pooling_type: str = '3d', | |
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras | |
**kwargs): | |
"""Implementation for video model head. | |
Args: | |
project_filters: number of pointwise projection filters. | |
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D | |
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their | |
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but | |
uses two sequential 3D ops instead. | |
activation: the input activation name. | |
kernel_initializer: kernel initializer for the conv operations. | |
kernel_regularizer: kernel regularizer for the conv operations. | |
batch_norm_layer: class to use for batch norm. | |
batch_norm_momentum: momentum of the batch norm operation. | |
batch_norm_epsilon: epsilon of the batch norm operation. | |
use_sync_bn: if True, use synchronized batch normalization. | |
average_pooling_type: The average pooling type. Currently supporting | |
['3d', '2d', 'none']. | |
state_prefix: a prefix string to identify states. | |
**kwargs: keyword arguments to be passed to this layer. | |
""" | |
super(Head, self).__init__(**kwargs) | |
self._project_filters = project_filters | |
self._conv_type = conv_type | |
self._activation = activation | |
self._kernel_initializer = kernel_initializer | |
self._kernel_regularizer = kernel_regularizer | |
self._batch_norm_layer = batch_norm_layer | |
self._batch_norm_momentum = batch_norm_momentum | |
self._batch_norm_epsilon = batch_norm_epsilon | |
self._use_sync_bn = use_sync_bn | |
self._state_prefix = state_prefix | |
self._project = ConvBlock( | |
filters=project_filters, | |
kernel_size=1, | |
activation=activation, | |
conv_type=conv_type, | |
kernel_regularizer=kernel_regularizer, | |
use_batch_norm=True, | |
batch_norm_layer=self._batch_norm_layer, | |
batch_norm_momentum=self._batch_norm_momentum, | |
batch_norm_epsilon=self._batch_norm_epsilon, | |
use_sync_bn=self._use_sync_bn, | |
name='project') | |
if average_pooling_type.lower() == '3d': | |
self._pool = nn_layers.GlobalAveragePool3D( | |
keepdims=True, causal=False, state_prefix=state_prefix) | |
elif average_pooling_type.lower() == '2d': | |
self._pool = nn_layers.SpatialAveragePool3D(keepdims=True) | |
elif average_pooling_type == 'none': | |
self._pool = None | |
else: | |
raise ValueError( | |
'%s average_pooling_type is not supported.' % average_pooling_type) | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = { | |
'project_filters': self._project_filters, | |
'conv_type': self._conv_type, | |
'activation': self._activation, | |
'kernel_initializer': self._kernel_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
'batch_norm_momentum': self._batch_norm_momentum, | |
'batch_norm_epsilon': self._batch_norm_epsilon, | |
'use_sync_bn': self._use_sync_bn, | |
'state_prefix': self._state_prefix, | |
} | |
base_config = super(Head, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call( | |
self, | |
inputs: Union[tf.Tensor, Mapping[str, tf.Tensor]], | |
states: Optional[nn_layers.States] = None, | |
) -> Tuple[tf.Tensor, nn_layers.States]: | |
"""Calls the layer with the given inputs. | |
Args: | |
inputs: the input tensor or dict of endpoints. | |
states: a dict of states such that, if any of the keys match for this | |
layer, will overwrite the contents of the buffer(s). | |
Returns: | |
the output tensor and states | |
""" | |
states = dict(states) if states is not None else {} | |
x = self._project(inputs) | |
if self._pool is not None: | |
outputs = self._pool(x, states=states, output_states=True) | |
else: | |
outputs = (x, states) | |
return outputs | |
class ClassifierHead(tf_keras.layers.Layer): | |
"""Head layer for video networks. | |
Applies dense projection, dropout, and classifier projection. Expects input | |
to be pooled vector with shape [batch_size, 1, 1, 1, num_channels] | |
""" | |
def __init__( | |
self, | |
head_filters: int, | |
num_classes: int, | |
dropout_rate: float = 0., | |
conv_type: str = '3d', | |
activation: nn_layers.Activation = 'swish', | |
output_activation: Optional[nn_layers.Activation] = None, | |
max_pool_predictions: bool = False, | |
kernel_initializer: tf_keras.initializers.Initializer = 'HeNormal', | |
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = | |
tf_keras.regularizers.L2(KERNEL_WEIGHT_DECAY), # pytype: disable=annotation-type-mismatch # typed-keras | |
**kwargs): | |
"""Implementation for video model classifier head. | |
Args: | |
head_filters: number of dense head projection filters. | |
num_classes: number of output classes for the final logits. | |
dropout_rate: the dropout rate applied to the head projection. | |
conv_type: '3d', '2plus1d', or '3d_2plus1d'. '3d' uses the default 3D | |
ops. '2plus1d' split any 3D ops into two sequential 2D ops with their | |
own batch norm and activation. '3d_2plus1d' is like '2plus1d', but | |
uses two sequential 3D ops instead. | |
activation: the input activation name. | |
output_activation: optional final activation (e.g., 'softmax'). | |
max_pool_predictions: apply temporal softmax pooling to predictions. | |
Intended for multi-label prediction, where multiple labels are | |
distributed across the video. Currently only supports single clips. | |
kernel_initializer: kernel initializer for the conv operations. | |
kernel_regularizer: kernel regularizer for the conv operations. | |
**kwargs: keyword arguments to be passed to this layer. | |
""" | |
super(ClassifierHead, self).__init__(**kwargs) | |
self._head_filters = head_filters | |
self._num_classes = num_classes | |
self._dropout_rate = dropout_rate | |
self._conv_type = conv_type | |
self._activation = activation | |
self._output_activation = output_activation | |
self._max_pool_predictions = max_pool_predictions | |
self._kernel_initializer = kernel_initializer | |
self._kernel_regularizer = kernel_regularizer | |
self._dropout = tf_keras.layers.Dropout(dropout_rate) | |
self._head = ConvBlock( | |
filters=head_filters, | |
kernel_size=1, | |
activation=activation, | |
use_bias=True, | |
use_batch_norm=False, | |
conv_type=conv_type, | |
kernel_initializer=kernel_initializer, | |
kernel_regularizer=kernel_regularizer, | |
name='head') | |
self._classifier = ConvBlock( | |
filters=num_classes, | |
kernel_size=1, | |
kernel_initializer=tf_keras.initializers.random_normal(stddev=0.01), | |
kernel_regularizer=None, | |
use_bias=True, | |
use_batch_norm=False, | |
conv_type=conv_type, | |
name='classifier') | |
self._max_pool = nn_layers.TemporalSoftmaxPool() | |
self._squeeze = Squeeze3D() | |
output_activation = output_activation if output_activation else 'linear' | |
self._cast = tf_keras.layers.Activation( | |
output_activation, dtype='float32', name='cast') | |
def get_config(self): | |
"""Returns a dictionary containing the config used for initialization.""" | |
config = { | |
'head_filters': self._head_filters, | |
'num_classes': self._num_classes, | |
'dropout_rate': self._dropout_rate, | |
'conv_type': self._conv_type, | |
'activation': self._activation, | |
'output_activation': self._output_activation, | |
'max_pool_predictions': self._max_pool_predictions, | |
'kernel_initializer': self._kernel_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
} | |
base_config = super(ClassifierHead, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs: tf.Tensor) -> tf.Tensor: | |
"""Calls the layer with the given inputs.""" | |
# Input Shape: [batch_size, 1, 1, 1, input_channels] | |
x = inputs | |
x = self._head(x) | |
if self._dropout_rate and self._dropout_rate > 0: | |
x = self._dropout(x) | |
x = self._classifier(x) | |
if self._max_pool_predictions: | |
x = self._max_pool(x) | |
x = self._squeeze(x) | |
x = self._cast(x) | |
return x | |