deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based Transformer XL layer."""
from absl import logging
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.nlp.modeling.layers import relative_attention
def _cache_memory(current_state, previous_state, memory_length, reuse_length=0):
"""Caches hidden states into memory.
Args:
current_state: `Tensor`, the current state.
previous_state: `Tensor`, the previous state.
memory_length: `int`, the number of tokens to cache.
reuse_length: `int`, the number of tokens in the current batch to be cached
and reused in the future.
Returns:
A `Tensor`, representing the cached state with stopped gradients.
"""
if memory_length is None or memory_length == 0:
return None
else:
if reuse_length > 0:
current_state = current_state[:, :reuse_length, :]
if previous_state is None:
new_mem = current_state[:, -memory_length:, :]
else:
new_mem = tf.concat(
[previous_state, current_state], 1)[:, -memory_length:, :]
return tf.stop_gradient(new_mem)
@tf_keras.utils.register_keras_serializable(package="Text")
class TransformerXLBlock(tf_keras.layers.Layer):
"""Transformer XL block.
This implements a Transformer XL block from "Transformer-XL: Attentive
Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
This block is further extended to allow for the Transformer-XL
re-parameterization in "XLNet: Generalized Autoregressive Pretraining for
Language Understanding" (https://arxiv.org/abs/1906.08237).
Given an input stream, this block computes attention, applies dropouts and
layer norms and feeds into the FFN network.
**Note: This layer is currently experimental.
Attributes:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_attention_heads: The number of attention heads.
head_size: The dimension size of each attention head.
inner_size: The inner size for the transformer layers.
dropout_rate: Dropout rate for the output of this layer.
attention_dropout_rate: Dropout rate on attention probabilities.
two_stream: Whether or not to use `TwoStreamRelativeAttention` used in the
XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
norm_epsilon: Epsilon value to initialize normalization layers.
inner_activation: The activation to use for the inner
FFN layers.
kernel_initializer: Initializer for dense layer kernels.
inner_dropout: Dropout probability for the inner dropout
layer.
"""
def __init__(self,
vocab_size,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
two_stream=False,
norm_epsilon=1e-12,
inner_activation="relu",
kernel_initializer="variance_scaling",
inner_dropout=0.0,
**kwargs):
"""Initializes TransformerXLBlock layer."""
super().__init__(**kwargs)
self._vocab_size = vocab_size
self._num_heads = num_attention_heads
self._head_size = head_size
self._hidden_size = hidden_size
self._inner_size = inner_size
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._inner_activation = inner_activation
self._norm_epsilon = norm_epsilon
self._kernel_initializer = kernel_initializer
self._inner_dropout = inner_dropout
self._two_stream = two_stream
if two_stream:
self._attention_layer_type = relative_attention.TwoStreamRelativeAttention
else:
self._attention_layer_type = relative_attention.MultiHeadRelativeAttention
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape.as_list()) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
if len(input_shape) == 2:
mask_tensor_shape = tf.TensorShape(input_shape[1])
expected_mask_tensor_shape = tf.TensorShape(
[batch_size, sequence_length, sequence_length])
if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape):
raise ValueError("When passing a mask tensor to TransformerXLBlock, "
"the mask tensor must be of shape [batch, "
"sequence_length, sequence_length] (here %s). Got a "
"mask tensor of shape %s." %
(expected_mask_tensor_shape, mask_tensor_shape))
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_layer = self._attention_layer_type(
num_heads=self._num_heads,
key_dim=self._head_size,
value_dim=self._head_size,
dropout=self._attention_dropout_rate,
use_bias=False,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="rel_attn")
self._attention_dropout = tf_keras.layers.Dropout(
rate=self._attention_dropout_rate)
self._attention_layer_norm = tf_keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
self._inner_dense = tf_keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._inner_size),
bias_axes="d",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="inner")
self._inner_activation_layer = tf_keras.layers.Activation(
self._inner_activation)
self._inner_dropout_layer = tf_keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf_keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer))
self._output_dropout = tf_keras.layers.Dropout(rate=self._dropout_rate)
self._output_layer_norm = tf_keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon)
super().build(input_shape)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"hidden_size":
self._hidden_size,
"num_attention_heads":
self._num_heads,
"head_size":
self._head_size,
"inner_size":
self._inner_size,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"two_stream":
self._two_stream,
"norm_epsilon":
self._norm_epsilon,
"inner_activation":
self._inner_activation,
"kernel_initializer":
self._kernel_initializer,
"inner_dropout":
self._inner_dropout,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
content_stream,
content_attention_bias,
positional_attention_bias,
relative_position_encoding=None,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
state=None,
content_attention_mask=None,
query_stream=None,
query_attention_mask=None,
target_mapping=None):
"""Implements `call` for the Layer.
Args:
content_stream: `Tensor`, the input content stream. This is the standard
input to Transformer XL and is commonly referred to as `h` in XLNet.
content_attention_bias: Bias `Tensor` for content based attention of shape
`[num_heads, dim]`.
positional_attention_bias: Bias `Tensor` for position based attention of
shape `[num_heads, dim]`.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet,
but not in Transformer XL.
segment_encoding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used
in XLNet, but not in Transformer XL.
segment_attention_bias: Optional bias `Tensor` for segment based attention
of shape `[num_heads, dim]`.
state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
content_attention_mask: Optional `Tensor` representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
query_stream: Optional `Tensor`, the query stream. This is introduced in
`TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if
`two_stream` is `False`.
query_attention_mask: Optional `Tensor` representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
target_mapping: Optional `Tensor` representing the target mapping when
calculating query attention.
Returns:
A `dict` object, containing the key value pairs for `content_attention`
and (if `two_stream` is `True`) `query_attention`.
"""
if not self._two_stream and query_stream is not None:
logging.warning("`query_stream` was provided but two stream attention is "
"disabled. `query_stream` will be ignored.")
if self._two_stream:
attention_kwargs = dict(
content_stream=content_stream,
query_stream=query_stream,
query_attention_mask=query_attention_mask,
target_mapping=target_mapping,
content_attention_mask=content_attention_mask)
else:
attention_kwargs = dict(
query=content_stream,
value=content_stream,
key=content_stream,
attention_mask=content_attention_mask)
common_attention_kwargs = dict(
content_attention_bias=content_attention_bias,
relative_position_encoding=relative_position_encoding,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
state=state)
attention_kwargs.update(common_attention_kwargs)
attention_output = self._attention_layer(**attention_kwargs)
if self._two_stream:
attention_streams = attention_output
input_streams = [content_stream, query_stream]
else:
attention_streams = [attention_output]
input_streams = [content_stream]
attention_keys = ["content_attention", "query_attention"]
attention_output = {}
for attention_stream, input_stream, attention_key in zip(
attention_streams, input_streams, attention_keys):
attention_stream = self._attention_dropout(attention_stream)
attention_stream = self._attention_layer_norm(
attention_stream + input_stream)
inner_output = self._inner_dense(attention_stream)
inner_output = self._inner_activation_layer(
inner_output)
inner_output = self._inner_dropout_layer(
inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
layer_output = self._output_layer_norm(layer_output + attention_stream)
attention_output[attention_key] = layer_output
return attention_output
class TransformerXL(tf_keras.layers.Layer):
"""Transformer XL.
This layer combines multiple Transformer XL blocks from "Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
This layer handles the attention biases as well as memory caching and reuse
as in Transformer XL and XLNet.
Attributes:
vocab_size: The number of tokens in vocabulary.
num_layers: The number of layers.
hidden_size: The hidden size.
num_attention_heads: The number of attention heads.
head_size: The dimension size of each attention head.
inner_size: The hidden size in feed-forward layers.
dropout_rate: Dropout rate used in each Transformer XL block.
attention_dropout_rate: Dropout rate on attention probabilities.
two_stream: Whether or not to use `TwoStreamRelativeAttention` used
in the XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
initializer: The initializer to use for attention biases.
tie_attention_biases: Whether or not to tie biases together. If `True`, then
each Transformer XL block shares the same trainable attention bias. If
`False`, then each block has its own attention bias. This is usually set
to `True`.
memory_length: The number of tokens to cache.
reuse_length: The number of tokens in the current batch to be cached
and reused in the future.
inner_activation: The activation to use in the inner layers
for Transformer XL blocks. Typically "relu" or "gelu".
"""
def __init__(self,
vocab_size,
num_layers,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
initializer,
two_stream=False,
tie_attention_biases=True,
memory_length=None,
reuse_length=None,
inner_activation="relu",
**kwargs):
"""Initializes TransformerXL."""
super().__init__(**kwargs)
self._vocab_size = vocab_size
self._initializer = initializer
self._num_layers = num_layers
self._hidden_size = hidden_size
self._num_attention_heads = num_attention_heads
self._head_size = head_size
self._inner_size = inner_size
self._inner_activation = inner_activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._tie_attention_biases = tie_attention_biases
self._two_stream = two_stream
self._memory_length = memory_length
self._reuse_length = reuse_length
if self._tie_attention_biases:
attention_bias_shape = [self._num_attention_heads, self._head_size]
else:
attention_bias_shape = [self._num_layers, self._num_attention_heads,
self._head_size]
self.content_attention_bias = self.add_weight(
"content_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=tf_utils.clone_initializer(self._initializer))
self.positional_attention_bias = self.add_weight(
"positional_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=tf_utils.clone_initializer(self._initializer))
self.segment_attention_bias = self.add_weight(
"segment_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=tf_utils.clone_initializer(self._initializer))
self.transformer_xl_layers = []
for i in range(self._num_layers):
self.transformer_xl_layers.append(
TransformerXLBlock(
vocab_size=self._vocab_size,
hidden_size=self._head_size * self._num_attention_heads,
num_attention_heads=self._num_attention_heads,
head_size=self._head_size,
inner_size=self._inner_size,
dropout_rate=self._dropout_rate,
attention_dropout_rate=self._attention_dropout_rate,
norm_epsilon=1e-12,
inner_activation=self._inner_activation,
two_stream=self._two_stream,
kernel_initializer="variance_scaling",
name="layer_%d" % i))
self.output_dropout = tf_keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"num_layers":
self._num_layers,
"hidden_size":
self._hidden_size,
"num_attention_heads":
self._num_attention_heads,
"head_size":
self._head_size,
"inner_size":
self._inner_size,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"initializer":
self._initializer,
"two_stream":
self._two_stream,
"tie_attention_biases":
self._tie_attention_biases,
"memory_length":
self._memory_length,
"reuse_length":
self._reuse_length,
"inner_activation":
self._inner_activation,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
content_stream,
relative_position_encoding,
segment_matrix=None,
segment_embedding=None,
state=None,
content_attention_mask=None,
query_stream=None,
query_attention_mask=None,
target_mapping=None):
"""Implements call() for the layer.
Args:
content_stream: `Tensor`, the input content stream. This is the standard
input to Transformer XL and is commonly referred to as `h` in XLNet.
relative_position_encoding: Relative positional encoding `Tensor` of shape
`[B, L, dim]`.
segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet,
but not in Transformer XL.
segment_embedding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used
in XLNet, but not in Transformer XL.
state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of
the state or memory. If passed, this is also attended over as in
Transformer XL.
content_attention_mask: Optional `Tensor` representing the mask that is
added to content attention logits. If state is not None, the mask source
sequence dimension should extend M.
query_stream: Optional `Tensor`, the query stream. This is introduced in
`TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if
`two_stream` is `False`.
query_attention_mask: Optional `Tensor` representing the mask that is
added to query attention logits. If state is not None, the mask source
sequence dimension should extend M.
target_mapping: Optional `Tensor` representing the target mapping when
calculating query attention.
Returns:
A tuple consisting of the attention output and the list of cached memory
states.
The attention output is `content_attention` if `two_stream` is `False`,
otherwise it is `query_attention`.
"""
new_mems = []
if state is None:
state = [None] * self._num_layers
for i in range(self._num_layers):
# cache new mems
new_mems.append(
_cache_memory(content_stream, state[i],
self._memory_length, self._reuse_length))
# segment bias
if segment_matrix is None:
segment_attention_bias = None
segment_encoding = None
else:
segment_attention_bias = (self.segment_attention_bias
if self._tie_attention_biases
else self.segment_attention_bias[i])
segment_encoding = segment_embedding[i]
content_attention_bias = (self.content_attention_bias
if self._tie_attention_biases
else self.content_attention_bias[i])
positional_attention_bias = (self.positional_attention_bias
if self._tie_attention_biases
else self.positional_attention_bias[i])
transformer_xl_layer = self.transformer_xl_layers[i]
transformer_xl_output = transformer_xl_layer(
content_stream=content_stream,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
relative_position_encoding=relative_position_encoding,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
state=state[i],
content_attention_mask=content_attention_mask,
query_attention_mask=query_attention_mask,
query_stream=query_stream,
target_mapping=target_mapping)
content_stream = transformer_xl_output["content_attention"]
if self._two_stream:
query_stream = transformer_xl_output["query_attention"]
else:
query_stream = None
if self._two_stream:
output_stream = query_stream
else:
output_stream = content_stream
return output_stream, new_mems