Spaces:
Sleeping
Sleeping
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Keras-based TransformerEncoder block layer.""" | |
from typing import Any, Optional | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.modeling import tf_utils | |
from official.nlp.modeling.layers import util | |
class TransformerEncoderBlock(tf_keras.layers.Layer): | |
"""TransformerEncoderBlock layer. | |
This layer implements the Transformer Encoder from | |
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762), | |
which combines a `tf_keras.layers.MultiHeadAttention` layer with a | |
two-layer feedforward network. | |
References: | |
[Attention Is All You Need](https://arxiv.org/abs/1706.03762) | |
[BERT: Pre-training of Deep Bidirectional Transformers for Language | |
Understanding](https://arxiv.org/abs/1810.04805) | |
""" | |
def __init__(self, | |
num_attention_heads, | |
inner_dim, | |
inner_activation, | |
output_range=None, | |
kernel_initializer="glorot_uniform", | |
bias_initializer="zeros", | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
bias_constraint=None, | |
use_bias=True, | |
norm_first=False, | |
norm_epsilon=1e-12, | |
output_dropout=0.0, | |
attention_dropout=0.0, | |
inner_dropout=0.0, | |
attention_initializer=None, | |
attention_axes=None, | |
use_query_residual=True, | |
key_dim=None, | |
value_dim=None, | |
output_last_dim=None, | |
diff_q_kv_att_layer_norm=False, | |
return_attention_scores=False, | |
**kwargs): | |
"""Initializes `TransformerEncoderBlock`. | |
Note: If `output_last_dim` is used and `use_query_residual` is `True`, the | |
`output_last_dim`'s value must equal the first input's last dimension for | |
the query residual connection to work. This is because the residual | |
connection after the multi-head-attention requires their dimensions to | |
match. If `use_query_residual` is `False`, the `output_last_dim` dictactes | |
the last dimension of the output of this module and the | |
multi-head-attention. | |
E.g. let's say input dims are `[batch_size, seq_dim, input_last_dim]`. | |
Scenario 1: If `output_last_dim` is not `None`, then the output dims of this | |
module would be `[batch_size, seq_dim, output_last_dim]`. Note `key_dim` is | |
overriden by `output_last_dim`. | |
Scenario 2: If `output_last_dim` is `None` and `key_dim` is not `None`, then | |
the output dims of this module would be `[batch_size, seq_dim, key_dim]`. | |
Scenario 3: If the `output_last_dim` and `key_dim` are both `None`, the | |
output dims would be `[batch_size, seq_dim, input_last_dim]`. | |
Args: | |
num_attention_heads: Number of attention heads. | |
inner_dim: The output dimension of the first Dense layer in a two-layer | |
feedforward network. | |
inner_activation: The activation for the first Dense layer in a two-layer | |
feedforward network. | |
output_range: the sequence output range, [0, output_range) for slicing the | |
target sequence. `None` means the target sequence is not sliced. | |
kernel_initializer: Initializer for dense layer kernels. | |
bias_initializer: Initializer for dense layer biases. | |
kernel_regularizer: Regularizer for dense layer kernels. | |
bias_regularizer: Regularizer for dense layer biases. | |
activity_regularizer: Regularizer for dense layer activity. | |
kernel_constraint: Constraint for dense layer kernels. | |
bias_constraint: Constraint for dense layer kernels. | |
use_bias: Whether to enable use_bias in attention layer. If set False, | |
use_bias in attention layer is disabled. | |
norm_first: Whether to normalize inputs to attention and intermediate | |
dense layers. If set False, output of attention and intermediate dense | |
layers is normalized. | |
norm_epsilon: Epsilon value to initialize normalization layers. | |
output_dropout: Dropout probability for the post-attention and output | |
dropout. | |
attention_dropout: Dropout probability for within the attention layer. | |
inner_dropout: Dropout probability for the first Dense layer in a | |
two-layer feedforward network. | |
attention_initializer: Initializer for kernels of attention layers. If set | |
`None`, attention layers use kernel_initializer as initializer for | |
kernel. | |
attention_axes: axes over which the attention is applied. `None` means | |
attention over all axes, but batch, heads, and features. | |
use_query_residual: Toggle to execute residual connection after attention. | |
key_dim: `key_dim` for the `tf_keras.layers.MultiHeadAttention`. If | |
`None`, we use the first `input_shape`'s last dim. | |
value_dim: `value_dim` for the `tf_keras.layers.MultiHeadAttention`. | |
output_last_dim: Final dimension of the output of this module. This also | |
dictates the value for the final dimension of the multi-head-attention. | |
When it's `None`, we use, in order of decreasing precedence, `key_dim` * | |
`num_heads` or the first `input_shape`'s last dim as the output's last | |
dim. | |
diff_q_kv_att_layer_norm: If `True`, create a separate attention layer | |
norm layer for query and key-value if `norm_first` is `True`. Invalid to | |
set to `True` if `norm_first` is `False`. | |
return_attention_scores: If `True`, the output of this layer will be a | |
tuple and additionally contain the attention scores in the shape of | |
`[batch_size, num_attention_heads, seq_dim, seq_dim]`. | |
**kwargs: keyword arguments. | |
""" | |
util.filter_kwargs(kwargs) | |
super().__init__(**kwargs) | |
# Deprecation warning. | |
if output_range is not None: | |
logging.warning("`output_range` is available as an argument for `call()`." | |
"The `output_range` as __init__ argument is deprecated.") | |
self._num_heads = num_attention_heads | |
self._inner_dim = inner_dim | |
self._inner_activation = inner_activation | |
self._attention_dropout_rate = attention_dropout | |
self._output_dropout_rate = output_dropout | |
self._output_range = output_range | |
self._kernel_initializer = tf_keras.initializers.get(kernel_initializer) | |
self._bias_initializer = tf_keras.initializers.get(bias_initializer) | |
self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer) | |
self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer) | |
self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer) | |
self._kernel_constraint = tf_keras.constraints.get(kernel_constraint) | |
self._bias_constraint = tf_keras.constraints.get(bias_constraint) | |
self._use_bias = use_bias | |
self._norm_first = norm_first | |
self._norm_epsilon = norm_epsilon | |
self._inner_dropout = inner_dropout | |
self._use_query_residual = use_query_residual | |
self._key_dim = key_dim | |
self._value_dim = value_dim | |
self._output_last_dim = output_last_dim | |
self._diff_q_kv_att_layer_norm = diff_q_kv_att_layer_norm | |
self._return_attention_scores = return_attention_scores | |
if attention_initializer: | |
self._attention_initializer = tf_keras.initializers.get( | |
attention_initializer) | |
else: | |
self._attention_initializer = tf_utils.clone_initializer( | |
self._kernel_initializer) | |
self._attention_axes = attention_axes | |
if self._diff_q_kv_att_layer_norm and not self._norm_first: | |
raise ValueError("Setting `diff_q_and_kv_attention_layer_norm` to True" | |
"when `norm_first` is False is invalid.") | |
def build(self, input_shape): | |
if isinstance(input_shape, tf.TensorShape): | |
input_tensor_shape = input_shape | |
elif isinstance(input_shape, (list, tuple)): | |
input_tensor_shape = tf.TensorShape(input_shape[0]) | |
else: | |
raise ValueError( | |
"The type of input shape argument is not supported, got: %s" % | |
type(input_shape)) | |
einsum_equation = "abc,cd->abd" | |
if len(input_tensor_shape.as_list()) > 3: | |
einsum_equation = "...bc,cd->...bd" | |
hidden_size = input_tensor_shape[-1] | |
if hidden_size % self._num_heads != 0: | |
logging.warning( | |
"The input size (%d) is not a multiple of the number of attention " | |
"heads (%d)", hidden_size, self._num_heads) | |
if self._key_dim is None: | |
self._key_dim = int(hidden_size // self._num_heads) | |
if self._output_last_dim is None: | |
last_output_shape = hidden_size | |
else: | |
last_output_shape = self._output_last_dim | |
common_kwargs = dict( | |
bias_regularizer=self._bias_regularizer, | |
activity_regularizer=self._activity_regularizer, | |
kernel_constraint=self._kernel_constraint, | |
bias_constraint=self._bias_constraint) | |
self._attention_layer = tf_keras.layers.MultiHeadAttention( | |
num_heads=self._num_heads, | |
key_dim=self._key_dim, | |
value_dim=self._value_dim, | |
dropout=self._attention_dropout_rate, | |
use_bias=self._use_bias, | |
kernel_initializer=self._attention_initializer, | |
bias_initializer=tf_utils.clone_initializer(self._bias_initializer), | |
attention_axes=self._attention_axes, | |
output_shape=self._output_last_dim, | |
name="self_attention", | |
**common_kwargs) | |
self._attention_dropout = tf_keras.layers.Dropout( | |
rate=self._attention_dropout_rate) | |
# Use float32 in layernorm for numeric stability. | |
# It is probably safe in mixed_float16, but we haven't validated this yet. | |
self._attention_layer_norm = ( | |
tf_keras.layers.LayerNormalization( | |
name="self_attention_layer_norm", | |
axis=-1, | |
epsilon=self._norm_epsilon, | |
dtype=tf.float32)) | |
self._attention_layer_norm_kv = self._attention_layer_norm | |
if self._diff_q_kv_att_layer_norm: | |
self._attention_layer_norm_kv = ( | |
tf_keras.layers.LayerNormalization( | |
name="self_attention_layer_norm_kv", | |
axis=-1, | |
epsilon=self._norm_epsilon, | |
dtype=tf.float32)) | |
self._intermediate_dense = tf_keras.layers.EinsumDense( | |
einsum_equation, | |
output_shape=(None, self._inner_dim), | |
bias_axes="d", | |
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), | |
bias_initializer=tf_utils.clone_initializer(self._bias_initializer), | |
name="intermediate", | |
**common_kwargs) | |
policy = tf_keras.mixed_precision.global_policy() | |
if policy.name == "mixed_bfloat16": | |
# bfloat16 causes BERT with the LAMB optimizer to not converge | |
# as well, so we use float32. | |
# TODO(b/154538392): Investigate this. | |
policy = tf.float32 | |
self._intermediate_activation_layer = tf_keras.layers.Activation( | |
self._inner_activation, dtype=policy) | |
self._inner_dropout_layer = tf_keras.layers.Dropout( | |
rate=self._inner_dropout) | |
self._output_dense = tf_keras.layers.EinsumDense( | |
einsum_equation, | |
output_shape=(None, last_output_shape), | |
bias_axes="d", | |
name="output", | |
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), | |
bias_initializer=tf_utils.clone_initializer(self._bias_initializer), | |
**common_kwargs) | |
self._output_dropout = tf_keras.layers.Dropout( | |
rate=self._output_dropout_rate) | |
# Use float32 in layernorm for numeric stability. | |
self._output_layer_norm = tf_keras.layers.LayerNormalization( | |
name="output_layer_norm", | |
axis=-1, | |
epsilon=self._norm_epsilon, | |
dtype=tf.float32) | |
super().build(input_shape) | |
def get_config(self): | |
config = { | |
"num_attention_heads": self._num_heads, | |
"inner_dim": self._inner_dim, | |
"inner_activation": self._inner_activation, | |
"output_dropout": self._output_dropout_rate, | |
"attention_dropout": self._attention_dropout_rate, | |
"output_range": self._output_range, | |
"kernel_initializer": tf_utils.serialize_initializer( | |
self._kernel_initializer, use_legacy_format=True | |
), | |
"bias_initializer": tf_utils.serialize_initializer( | |
self._bias_initializer, use_legacy_format=True | |
), | |
"kernel_regularizer": tf_utils.serialize_regularizer( | |
self._kernel_regularizer, use_legacy_format=True | |
), | |
"bias_regularizer": tf_utils.serialize_regularizer( | |
self._bias_regularizer, use_legacy_format=True | |
), | |
"activity_regularizer": tf_utils.serialize_regularizer( | |
self._activity_regularizer, use_legacy_format=True | |
), | |
"kernel_constraint": tf_utils.serialize_constraint( | |
self._kernel_constraint, use_legacy_format=True | |
), | |
"bias_constraint": tf_utils.serialize_constraint( | |
self._bias_constraint, use_legacy_format=True | |
), | |
"use_bias": self._use_bias, | |
"norm_first": self._norm_first, | |
"norm_epsilon": self._norm_epsilon, | |
"inner_dropout": self._inner_dropout, | |
"attention_initializer": tf_utils.serialize_initializer( | |
self._attention_initializer, use_legacy_format=True | |
), | |
"attention_axes": self._attention_axes, | |
"use_query_residual": self._use_query_residual, | |
"key_dim": self._key_dim, | |
"value_dim": self._value_dim, | |
"output_last_dim": self._output_last_dim, | |
"diff_q_kv_att_layer_norm": self._diff_q_kv_att_layer_norm, | |
} | |
base_config = super().get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs: Any, output_range: Optional[tf.Tensor] = None) -> Any: | |
"""Transformer self-attention encoder block call. | |
Args: | |
inputs: a single tensor or a list of tensors. `input tensor` as the single | |
sequence of embeddings. [`input tensor`, `attention mask`] to have the | |
additional attention mask. [`query tensor`, `key value tensor`, | |
`attention mask`] to have separate input streams for the query, and | |
key/value to the multi-head attention. | |
output_range: the sequence output range, [0, output_range) for slicing the | |
target sequence. `None` means the target sequence is not sliced. If you | |
would like to have no change to the model training, it is better to only | |
set the `output_range` for serving. | |
Returns: | |
An output tensor with the same dimensions as input/query tensor. | |
""" | |
if isinstance(inputs, (list, tuple)): | |
if len(inputs) == 2: | |
input_tensor, attention_mask = inputs | |
key_value = None | |
elif len(inputs) == 3: | |
input_tensor, key_value, attention_mask = inputs | |
else: | |
raise ValueError("Unexpected inputs to %s with length at %d" % | |
(self.__class__, len(inputs))) | |
else: | |
input_tensor, key_value, attention_mask = (inputs, None, None) | |
if output_range is None: | |
output_range = self._output_range | |
if output_range: | |
if self._norm_first: | |
source_tensor = input_tensor[:, 0:output_range, :] | |
input_tensor = self._attention_layer_norm(input_tensor) | |
if key_value is not None: | |
key_value = self._attention_layer_norm_kv(key_value) | |
target_tensor = input_tensor[:, 0:output_range, :] | |
if attention_mask is not None: | |
attention_mask = attention_mask[:, 0:output_range, :] | |
else: | |
if self._norm_first: | |
source_tensor = input_tensor | |
input_tensor = self._attention_layer_norm(input_tensor) | |
if key_value is not None: | |
key_value = self._attention_layer_norm_kv(key_value) | |
target_tensor = input_tensor | |
if key_value is None: | |
key_value = input_tensor | |
if self._return_attention_scores: | |
attention_output, attention_scores = self._attention_layer( | |
query=target_tensor, | |
value=key_value, | |
attention_mask=attention_mask, | |
return_attention_scores=True) | |
else: | |
attention_output = self._attention_layer( | |
query=target_tensor, value=key_value, attention_mask=attention_mask) | |
attention_output = self._attention_dropout(attention_output) | |
if self._norm_first: | |
# Important to not combine `self._norm_first` and | |
# `self._use_query_residual` into one if clause because else is only for | |
# `_norm_first == False`. | |
if self._use_query_residual: | |
attention_output = source_tensor + attention_output | |
else: | |
if self._use_query_residual: | |
attention_output = target_tensor + attention_output | |
attention_output = self._attention_layer_norm(attention_output) | |
if self._norm_first: | |
source_attention_output = attention_output | |
attention_output = self._output_layer_norm(attention_output) | |
inner_output = self._intermediate_dense(attention_output) | |
inner_output = self._intermediate_activation_layer(inner_output) | |
inner_output = self._inner_dropout_layer(inner_output) | |
layer_output = self._output_dense(inner_output) | |
layer_output = self._output_dropout(layer_output) | |
if self._norm_first: | |
layer_output = source_attention_output + layer_output | |
else: | |
# During mixed precision training, layer norm output is always fp32 for | |
# now. Casts fp32 for the subsequent add. | |
layer_output = tf.cast(layer_output, tf.float32) | |
layer_output = self._output_layer_norm(layer_output + attention_output) | |
if self._return_attention_scores: | |
return layer_output, attention_scores | |
else: | |
return layer_output | |