Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Keras-based XLNet Model.""" | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.modeling import tf_utils | |
from official.nlp.modeling import layers | |
from official.nlp.modeling.layers import transformer_xl | |
_SEG_ID_CLS = 2 | |
def _create_causal_attention_mask( | |
seq_length, | |
memory_length, | |
dtype=tf.float32, | |
same_length=False): | |
"""Creates a causal attention mask with a single-sided context. | |
When applying the attention mask in `MultiHeadRelativeAttention`, the | |
attention scores are of shape `[(batch dimensions), S, S + M]`, where: | |
- S = sequence length. | |
- M = memory length. | |
In a simple case where S = 2, M = 1, here is a simple illustration of the | |
`attention_scores` matrix, where `a` represents an attention function: | |
token_0 [[a(token_0, mem_0) a(token_0, token_0) a(token_0, token_1)], | |
token_1 [a(token_1, mem_0) a(token_1, token_0) a(token_1, token_1)]] | |
mem_0 token_0 token_1 | |
For uni-directional attention, we want to mask out values in the attention | |
scores that represent a(token_i, token_j) where j > i. We can achieve this by | |
concatenating 0s (representing memory positions) with a strictly upper | |
triangular matrix of 1s. | |
We then flip the matrix values in order to match the representation where | |
real values are 1s. | |
Args: | |
seq_length: int, The length of each sequence. | |
memory_length: int, The length of memory blocks. | |
dtype: dtype of the mask. | |
same_length: bool, whether to use the same attention length for each token. | |
Returns: | |
A unidirectional attention mask of shape | |
`[seq_length, seq_length + memory_length]`. E.g.: | |
[[1. 1. 1. 0. 0. 0.] | |
[1. 1. 1. 1. 0. 0.] | |
[1. 1. 1. 1. 1. 0.] | |
[1. 1. 1. 1. 1. 1.]] | |
""" | |
ones_matrix = tf.ones([seq_length, seq_length], dtype=dtype) | |
upper_triangular = tf.linalg.band_part(ones_matrix, 0, -1) | |
diagonal = tf.linalg.band_part(ones_matrix, 0, 0) | |
padding = tf.zeros([seq_length, memory_length], dtype=dtype) | |
causal_attention_mask = tf.concat( | |
[padding, upper_triangular - diagonal], 1) | |
if same_length: | |
lower_triangular = tf.linalg.band_part(ones_matrix, -1, 0) | |
strictly_lower_triangular = lower_triangular - diagonal | |
causal_attention_mask = tf.concat( | |
[causal_attention_mask[:, :seq_length] + strictly_lower_triangular, | |
causal_attention_mask[:, seq_length:]], 1) | |
return 1 - causal_attention_mask | |
def _combine_masks(mask1, mask2, dtype, how="and"): | |
"""Combines two masks. | |
Use "and" if trying to combine two existing masks. | |
Use "or" if trying to flip a few positions to "real". | |
Args: | |
mask1: tf.Tensor, input mask 1 | |
mask2: tf.Tensor, input mask 2 | |
dtype: tf.dtype | |
how: Which logical operation should run. | |
Returns: | |
The combined input masks. | |
""" | |
if how == "and": | |
operator = tf.math.logical_and | |
else: | |
operator = tf.math.logical_or | |
return tf.cast(operator( | |
tf.cast(mask1, tf.bool), | |
tf.cast(mask2, tf.bool)), dtype=dtype) | |
def _compute_attention_mask( | |
input_mask, | |
permutation_mask, | |
attention_type, | |
seq_length, | |
memory_length, | |
batch_size, | |
dtype=tf.float32): | |
"""Combines all input attention masks for XLNet. | |
In XLNet modeling, `0` represents tokens that can be attended, and `1` | |
represents tokens that cannot be attended. | |
For XLNet pre-training and fine tuning, there are a few masks used: | |
- Causal attention mask: If the attention type is unidirectional, then all | |
tokens after the current position cannot be attended to. | |
- Input mask: when generating data, padding is added to a max sequence length | |
to make all sequences the same length. This masks out real tokens (`0`) from | |
padding tokens (`1`). | |
- Permutation mask: during XLNet pretraining, the input sequence is factorized | |
into a factorization sequence `z`. During partial prediction, `z` is split | |
at a cutting point `c` (an index of the factorization sequence) and | |
prediction is only applied to all tokens after `c`. Therefore, tokens at | |
factorization positions `i` > `c` can be attended to and tokens at | |
factorization positions `i` <= `c` cannot be attended to. | |
This function broadcasts and combines all attention masks to produce the | |
query attention mask and the content attention mask. | |
Args: | |
input_mask: Tensor, the input mask related to padding. Input shape: | |
`(B, S)`. | |
permutation_mask: Tensor, the permutation mask used in partial prediction. | |
Input shape: `(B, S, S)`. | |
attention_type: str, the attention type. Can be "uni" (directional) or | |
"bi" (directional). | |
seq_length: int, the length of each sequence. | |
memory_length: int the length of memory blocks. | |
batch_size: int, the batch size. | |
dtype: The dtype of the masks. | |
Returns: | |
attention_mask, content_attention_mask: The position and context-based | |
attention masks and content attention masks, respectively. | |
""" | |
attention_mask = None | |
# `1` values mean do not attend to this position. | |
if attention_type == "uni": | |
causal_attention_mask = _create_causal_attention_mask( | |
seq_length=seq_length, | |
memory_length=memory_length, | |
dtype=dtype) | |
causal_attention_mask = causal_attention_mask[None, None, :, :] | |
# `causal_attention_mask`: [1, 1, S, S + M] | |
# input_mask: [B, S] | |
# permutation_mask: [B, S, S] | |
if input_mask is not None and permutation_mask is not None: | |
data_mask = _combine_masks(input_mask[:, None, :], permutation_mask, dtype) | |
elif input_mask is not None and permutation_mask is None: | |
data_mask = input_mask[:, None, :] | |
elif input_mask is None and permutation_mask is not None: | |
data_mask = permutation_mask | |
else: | |
data_mask = None | |
# data_mask: [B, S, S] or [B, 1, S] | |
if data_mask is not None: | |
# All positions within state can be attended to. | |
state_mask = tf.ones([batch_size, tf.shape(data_mask)[1], memory_length], | |
dtype=dtype) | |
# state_mask: [B, 1, M] or [B, S, M] | |
data_mask = tf.concat([state_mask, data_mask], 2) | |
# data_mask: [B, 1, S + M] or [B, S, S + M] | |
if attention_type == "uni": | |
attention_mask = _combine_masks(causal_attention_mask, | |
data_mask[:, None, :, :], | |
dtype=dtype) | |
else: | |
attention_mask = data_mask[:, None, :, :] | |
if attention_mask is not None: | |
# Construct the content attention mask. | |
# This ensures that the mask allows the model to attend to positions in | |
# content positions (e.g. the content diagonal). | |
non_target_mask = tf.concat( | |
[tf.zeros([seq_length, memory_length], dtype=dtype), | |
tf.eye(seq_length, dtype=dtype)], axis=-1) | |
content_attention_mask = _combine_masks( | |
attention_mask, non_target_mask, how="or", dtype=dtype) | |
else: | |
content_attention_mask = None | |
return attention_mask, content_attention_mask | |
def _compute_segment_matrix( | |
segment_ids, | |
memory_length, | |
batch_size, | |
use_cls_mask): | |
"""Computes the segment embedding matrix. | |
XLNet introduced segment-based attention for attention calculations. This | |
extends the idea of relative encodings in Transformer XL by considering | |
whether or not two positions are within the same segment, rather than | |
which segments they come from. | |
This function generates a segment matrix by broadcasting provided segment IDs | |
in two different dimensions and checking where values are equal. This output | |
matrix shows `True` whenever two tokens are NOT in the same segment and | |
`False` whenever they are. | |
Args: | |
segment_ids: A Tensor of size `[B, S]` that represents which segment | |
each token belongs to. | |
memory_length: int, the length of memory blocks. | |
batch_size: int, the batch size. | |
use_cls_mask: bool, whether or not to introduce cls mask in | |
input sequences. | |
Returns: | |
A boolean Tensor of size `[B, S, S + M]`, where `True` means that two | |
tokens are NOT in the same segment, and `False` means they are in the same | |
segment. | |
""" | |
if segment_ids is None: | |
return None | |
memory_padding = tf.zeros([batch_size, memory_length], | |
dtype=segment_ids.dtype) | |
padded_segment_ids = tf.concat([memory_padding, segment_ids], 1) | |
# segment_ids: [B, S] | |
# padded_segment_ids: [B, S + M] | |
if use_cls_mask: | |
# `1` indicates not in the same segment. | |
# Target result: [B, S, S + M] | |
# segment_ids: [B, S] | |
# padded_segment_ids: [B, S + M] | |
broadcasted_segment_class_indices = ( | |
tf.equal(segment_ids, | |
tf.constant([_SEG_ID_CLS]))[:, :, None]) | |
broadcasted_padded_class_indices = ( | |
tf.equal( | |
padded_segment_ids, | |
tf.constant([_SEG_ID_CLS]))[:, None, :]) | |
class_index_matrix = tf.logical_or(broadcasted_segment_class_indices, | |
broadcasted_padded_class_indices) | |
segment_matrix = tf.equal(segment_ids[:, :, None], | |
padded_segment_ids[:, None, :]) | |
segment_matrix = tf.logical_or(class_index_matrix, segment_matrix) | |
else: | |
# TODO(allencwang) - address this legacy mismatch from `use_cls_mask`. | |
segment_matrix = tf.logical_not( | |
tf.equal(segment_ids[:, :, None], padded_segment_ids[:, None, :])) | |
return segment_matrix | |
def _compute_positional_encoding( | |
attention_type, | |
position_encoding_layer, | |
hidden_size, | |
batch_size, | |
total_length, | |
seq_length, | |
clamp_length, | |
bi_data, | |
dtype=tf.float32): | |
"""Computes the relative position encoding. | |
Args: | |
attention_type: str, the attention type. Can be "uni" (directional) or | |
"bi" (directional). | |
position_encoding_layer: An instance of `RelativePositionEncoding`. | |
hidden_size: int, the hidden size. | |
batch_size: int, the batch size. | |
total_length: int, the sequence length added to the memory length. | |
seq_length: int, the length of each sequence. | |
clamp_length: int, clamp all relative distances larger than clamp_length. -1 | |
means no clamping. | |
bi_data: bool, whether to use bidirectional input pipeline. Usually set to | |
True during pretraining and False during finetuning. | |
dtype: the dtype of the encoding. | |
Returns: | |
A Tensor, representing the position encoding. | |
""" | |
freq_seq = tf.range(0, hidden_size, 2.0) | |
if dtype is not None and dtype != tf.float32: | |
freq_seq = tf.cast(freq_seq, dtype=dtype) | |
if attention_type == "bi": | |
beg, end = total_length, -seq_length | |
elif attention_type == "uni": | |
beg, end = total_length, -1 | |
else: | |
raise ValueError("Unknown `attention_type` {}.".format(attention_type)) | |
if bi_data: | |
forward_position_sequence = tf.range(beg, end, -1.0) | |
backward_position_sequence = tf.range(-beg, -end, 1.0) | |
if dtype is not None and dtype != tf.float32: | |
forward_position_sequence = tf.cast(forward_position_sequence, | |
dtype=dtype) | |
backward_position_sequence = tf.cast(backward_position_sequence, | |
dtype=dtype) | |
if clamp_length > 0: | |
forward_position_sequence = tf.clip_by_value( | |
forward_position_sequence, | |
-clamp_length, | |
clamp_length) | |
backward_position_sequence = tf.clip_by_value( | |
backward_position_sequence, | |
-clamp_length, | |
clamp_length) | |
if batch_size is not None: | |
forward_positional_encoding = position_encoding_layer( | |
forward_position_sequence, batch_size // 2) | |
backward_positional_encoding = position_encoding_layer( | |
backward_position_sequence, batch_size // 2) | |
else: | |
forward_positional_encoding = position_encoding_layer( | |
forward_position_sequence, None) | |
backward_positional_encoding = position_encoding_layer( | |
backward_position_sequence, None) | |
relative_position_encoding = tf.concat( | |
[forward_positional_encoding, backward_positional_encoding], axis=0) | |
else: | |
forward_position_sequence = tf.range(beg, end, -1.0) | |
if dtype is not None and dtype != tf.float32: | |
forward_position_sequence = tf.cast( | |
forward_position_sequence, dtype=dtype) | |
if clamp_length > 0: | |
forward_position_sequence = tf.clip_by_value( | |
forward_position_sequence, | |
-clamp_length, | |
clamp_length) | |
relative_position_encoding = position_encoding_layer( | |
forward_position_sequence, batch_size) | |
return relative_position_encoding | |
class RelativePositionEncoding(tf_keras.layers.Layer): | |
"""Creates a relative positional encoding. | |
This layer creates a relative positional encoding as described in | |
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" | |
(https://arxiv.org/abs/1901.02860). | |
Rather than an absolute position embedding as in Transformer, this | |
formulation represents position as the relative distance between tokens using | |
sinusoidal positional embeddings. | |
Note: This layer is currently experimental. | |
Attributes: | |
hidden_size: The dimensionality of the input embeddings. | |
""" | |
def __init__(self, hidden_size, **kwargs): | |
super().__init__(**kwargs) | |
self._hidden_size = hidden_size | |
self._inv_freq = 1.0 / (10000.0**( | |
tf.range(0, self._hidden_size, 2.0) / self._hidden_size)) | |
def call(self, pos_seq, batch_size=None): | |
"""Implements call() for the layer. | |
Args: | |
pos_seq: A 1-D `Tensor` | |
batch_size: The optionally provided batch size that tiles the relative | |
positional encoding. | |
Returns: | |
The relative positional encoding of shape: | |
[batch_size, len(pos_seq), hidden_size] if batch_size is provided, else | |
[1, len(pos_seq), hidden_size]. | |
""" | |
sinusoid_input = tf.einsum("i,d->id", pos_seq, self._inv_freq) | |
relative_position_encoding = tf.concat([tf.sin(sinusoid_input), | |
tf.cos(sinusoid_input)], -1) | |
relative_position_encoding = relative_position_encoding[None, :, :] | |
if batch_size is not None: | |
relative_position_encoding = tf.tile(relative_position_encoding, | |
[batch_size, 1, 1]) | |
return relative_position_encoding | |
class XLNetBase(tf_keras.layers.Layer): | |
"""Base XLNet model. | |
Attributes: | |
vocab_size: int, the number of tokens in vocabulary. | |
num_layers: int, the number of layers. | |
hidden_size: int, the hidden size. | |
num_attention_heads: int, the number of attention heads. | |
head_size: int, the dimension size of each attention head. | |
inner_size: int, the hidden size in feed-forward layers. | |
dropout_rate: float, dropout rate. | |
attention_dropout_rate: float, dropout rate on attention probabilities. | |
attention_type: str, "uni" or "bi". | |
bi_data: bool, whether to use bidirectional input pipeline. Usually set to | |
True during pretraining and False during finetuning. | |
initializer: A tf initializer. | |
two_stream: bool, whether or not to use `TwoStreamRelativeAttention` used | |
in the XLNet pretrainer. If `False`, then it will use | |
`MultiHeadRelativeAttention` as in Transformer XL. | |
tie_attention_biases: bool, whether or not to tie the biases together. | |
Usually set to `True`. Used for backwards compatibility. | |
memory_length: int, the number of tokens to cache. | |
same_length: bool, whether to use the same attention length for each | |
token. | |
clamp_length: int, clamp all relative distances larger than clamp_length. -1 | |
means no clamping. | |
reuse_length: int, the number of tokens in the currect batch to be cached | |
and reused in the future. | |
inner_activation: str, "relu" or "gelu". | |
use_cls_mask: bool, whether or not cls mask is included in the | |
input sequences. | |
embedding_width: The width of the word embeddings. If the embedding width | |
is not equal to hidden size, embedding parameters will be factorized | |
into two matrices in the shape of ["vocab_size", "embedding_width"] and | |
["embedding_width", "hidden_size"] ("embedding_width" is usually much | |
smaller than "hidden_size"). | |
embedding_layer: The word embedding layer. `None` means we will create a | |
new embedding layer. Otherwise, we will reuse the given embedding layer. | |
This parameter is originally added for ELECTRA model which needs to tie | |
the generator embeddings with the discriminator embeddings. | |
""" | |
def __init__(self, | |
vocab_size, | |
num_layers, | |
hidden_size, | |
num_attention_heads, | |
head_size, | |
inner_size, | |
dropout_rate, | |
attention_dropout_rate, | |
attention_type, | |
bi_data, | |
initializer, | |
two_stream=False, | |
tie_attention_biases=True, | |
memory_length=None, | |
clamp_length=-1, | |
reuse_length=None, | |
inner_activation="relu", | |
use_cls_mask=False, | |
embedding_width=None, | |
**kwargs): | |
super().__init__(**kwargs) | |
self._vocab_size = vocab_size | |
self._initializer = initializer | |
self._attention_type = attention_type | |
self._num_layers = num_layers | |
self._hidden_size = hidden_size | |
self._num_attention_heads = num_attention_heads | |
self._head_size = head_size | |
self._inner_size = inner_size | |
self._inner_activation = inner_activation | |
self._dropout_rate = dropout_rate | |
self._attention_dropout_rate = attention_dropout_rate | |
self._tie_attention_biases = tie_attention_biases | |
self._two_stream = two_stream | |
self._memory_length = memory_length | |
self._reuse_length = reuse_length | |
self._bi_data = bi_data | |
self._clamp_length = clamp_length | |
self._use_cls_mask = use_cls_mask | |
self._segment_embedding = None | |
self._mask_embedding = None | |
self._embedding_width = embedding_width | |
if embedding_width is None: | |
embedding_width = hidden_size | |
self._embedding_layer = layers.OnDeviceEmbedding( | |
vocab_size=self._vocab_size, | |
embedding_width=embedding_width, | |
initializer=tf_utils.clone_initializer(self._initializer), | |
dtype=tf.float32, | |
name="word_embedding") | |
self._dropout = tf_keras.layers.Dropout(rate=self._dropout_rate) | |
self.embedding_dropout = tf_keras.layers.Dropout(rate=self._dropout_rate) | |
self.position_encoding = RelativePositionEncoding(self._hidden_size) | |
self._transformer_xl = transformer_xl.TransformerXL( | |
vocab_size=vocab_size, | |
num_layers=num_layers, | |
hidden_size=hidden_size, | |
num_attention_heads=num_attention_heads, | |
head_size=head_size, | |
inner_size=inner_size, | |
dropout_rate=dropout_rate, | |
attention_dropout_rate=attention_dropout_rate, | |
initializer=initializer, | |
two_stream=two_stream, | |
tie_attention_biases=tie_attention_biases, | |
memory_length=memory_length, | |
reuse_length=reuse_length, | |
inner_activation=inner_activation, | |
name="transformer_xl") | |
def get_config(self): | |
config = { | |
"vocab_size": | |
self._vocab_size, | |
"num_layers": | |
self._num_layers, | |
"hidden_size": | |
self._hidden_size, | |
"num_attention_heads": | |
self._num_attention_heads, | |
"head_size": | |
self._head_size, | |
"inner_size": | |
self._inner_size, | |
"dropout_rate": | |
self._dropout_rate, | |
"attention_dropout_rate": | |
self._attention_dropout_rate, | |
"attention_type": | |
self._attention_type, | |
"bi_data": | |
self._bi_data, | |
"initializer": | |
self._initializer, | |
"two_stream": | |
self._two_stream, | |
"tie_attention_biases": | |
self._tie_attention_biases, | |
"memory_length": | |
self._memory_length, | |
"clamp_length": | |
self._clamp_length, | |
"reuse_length": | |
self._reuse_length, | |
"inner_activation": | |
self._inner_activation, | |
"use_cls_mask": | |
self._use_cls_mask, | |
"embedding_width": | |
self._embedding_width, | |
} | |
base_config = super().get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def get_embedding_lookup_table(self): | |
"""Returns the embedding layer weights.""" | |
return self._embedding_layer.embeddings | |
def __call__(self, | |
input_ids, | |
segment_ids=None, | |
input_mask=None, | |
state=None, | |
permutation_mask=None, | |
target_mapping=None, | |
masked_tokens=None, | |
**kwargs): | |
# Uses dict to feed inputs into call() in order to keep state as a python | |
# list. | |
inputs = { | |
"input_ids": input_ids, | |
"segment_ids": segment_ids, | |
"input_mask": input_mask, | |
"state": state, | |
"permutation_mask": permutation_mask, | |
"target_mapping": target_mapping, | |
"masked_tokens": masked_tokens | |
} | |
return super().__call__(inputs, **kwargs) | |
def call(self, inputs): | |
"""Implements call() for the layer.""" | |
input_ids = inputs["input_ids"] | |
segment_ids = inputs["segment_ids"] | |
input_mask = inputs["input_mask"] | |
state = inputs["state"] | |
permutation_mask = inputs["permutation_mask"] | |
target_mapping = inputs["target_mapping"] | |
masked_tokens = inputs["masked_tokens"] | |
batch_size = tf.shape(input_ids)[0] | |
seq_length = tf.shape(input_ids)[1] | |
if state is not None: | |
memory_length = tf.shape(state[0])[1] | |
else: | |
memory_length = 0 | |
total_length = memory_length + seq_length | |
if self._two_stream and masked_tokens is None: | |
raise ValueError("`masked_tokens` must be provided in order to " | |
"initialize the query stream in " | |
"`TwoStreamRelativeAttention`.") | |
if masked_tokens is not None and not self._two_stream: | |
logging.warning("`masked_tokens` is provided but `two_stream` is not " | |
"enabled. Please enable `two_stream` to enable two " | |
"stream attention.") | |
if input_mask is not None: | |
dtype = input_mask.dtype | |
elif permutation_mask is not None: | |
dtype = permutation_mask.dtype | |
else: | |
dtype = tf.int32 | |
query_attention_mask, content_attention_mask = _compute_attention_mask( | |
input_mask=input_mask, | |
permutation_mask=permutation_mask, | |
attention_type=self._attention_type, | |
seq_length=seq_length, | |
memory_length=memory_length, | |
batch_size=batch_size, | |
dtype=dtype) | |
relative_position_encoding = _compute_positional_encoding( | |
attention_type=self._attention_type, | |
position_encoding_layer=self.position_encoding, | |
hidden_size=self._hidden_size, | |
batch_size=batch_size, | |
total_length=total_length, | |
seq_length=seq_length, | |
clamp_length=self._clamp_length, | |
bi_data=self._bi_data, | |
dtype=tf.float32) | |
relative_position_encoding = self.embedding_dropout( | |
relative_position_encoding) | |
if segment_ids is None: | |
segment_embedding = None | |
segment_matrix = None | |
else: | |
if self._segment_embedding is None: | |
self._segment_embedding = self.add_weight( | |
"seg_embed", | |
shape=[self._num_layers, 2, self._num_attention_heads, | |
self._head_size], | |
dtype=tf.float32, | |
initializer=tf_utils.clone_initializer(self._initializer)) | |
segment_embedding = self._segment_embedding | |
segment_matrix = _compute_segment_matrix( | |
segment_ids=segment_ids, | |
memory_length=memory_length, | |
batch_size=batch_size, | |
use_cls_mask=self._use_cls_mask) | |
word_embeddings = self._embedding_layer(input_ids) | |
content_stream = self._dropout(word_embeddings) | |
if self._two_stream: | |
if self._mask_embedding is None: | |
self._mask_embedding = self.add_weight( | |
"mask_emb/mask_emb", | |
shape=[1, 1, self._hidden_size], | |
dtype=tf.float32) | |
if target_mapping is None: | |
masked_tokens = masked_tokens[:, :, None] | |
masked_token_embedding = ( | |
masked_tokens * self._mask_embedding + | |
(1 - masked_tokens) * word_embeddings) | |
else: | |
masked_token_embedding = tf.tile( | |
self._mask_embedding, | |
[batch_size, tf.shape(target_mapping)[1], 1]) | |
query_stream = self._dropout(masked_token_embedding) | |
else: | |
query_stream = None | |
return self._transformer_xl( | |
content_stream=content_stream, | |
query_stream=query_stream, | |
target_mapping=target_mapping, | |
state=state, | |
relative_position_encoding=relative_position_encoding, | |
segment_matrix=segment_matrix, | |
segment_embedding=segment_embedding, | |
content_attention_mask=content_attention_mask, | |
query_attention_mask=query_attention_mask) | |