# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras-based TransformerEncoder block layer.""" from typing import Any, Optional from absl import logging import tensorflow as tf, tf_keras from official.modeling import tf_utils from official.nlp.modeling.layers import util @tf_keras.utils.register_keras_serializable(package="Text") class TransformerEncoderBlock(tf_keras.layers.Layer): """TransformerEncoderBlock layer. This layer implements the Transformer Encoder from "Attention Is All You Need". (https://arxiv.org/abs/1706.03762), which combines a `tf_keras.layers.MultiHeadAttention` layer with a two-layer feedforward network. References: [Attention Is All You Need](https://arxiv.org/abs/1706.03762) [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) """ def __init__(self, num_attention_heads, inner_dim, inner_activation, output_range=None, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, use_bias=True, norm_first=False, norm_epsilon=1e-12, output_dropout=0.0, attention_dropout=0.0, inner_dropout=0.0, attention_initializer=None, attention_axes=None, use_query_residual=True, key_dim=None, value_dim=None, output_last_dim=None, diff_q_kv_att_layer_norm=False, return_attention_scores=False, **kwargs): """Initializes `TransformerEncoderBlock`. Note: If `output_last_dim` is used and `use_query_residual` is `True`, the `output_last_dim`'s value must equal the first input's last dimension for the query residual connection to work. This is because the residual connection after the multi-head-attention requires their dimensions to match. If `use_query_residual` is `False`, the `output_last_dim` dictactes the last dimension of the output of this module and the multi-head-attention. E.g. let's say input dims are `[batch_size, seq_dim, input_last_dim]`. Scenario 1: If `output_last_dim` is not `None`, then the output dims of this module would be `[batch_size, seq_dim, output_last_dim]`. Note `key_dim` is overriden by `output_last_dim`. Scenario 2: If `output_last_dim` is `None` and `key_dim` is not `None`, then the output dims of this module would be `[batch_size, seq_dim, key_dim]`. Scenario 3: If the `output_last_dim` and `key_dim` are both `None`, the output dims would be `[batch_size, seq_dim, input_last_dim]`. Args: num_attention_heads: Number of attention heads. inner_dim: The output dimension of the first Dense layer in a two-layer feedforward network. inner_activation: The activation for the first Dense layer in a two-layer feedforward network. output_range: the sequence output range, [0, output_range) for slicing the target sequence. `None` means the target sequence is not sliced. kernel_initializer: Initializer for dense layer kernels. bias_initializer: Initializer for dense layer biases. kernel_regularizer: Regularizer for dense layer kernels. bias_regularizer: Regularizer for dense layer biases. activity_regularizer: Regularizer for dense layer activity. kernel_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels. use_bias: Whether to enable use_bias in attention layer. If set False, use_bias in attention layer is disabled. norm_first: Whether to normalize inputs to attention and intermediate dense layers. If set False, output of attention and intermediate dense layers is normalized. norm_epsilon: Epsilon value to initialize normalization layers. output_dropout: Dropout probability for the post-attention and output dropout. attention_dropout: Dropout probability for within the attention layer. inner_dropout: Dropout probability for the first Dense layer in a two-layer feedforward network. attention_initializer: Initializer for kernels of attention layers. If set `None`, attention layers use kernel_initializer as initializer for kernel. attention_axes: axes over which the attention is applied. `None` means attention over all axes, but batch, heads, and features. use_query_residual: Toggle to execute residual connection after attention. key_dim: `key_dim` for the `tf_keras.layers.MultiHeadAttention`. If `None`, we use the first `input_shape`'s last dim. value_dim: `value_dim` for the `tf_keras.layers.MultiHeadAttention`. output_last_dim: Final dimension of the output of this module. This also dictates the value for the final dimension of the multi-head-attention. When it's `None`, we use, in order of decreasing precedence, `key_dim` * `num_heads` or the first `input_shape`'s last dim as the output's last dim. diff_q_kv_att_layer_norm: If `True`, create a separate attention layer norm layer for query and key-value if `norm_first` is `True`. Invalid to set to `True` if `norm_first` is `False`. return_attention_scores: If `True`, the output of this layer will be a tuple and additionally contain the attention scores in the shape of `[batch_size, num_attention_heads, seq_dim, seq_dim]`. **kwargs: keyword arguments. """ util.filter_kwargs(kwargs) super().__init__(**kwargs) # Deprecation warning. if output_range is not None: logging.warning("`output_range` is available as an argument for `call()`." "The `output_range` as __init__ argument is deprecated.") self._num_heads = num_attention_heads self._inner_dim = inner_dim self._inner_activation = inner_activation self._attention_dropout_rate = attention_dropout self._output_dropout_rate = output_dropout self._output_range = output_range self._kernel_initializer = tf_keras.initializers.get(kernel_initializer) self._bias_initializer = tf_keras.initializers.get(bias_initializer) self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer) self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer) self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer) self._kernel_constraint = tf_keras.constraints.get(kernel_constraint) self._bias_constraint = tf_keras.constraints.get(bias_constraint) self._use_bias = use_bias self._norm_first = norm_first self._norm_epsilon = norm_epsilon self._inner_dropout = inner_dropout self._use_query_residual = use_query_residual self._key_dim = key_dim self._value_dim = value_dim self._output_last_dim = output_last_dim self._diff_q_kv_att_layer_norm = diff_q_kv_att_layer_norm self._return_attention_scores = return_attention_scores if attention_initializer: self._attention_initializer = tf_keras.initializers.get( attention_initializer) else: self._attention_initializer = tf_utils.clone_initializer( self._kernel_initializer) self._attention_axes = attention_axes if self._diff_q_kv_att_layer_norm and not self._norm_first: raise ValueError("Setting `diff_q_and_kv_attention_layer_norm` to True" "when `norm_first` is False is invalid.") def build(self, input_shape): if isinstance(input_shape, tf.TensorShape): input_tensor_shape = input_shape elif isinstance(input_shape, (list, tuple)): input_tensor_shape = tf.TensorShape(input_shape[0]) else: raise ValueError( "The type of input shape argument is not supported, got: %s" % type(input_shape)) einsum_equation = "abc,cd->abd" if len(input_tensor_shape.as_list()) > 3: einsum_equation = "...bc,cd->...bd" hidden_size = input_tensor_shape[-1] if hidden_size % self._num_heads != 0: logging.warning( "The input size (%d) is not a multiple of the number of attention " "heads (%d)", hidden_size, self._num_heads) if self._key_dim is None: self._key_dim = int(hidden_size // self._num_heads) if self._output_last_dim is None: last_output_shape = hidden_size else: last_output_shape = self._output_last_dim common_kwargs = dict( bias_regularizer=self._bias_regularizer, activity_regularizer=self._activity_regularizer, kernel_constraint=self._kernel_constraint, bias_constraint=self._bias_constraint) self._attention_layer = tf_keras.layers.MultiHeadAttention( num_heads=self._num_heads, key_dim=self._key_dim, value_dim=self._value_dim, dropout=self._attention_dropout_rate, use_bias=self._use_bias, kernel_initializer=self._attention_initializer, bias_initializer=tf_utils.clone_initializer(self._bias_initializer), attention_axes=self._attention_axes, output_shape=self._output_last_dim, name="self_attention", **common_kwargs) self._attention_dropout = tf_keras.layers.Dropout( rate=self._attention_dropout_rate) # Use float32 in layernorm for numeric stability. # It is probably safe in mixed_float16, but we haven't validated this yet. self._attention_layer_norm = ( tf_keras.layers.LayerNormalization( name="self_attention_layer_norm", axis=-1, epsilon=self._norm_epsilon, dtype=tf.float32)) self._attention_layer_norm_kv = self._attention_layer_norm if self._diff_q_kv_att_layer_norm: self._attention_layer_norm_kv = ( tf_keras.layers.LayerNormalization( name="self_attention_layer_norm_kv", axis=-1, epsilon=self._norm_epsilon, dtype=tf.float32)) self._intermediate_dense = tf_keras.layers.EinsumDense( einsum_equation, output_shape=(None, self._inner_dim), bias_axes="d", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), name="intermediate", **common_kwargs) policy = tf_keras.mixed_precision.global_policy() if policy.name == "mixed_bfloat16": # bfloat16 causes BERT with the LAMB optimizer to not converge # as well, so we use float32. # TODO(b/154538392): Investigate this. policy = tf.float32 self._intermediate_activation_layer = tf_keras.layers.Activation( self._inner_activation, dtype=policy) self._inner_dropout_layer = tf_keras.layers.Dropout( rate=self._inner_dropout) self._output_dense = tf_keras.layers.EinsumDense( einsum_equation, output_shape=(None, last_output_shape), bias_axes="d", name="output", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), **common_kwargs) self._output_dropout = tf_keras.layers.Dropout( rate=self._output_dropout_rate) # Use float32 in layernorm for numeric stability. self._output_layer_norm = tf_keras.layers.LayerNormalization( name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon, dtype=tf.float32) super().build(input_shape) def get_config(self): config = { "num_attention_heads": self._num_heads, "inner_dim": self._inner_dim, "inner_activation": self._inner_activation, "output_dropout": self._output_dropout_rate, "attention_dropout": self._attention_dropout_rate, "output_range": self._output_range, "kernel_initializer": tf_utils.serialize_initializer( self._kernel_initializer, use_legacy_format=True ), "bias_initializer": tf_utils.serialize_initializer( self._bias_initializer, use_legacy_format=True ), "kernel_regularizer": tf_utils.serialize_regularizer( self._kernel_regularizer, use_legacy_format=True ), "bias_regularizer": tf_utils.serialize_regularizer( self._bias_regularizer, use_legacy_format=True ), "activity_regularizer": tf_utils.serialize_regularizer( self._activity_regularizer, use_legacy_format=True ), "kernel_constraint": tf_utils.serialize_constraint( self._kernel_constraint, use_legacy_format=True ), "bias_constraint": tf_utils.serialize_constraint( self._bias_constraint, use_legacy_format=True ), "use_bias": self._use_bias, "norm_first": self._norm_first, "norm_epsilon": self._norm_epsilon, "inner_dropout": self._inner_dropout, "attention_initializer": tf_utils.serialize_initializer( self._attention_initializer, use_legacy_format=True ), "attention_axes": self._attention_axes, "use_query_residual": self._use_query_residual, "key_dim": self._key_dim, "value_dim": self._value_dim, "output_last_dim": self._output_last_dim, "diff_q_kv_att_layer_norm": self._diff_q_kv_att_layer_norm, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any: """Transformer self-attention encoder block call. Args: inputs: a single tensor or a list of tensors. `input tensor` as the single sequence of embeddings. [`input tensor`, `attention mask`] to have the additional attention mask. [`query tensor`, `key value tensor`, `attention mask`] to have separate input streams for the query, and key/value to the multi-head attention. output_range: the sequence output range, [0, output_range) for slicing the target sequence. `None` means the target sequence is not sliced. If you would like to have no change to the model training, it is better to only set the `output_range` for serving. Returns: An output tensor with the same dimensions as input/query tensor. """ if isinstance(inputs, (list, tuple)): if len(inputs) == 2: input_tensor, attention_mask = inputs key_value = None elif len(inputs) == 3: input_tensor, key_value, attention_mask = inputs else: raise ValueError("Unexpected inputs to %s with length at %d" % (self.__class__, len(inputs))) else: input_tensor, key_value, attention_mask = (inputs, None, None) if output_range is None: output_range = self._output_range if output_range: if self._norm_first: source_tensor = input_tensor[:, 0:output_range, :] input_tensor = self._attention_layer_norm(input_tensor) if key_value is not None: key_value = self._attention_layer_norm_kv(key_value) target_tensor = input_tensor[:, 0:output_range, :] if attention_mask is not None: attention_mask = attention_mask[:, 0:output_range, :] else: if self._norm_first: source_tensor = input_tensor input_tensor = self._attention_layer_norm(input_tensor) if key_value is not None: key_value = self._attention_layer_norm_kv(key_value) target_tensor = input_tensor if key_value is None: key_value = input_tensor if self._return_attention_scores: attention_output, attention_scores = self._attention_layer( query=target_tensor, value=key_value, attention_mask=attention_mask, return_attention_scores=True) else: attention_output = self._attention_layer( query=target_tensor, value=key_value, attention_mask=attention_mask) attention_output = self._attention_dropout(attention_output) if self._norm_first: # Important to not combine `self._norm_first` and # `self._use_query_residual` into one if clause because else is only for # `_norm_first == False`. if self._use_query_residual: attention_output = source_tensor + attention_output else: if self._use_query_residual: attention_output = target_tensor + attention_output attention_output = self._attention_layer_norm(attention_output) if self._norm_first: source_attention_output = attention_output attention_output = self._output_layer_norm(attention_output) inner_output = self._intermediate_dense(attention_output) inner_output = self._intermediate_activation_layer(inner_output) inner_output = self._inner_dropout_layer(inner_output) layer_output = self._output_dense(inner_output) layer_output = self._output_dropout(layer_output) if self._norm_first: layer_output = source_attention_output + layer_output else: # During mixed precision training, layer norm output is always fp32 for # now. Casts fp32 for the subsequent add. layer_output = tf.cast(layer_output, tf.float32) layer_output = self._output_layer_norm(layer_output + attention_output) if self._return_attention_scores: return layer_output, attention_scores else: return layer_output