deanna-emery's picture
updates
93528c6
raw
history blame
7.24 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based gated feedforward layer."""
# pylint: disable=g-classes-have-attributes
from typing import Optional
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
class BlockDiagFeedforward(tf_keras.layers.Layer):
"""Block diagonal feedforward layer.
This layer replaces the weight matrix of the output_dense layer with a block
diagonal matrix to save layer parameters and FLOPs. A linear mixing layer can
be added optionally to improve layer expressibility.
Args:
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout: Dropout probability for the output dropout.
num_blocks: The number of blocks for the block diagonal matrix of the
output_dense layer.
apply_mixing: Apply linear mixing if True.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def __init__(
self,
intermediate_size: int,
intermediate_activation: str,
dropout: float,
num_blocks: int = 1,
apply_mixing: bool = True,
kernel_initializer: str = "glorot_uniform",
bias_initializer: str = "zeros",
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
activity_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
kernel_constraint: Optional[tf_keras.constraints.Constraint] = None,
bias_constraint: Optional[tf_keras.constraints.Constraint] = None,
**kwargs): # pylint: disable=g-doc-args
super().__init__(**kwargs)
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
self._dropout = dropout
self._num_blocks = num_blocks
self._apply_mixing = apply_mixing
if intermediate_size % num_blocks != 0:
raise ValueError("Intermediate_size (%d) isn't a multiple of num_blocks "
"(%d)." % (intermediate_size, num_blocks))
self._kernel_initializer = tf_keras.initializers.get(kernel_initializer)
self._bias_initializer = tf_keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf_keras.constraints.get(kernel_constraint)
self._bias_constraint = tf_keras.constraints.get(bias_constraint)
def build(self, input_shape):
hidden_size = input_shape.as_list()[-1]
common_kwargs = dict(
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._intermediate_dense = tf_keras.layers.EinsumDense(
"abc,cde->abde",
output_shape=(None, self._num_blocks,
self._intermediate_size // self._num_blocks),
bias_axes="de",
name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
policy = tf_keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
policy = tf.float32
self._intermediate_activation_layer = tf_keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = tf_keras.layers.EinsumDense(
"abde,deo->abdo",
output_shape=(None, self._num_blocks, hidden_size // self._num_blocks),
bias_axes="do",
name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
if self._apply_mixing:
self._output_mixing = tf_keras.layers.EinsumDense(
"abdo,de->abeo",
output_shape=(None, self._num_blocks,
hidden_size // self._num_blocks),
name="output_mixing",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_reshape = tf_keras.layers.Reshape((-1, hidden_size))
self._output_dropout = tf_keras.layers.Dropout(rate=self._dropout)
def get_config(self):
config = {
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"dropout":
self._dropout,
"num_blocks":
self._num_blocks,
"apply_mixing":
self._apply_mixing,
"kernel_initializer":
tf_keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf_keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf_keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf_keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf_keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf_keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf_keras.constraints.serialize(self._bias_constraint)
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
intermediate_output = self._intermediate_dense(inputs)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output)
if self._apply_mixing:
layer_output = self._output_mixing(layer_output)
layer_output = self._output_reshape(layer_output)
layer_output = self._output_dropout(layer_output)
return layer_output