|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Keras-based transformer block layer.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
import gin |
|
import tensorflow as tf |
|
|
|
from official.nlp.modeling.layers import attention |
|
from official.nlp.modeling.layers import dense_einsum |
|
from official.nlp.modeling.layers import multi_channel_attention |
|
from official.nlp.modeling.layers.util import tf_function_if_eager |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable(package="Text") |
|
class Transformer(tf.keras.layers.Layer): |
|
"""Transformer layer. |
|
|
|
This layer implements the Transformer from "Attention Is All You Need". |
|
(https://arxiv.org/abs/1706.03762). |
|
|
|
Arguments: |
|
num_attention_heads: Number of attention heads. |
|
intermediate_size: Size of the intermediate layer. |
|
intermediate_activation: Activation for the intermediate layer. |
|
dropout_rate: Dropout probability for the post-attention and output dropout. |
|
attention_dropout_rate: Dropout probability for within the attention layer. |
|
output_range: the sequence output range, [0, output_range) by slicing the |
|
target sequence. `None` means the target sequence is not sliced. |
|
kernel_initializer: Initializer for dense layer kernels. |
|
bias_initializer: Initializer for dense layer biases. |
|
kernel_regularizer: Regularizer for dense layer kernels. |
|
bias_regularizer: Regularizer for dense layer biases. |
|
activity_regularizer: Regularizer for dense layer activity. |
|
kernel_constraint: Constraint for dense layer kernels. |
|
bias_constraint: Constraint for dense layer kernels. |
|
""" |
|
|
|
def __init__(self, |
|
num_attention_heads, |
|
intermediate_size, |
|
intermediate_activation, |
|
dropout_rate=0.0, |
|
attention_dropout_rate=0.0, |
|
output_range=None, |
|
kernel_initializer="glorot_uniform", |
|
bias_initializer="zeros", |
|
kernel_regularizer=None, |
|
bias_regularizer=None, |
|
activity_regularizer=None, |
|
kernel_constraint=None, |
|
bias_constraint=None, |
|
**kwargs): |
|
super(Transformer, self).__init__(**kwargs) |
|
|
|
self._num_heads = num_attention_heads |
|
self._intermediate_size = intermediate_size |
|
self._intermediate_activation = intermediate_activation |
|
self._attention_dropout_rate = attention_dropout_rate |
|
self._dropout_rate = dropout_rate |
|
self._output_range = output_range |
|
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) |
|
self._bias_initializer = tf.keras.initializers.get(bias_initializer) |
|
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) |
|
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) |
|
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer) |
|
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) |
|
self._bias_constraint = tf.keras.constraints.get(bias_constraint) |
|
|
|
def build(self, input_shape): |
|
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape |
|
input_tensor_shape = tf.TensorShape(input_tensor) |
|
if len(input_tensor_shape) != 3: |
|
raise ValueError("TransformerLayer expects a three-dimensional input of " |
|
"shape [batch, sequence, width].") |
|
batch_size, sequence_length, hidden_size = input_tensor_shape |
|
|
|
if len(input_shape) == 2: |
|
mask_tensor_shape = tf.TensorShape(input_shape[1]) |
|
expected_mask_tensor_shape = tf.TensorShape( |
|
[batch_size, sequence_length, sequence_length]) |
|
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape): |
|
raise ValueError("When passing a mask tensor to TransformerLayer, the " |
|
"mask tensor must be of shape [batch, " |
|
"sequence_length, sequence_length] (here %s). Got a " |
|
"mask tensor of shape %s." % |
|
(expected_mask_tensor_shape, mask_tensor_shape)) |
|
if hidden_size % self._num_heads != 0: |
|
raise ValueError( |
|
"The input size (%d) is not a multiple of the number of attention " |
|
"heads (%d)" % (hidden_size, self._num_heads)) |
|
self._attention_head_size = int(hidden_size // self._num_heads) |
|
|
|
self._attention_layer = attention.MultiHeadAttention( |
|
num_heads=self._num_heads, |
|
key_size=self._attention_head_size, |
|
dropout=self._attention_dropout_rate, |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint, |
|
name="self_attention") |
|
|
|
self._attention_layer.build([input_tensor_shape] * 3) |
|
self._attention_output_dense = self._attention_layer._output_dense |
|
|
|
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) |
|
|
|
|
|
self._attention_layer_norm = ( |
|
tf.keras.layers.LayerNormalization( |
|
name="self_attention_layer_norm", |
|
axis=-1, |
|
epsilon=1e-12, |
|
dtype=tf.float32)) |
|
self._intermediate_dense = dense_einsum.DenseEinsum( |
|
output_shape=self._intermediate_size, |
|
activation=None, |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint, |
|
name="intermediate") |
|
policy = tf.keras.mixed_precision.experimental.global_policy() |
|
if policy.name == "mixed_bfloat16": |
|
|
|
|
|
|
|
policy = tf.float32 |
|
self._intermediate_activation_layer = tf.keras.layers.Activation( |
|
self._intermediate_activation, dtype=policy) |
|
self._output_dense = dense_einsum.DenseEinsum( |
|
output_shape=hidden_size, |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint, |
|
name="output") |
|
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) |
|
|
|
self._output_layer_norm = tf.keras.layers.LayerNormalization( |
|
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) |
|
|
|
super(Transformer, self).build(input_shape) |
|
|
|
def get_config(self): |
|
config = { |
|
"num_attention_heads": |
|
self._num_heads, |
|
"intermediate_size": |
|
self._intermediate_size, |
|
"intermediate_activation": |
|
self._intermediate_activation, |
|
"dropout_rate": |
|
self._dropout_rate, |
|
"attention_dropout_rate": |
|
self._attention_dropout_rate, |
|
"output_range": |
|
self._output_range, |
|
"kernel_initializer": |
|
tf.keras.initializers.serialize(self._kernel_initializer), |
|
"bias_initializer": |
|
tf.keras.initializers.serialize(self._bias_initializer), |
|
"kernel_regularizer": |
|
tf.keras.regularizers.serialize(self._kernel_regularizer), |
|
"bias_regularizer": |
|
tf.keras.regularizers.serialize(self._bias_regularizer), |
|
"activity_regularizer": |
|
tf.keras.regularizers.serialize(self._activity_regularizer), |
|
"kernel_constraint": |
|
tf.keras.constraints.serialize(self._kernel_constraint), |
|
"bias_constraint": |
|
tf.keras.constraints.serialize(self._bias_constraint) |
|
} |
|
base_config = super(Transformer, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
def call(self, inputs): |
|
if isinstance(inputs, (list, tuple)) and len(inputs) == 2: |
|
input_tensor, attention_mask = inputs |
|
else: |
|
input_tensor, attention_mask = (inputs, None) |
|
|
|
if self._output_range: |
|
target_tensor = input_tensor[:, 0:self._output_range, :] |
|
attention_mask = attention_mask[:, 0:self._output_range, :] |
|
else: |
|
target_tensor = input_tensor |
|
attention_inputs = [target_tensor, input_tensor] |
|
|
|
attention_output = self._attention_layer(attention_inputs, attention_mask) |
|
attention_output = self._attention_dropout(attention_output) |
|
attention_output = self._attention_layer_norm(target_tensor + |
|
attention_output) |
|
intermediate_output = self._intermediate_dense(attention_output) |
|
intermediate_output = self._intermediate_activation_layer( |
|
intermediate_output) |
|
layer_output = self._output_dense(intermediate_output) |
|
layer_output = self._output_dropout(layer_output) |
|
|
|
|
|
|
|
layer_output = tf.cast(layer_output, tf.float32) |
|
layer_output = self._output_layer_norm(layer_output + attention_output) |
|
|
|
return layer_output |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable(package="Text") |
|
@gin.configurable |
|
class CompiledTransformer(Transformer): |
|
|
|
@tf_function_if_eager(experimental_compile=True) |
|
def call(self, inputs): |
|
return super(CompiledTransformer, self).call(inputs) |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable(package="Text") |
|
class TransformerDecoderLayer(tf.keras.layers.Layer): |
|
"""Single transformer layer for decoder. |
|
|
|
It has three sub-layers: |
|
(1) a multi-head self-attention mechanism. |
|
(2) a encoder-decoder attention. |
|
(3) a positionwise fully connected feed-forward network. |
|
|
|
Arguments: |
|
num_attention_heads: Number of attention heads. |
|
intermediate_size: Size of the intermediate layer. |
|
intermediate_activation: Activation for the intermediate layer. |
|
dropout_rate: Dropout probability for the post-attention and output dropout. |
|
attention_dropout_rate: Dropout probability for within the attention layer. |
|
multi_channel_cross_attention: Whether to use `MultiChannelAttention` for |
|
cross-attention between target sequences and source sequences. |
|
kernel_initializer: Initializer for dense layer kernels. |
|
bias_initializer: Initializer for dense layer biases. |
|
kernel_regularizer: Regularizer for dense layer kernels. |
|
bias_regularizer: Regularizer for dense layer biases. |
|
activity_regularizer: Regularizer for dense layer activity. |
|
kernel_constraint: Constraint for dense layer kernels. |
|
bias_constraint: Constraint for dense layer kernels. |
|
""" |
|
|
|
def __init__(self, |
|
num_attention_heads, |
|
intermediate_size, |
|
intermediate_activation, |
|
dropout_rate=0.0, |
|
attention_dropout_rate=0.0, |
|
multi_channel_cross_attention=False, |
|
kernel_initializer="glorot_uniform", |
|
bias_initializer="zeros", |
|
kernel_regularizer=None, |
|
bias_regularizer=None, |
|
activity_regularizer=None, |
|
kernel_constraint=None, |
|
bias_constraint=None, |
|
**kwargs): |
|
super(TransformerDecoderLayer, self).__init__(**kwargs) |
|
self.num_attention_heads = num_attention_heads |
|
self.intermediate_size = intermediate_size |
|
self.intermediate_activation = tf.keras.activations.get( |
|
intermediate_activation) |
|
self.dropout_rate = dropout_rate |
|
self.attention_dropout_rate = attention_dropout_rate |
|
self.multi_channel_cross_attention = multi_channel_cross_attention |
|
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) |
|
self._bias_initializer = tf.keras.initializers.get(bias_initializer) |
|
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) |
|
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) |
|
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer) |
|
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) |
|
self._bias_constraint = tf.keras.constraints.get(bias_constraint) |
|
if self.multi_channel_cross_attention: |
|
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention |
|
else: |
|
self._cross_attention_cls = attention.MultiHeadAttention |
|
|
|
def build(self, input_shape): |
|
target_tensor_shape = tf.TensorShape(input_shape[0]) |
|
if len(target_tensor_shape) != 3: |
|
raise ValueError("TransformerLayer expects a three-dimensional input of " |
|
"shape [batch, sequence, width].") |
|
hidden_size = target_tensor_shape[2] |
|
if hidden_size % self.num_attention_heads != 0: |
|
raise ValueError( |
|
"The hidden size (%d) is not a multiple of the number of attention " |
|
"heads (%d)" % (hidden_size, self.num_attention_heads)) |
|
self.attention_head_size = int(hidden_size / self.num_attention_heads) |
|
|
|
self.self_attention = attention.CachedAttention( |
|
num_heads=self.num_attention_heads, |
|
key_size=self.attention_head_size, |
|
dropout=self.attention_dropout_rate, |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint, |
|
name="self_attention") |
|
self.self_attention_output_dense = dense_einsum.DenseEinsum( |
|
output_shape=hidden_size, |
|
num_summed_dimensions=2, |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint, |
|
name="self_attention_output") |
|
self.self_attention_dropout = tf.keras.layers.Dropout( |
|
rate=self.dropout_rate) |
|
self.self_attention_layer_norm = ( |
|
tf.keras.layers.LayerNormalization( |
|
name="self_attention_layer_norm", axis=-1, epsilon=1e-12)) |
|
|
|
self.encdec_attention = self._cross_attention_cls( |
|
num_heads=self.num_attention_heads, |
|
key_size=self.attention_head_size, |
|
dropout=self.attention_dropout_rate, |
|
output_shape=hidden_size, |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint, |
|
name="attention/encdec") |
|
|
|
self.encdec_attention_dropout = tf.keras.layers.Dropout( |
|
rate=self.dropout_rate) |
|
self.encdec_attention_layer_norm = ( |
|
tf.keras.layers.LayerNormalization( |
|
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12)) |
|
|
|
|
|
self.intermediate_dense = dense_einsum.DenseEinsum( |
|
output_shape=self.intermediate_size, |
|
activation=None, |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint, |
|
name="intermediate") |
|
self.intermediate_activation_layer = tf.keras.layers.Activation( |
|
self.intermediate_activation) |
|
self.output_dense = dense_einsum.DenseEinsum( |
|
output_shape=hidden_size, |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint, |
|
name="output") |
|
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) |
|
self.output_layer_norm = tf.keras.layers.LayerNormalization( |
|
name="output_layer_norm", axis=-1, epsilon=1e-12) |
|
super(TransformerDecoderLayer, self).build(input_shape) |
|
|
|
def common_layers_with_encoder(self): |
|
"""Gets layer objects that can make a Transformer encoder block.""" |
|
return [ |
|
self.self_attention, self.self_attention_layer_norm, |
|
self.intermediate_dense, self.output_dense, self.output_layer_norm |
|
] |
|
|
|
def call(self, inputs, cache=None, decode_loop_step=None): |
|
if self.multi_channel_cross_attention: |
|
if len(inputs) != 5: |
|
raise ValueError( |
|
"TransformerDecoderLayer must have 5 inputs, when it uses " |
|
"multi_channel_cross_attention. But it got: %d" % len(inputs)) |
|
elif len(inputs) != 4: |
|
raise ValueError( |
|
"TransformerDecoderLayer must have 4 inputs, but it got: %d" % |
|
len(inputs)) |
|
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] |
|
self_attention_inputs = [input_tensor, input_tensor] |
|
self_attention_output, cache = self.self_attention( |
|
self_attention_inputs, |
|
attention_mask=self_attention_mask, |
|
cache=cache, |
|
decode_loop_step=decode_loop_step) |
|
self_attention_output = self.self_attention_dropout(self_attention_output) |
|
self_attention_output = self.self_attention_layer_norm( |
|
input_tensor + self_attention_output) |
|
|
|
cross_attn_inputs = [self_attention_output, memory] |
|
if self.multi_channel_cross_attention: |
|
|
|
cross_attn_inputs.append(inputs[-1]) |
|
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask) |
|
attention_output = self.encdec_attention_dropout(attention_output) |
|
attention_output = self.encdec_attention_layer_norm(self_attention_output + |
|
attention_output) |
|
|
|
intermediate_output = self.intermediate_dense(attention_output) |
|
intermediate_output = self.intermediate_activation_layer( |
|
intermediate_output) |
|
layer_output = self.output_dense(intermediate_output) |
|
layer_output = self.output_dropout(layer_output) |
|
layer_output = self.output_layer_norm(layer_output + attention_output) |
|
return layer_output, cache |
|
|