Spaces:
Sleeping
Sleeping
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Keras-based transformer block layer.""" | |
# pylint: disable=g-classes-have-attributes | |
from absl import logging | |
import gin | |
import tensorflow as tf, tf_keras | |
from official.modeling import tf_utils | |
from official.nlp.modeling.layers import attention | |
from official.nlp.modeling.layers import multi_channel_attention | |
from official.nlp.modeling.layers import transformer_encoder_block | |
from official.nlp.modeling.layers.util import tf_function_if_eager | |
class Transformer(transformer_encoder_block.TransformerEncoderBlock): | |
"""Transformer layer. | |
This layer implements the Transformer from "Attention Is All You Need". | |
(https://arxiv.org/abs/1706.03762). | |
**Warning: this layer is deprecated. Please don't use it. Use the | |
`TransformerEncoderBlock` layer instead.** | |
Args: | |
num_attention_heads: Number of attention heads. | |
intermediate_size: Size of the intermediate layer. | |
intermediate_activation: Activation for the intermediate layer. | |
dropout_rate: Dropout probability for the post-attention and output dropout. | |
attention_dropout_rate: Dropout probability for within the attention layer. | |
output_range: the sequence output range, [0, output_range) by slicing the | |
target sequence. `None` means the target sequence is not sliced. | |
kernel_initializer: Initializer for dense layer kernels. | |
bias_initializer: Initializer for dense layer biases. | |
kernel_regularizer: Regularizer for dense layer kernels. | |
bias_regularizer: Regularizer for dense layer biases. | |
activity_regularizer: Regularizer for dense layer activity. | |
kernel_constraint: Constraint for dense layer kernels. | |
bias_constraint: Constraint for dense layer kernels. | |
use_bias: Whether to enable use_bias in attention layer. If set False, | |
use_bias in attention layer is disabled. | |
norm_first: Whether to normalize inputs to attention and intermediate dense | |
layers. If set False, output of attention and intermediate dense layers is | |
normalized. | |
norm_epsilon: Epsilon value to initialize normalization layers. | |
intermediate_dropout: Dropout probability for intermediate_dropout_layer. | |
attention_initializer: Initializer for kernels of attention layers. If set | |
`None`, attention layers use kernel_initializer as initializer for kernel. | |
""" | |
def __init__(self, | |
num_attention_heads, | |
intermediate_size, | |
intermediate_activation, | |
dropout_rate=0.0, | |
attention_dropout_rate=0.0, | |
output_range=None, | |
kernel_initializer="glorot_uniform", | |
bias_initializer="zeros", | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
bias_constraint=None, | |
use_bias=True, | |
norm_first=False, | |
norm_epsilon=1e-12, | |
intermediate_dropout=0.0, | |
attention_initializer=None, | |
**kwargs): | |
super().__init__( | |
num_attention_heads=num_attention_heads, | |
inner_dim=intermediate_size, | |
inner_activation=intermediate_activation, | |
output_dropout=dropout_rate, | |
attention_dropout=attention_dropout_rate, | |
output_range=output_range, | |
kernel_initializer=kernel_initializer, | |
bias_initializer=bias_initializer, | |
kernel_regularizer=kernel_regularizer, | |
bias_regularizer=bias_regularizer, | |
activity_regularizer=activity_regularizer, | |
kernel_constraint=kernel_constraint, | |
bias_constraint=bias_constraint, | |
use_bias=use_bias, | |
norm_first=norm_first, | |
norm_epsilon=norm_epsilon, | |
inner_dropout=intermediate_dropout, | |
attention_initializer=attention_initializer, | |
**kwargs) | |
logging.warning("The `Transformer` layer is deprecated. Please directly " | |
"use `TransformerEncoderBlock`.") | |
def get_config(self): | |
return { | |
"num_attention_heads": self._num_heads, | |
"intermediate_size": self._inner_dim, | |
"intermediate_activation": self._inner_activation, | |
"dropout_rate": self._output_dropout_rate, | |
"attention_dropout_rate": self._attention_dropout_rate, | |
"output_range": self._output_range, | |
"kernel_initializer": tf_utils.serialize_initializer( | |
self._kernel_initializer, use_legacy_format=True | |
), | |
"bias_initializer": tf_utils.serialize_initializer( | |
self._bias_initializer, use_legacy_format=True | |
), | |
"kernel_regularizer": tf_utils.serialize_regularizer( | |
self._kernel_regularizer, use_legacy_format=True | |
), | |
"bias_regularizer": tf_utils.serialize_regularizer( | |
self._bias_regularizer, use_legacy_format=True | |
), | |
"activity_regularizer": tf_utils.serialize_regularizer( | |
self._activity_regularizer, use_legacy_format=True | |
), | |
"kernel_constraint": tf_utils.serialize_constraint( | |
self._kernel_constraint, use_legacy_format=True | |
), | |
"bias_constraint": tf_utils.serialize_constraint( | |
self._bias_constraint, use_legacy_format=True | |
), | |
"use_bias": self._use_bias, | |
"norm_first": self._norm_first, | |
"norm_epsilon": self._norm_epsilon, | |
"intermediate_dropout": self._inner_dropout, | |
"attention_initializer": tf_utils.serialize_initializer( | |
self._attention_initializer, use_legacy_format=True | |
), | |
} | |
class CompiledTransformer(Transformer): | |
def call(self, inputs): | |
return super().call(inputs) | |
class TransformerDecoderBlock(tf_keras.layers.Layer): | |
"""Single transformer layer for decoder. | |
It has three sub-layers: | |
(1) a multi-head self-attention mechanism. | |
(2) a encoder-decoder attention. | |
(3) a positionwise fully connected feed-forward network. | |
Args: | |
num_attention_heads: Number of attention heads. | |
intermediate_size: Size of the intermediate layer. | |
intermediate_activation: Activation for the intermediate layer. | |
dropout_rate: Dropout probability for the post-attention and output dropout. | |
attention_dropout_rate: Dropout probability for within the attention layer. | |
multi_channel_cross_attention: Whether to use `MultiChannelAttention` for | |
cross-attention between target sequences and source sequences. | |
kernel_initializer: Initializer for dense layer kernels. | |
bias_initializer: Initializer for dense layer biases. | |
kernel_regularizer: Regularizer for dense layer kernels. | |
bias_regularizer: Regularizer for dense layer biases. | |
activity_regularizer: Regularizer for dense layer activity. | |
kernel_constraint: Constraint for dense layer kernels. | |
bias_constraint: Constraint for dense layer kernels. | |
use_bias: Whether to enable use_bias in attention layer. If set False, | |
use_bias in attention layer is disabled. | |
norm_first: Whether to normalize inputs to attention and intermediate dense | |
layers. If set False, output of attention and intermediate dense layers is | |
normalized. | |
norm_epsilon: Epsilon value to initialize normalization layers. | |
intermediate_dropout: Dropout probability for intermediate_dropout_layer. | |
attention_initializer: Initializer for kernels of attention layers. If set | |
`None`, attention layers use kernel_initializer as initializer for kernel. | |
self_attention_cls: An optional class to use for self attention. | |
cross_attention_cls: An optional class to use for cross attention. | |
""" | |
def __init__(self, | |
num_attention_heads, | |
intermediate_size, | |
intermediate_activation, | |
dropout_rate=0.0, | |
attention_dropout_rate=0.0, | |
multi_channel_cross_attention=False, | |
kernel_initializer="glorot_uniform", | |
bias_initializer="zeros", | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
bias_constraint=None, | |
use_bias=True, | |
norm_first=False, | |
norm_epsilon=1e-12, | |
intermediate_dropout=0.0, | |
attention_initializer=None, | |
self_attention_cls=None, | |
cross_attention_cls=None, | |
**kwargs): | |
super().__init__(**kwargs) | |
self.num_attention_heads = num_attention_heads | |
self.intermediate_size = intermediate_size | |
self.intermediate_activation = tf_keras.activations.get( | |
intermediate_activation) | |
self.dropout_rate = dropout_rate | |
self.attention_dropout_rate = attention_dropout_rate | |
self.multi_channel_cross_attention = multi_channel_cross_attention | |
self._kernel_initializer = tf_keras.initializers.get(kernel_initializer) | |
self._bias_initializer = tf_keras.initializers.get(bias_initializer) | |
self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer) | |
self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer) | |
self._activity_regularizer = tf_keras.regularizers.get(activity_regularizer) | |
self._kernel_constraint = tf_keras.constraints.get(kernel_constraint) | |
self._bias_constraint = tf_keras.constraints.get(bias_constraint) | |
self._use_bias = use_bias | |
self._norm_first = norm_first | |
self._norm_epsilon = norm_epsilon | |
self._intermediate_dropout = intermediate_dropout | |
if attention_initializer: | |
self._attention_initializer = tf_keras.initializers.get( | |
attention_initializer) | |
else: | |
self._attention_initializer = tf_utils.clone_initializer( | |
self._kernel_initializer) | |
self._self_attention_cls = self_attention_cls or attention.CachedAttention | |
if cross_attention_cls is not None: | |
self._cross_attention_cls = cross_attention_cls | |
if self.multi_channel_cross_attention: | |
logging.warning( | |
"%s will be used for cross attention", cross_attention_cls | |
) | |
elif self.multi_channel_cross_attention: | |
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention | |
else: | |
self._cross_attention_cls = attention.MultiHeadAttention | |
def build(self, input_shape): | |
target_tensor_shape = tf.TensorShape(input_shape[0]) | |
if len(target_tensor_shape.as_list()) != 3: | |
raise ValueError("TransformerLayer expects a three-dimensional input of " | |
"shape [batch, sequence, width].") | |
hidden_size = target_tensor_shape[2] | |
if hidden_size % self.num_attention_heads != 0: | |
raise ValueError( | |
"The hidden size (%d) is not a multiple of the number of attention " | |
"heads (%d)" % (hidden_size, self.num_attention_heads)) | |
self.attention_head_size = int(hidden_size) // self.num_attention_heads | |
common_kwargs = dict( | |
kernel_regularizer=self._kernel_regularizer, | |
bias_regularizer=self._bias_regularizer, | |
activity_regularizer=self._activity_regularizer, | |
kernel_constraint=self._kernel_constraint, | |
bias_constraint=self._bias_constraint) | |
# Self attention. | |
self.self_attention = self._self_attention_cls( | |
num_heads=self.num_attention_heads, | |
key_dim=self.attention_head_size, | |
dropout=self.attention_dropout_rate, | |
use_bias=self._use_bias, | |
kernel_initializer=tf_utils.clone_initializer( | |
self._attention_initializer), | |
bias_initializer=tf_utils.clone_initializer(self._bias_initializer), | |
name="self_attention", | |
**common_kwargs) | |
self.self_attention_output_dense = tf_keras.layers.EinsumDense( | |
"abc,cd->abd", | |
output_shape=(None, hidden_size), | |
bias_axes="d", | |
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), | |
bias_initializer=tf_utils.clone_initializer(self._bias_initializer), | |
name="output", | |
**common_kwargs) | |
self.self_attention_dropout = tf_keras.layers.Dropout( | |
rate=self.dropout_rate) | |
self.self_attention_layer_norm = ( | |
tf_keras.layers.LayerNormalization( | |
name="self_attention_layer_norm", | |
axis=-1, | |
epsilon=self._norm_epsilon, | |
dtype="float32")) | |
# Encoder-decoder attention. | |
self.encdec_attention = self._cross_attention_cls( | |
num_heads=self.num_attention_heads, | |
key_dim=self.attention_head_size, | |
dropout=self.attention_dropout_rate, | |
output_shape=hidden_size, | |
use_bias=self._use_bias, | |
kernel_initializer=tf_utils.clone_initializer( | |
self._attention_initializer), | |
bias_initializer=tf_utils.clone_initializer(self._bias_initializer), | |
name="attention/encdec", | |
**common_kwargs) | |
self.encdec_attention_dropout = tf_keras.layers.Dropout( | |
rate=self.dropout_rate) | |
self.encdec_attention_layer_norm = ( | |
tf_keras.layers.LayerNormalization( | |
name="attention/encdec_output_layer_norm", | |
axis=-1, | |
epsilon=self._norm_epsilon, | |
dtype="float32")) | |
# Feed-forward projection. | |
self.intermediate_dense = tf_keras.layers.EinsumDense( | |
"abc,cd->abd", | |
output_shape=(None, self.intermediate_size), | |
bias_axes="d", | |
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), | |
bias_initializer=tf_utils.clone_initializer(self._bias_initializer), | |
name="intermediate", | |
**common_kwargs) | |
self.intermediate_activation_layer = tf_keras.layers.Activation( | |
self.intermediate_activation) | |
self._intermediate_dropout_layer = tf_keras.layers.Dropout( | |
rate=self._intermediate_dropout) | |
self.output_dense = tf_keras.layers.EinsumDense( | |
"abc,cd->abd", | |
output_shape=(None, hidden_size), | |
bias_axes="d", | |
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), | |
bias_initializer=tf_utils.clone_initializer(self._bias_initializer), | |
name="output", | |
**common_kwargs) | |
self.output_dropout = tf_keras.layers.Dropout(rate=self.dropout_rate) | |
self.output_layer_norm = tf_keras.layers.LayerNormalization( | |
name="output_layer_norm", | |
axis=-1, | |
epsilon=self._norm_epsilon, | |
dtype="float32") | |
super().build(input_shape) | |
def get_config(self): | |
config = { | |
"num_attention_heads": self.num_attention_heads, | |
"intermediate_size": self.intermediate_size, | |
"intermediate_activation": self.intermediate_activation, | |
"dropout_rate": self.dropout_rate, | |
"attention_dropout_rate": self.attention_dropout_rate, | |
"multi_channel_cross_attention": self.multi_channel_cross_attention, | |
"kernel_initializer": tf_utils.serialize_initializer( | |
self._kernel_initializer, use_legacy_format=True | |
), | |
"bias_initializer": tf_utils.serialize_initializer( | |
self._bias_initializer, use_legacy_format=True | |
), | |
"kernel_regularizer": tf_utils.serialize_regularizer( | |
self._kernel_regularizer, use_legacy_format=True | |
), | |
"bias_regularizer": tf_utils.serialize_regularizer( | |
self._bias_regularizer, use_legacy_format=True | |
), | |
"activity_regularizer": tf_utils.serialize_regularizer( | |
self._activity_regularizer, use_legacy_format=True | |
), | |
"kernel_constraint": tf_utils.serialize_constraint( | |
self._kernel_constraint, use_legacy_format=True | |
), | |
"bias_constraint": tf_utils.serialize_constraint( | |
self._bias_constraint, use_legacy_format=True | |
), | |
"use_bias": self._use_bias, | |
"norm_first": self._norm_first, | |
"norm_epsilon": self._norm_epsilon, | |
"intermediate_dropout": self._intermediate_dropout, | |
"attention_initializer": tf_utils.serialize_initializer( | |
self._attention_initializer, use_legacy_format=True | |
), | |
"self_attention_cls": self._self_attention_cls, | |
"cross_attention_cls": self._cross_attention_cls, | |
} | |
base_config = super().get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def common_layers_with_encoder(self): | |
"""Gets layer objects that can make a Transformer encoder block.""" | |
return [ | |
self.self_attention, self.self_attention_layer_norm, | |
self.intermediate_dense, self.output_dense, self.output_layer_norm | |
] | |
def call(self, inputs, cache=None, decode_loop_step=None): | |
if self.multi_channel_cross_attention: | |
if len(inputs) != 5: | |
raise ValueError( | |
"TransformerDecoderBlock must have 5 inputs, when it uses " | |
"multi_channel_cross_attention. But it got: %d" % len(inputs)) | |
elif len(inputs) != 4: | |
raise ValueError( | |
"TransformerDecoderBlock must have 4 inputs, but it got: %d" % | |
len(inputs)) | |
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] | |
source_tensor = input_tensor | |
if self._norm_first: | |
input_tensor = self.self_attention_layer_norm(input_tensor) | |
self_attention_output, cache = self.self_attention( | |
query=input_tensor, | |
value=input_tensor, | |
attention_mask=self_attention_mask, | |
cache=cache, | |
decode_loop_step=decode_loop_step) | |
self_attention_output = self.self_attention_dropout(self_attention_output) | |
if self._norm_first: | |
self_attention_output = source_tensor + self_attention_output | |
else: | |
self_attention_output = self.self_attention_layer_norm( | |
input_tensor + self_attention_output) | |
if self._norm_first: | |
source_self_attention_output = self_attention_output | |
self_attention_output = self.encdec_attention_layer_norm( | |
self_attention_output) | |
cross_attn_inputs = dict( | |
query=self_attention_output, | |
value=memory, | |
attention_mask=attention_mask) | |
if self.multi_channel_cross_attention: | |
# Accesses the 5-th input tensor for the doc-attention probabilities. | |
cross_attn_inputs["context_attention_weights"] = inputs[-1] | |
attention_output = self.encdec_attention(**cross_attn_inputs) | |
attention_output = self.encdec_attention_dropout(attention_output) | |
if self._norm_first: | |
attention_output = source_self_attention_output + attention_output | |
else: | |
attention_output = self.encdec_attention_layer_norm( | |
self_attention_output + attention_output) | |
if self._norm_first: | |
source_attention_output = attention_output | |
attention_output = self.output_layer_norm(attention_output) | |
intermediate_output = self.intermediate_dense(attention_output) | |
intermediate_output = self.intermediate_activation_layer( | |
intermediate_output) | |
intermediate_output = self._intermediate_dropout_layer(intermediate_output) | |
layer_output = self.output_dense(intermediate_output) | |
layer_output = self.output_dropout(layer_output) | |
if self._norm_first: | |
layer_output = source_attention_output + layer_output | |
else: | |
layer_output = self.output_layer_norm(layer_output + attention_output) | |
return layer_output, cache | |