deanna-emery's picture
updates
93528c6
raw
history blame
26.9 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implement Seq2Seq Transformer model by TF official NLP library.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
"""
import inspect
import math
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling.ops import beam_search
EOS_ID = 1
class Seq2SeqTransformer(tf_keras.Model):
"""Transformer model with Keras.
Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf
The Transformer model consists of an encoder and decoder. The input is an int
sequence (or a batch of sequences). The encoder produces a continuous
representation, and the decoder uses the encoder output to generate
probabilities for the output sequence.
"""
def __init__(self,
vocab_size=33708,
embedding_width=512,
dropout_rate=0.0,
padded_decode=False,
decode_max_length=None,
extra_decode_length=0,
beam_size=4,
alpha=0.6,
encoder_layer=None,
decoder_layer=None,
eos_id=EOS_ID,
**kwargs):
"""Initialize layers to build Transformer model.
Args:
vocab_size: Size of vocabulary.
embedding_width: Size of hidden layer for embedding.
dropout_rate: Dropout probability.
padded_decode: Whether to max_sequence_length padding is used. If set
False, max_sequence_length padding is not used.
decode_max_length: maximum number of steps to decode a sequence.
extra_decode_length: Beam search will run extra steps to decode.
beam_size: Number of beams for beam search
alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
eos_id: Id of end of sentence token.
**kwargs: other keyword arguments.
"""
super().__init__(**kwargs)
self._vocab_size = vocab_size
self._embedding_width = embedding_width
self._dropout_rate = dropout_rate
self._padded_decode = padded_decode
self._decode_max_length = decode_max_length
self._extra_decode_length = extra_decode_length
self._beam_size = beam_size
self._alpha = alpha
self._eos_id = eos_id
self.embedding_lookup = layers.OnDeviceEmbedding(
vocab_size=self._vocab_size,
embedding_width=self._embedding_width,
initializer=tf.random_normal_initializer(
mean=0., stddev=self._embedding_width**-0.5),
scale_factor=self._embedding_width**0.5)
self.encoder_layer = encoder_layer
self.decoder_layer = decoder_layer
self.position_embedding = layers.RelativePositionEmbedding(
hidden_size=self._embedding_width)
self.encoder_dropout = tf_keras.layers.Dropout(rate=self._dropout_rate)
self.decoder_dropout = tf_keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self):
config = {
"vocab_size": self._vocab_size,
"hidden_size": self._embedding_width,
"dropout_rate": self._dropout_rate,
"padded_decode": self._padded_decode,
"decode_max_length": self._decode_max_length,
"eos_id": self._eos_id,
"extra_decode_length": self._extra_decode_length,
"beam_size": self._beam_size,
"alpha": self._alpha,
"encoder_layer": self.encoder_layer,
"decoder_layer": self.decoder_layer,
}
base_config = super(Seq2SeqTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _embedding_linear(self, embedding_matrix, x):
"""Uses embeddings as linear transformation weights."""
embedding_matrix = tf.cast(embedding_matrix, dtype=self.compute_dtype)
x = tf.cast(x, dtype=self.compute_dtype)
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(x, embedding_matrix, transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size])
def _parse_inputs(self, inputs):
"""Parses the `call` inputs and returns an uniformed output."""
sources = inputs.get("inputs", None)
input_mask = inputs.get("input_masks", None)
embedded = inputs.get("embedded_inputs", None)
if sources is None and embedded is not None:
embedded_inputs = embedded
boolean_mask = input_mask
input_shape = tf_utils.get_shape_list(embedded, expected_rank=3)
source_dtype = embedded.dtype
elif sources is not None:
embedded_inputs = self.embedding_lookup(sources)
boolean_mask = tf.not_equal(sources, 0)
input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
source_dtype = sources.dtype
else:
raise KeyError(
"The call method expects either `inputs` or `embedded_inputs` and "
"`input_masks` as input features.")
return embedded_inputs, boolean_mask, input_shape, source_dtype
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
"""Calculate target logits or inferred target sequences.
Args:
inputs: a dictionary of tensors.
Feature `inputs` (optional): int tensor with shape
`[batch_size, input_length]`.
Feature `embedded_inputs` (optional): float tensor with shape
`[batch_size, input_length, embedding_width]`.
Feature `targets` (optional): None or int tensor with shape
`[batch_size, target_length]`.
Feature `input_masks` (optional): When providing the `embedded_inputs`,
the dictionary must provide a boolean mask marking the filled time
steps. The shape of the tensor is `[batch_size, input_length]`.
Either `inputs` or `embedded_inputs` and `input_masks` must be present
in the input dictionary. In the second case the projection of the
integer tokens to the transformer embedding space is skipped and
`input_masks` is expected to be present.
Returns:
If targets is defined, then return logits for each word in the target
sequence, which is a float tensor with shape
`(batch_size, target_length, vocab_size)`. If target is `None`, then
generate output sequence one token at a time and
returns a dictionary {
outputs: `(batch_size, decoded_length)`
scores: `(batch_size, 1)`}
Even when `float16` is used, the output tensor(s) are always `float32`.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
targets = inputs.get("targets", None)
(embedded_inputs, boolean_mask,
input_shape, source_dtype) = self._parse_inputs(inputs)
embedding_mask = tf.cast(boolean_mask, embedded_inputs.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
# Attention_mask generation.
attention_mask = tf.cast(
tf.reshape(boolean_mask, [input_shape[0], 1, input_shape[1]]),
dtype=source_dtype)
broadcast_ones = tf.ones(
shape=[input_shape[0], input_shape[1], 1], dtype=source_dtype)
attention_mask = broadcast_ones * attention_mask
pos_encoding = self.position_embedding(embedded_inputs)
pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
encoder_inputs = embedded_inputs + pos_encoding
encoder_inputs = self.encoder_dropout(encoder_inputs)
encoder_outputs = self.encoder_layer(
encoder_inputs, attention_mask=attention_mask)
if targets is None:
if self._padded_decode:
max_decode_length = self._decode_max_length
else:
max_decode_length = self._decode_max_length or (
tf.shape(encoder_outputs)[1] + self._extra_decode_length)
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
batch_size = tf.shape(encoder_outputs)[0]
# Create initial set of IDs that will be passed to symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer.
init_decode_length = (max_decode_length if self._padded_decode else 0)
num_heads = self.decoder_layer.num_attention_heads
dim_per_head = self._embedding_width // num_heads
# Cache dtype needs to match beam_search dtype.
# pylint: disable=g-complex-comprehension
cache = {
str(layer): {
"key":
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self.compute_dtype),
"value":
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self.compute_dtype)
} for layer in range(self.decoder_layer.num_layers)
}
# pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache.
encoder_outputs = tf.cast(encoder_outputs, dtype=self.compute_dtype)
attention_mask = tf.cast(
tf.reshape(boolean_mask, [input_shape[0], 1, input_shape[1]]),
dtype=self.compute_dtype)
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_mask"] = attention_mask
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self._vocab_size,
beam_size=self._beam_size,
alpha=self._alpha,
max_decode_length=max_decode_length,
eos_id=self._eos_id,
padded_decode=self._padded_decode,
dtype=self.compute_dtype)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
top_scores = scores[:, 0]
return {"outputs": top_decoded_ids, "scores": top_scores}
# Shift targets to the right, and remove the last element
targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1]
decoder_inputs = self.embedding_lookup(targets)
length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
decoder_inputs += pos_encoding
decoder_inputs = self.decoder_dropout(decoder_inputs)
decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
self_attention_mask = tf.linalg.band_part(tf.ones([length, length]), -1, 0)
self_attention_mask = tf.reshape(self_attention_mask, [1, length, length])
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = tf.cast(
tf.expand_dims(boolean_mask, axis=1), dtype=source_dtype)
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
outputs = self.decoder_layer(
decoder_inputs,
encoder_outputs,
self_attention_mask=self_attention_mask,
cross_attention_mask=attention_mask)
logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs)
# Model outputs should be float32 to avoid numeric issues.
# https://www.tensorflow.org/guide/mixed_precision#building_the_model
logits = tf.cast(logits, tf.float32)
return logits
def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1)
timing_signal = tf.cast(timing_signal, dtype=self.compute_dtype)
decoder_self_attention_mask = tf.linalg.band_part(
tf.ones([max_decode_length, max_decode_length],
dtype=self.compute_dtype), -1, 0)
decoder_self_attention_mask = tf.reshape(
decoder_self_attention_mask, [1, max_decode_length, max_decode_length])
def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences. int tensor with shape `(batch_size *
beam_size, i + 1)`.
i: Loop index.
cache: Dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
Returns:
Tuple of
(logits with shape `(batch_size * beam_size, vocab_size)`,
updated cache values)
"""
# Set decoder input to the last generated IDs
decoder_input = ids[:, -1:]
# Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input = self.embedding_lookup(decoder_input)
decoder_input += timing_signal[i]
if self._padded_decode:
# indexing does not work on TPU.
bias_shape = decoder_self_attention_mask.shape.as_list()
self_attention_mask = tf.slice(decoder_self_attention_mask, [0, i, 0],
[bias_shape[0], 1, bias_shape[2]])
else:
self_attention_mask = decoder_self_attention_mask[:, i:i + 1, :i + 1]
decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = cache.get("encoder_decoder_attention_mask")
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
decoder_outputs = self.decoder_layer(
decoder_input,
cache.get("encoder_outputs"),
self_attention_mask=self_attention_mask,
cross_attention_mask=attention_mask,
cache=cache,
decode_loop_step=i if self._padded_decode else None)
decoder_outputs = tf.cast(decoder_outputs, dtype=self.compute_dtype)
logits = self._embedding_linear(self.embedding_lookup.embeddings,
decoder_outputs)
logits = tf.squeeze(logits, axis=[1])
return logits, cache
return symbols_to_logits_fn
class TransformerEncoder(tf_keras.layers.Layer):
"""Transformer encoder.
Transformer encoder is made up of N identical layers. Each layer is composed
of the sublayers:
1. Self-attention layer
2. Feedforward network (which is 2 fully-connected layers)
"""
def __init__(self,
num_layers=6,
num_attention_heads=8,
intermediate_size=2048,
activation="relu",
dropout_rate=0.0,
attention_dropout_rate=0.0,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.0,
**kwargs):
"""Initialize a Transformer encoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf_keras.layers.Layer.
"""
super(TransformerEncoder, self).__init__(**kwargs)
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._activation = activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
def build(self, input_shape):
"""Implements build() for the layer."""
self.encoder_layers = []
for i in range(self.num_layers):
self.encoder_layers.append(
layers.TransformerEncoderBlock(
num_attention_heads=self.num_attention_heads,
inner_dim=self._intermediate_size,
inner_activation=self._activation,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
inner_dropout=self._intermediate_dropout,
attention_initializer=attention_initializer(input_shape[2]),
name=("layer_%d" % i)))
self.output_normalization = tf_keras.layers.LayerNormalization(
epsilon=self._norm_epsilon, dtype="float32")
super(TransformerEncoder, self).build(input_shape)
def get_config(self):
config = {
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size,
"activation": self._activation,
"dropout_rate": self._dropout_rate,
"attention_dropout_rate": self._attention_dropout_rate,
"use_bias": self._use_bias,
"norm_first": self._norm_first,
"norm_epsilon": self._norm_epsilon,
"intermediate_dropout": self._intermediate_dropout
}
base_config = super(TransformerEncoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, encoder_inputs, attention_mask=None):
"""Return the output of the encoder.
Args:
encoder_inputs: A tensor with shape `(batch_size, input_length,
hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
Returns:
Output of encoder which is a `float32` tensor with shape
`(batch_size, input_length, hidden_size)`.
"""
for layer_idx in range(self.num_layers):
encoder_inputs = self.encoder_layers[layer_idx](
[encoder_inputs, attention_mask])
output_tensor = encoder_inputs
output_tensor = self.output_normalization(output_tensor)
return output_tensor
class TransformerDecoder(tf_keras.layers.Layer):
"""Transformer decoder.
Like the encoder, the decoder is made up of N identical layers.
Each layer is composed of the sublayers:
1. Self-attention layer
2. Multi-headed attention layer combining encoder outputs with results from
the previous self-attention layer.
3. Feedforward network (2 fully-connected layers)
"""
def __init__(self,
num_layers=6,
num_attention_heads=8,
intermediate_size=2048,
activation="relu",
dropout_rate=0.0,
attention_dropout_rate=0.0,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.0,
self_attention_cls=None,
cross_attention_cls=None,
**kwargs):
"""Initialize a Transformer decoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set `False`, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
self_attention_cls: An optional class to use for self attention
or a function that provides the class per layer.
cross_attention_cls: An optional class to use for cross attention
or a function that provides the class per layer.
**kwargs: key word arguemnts passed to tf_keras.layers.Layer.
"""
super(TransformerDecoder, self).__init__(**kwargs)
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._activation = activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
self._self_attention_cls = self_attention_cls
self._cross_attention_cls = cross_attention_cls
def build(self, input_shape):
"""Implements build() for the layer."""
def _select_attention_cls(attention_cls, index):
cls = None
if attention_cls is not None:
cls = (
attention_cls(index)
if inspect.isfunction(attention_cls)
else attention_cls
)
return cls
self.decoder_layers = []
for i in range(self.num_layers):
self_attention_cls = _select_attention_cls(self._self_attention_cls, i)
cross_attention_cls = _select_attention_cls(self._cross_attention_cls, i)
self.decoder_layers.append(
layers.TransformerDecoderBlock(
num_attention_heads=self.num_attention_heads,
intermediate_size=self._intermediate_size,
intermediate_activation=self._activation,
dropout_rate=self._dropout_rate,
attention_dropout_rate=self._attention_dropout_rate,
use_bias=self._use_bias,
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
intermediate_dropout=self._intermediate_dropout,
attention_initializer=attention_initializer(input_shape[2]),
name=("layer_%d" % i),
self_attention_cls=self_attention_cls,
cross_attention_cls=cross_attention_cls))
self.output_normalization = tf_keras.layers.LayerNormalization(
epsilon=1e-6, dtype="float32")
super(TransformerDecoder, self).build(input_shape)
def get_config(self):
config = {
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size,
"activation": self._activation,
"dropout_rate": self._dropout_rate,
"attention_dropout_rate": self._attention_dropout_rate,
"use_bias": self._use_bias,
"norm_first": self._norm_first,
"norm_epsilon": self._norm_epsilon,
"intermediate_dropout": self._intermediate_dropout,
"self_attention_cls": self._self_attention_cls,
"cross_attention_cls": self._cross_attention_cls,
}
base_config = super(TransformerDecoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
target,
memory,
self_attention_mask=None,
cross_attention_mask=None,
cache=None,
decode_loop_step=None,
return_all_decoder_outputs=False):
"""Return the output of the decoder layer stacks.
Args:
target: A tensor with shape `(batch_size, target_length, hidden_size)`.
memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
self_attention_mask: A tensor with shape `(batch_size, target_len,
target_length)`, the mask for decoder self-attention layer.
cross_attention_mask: A tensor with shape `(batch_size, target_length,
input_length)` which is the mask for encoder-decoder attention layer.
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
"v": A tensor with shape `(batch_size, i, value_channels)`},
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
return_all_decoder_outputs: Return all decoder layer outputs.
Note that the outputs are layer normed.
This is useful when introducing per layer auxiliary loss.
Returns:
Output of decoder.
float32 tensor with shape `(batch_size, target_length, hidden_size`).
"""
output_tensor = target
decoder_outputs = []
for layer_idx in range(self.num_layers):
transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask
]
# Gets the cache for decoding.
if cache is None:
output_tensor, _ = self.decoder_layers[layer_idx](transformer_inputs)
else:
cache_layer_idx = str(layer_idx)
output_tensor, cache[cache_layer_idx] = self.decoder_layers[layer_idx](
transformer_inputs,
cache=cache[cache_layer_idx],
decode_loop_step=decode_loop_step)
if return_all_decoder_outputs:
decoder_outputs.append(self.output_normalization(output_tensor))
if return_all_decoder_outputs:
return decoder_outputs
else:
return self.output_normalization(output_tensor)
def attention_initializer(hidden_size):
"""Initializer for attention layers in Seq2SeqTransformer."""
hidden_size = int(hidden_size)
limit = math.sqrt(6.0 / (hidden_size + hidden_size))
return tf_keras.initializers.RandomUniform(minval=-limit, maxval=limit)