# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras-based gated feedforward layer.""" # pylint: disable=g-classes-have-attributes from typing import Optional import tensorflow as tf, tf_keras from official.modeling import tf_utils class BlockDiagFeedforward(tf_keras.layers.Layer): """Block diagonal feedforward layer. This layer replaces the weight matrix of the output_dense layer with a block diagonal matrix to save layer parameters and FLOPs. A linear mixing layer can be added optionally to improve layer expressibility. Args: intermediate_size: Size of the intermediate layer. intermediate_activation: Activation for the intermediate layer. dropout: Dropout probability for the output dropout. num_blocks: The number of blocks for the block diagonal matrix of the output_dense layer. apply_mixing: Apply linear mixing if True. kernel_initializer: Initializer for dense layer kernels. bias_initializer: Initializer for dense layer biases. kernel_regularizer: Regularizer for dense layer kernels. bias_regularizer: Regularizer for dense layer biases. activity_regularizer: Regularizer for dense layer activity. kernel_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels. """ def __init__( self, intermediate_size: int, intermediate_activation: str, dropout: float, num_blocks: int = 1, apply_mixing: bool = True, kernel_initializer: str = "glorot_uniform", bias_initializer: str = "zeros", kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, activity_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, kernel_constraint: Optional[tf_keras.constraints.Constraint] = None, bias_constraint: Optional[tf_keras.constraints.Constraint] = None, **kwargs): # pylint: disable=g-doc-args super().__init__(**kwargs) self._intermediate_size = intermediate_size self._intermediate_activation = intermediate_activation self._dropout = dropout self._num_blocks = num_blocks self._apply_mixing = apply_mixing if intermediate_size % num_blocks != 0: raise ValueError("Intermediate_size (%d) isn't a multiple of num_blocks " "(%d)." % (intermediate_size, num_blocks)) self._kernel_initializer = tf_keras.initializers.get(kernel_initializer) self._bias_initializer = tf_keras.initializers.get(bias_initializer) self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer) self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer) self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer) self._kernel_constraint = tf_keras.constraints.get(kernel_constraint) self._bias_constraint = tf_keras.constraints.get(bias_constraint) def build(self, input_shape): hidden_size = input_shape.as_list()[-1] common_kwargs = dict( kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, activity_regularizer=self._activity_regularizer, kernel_constraint=self._kernel_constraint, bias_constraint=self._bias_constraint) self._intermediate_dense = tf_keras.layers.EinsumDense( "abc,cde->abde", output_shape=(None, self._num_blocks, self._intermediate_size // self._num_blocks), bias_axes="de", name="intermediate", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), **common_kwargs) policy = tf_keras.mixed_precision.global_policy() if policy.name == "mixed_bfloat16": # bfloat16 causes BERT with the LAMB optimizer to not converge # as well, so we use float32. policy = tf.float32 self._intermediate_activation_layer = tf_keras.layers.Activation( self._intermediate_activation, dtype=policy) self._output_dense = tf_keras.layers.EinsumDense( "abde,deo->abdo", output_shape=(None, self._num_blocks, hidden_size // self._num_blocks), bias_axes="do", name="output", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), **common_kwargs) if self._apply_mixing: self._output_mixing = tf_keras.layers.EinsumDense( "abdo,de->abeo", output_shape=(None, self._num_blocks, hidden_size // self._num_blocks), name="output_mixing", kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), **common_kwargs) self._output_reshape = tf_keras.layers.Reshape((-1, hidden_size)) self._output_dropout = tf_keras.layers.Dropout(rate=self._dropout) def get_config(self): config = { "intermediate_size": self._intermediate_size, "intermediate_activation": self._intermediate_activation, "dropout": self._dropout, "num_blocks": self._num_blocks, "apply_mixing": self._apply_mixing, "kernel_initializer": tf_keras.initializers.serialize(self._kernel_initializer), "bias_initializer": tf_keras.initializers.serialize(self._bias_initializer), "kernel_regularizer": tf_keras.regularizers.serialize(self._kernel_regularizer), "bias_regularizer": tf_keras.regularizers.serialize(self._bias_regularizer), "activity_regularizer": tf_keras.regularizers.serialize(self._activity_regularizer), "kernel_constraint": tf_keras.constraints.serialize(self._kernel_constraint), "bias_constraint": tf_keras.constraints.serialize(self._bias_constraint) } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) def call(self, inputs): intermediate_output = self._intermediate_dense(inputs) intermediate_output = self._intermediate_activation_layer( intermediate_output) layer_output = self._output_dense(intermediate_output) if self._apply_mixing: layer_output = self._output_mixing(layer_output) layer_output = self._output_reshape(layer_output) layer_output = self._output_dropout(layer_output) return layer_output