deanna-emery's picture
updates
93528c6
raw
history blame
25.9 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based XLNet Model."""
from absl import logging
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling.layers import transformer_xl
_SEG_ID_CLS = 2
def _create_causal_attention_mask(
seq_length,
memory_length,
dtype=tf.float32,
same_length=False):
"""Creates a causal attention mask with a single-sided context.
When applying the attention mask in `MultiHeadRelativeAttention`, the
attention scores are of shape `[(batch dimensions), S, S + M]`, where:
- S = sequence length.
- M = memory length.
In a simple case where S = 2, M = 1, here is a simple illustration of the
`attention_scores` matrix, where `a` represents an attention function:
token_0 [[a(token_0, mem_0) a(token_0, token_0) a(token_0, token_1)],
token_1 [a(token_1, mem_0) a(token_1, token_0) a(token_1, token_1)]]
mem_0 token_0 token_1
For uni-directional attention, we want to mask out values in the attention
scores that represent a(token_i, token_j) where j > i. We can achieve this by
concatenating 0s (representing memory positions) with a strictly upper
triangular matrix of 1s.
We then flip the matrix values in order to match the representation where
real values are 1s.
Args:
seq_length: int, The length of each sequence.
memory_length: int, The length of memory blocks.
dtype: dtype of the mask.
same_length: bool, whether to use the same attention length for each token.
Returns:
A unidirectional attention mask of shape
`[seq_length, seq_length + memory_length]`. E.g.:
[[1. 1. 1. 0. 0. 0.]
[1. 1. 1. 1. 0. 0.]
[1. 1. 1. 1. 1. 0.]
[1. 1. 1. 1. 1. 1.]]
"""
ones_matrix = tf.ones([seq_length, seq_length], dtype=dtype)
upper_triangular = tf.linalg.band_part(ones_matrix, 0, -1)
diagonal = tf.linalg.band_part(ones_matrix, 0, 0)
padding = tf.zeros([seq_length, memory_length], dtype=dtype)
causal_attention_mask = tf.concat(
[padding, upper_triangular - diagonal], 1)
if same_length:
lower_triangular = tf.linalg.band_part(ones_matrix, -1, 0)
strictly_lower_triangular = lower_triangular - diagonal
causal_attention_mask = tf.concat(
[causal_attention_mask[:, :seq_length] + strictly_lower_triangular,
causal_attention_mask[:, seq_length:]], 1)
return 1 - causal_attention_mask
def _combine_masks(mask1, mask2, dtype, how="and"):
"""Combines two masks.
Use "and" if trying to combine two existing masks.
Use "or" if trying to flip a few positions to "real".
Args:
mask1: tf.Tensor, input mask 1
mask2: tf.Tensor, input mask 2
dtype: tf.dtype
how: Which logical operation should run.
Returns:
The combined input masks.
"""
if how == "and":
operator = tf.math.logical_and
else:
operator = tf.math.logical_or
return tf.cast(operator(
tf.cast(mask1, tf.bool),
tf.cast(mask2, tf.bool)), dtype=dtype)
def _compute_attention_mask(
input_mask,
permutation_mask,
attention_type,
seq_length,
memory_length,
batch_size,
dtype=tf.float32):
"""Combines all input attention masks for XLNet.
In XLNet modeling, `0` represents tokens that can be attended, and `1`
represents tokens that cannot be attended.
For XLNet pre-training and fine tuning, there are a few masks used:
- Causal attention mask: If the attention type is unidirectional, then all
tokens after the current position cannot be attended to.
- Input mask: when generating data, padding is added to a max sequence length
to make all sequences the same length. This masks out real tokens (`0`) from
padding tokens (`1`).
- Permutation mask: during XLNet pretraining, the input sequence is factorized
into a factorization sequence `z`. During partial prediction, `z` is split
at a cutting point `c` (an index of the factorization sequence) and
prediction is only applied to all tokens after `c`. Therefore, tokens at
factorization positions `i` > `c` can be attended to and tokens at
factorization positions `i` <= `c` cannot be attended to.
This function broadcasts and combines all attention masks to produce the
query attention mask and the content attention mask.
Args:
input_mask: Tensor, the input mask related to padding. Input shape:
`(B, S)`.
permutation_mask: Tensor, the permutation mask used in partial prediction.
Input shape: `(B, S, S)`.
attention_type: str, the attention type. Can be "uni" (directional) or
"bi" (directional).
seq_length: int, the length of each sequence.
memory_length: int the length of memory blocks.
batch_size: int, the batch size.
dtype: The dtype of the masks.
Returns:
attention_mask, content_attention_mask: The position and context-based
attention masks and content attention masks, respectively.
"""
attention_mask = None
# `1` values mean do not attend to this position.
if attention_type == "uni":
causal_attention_mask = _create_causal_attention_mask(
seq_length=seq_length,
memory_length=memory_length,
dtype=dtype)
causal_attention_mask = causal_attention_mask[None, None, :, :]
# `causal_attention_mask`: [1, 1, S, S + M]
# input_mask: [B, S]
# permutation_mask: [B, S, S]
if input_mask is not None and permutation_mask is not None:
data_mask = _combine_masks(input_mask[:, None, :], permutation_mask, dtype)
elif input_mask is not None and permutation_mask is None:
data_mask = input_mask[:, None, :]
elif input_mask is None and permutation_mask is not None:
data_mask = permutation_mask
else:
data_mask = None
# data_mask: [B, S, S] or [B, 1, S]
if data_mask is not None:
# All positions within state can be attended to.
state_mask = tf.ones([batch_size, tf.shape(data_mask)[1], memory_length],
dtype=dtype)
# state_mask: [B, 1, M] or [B, S, M]
data_mask = tf.concat([state_mask, data_mask], 2)
# data_mask: [B, 1, S + M] or [B, S, S + M]
if attention_type == "uni":
attention_mask = _combine_masks(causal_attention_mask,
data_mask[:, None, :, :],
dtype=dtype)
else:
attention_mask = data_mask[:, None, :, :]
if attention_mask is not None:
# Construct the content attention mask.
# This ensures that the mask allows the model to attend to positions in
# content positions (e.g. the content diagonal).
non_target_mask = tf.concat(
[tf.zeros([seq_length, memory_length], dtype=dtype),
tf.eye(seq_length, dtype=dtype)], axis=-1)
content_attention_mask = _combine_masks(
attention_mask, non_target_mask, how="or", dtype=dtype)
else:
content_attention_mask = None
return attention_mask, content_attention_mask
def _compute_segment_matrix(
segment_ids,
memory_length,
batch_size,
use_cls_mask):
"""Computes the segment embedding matrix.
XLNet introduced segment-based attention for attention calculations. This
extends the idea of relative encodings in Transformer XL by considering
whether or not two positions are within the same segment, rather than
which segments they come from.
This function generates a segment matrix by broadcasting provided segment IDs
in two different dimensions and checking where values are equal. This output
matrix shows `True` whenever two tokens are NOT in the same segment and
`False` whenever they are.
Args:
segment_ids: A Tensor of size `[B, S]` that represents which segment
each token belongs to.
memory_length: int, the length of memory blocks.
batch_size: int, the batch size.
use_cls_mask: bool, whether or not to introduce cls mask in
input sequences.
Returns:
A boolean Tensor of size `[B, S, S + M]`, where `True` means that two
tokens are NOT in the same segment, and `False` means they are in the same
segment.
"""
if segment_ids is None:
return None
memory_padding = tf.zeros([batch_size, memory_length],
dtype=segment_ids.dtype)
padded_segment_ids = tf.concat([memory_padding, segment_ids], 1)
# segment_ids: [B, S]
# padded_segment_ids: [B, S + M]
if use_cls_mask:
# `1` indicates not in the same segment.
# Target result: [B, S, S + M]
# segment_ids: [B, S]
# padded_segment_ids: [B, S + M]
broadcasted_segment_class_indices = (
tf.equal(segment_ids,
tf.constant([_SEG_ID_CLS]))[:, :, None])
broadcasted_padded_class_indices = (
tf.equal(
padded_segment_ids,
tf.constant([_SEG_ID_CLS]))[:, None, :])
class_index_matrix = tf.logical_or(broadcasted_segment_class_indices,
broadcasted_padded_class_indices)
segment_matrix = tf.equal(segment_ids[:, :, None],
padded_segment_ids[:, None, :])
segment_matrix = tf.logical_or(class_index_matrix, segment_matrix)
else:
# TODO(allencwang) - address this legacy mismatch from `use_cls_mask`.
segment_matrix = tf.logical_not(
tf.equal(segment_ids[:, :, None], padded_segment_ids[:, None, :]))
return segment_matrix
def _compute_positional_encoding(
attention_type,
position_encoding_layer,
hidden_size,
batch_size,
total_length,
seq_length,
clamp_length,
bi_data,
dtype=tf.float32):
"""Computes the relative position encoding.
Args:
attention_type: str, the attention type. Can be "uni" (directional) or
"bi" (directional).
position_encoding_layer: An instance of `RelativePositionEncoding`.
hidden_size: int, the hidden size.
batch_size: int, the batch size.
total_length: int, the sequence length added to the memory length.
seq_length: int, the length of each sequence.
clamp_length: int, clamp all relative distances larger than clamp_length. -1
means no clamping.
bi_data: bool, whether to use bidirectional input pipeline. Usually set to
True during pretraining and False during finetuning.
dtype: the dtype of the encoding.
Returns:
A Tensor, representing the position encoding.
"""
freq_seq = tf.range(0, hidden_size, 2.0)
if dtype is not None and dtype != tf.float32:
freq_seq = tf.cast(freq_seq, dtype=dtype)
if attention_type == "bi":
beg, end = total_length, -seq_length
elif attention_type == "uni":
beg, end = total_length, -1
else:
raise ValueError("Unknown `attention_type` {}.".format(attention_type))
if bi_data:
forward_position_sequence = tf.range(beg, end, -1.0)
backward_position_sequence = tf.range(-beg, -end, 1.0)
if dtype is not None and dtype != tf.float32:
forward_position_sequence = tf.cast(forward_position_sequence,
dtype=dtype)
backward_position_sequence = tf.cast(backward_position_sequence,
dtype=dtype)
if clamp_length > 0:
forward_position_sequence = tf.clip_by_value(
forward_position_sequence,
-clamp_length,
clamp_length)
backward_position_sequence = tf.clip_by_value(
backward_position_sequence,
-clamp_length,
clamp_length)
if batch_size is not None:
forward_positional_encoding = position_encoding_layer(
forward_position_sequence, batch_size // 2)
backward_positional_encoding = position_encoding_layer(
backward_position_sequence, batch_size // 2)
else:
forward_positional_encoding = position_encoding_layer(
forward_position_sequence, None)
backward_positional_encoding = position_encoding_layer(
backward_position_sequence, None)
relative_position_encoding = tf.concat(
[forward_positional_encoding, backward_positional_encoding], axis=0)
else:
forward_position_sequence = tf.range(beg, end, -1.0)
if dtype is not None and dtype != tf.float32:
forward_position_sequence = tf.cast(
forward_position_sequence, dtype=dtype)
if clamp_length > 0:
forward_position_sequence = tf.clip_by_value(
forward_position_sequence,
-clamp_length,
clamp_length)
relative_position_encoding = position_encoding_layer(
forward_position_sequence, batch_size)
return relative_position_encoding
class RelativePositionEncoding(tf_keras.layers.Layer):
"""Creates a relative positional encoding.
This layer creates a relative positional encoding as described in
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
Rather than an absolute position embedding as in Transformer, this
formulation represents position as the relative distance between tokens using
sinusoidal positional embeddings.
Note: This layer is currently experimental.
Attributes:
hidden_size: The dimensionality of the input embeddings.
"""
def __init__(self, hidden_size, **kwargs):
super().__init__(**kwargs)
self._hidden_size = hidden_size
self._inv_freq = 1.0 / (10000.0**(
tf.range(0, self._hidden_size, 2.0) / self._hidden_size))
def call(self, pos_seq, batch_size=None):
"""Implements call() for the layer.
Args:
pos_seq: A 1-D `Tensor`
batch_size: The optionally provided batch size that tiles the relative
positional encoding.
Returns:
The relative positional encoding of shape:
[batch_size, len(pos_seq), hidden_size] if batch_size is provided, else
[1, len(pos_seq), hidden_size].
"""
sinusoid_input = tf.einsum("i,d->id", pos_seq, self._inv_freq)
relative_position_encoding = tf.concat([tf.sin(sinusoid_input),
tf.cos(sinusoid_input)], -1)
relative_position_encoding = relative_position_encoding[None, :, :]
if batch_size is not None:
relative_position_encoding = tf.tile(relative_position_encoding,
[batch_size, 1, 1])
return relative_position_encoding
@tf_keras.utils.register_keras_serializable(package="Text")
class XLNetBase(tf_keras.layers.Layer):
"""Base XLNet model.
Attributes:
vocab_size: int, the number of tokens in vocabulary.
num_layers: int, the number of layers.
hidden_size: int, the hidden size.
num_attention_heads: int, the number of attention heads.
head_size: int, the dimension size of each attention head.
inner_size: int, the hidden size in feed-forward layers.
dropout_rate: float, dropout rate.
attention_dropout_rate: float, dropout rate on attention probabilities.
attention_type: str, "uni" or "bi".
bi_data: bool, whether to use bidirectional input pipeline. Usually set to
True during pretraining and False during finetuning.
initializer: A tf initializer.
two_stream: bool, whether or not to use `TwoStreamRelativeAttention` used
in the XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
tie_attention_biases: bool, whether or not to tie the biases together.
Usually set to `True`. Used for backwards compatibility.
memory_length: int, the number of tokens to cache.
same_length: bool, whether to use the same attention length for each
token.
clamp_length: int, clamp all relative distances larger than clamp_length. -1
means no clamping.
reuse_length: int, the number of tokens in the currect batch to be cached
and reused in the future.
inner_activation: str, "relu" or "gelu".
use_cls_mask: bool, whether or not cls mask is included in the
input sequences.
embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized
into two matrices in the shape of ["vocab_size", "embedding_width"] and
["embedding_width", "hidden_size"] ("embedding_width" is usually much
smaller than "hidden_size").
embedding_layer: The word embedding layer. `None` means we will create a
new embedding layer. Otherwise, we will reuse the given embedding layer.
This parameter is originally added for ELECTRA model which needs to tie
the generator embeddings with the discriminator embeddings.
"""
def __init__(self,
vocab_size,
num_layers,
hidden_size,
num_attention_heads,
head_size,
inner_size,
dropout_rate,
attention_dropout_rate,
attention_type,
bi_data,
initializer,
two_stream=False,
tie_attention_biases=True,
memory_length=None,
clamp_length=-1,
reuse_length=None,
inner_activation="relu",
use_cls_mask=False,
embedding_width=None,
**kwargs):
super().__init__(**kwargs)
self._vocab_size = vocab_size
self._initializer = initializer
self._attention_type = attention_type
self._num_layers = num_layers
self._hidden_size = hidden_size
self._num_attention_heads = num_attention_heads
self._head_size = head_size
self._inner_size = inner_size
self._inner_activation = inner_activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._tie_attention_biases = tie_attention_biases
self._two_stream = two_stream
self._memory_length = memory_length
self._reuse_length = reuse_length
self._bi_data = bi_data
self._clamp_length = clamp_length
self._use_cls_mask = use_cls_mask
self._segment_embedding = None
self._mask_embedding = None
self._embedding_width = embedding_width
if embedding_width is None:
embedding_width = hidden_size
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=self._vocab_size,
embedding_width=embedding_width,
initializer=tf_utils.clone_initializer(self._initializer),
dtype=tf.float32,
name="word_embedding")
self._dropout = tf_keras.layers.Dropout(rate=self._dropout_rate)
self.embedding_dropout = tf_keras.layers.Dropout(rate=self._dropout_rate)
self.position_encoding = RelativePositionEncoding(self._hidden_size)
self._transformer_xl = transformer_xl.TransformerXL(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
head_size=head_size,
inner_size=inner_size,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
initializer=initializer,
two_stream=two_stream,
tie_attention_biases=tie_attention_biases,
memory_length=memory_length,
reuse_length=reuse_length,
inner_activation=inner_activation,
name="transformer_xl")
def get_config(self):
config = {
"vocab_size":
self._vocab_size,
"num_layers":
self._num_layers,
"hidden_size":
self._hidden_size,
"num_attention_heads":
self._num_attention_heads,
"head_size":
self._head_size,
"inner_size":
self._inner_size,
"dropout_rate":
self._dropout_rate,
"attention_dropout_rate":
self._attention_dropout_rate,
"attention_type":
self._attention_type,
"bi_data":
self._bi_data,
"initializer":
self._initializer,
"two_stream":
self._two_stream,
"tie_attention_biases":
self._tie_attention_biases,
"memory_length":
self._memory_length,
"clamp_length":
self._clamp_length,
"reuse_length":
self._reuse_length,
"inner_activation":
self._inner_activation,
"use_cls_mask":
self._use_cls_mask,
"embedding_width":
self._embedding_width,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def get_embedding_lookup_table(self):
"""Returns the embedding layer weights."""
return self._embedding_layer.embeddings
def __call__(self,
input_ids,
segment_ids=None,
input_mask=None,
state=None,
permutation_mask=None,
target_mapping=None,
masked_tokens=None,
**kwargs):
# Uses dict to feed inputs into call() in order to keep state as a python
# list.
inputs = {
"input_ids": input_ids,
"segment_ids": segment_ids,
"input_mask": input_mask,
"state": state,
"permutation_mask": permutation_mask,
"target_mapping": target_mapping,
"masked_tokens": masked_tokens
}
return super().__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
input_ids = inputs["input_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
state = inputs["state"]
permutation_mask = inputs["permutation_mask"]
target_mapping = inputs["target_mapping"]
masked_tokens = inputs["masked_tokens"]
batch_size = tf.shape(input_ids)[0]
seq_length = tf.shape(input_ids)[1]
if state is not None:
memory_length = tf.shape(state[0])[1]
else:
memory_length = 0
total_length = memory_length + seq_length
if self._two_stream and masked_tokens is None:
raise ValueError("`masked_tokens` must be provided in order to "
"initialize the query stream in "
"`TwoStreamRelativeAttention`.")
if masked_tokens is not None and not self._two_stream:
logging.warning("`masked_tokens` is provided but `two_stream` is not "
"enabled. Please enable `two_stream` to enable two "
"stream attention.")
if input_mask is not None:
dtype = input_mask.dtype
elif permutation_mask is not None:
dtype = permutation_mask.dtype
else:
dtype = tf.int32
query_attention_mask, content_attention_mask = _compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
attention_type=self._attention_type,
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=dtype)
relative_position_encoding = _compute_positional_encoding(
attention_type=self._attention_type,
position_encoding_layer=self.position_encoding,
hidden_size=self._hidden_size,
batch_size=batch_size,
total_length=total_length,
seq_length=seq_length,
clamp_length=self._clamp_length,
bi_data=self._bi_data,
dtype=tf.float32)
relative_position_encoding = self.embedding_dropout(
relative_position_encoding)
if segment_ids is None:
segment_embedding = None
segment_matrix = None
else:
if self._segment_embedding is None:
self._segment_embedding = self.add_weight(
"seg_embed",
shape=[self._num_layers, 2, self._num_attention_heads,
self._head_size],
dtype=tf.float32,
initializer=tf_utils.clone_initializer(self._initializer))
segment_embedding = self._segment_embedding
segment_matrix = _compute_segment_matrix(
segment_ids=segment_ids,
memory_length=memory_length,
batch_size=batch_size,
use_cls_mask=self._use_cls_mask)
word_embeddings = self._embedding_layer(input_ids)
content_stream = self._dropout(word_embeddings)
if self._two_stream:
if self._mask_embedding is None:
self._mask_embedding = self.add_weight(
"mask_emb/mask_emb",
shape=[1, 1, self._hidden_size],
dtype=tf.float32)
if target_mapping is None:
masked_tokens = masked_tokens[:, :, None]
masked_token_embedding = (
masked_tokens * self._mask_embedding +
(1 - masked_tokens) * word_embeddings)
else:
masked_token_embedding = tf.tile(
self._mask_embedding,
[batch_size, tf.shape(target_mapping)[1], 1])
query_stream = self._dropout(masked_token_embedding)
else:
query_stream = None
return self._transformer_xl(
content_stream=content_stream,
query_stream=query_stream,
target_mapping=target_mapping,
state=state,
relative_position_encoding=relative_position_encoding,
segment_matrix=segment_matrix,
segment_embedding=segment_embedding,
content_attention_mask=content_attention_mask,
query_attention_mask=query_attention_mask)