deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for 3D networks."""
# Import libraries
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.vision.modeling.layers import nn_layers
@tf_keras.utils.register_keras_serializable(package='Vision')
class SelfGating(tf_keras.layers.Layer):
"""Feature gating as used in S3D-G.
This implements the S3D-G network from:
Saining Xie, Chen Sun, Jonathan Huang, Zhuowen Tu, Kevin Murphy.
Rethinking Spatiotemporal Feature Learning: Speed-Accuracy Trade-offs in Video
Classification.
(https://arxiv.org/pdf/1712.04851.pdf)
"""
def __init__(self, filters, **kwargs):
"""Initializes a self-gating layer.
Args:
filters: An `int` number of filters for the convolutional layer.
**kwargs: Additional keyword arguments to be passed.
"""
super(SelfGating, self).__init__(**kwargs)
self._filters = filters
def build(self, input_shape):
self._spatial_temporal_average = tf_keras.layers.GlobalAveragePooling3D()
# No BN and activation after conv.
self._transformer_w = tf_keras.layers.Conv3D(
filters=self._filters,
kernel_size=[1, 1, 1],
use_bias=True,
kernel_initializer=tf_keras.initializers.TruncatedNormal(
mean=0.0, stddev=0.01))
super(SelfGating, self).build(input_shape)
def call(self, inputs):
x = self._spatial_temporal_average(inputs)
x = tf.expand_dims(x, 1)
x = tf.expand_dims(x, 2)
x = tf.expand_dims(x, 3)
x = self._transformer_w(x)
x = tf.nn.sigmoid(x)
return tf.math.multiply(x, inputs)
@tf_keras.utils.register_keras_serializable(package='Vision')
class BottleneckBlock3D(tf_keras.layers.Layer):
"""Creates a 3D bottleneck block."""
def __init__(self,
filters,
temporal_kernel_size,
temporal_strides,
spatial_strides,
stochastic_depth_drop_rate=0.0,
se_ratio=None,
use_self_gating=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""Initializes a 3D bottleneck block with BN after convolutions.
Args:
filters: An `int` number of filters for the first two convolutions. Note
that the third and final convolution will use 4 times as many filters.
temporal_kernel_size: An `int` of kernel size for the temporal
convolutional layer.
temporal_strides: An `int` of ftemporal stride for the temporal
convolutional layer.
spatial_strides: An `int` of spatial stride for the spatial convolutional
layer.
stochastic_depth_drop_rate: A `float` or None. If not None, drop rate for
the stochastic depth layer.
se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
use_self_gating: A `bool` of whether to apply self-gating module or not.
kernel_initializer: A `str` of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2d.
Default to None.
activation: A `str` name of the activation function.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: Additional keyword arguments to be passed.
"""
super(BottleneckBlock3D, self).__init__(**kwargs)
self._filters = filters
self._temporal_kernel_size = temporal_kernel_size
self._spatial_strides = spatial_strides
self._temporal_strides = temporal_strides
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._use_self_gating = use_self_gating
self._se_ratio = se_ratio
self._use_sync_bn = use_sync_bn
self._activation = activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._norm = tf_keras.layers.BatchNormalization
if tf_keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
self._shortcut_maxpool = tf_keras.layers.MaxPool3D(
pool_size=[1, 1, 1],
strides=[
self._temporal_strides, self._spatial_strides, self._spatial_strides
])
self._shortcut_conv = tf_keras.layers.Conv3D(
filters=4 * self._filters,
kernel_size=1,
strides=[
self._temporal_strides, self._spatial_strides, self._spatial_strides
],
use_bias=False,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._temporal_conv = tf_keras.layers.Conv3D(
filters=self._filters,
kernel_size=[self._temporal_kernel_size, 1, 1],
strides=[self._temporal_strides, 1, 1],
padding='same',
use_bias=False,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._spatial_conv = tf_keras.layers.Conv3D(
filters=self._filters,
kernel_size=[1, 3, 3],
strides=[1, self._spatial_strides, self._spatial_strides],
padding='same',
use_bias=False,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._expand_conv = tf_keras.layers.Conv3D(
filters=4 * self._filters,
kernel_size=[1, 1, 1],
strides=[1, 1, 1],
padding='same',
use_bias=False,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm3 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
in_filters=self._filters * 4,
out_filters=self._filters * 4,
se_ratio=self._se_ratio,
use_3d_input=True,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._squeeze_excitation = None
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
if self._use_self_gating:
self._self_gating = SelfGating(filters=4 * self._filters)
else:
self._self_gating = None
super(BottleneckBlock3D, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'temporal_kernel_size': self._temporal_kernel_size,
'temporal_strides': self._temporal_strides,
'spatial_strides': self._spatial_strides,
'use_self_gating': self._use_self_gating,
'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(BottleneckBlock3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
in_filters = inputs.shape.as_list()[-1]
if in_filters == 4 * self._filters:
if self._temporal_strides == 1 and self._spatial_strides == 1:
shortcut = inputs
else:
shortcut = self._shortcut_maxpool(inputs)
else:
shortcut = self._shortcut_conv(inputs)
shortcut = self._norm0(shortcut)
x = self._temporal_conv(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._spatial_conv(x)
x = self._norm2(x)
x = self._activation_fn(x)
x = self._expand_conv(x)
x = self._norm3(x)
# Apply self-gating, SE, stochastic depth.
if self._self_gating:
x = self._self_gating(x)
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
# Apply activation before additional modules.
x = self._activation_fn(x + shortcut)
return x