# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras-based gated feedforward layer.""" # pylint: disable=g-classes-have-attributes import gin import tensorflow as tf, tf_keras from official.modeling import tf_utils from official.nlp.modeling.layers import util @tf_keras.utils.register_keras_serializable(package="Text") @gin.configurable class GatedFeedforward(tf_keras.layers.Layer): """Gated linear feedforward layer. This layer follows the paper "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). In additional, it allows to stack multiple feedforward blocks and specify the position of dropout layer. Args: intermediate_size: Size of the intermediate layer. intermediate_activation: Activation for the intermediate layer. dropout: Dropout probability for the output dropout. use_gate: Whether to use gated linear units. If True, assuming `GELU` as the activation and omitting bias, will apply `GEGLU(x, W, V, W_2) = (GEGLU(xW) * xV)W2`; if False, will follow "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) paper and apply `FFN(x, W, W_2) = GELU(xW_1)W_2.` num_blocks: The number of feedforward blocks to stack. Each block contains a (gated) linear layer and a fully connected layer followed by dropout, layer norm and residual. dropout_position: Where to apply the dropout, the value can be either `before_residual` or `after_residual`. If `before_residual`, will apply `layer_output = layer_norm(dropout(layer_output) + layer_input)`; if `after residual`, will apply `layer_output = dropout(layer_norm(layer_output + layer_input))`. kernel_initializer: Initializer for dense layer kernels. bias_initializer: Initializer for dense layer biases. kernel_regularizer: Regularizer for dense layer kernels. bias_regularizer: Regularizer for dense layer biases. activity_regularizer: Regularizer for dense layer activity. kernel_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels. """ def __init__(self, inner_dim=768, inner_activation=tf_utils.get_activation("gelu"), dropout=0.0, use_gate=True, apply_output_layer_norm=True, num_blocks=1, dropout_position="before_residual", kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): inner_dim = kwargs.pop("intermediate_size", inner_dim) inner_activation = kwargs.pop("intermediate_activation", inner_activation) util.filter_kwargs(kwargs) super().__init__(**kwargs) self._inner_dim = inner_dim self._inner_activation = inner_activation self._dropout = dropout self._use_gate = use_gate self._num_blocks = num_blocks self._apply_output_layer_norm = apply_output_layer_norm self._dropout_position = dropout_position if self._dropout_position not in ("before_residual", "after_residual"): raise ValueError( "The dropout_position should be either `before_residual` or" "`after_residual`, got: %s" % self._dropout_position) self._kernel_initializer = tf_keras.initializers.get(kernel_initializer) self._bias_initializer = tf_keras.initializers.get(bias_initializer) self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer) self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer) self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer) self._kernel_constraint = tf_keras.constraints.get(kernel_constraint) self._bias_constraint = tf_keras.constraints.get(bias_constraint) def build(self, input_shape): hidden_size = input_shape.as_list()[-1] common_kwargs = dict( kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, activity_regularizer=self._activity_regularizer, kernel_constraint=self._kernel_constraint, bias_constraint=self._bias_constraint) self._intermediate_dense = [] self._inner_activation_layers = [] self._gate_dense = [] self._output_dense = [] self._output_dropout = [] self._output_layer_norm = [] activation_policy = tf_keras.mixed_precision.global_policy() if activation_policy.name == "mixed_bfloat16": # bfloat16 causes BERT with the LAMB optimizer to not converge # as well, so we use float32. # TODO(b/154538392): Investigate this. activation_policy = tf.float32 for i in range(self._num_blocks): self._intermediate_dense.append( tf_keras.layers.EinsumDense( "abc,cd->abd", output_shape=(None, self._inner_dim), bias_axes="d", name="intermediate_%d" % i, kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer), bias_initializer=tf_utils.clone_initializer( self._bias_initializer), **common_kwargs)) self._inner_activation_layers.append( tf_keras.layers.Activation( self._inner_activation, dtype=activation_policy)) if self._use_gate: self._gate_dense.append( tf_keras.layers.EinsumDense( "abc,cd->abd", output_shape=(None, self._inner_dim), bias_axes="d", name="gate_%d" % i, kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer), bias_initializer=tf_utils.clone_initializer( self._bias_initializer), **common_kwargs)) self._output_dense.append( tf_keras.layers.EinsumDense( "abc,cd->abd", output_shape=(None, hidden_size), bias_axes="d", name="output_%d" % i, kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer), bias_initializer=tf_utils.clone_initializer( self._bias_initializer), **common_kwargs)) self._output_dropout.append(tf_keras.layers.Dropout(rate=self._dropout)) # Use float32 in layernorm for numeric stability. if self._apply_output_layer_norm: self._output_layer_norm.append( tf_keras.layers.LayerNormalization( name="output_layer_norm_%d" % i, axis=-1, epsilon=1e-12, dtype=tf.float32)) def get_config(self): config = { "inner_dim": self._inner_dim, "inner_activation": self._inner_activation, "dropout": self._dropout, "use_gate": self._use_gate, "num_blocks": self._num_blocks, "dropout_position": self._dropout_position, "kernel_initializer": tf_keras.initializers.serialize(self._kernel_initializer), "bias_initializer": tf_keras.initializers.serialize(self._bias_initializer), "kernel_regularizer": tf_keras.regularizers.serialize(self._kernel_regularizer), "bias_regularizer": tf_keras.regularizers.serialize(self._bias_regularizer), "activity_regularizer": tf_keras.regularizers.serialize(self._activity_regularizer), "kernel_constraint": tf_keras.constraints.serialize(self._kernel_constraint), "bias_constraint": tf_keras.constraints.serialize(self._bias_constraint) } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) def call(self, inputs): layer_output = inputs for i in range(self._num_blocks): layer_input = layer_output intermediate_output = self._intermediate_dense[i](layer_input) intermediate_output = self._inner_activation_layers[i]( intermediate_output) if self._use_gate: gated_linear = self._gate_dense[i](layer_input) intermediate_output = intermediate_output * gated_linear layer_output = self._output_dense[i](intermediate_output) if self._dropout_position == "before_residual": layer_output = self._output_dropout[i](layer_output) # During mixed precision training, `layer_input` may be from layer norm. # If so, it is always fp32. Cast layer_output to fp32 for the subsequent # add. if layer_input.dtype == tf.float32: layer_output = tf.cast(layer_output, tf.float32) if self._apply_output_layer_norm: layer_output = self._output_layer_norm[i](layer_output + layer_input) if self._dropout_position == "after_residual": layer_output = self._output_dropout[i](layer_output) return layer_output