# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras-based transformer block layer.""" # pylint: disable=g-classes-have-attributes from absl import logging import gin import tensorflow as tf, tf_keras from official.modeling import tf_utils from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import multi_channel_attention from official.nlp.modeling.layers import transformer_encoder_block from official.nlp.modeling.layers.util import tf_function_if_eager @tf_keras.utils.register_keras_serializable(package="Text") class Transformer(transformer_encoder_block.TransformerEncoderBlock): """Transformer layer. This layer implements the Transformer from "Attention Is All You Need". (https://arxiv.org/abs/1706.03762). **Warning: this layer is deprecated. Please don't use it. Use the `TransformerEncoderBlock` layer instead.** Args: num_attention_heads: Number of attention heads. intermediate_size: Size of the intermediate layer. intermediate_activation: Activation for the intermediate layer. dropout_rate: Dropout probability for the post-attention and output dropout. attention_dropout_rate: Dropout probability for within the attention layer. output_range: the sequence output range, [0, output_range) by slicing the target sequence. `None` means the target sequence is not sliced. kernel_initializer: Initializer for dense layer kernels. bias_initializer: Initializer for dense layer biases. kernel_regularizer: Regularizer for dense layer kernels. bias_regularizer: Regularizer for dense layer biases. activity_regularizer: Regularizer for dense layer activity. kernel_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels. use_bias: Whether to enable use_bias in attention layer. If set False, use_bias in attention layer is disabled. norm_first: Whether to normalize inputs to attention and intermediate dense layers. If set False, output of attention and intermediate dense layers is normalized. norm_epsilon: Epsilon value to initialize normalization layers. intermediate_dropout: Dropout probability for intermediate_dropout_layer. attention_initializer: Initializer for kernels of attention layers. If set `None`, attention layers use kernel_initializer as initializer for kernel. """ def __init__(self, num_attention_heads, intermediate_size, intermediate_activation, dropout_rate=0.0, attention_dropout_rate=0.0, output_range=None, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, use_bias=True, norm_first=False, norm_epsilon=1e-12, intermediate_dropout=0.0, attention_initializer=None, **kwargs): super().__init__( num_attention_heads=num_attention_heads, inner_dim=intermediate_size, inner_activation=intermediate_activation, output_dropout=dropout_rate, attention_dropout=attention_dropout_rate, output_range=output_range, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint, use_bias=use_bias, norm_first=norm_first, norm_epsilon=norm_epsilon, inner_dropout=intermediate_dropout, attention_initializer=attention_initializer, **kwargs) logging.warning("The `Transformer` layer is deprecated. Please directly " "use `TransformerEncoderBlock`.") def get_config(self): return { "num_attention_heads": self._num_heads, "intermediate_size": self._inner_dim, "intermediate_activation": self._inner_activation, "dropout_rate": self._output_dropout_rate, "attention_dropout_rate": self._attention_dropout_rate, "output_range": self._output_range, "kernel_initializer": tf_utils.serialize_initializer( self._kernel_initializer, use_legacy_format=True ), "bias_initializer": tf_utils.serialize_initializer( self._bias_initializer, use_legacy_format=True ), "kernel_regularizer": tf_utils.serialize_regularizer( self._kernel_regularizer, use_legacy_format=True ), "bias_regularizer": tf_utils.serialize_regularizer( self._bias_regularizer, use_legacy_format=True ), "activity_regularizer": tf_utils.serialize_regularizer( self._activity_regularizer, use_legacy_format=True ), "kernel_constraint": tf_utils.serialize_constraint( self._kernel_constraint, use_legacy_format=True ), "bias_constraint": tf_utils.serialize_constraint( self._bias_constraint, use_legacy_format=True ), "use_bias": self._use_bias, "norm_first": self._norm_first, "norm_epsilon": self._norm_epsilon, "intermediate_dropout": self._inner_dropout, "attention_initializer": tf_utils.serialize_initializer( self._attention_initializer, use_legacy_format=True ), } @tf_keras.utils.register_keras_serializable(package="Text") @gin.configurable class CompiledTransformer(Transformer): @tf_function_if_eager(experimental_compile=True) def call(self, inputs): return super().call(inputs) @tf_keras.utils.register_keras_serializable(package="Text") class TransformerDecoderBlock(tf_keras.layers.Layer): """Single transformer layer for decoder. It has three sub-layers: (1) a multi-head self-attention mechanism. (2) a encoder-decoder attention. (3) a positionwise fully connected feed-forward network. Args: num_attention_heads: Number of attention heads. intermediate_size: Size of the intermediate layer. intermediate_activation: Activation for the intermediate layer. dropout_rate: Dropout probability for the post-attention and output dropout. attention_dropout_rate: Dropout probability for within the attention layer. multi_channel_cross_attention: Whether to use `MultiChannelAttention` for cross-attention between target sequences and source sequences. kernel_initializer: Initializer for dense layer kernels. bias_initializer: Initializer for dense layer biases. kernel_regularizer: Regularizer for dense layer kernels. bias_regularizer: Regularizer for dense layer biases. activity_regularizer: Regularizer for dense layer activity. kernel_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels. use_bias: Whether to enable use_bias in attention layer. If set False, use_bias in attention layer is disabled. norm_first: Whether to normalize inputs to attention and intermediate dense layers. If set False, output of attention and intermediate dense layers is normalized. norm_epsilon: Epsilon value to initialize normalization layers. intermediate_dropout: Dropout probability for intermediate_dropout_layer. attention_initializer: Initializer for kernels of attention layers. If set `None`, attention layers use kernel_initializer as initializer for kernel. self_attention_cls: An optional class to use for self attention. cross_attention_cls: An optional class to use for cross attention. """ def __init__(self, num_attention_heads, intermediate_size, intermediate_activation, dropout_rate=0.0, attention_dropout_rate=0.0, multi_channel_cross_attention=False, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, use_bias=True, norm_first=False, norm_epsilon=1e-12, intermediate_dropout=0.0, attention_initializer=None, self_attention_cls=None, cross_attention_cls=None, **kwargs): super().__init__(**kwargs) self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.intermediate_activation = tf_keras.activations.get( intermediate_activation) self.dropout_rate = dropout_rate self.attention_dropout_rate = attention_dropout_rate self.multi_channel_cross_attention = multi_channel_cross_attention self._kernel_initializer = tf_keras.initializers.get(kernel_initializer) self._bias_initializer = tf_keras.initializers.get(bias_initializer) self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer) self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer) self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer) self._kernel_constraint = tf_keras.constraints.get(kernel_constraint) self._bias_constraint = tf_keras.constraints.get(bias_constraint) self._use_bias = use_bias self._norm_first = norm_first self._norm_epsilon = norm_epsilon self._intermediate_dropout = intermediate_dropout if attention_initializer: self._attention_initializer = tf_keras.initializers.get( attention_initializer) else: self._attention_initializer = tf_utils.clone_initializer( self._kernel_initializer) self._self_attention_cls = self_attention_cls or attention.CachedAttention if cross_attention_cls is not None: self._cross_attention_cls = cross_attention_cls if self.multi_channel_cross_attention: logging.warning( "%s will be used for cross attention", cross_attention_cls ) elif self.multi_channel_cross_attention: self._cross_attention_cls = multi_channel_attention.MultiChannelAttention else: self._cross_attention_cls = attention.MultiHeadAttention def build(self, input_shape): target_tensor_shape = tf.TensorShape(input_shape[0]) if len(target_tensor_shape.as_list()) != 3: raise ValueError("TransformerLayer expects a three-dimensional input of " "shape [batch, sequence, width].") hidden_size = target_tensor_shape[2] if hidden_size % self.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, self.num_attention_heads)) self.attention_head_size = int(hidden_size) // self.num_attention_heads common_kwargs = dict( kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, activity_regularizer=self._activity_regularizer, kernel_constraint=self._kernel_constraint, bias_constraint=self._bias_constraint) # Self attention. self.self_attention = self._self_attention_cls( num_heads=self.num_attention_heads, key_dim=self.attention_head_size, dropout=self.attention_dropout_rate, use_bias=self._use_bias, kernel_initializer=tf_utils.clone_initializer( self._attention_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), name="self_attention", **common_kwargs) self.self_attention_output_dense = tf_keras.layers.EinsumDense( "abc,cd->abd", output_shape=(None, hidden_size), bias_axes="d", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), name="output", **common_kwargs) self.self_attention_dropout = tf_keras.layers.Dropout( rate=self.dropout_rate) self.self_attention_layer_norm = ( tf_keras.layers.LayerNormalization( name="self_attention_layer_norm", axis=-1, epsilon=self._norm_epsilon, dtype="float32")) # Encoder-decoder attention. self.encdec_attention = self._cross_attention_cls( num_heads=self.num_attention_heads, key_dim=self.attention_head_size, dropout=self.attention_dropout_rate, output_shape=hidden_size, use_bias=self._use_bias, kernel_initializer=tf_utils.clone_initializer( self._attention_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), name="attention/encdec", **common_kwargs) self.encdec_attention_dropout = tf_keras.layers.Dropout( rate=self.dropout_rate) self.encdec_attention_layer_norm = ( tf_keras.layers.LayerNormalization( name="attention/encdec_output_layer_norm", axis=-1, epsilon=self._norm_epsilon, dtype="float32")) # Feed-forward projection. self.intermediate_dense = tf_keras.layers.EinsumDense( "abc,cd->abd", output_shape=(None, self.intermediate_size), bias_axes="d", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), name="intermediate", **common_kwargs) self.intermediate_activation_layer = tf_keras.layers.Activation( self.intermediate_activation) self._intermediate_dropout_layer = tf_keras.layers.Dropout( rate=self._intermediate_dropout) self.output_dense = tf_keras.layers.EinsumDense( "abc,cd->abd", output_shape=(None, hidden_size), bias_axes="d", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), name="output", **common_kwargs) self.output_dropout = tf_keras.layers.Dropout(rate=self.dropout_rate) self.output_layer_norm = tf_keras.layers.LayerNormalization( name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon, dtype="float32") super().build(input_shape) def get_config(self): config = { "num_attention_heads": self.num_attention_heads, "intermediate_size": self.intermediate_size, "intermediate_activation": self.intermediate_activation, "dropout_rate": self.dropout_rate, "attention_dropout_rate": self.attention_dropout_rate, "multi_channel_cross_attention": self.multi_channel_cross_attention, "kernel_initializer": tf_utils.serialize_initializer( self._kernel_initializer, use_legacy_format=True ), "bias_initializer": tf_utils.serialize_initializer( self._bias_initializer, use_legacy_format=True ), "kernel_regularizer": tf_utils.serialize_regularizer( self._kernel_regularizer, use_legacy_format=True ), "bias_regularizer": tf_utils.serialize_regularizer( self._bias_regularizer, use_legacy_format=True ), "activity_regularizer": tf_utils.serialize_regularizer( self._activity_regularizer, use_legacy_format=True ), "kernel_constraint": tf_utils.serialize_constraint( self._kernel_constraint, use_legacy_format=True ), "bias_constraint": tf_utils.serialize_constraint( self._bias_constraint, use_legacy_format=True ), "use_bias": self._use_bias, "norm_first": self._norm_first, "norm_epsilon": self._norm_epsilon, "intermediate_dropout": self._intermediate_dropout, "attention_initializer": tf_utils.serialize_initializer( self._attention_initializer, use_legacy_format=True ), "self_attention_cls": self._self_attention_cls, "cross_attention_cls": self._cross_attention_cls, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) def common_layers_with_encoder(self): """Gets layer objects that can make a Transformer encoder block.""" return [ self.self_attention, self.self_attention_layer_norm, self.intermediate_dense, self.output_dense, self.output_layer_norm ] def call(self, inputs, cache=None, decode_loop_step=None): if self.multi_channel_cross_attention: if len(inputs) != 5: raise ValueError( "TransformerDecoderBlock must have 5 inputs, when it uses " "multi_channel_cross_attention. But it got: %d" % len(inputs)) elif len(inputs) != 4: raise ValueError( "TransformerDecoderBlock must have 4 inputs, but it got: %d" % len(inputs)) input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] source_tensor = input_tensor if self._norm_first: input_tensor = self.self_attention_layer_norm(input_tensor) self_attention_output, cache = self.self_attention( query=input_tensor, value=input_tensor, attention_mask=self_attention_mask, cache=cache, decode_loop_step=decode_loop_step) self_attention_output = self.self_attention_dropout(self_attention_output) if self._norm_first: self_attention_output = source_tensor + self_attention_output else: self_attention_output = self.self_attention_layer_norm( input_tensor + self_attention_output) if self._norm_first: source_self_attention_output = self_attention_output self_attention_output = self.encdec_attention_layer_norm( self_attention_output) cross_attn_inputs = dict( query=self_attention_output, value=memory, attention_mask=attention_mask) if self.multi_channel_cross_attention: # Accesses the 5-th input tensor for the doc-attention probabilities. cross_attn_inputs["context_attention_weights"] = inputs[-1] attention_output = self.encdec_attention(**cross_attn_inputs) attention_output = self.encdec_attention_dropout(attention_output) if self._norm_first: attention_output = source_self_attention_output + attention_output else: attention_output = self.encdec_attention_layer_norm( self_attention_output + attention_output) if self._norm_first: source_attention_output = attention_output attention_output = self.output_layer_norm(attention_output) intermediate_output = self.intermediate_dense(attention_output) intermediate_output = self.intermediate_activation_layer( intermediate_output) intermediate_output = self._intermediate_dropout_layer(intermediate_output) layer_output = self.output_dense(intermediate_output) layer_output = self.output_dropout(layer_output) if self._norm_first: layer_output = source_attention_output + layer_output else: layer_output = self.output_layer_norm(layer_output + attention_output) return layer_output, cache