# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Neural network operations commonly shared by the architectures.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import tensorflow as tf, tf_keras class NormActivation(tf_keras.layers.Layer): """Combined Normalization and Activation layers.""" def __init__(self, momentum=0.997, epsilon=1e-4, trainable=True, init_zero=False, use_activation=True, activation='relu', fused=True, name=None): """A class to construct layers for a batch normalization followed by a ReLU. Args: momentum: momentum for the moving average. epsilon: small float added to variance to avoid dividing by zero. trainable: `bool`, if True also add variables to the graph collection GraphKeys.TRAINABLE_VARIABLES. If False, freeze batch normalization layer. init_zero: `bool` if True, initializes scale parameter of batch normalization with 0. If False, initialize it with 1. use_activation: `bool`, whether to add the optional activation layer after the batch normalization layer. activation: 'string', the type of the activation layer. Currently support `relu` and `swish`. fused: `bool` fused option in batch normalziation. name: `str` name for the operation. """ super(NormActivation, self).__init__(trainable=trainable) if init_zero: gamma_initializer = tf_keras.initializers.Zeros() else: gamma_initializer = tf_keras.initializers.Ones() self._normalization_op = tf_keras.layers.BatchNormalization( momentum=momentum, epsilon=epsilon, center=True, scale=True, trainable=trainable, fused=fused, gamma_initializer=gamma_initializer, name=name) self._use_activation = use_activation if activation == 'relu': self._activation_op = tf.nn.relu elif activation == 'swish': self._activation_op = tf.nn.swish else: raise ValueError('Unsupported activation `{}`.'.format(activation)) def __call__(self, inputs, is_training=None): """Builds the normalization layer followed by an optional activation layer. Args: inputs: `Tensor` of shape `[batch, channels, ...]`. is_training: `boolean`, if True if model is in training mode. Returns: A normalized `Tensor` with the same `data_format`. """ # We will need to keep training=None by default, so that it can be inherit # from keras.Model.training if is_training and self.trainable: is_training = True inputs = self._normalization_op(inputs, training=is_training) if self._use_activation: inputs = self._activation_op(inputs) return inputs def norm_activation_builder(momentum=0.997, epsilon=1e-4, trainable=True, activation='relu', **kwargs): return functools.partial( NormActivation, momentum=momentum, epsilon=epsilon, trainable=trainable, activation=activation, **kwargs)