# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras-based Transformer XL layer.""" from absl import logging import tensorflow as tf, tf_keras from official.modeling import tf_utils from official.nlp.modeling.layers import relative_attention def _cache_memory(current_state, previous_state, memory_length, reuse_length=0): """Caches hidden states into memory. Args: current_state: `Tensor`, the current state. previous_state: `Tensor`, the previous state. memory_length: `int`, the number of tokens to cache. reuse_length: `int`, the number of tokens in the current batch to be cached and reused in the future. Returns: A `Tensor`, representing the cached state with stopped gradients. """ if memory_length is None or memory_length == 0: return None else: if reuse_length > 0: current_state = current_state[:, :reuse_length, :] if previous_state is None: new_mem = current_state[:, -memory_length:, :] else: new_mem = tf.concat( [previous_state, current_state], 1)[:, -memory_length:, :] return tf.stop_gradient(new_mem) @tf_keras.utils.register_keras_serializable(package="Text") class TransformerXLBlock(tf_keras.layers.Layer): """Transformer XL block. This implements a Transformer XL block from "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" (https://arxiv.org/abs/1901.02860). This block is further extended to allow for the Transformer-XL re-parameterization in "XLNet: Generalized Autoregressive Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237). Given an input stream, this block computes attention, applies dropouts and layer norms and feeds into the FFN network. **Note: This layer is currently experimental. Attributes: vocab_size: The size of the token vocabulary. hidden_size: The size of the transformer hidden layers. num_attention_heads: The number of attention heads. head_size: The dimension size of each attention head. inner_size: The inner size for the transformer layers. dropout_rate: Dropout rate for the output of this layer. attention_dropout_rate: Dropout rate on attention probabilities. two_stream: Whether or not to use `TwoStreamRelativeAttention` used in the XLNet pretrainer. If `False`, then it will use `MultiHeadRelativeAttention` as in Transformer XL. norm_epsilon: Epsilon value to initialize normalization layers. inner_activation: The activation to use for the inner FFN layers. kernel_initializer: Initializer for dense layer kernels. inner_dropout: Dropout probability for the inner dropout layer. """ def __init__(self, vocab_size, hidden_size, num_attention_heads, head_size, inner_size, dropout_rate, attention_dropout_rate, two_stream=False, norm_epsilon=1e-12, inner_activation="relu", kernel_initializer="variance_scaling", inner_dropout=0.0, **kwargs): """Initializes TransformerXLBlock layer.""" super().__init__(**kwargs) self._vocab_size = vocab_size self._num_heads = num_attention_heads self._head_size = head_size self._hidden_size = hidden_size self._inner_size = inner_size self._dropout_rate = dropout_rate self._attention_dropout_rate = attention_dropout_rate self._inner_activation = inner_activation self._norm_epsilon = norm_epsilon self._kernel_initializer = kernel_initializer self._inner_dropout = inner_dropout self._two_stream = two_stream if two_stream: self._attention_layer_type = relative_attention.TwoStreamRelativeAttention else: self._attention_layer_type = relative_attention.MultiHeadRelativeAttention def build(self, input_shape): input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape input_tensor_shape = tf.TensorShape(input_tensor) if len(input_tensor_shape.as_list()) != 3: raise ValueError("TransformerLayer expects a three-dimensional input of " "shape [batch, sequence, width].") batch_size, sequence_length, hidden_size = input_tensor_shape if len(input_shape) == 2: mask_tensor_shape = tf.TensorShape(input_shape[1]) expected_mask_tensor_shape = tf.TensorShape( [batch_size, sequence_length, sequence_length]) if not expected_mask_tensor_shape.is_compatible_with(mask_tensor_shape): raise ValueError("When passing a mask tensor to TransformerXLBlock, " "the mask tensor must be of shape [batch, " "sequence_length, sequence_length] (here %s). Got a " "mask tensor of shape %s." % (expected_mask_tensor_shape, mask_tensor_shape)) if hidden_size % self._num_heads != 0: raise ValueError( "The input size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, self._num_heads)) self._attention_layer = self._attention_layer_type( num_heads=self._num_heads, key_dim=self._head_size, value_dim=self._head_size, dropout=self._attention_dropout_rate, use_bias=False, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), name="rel_attn") self._attention_dropout = tf_keras.layers.Dropout( rate=self._attention_dropout_rate) self._attention_layer_norm = tf_keras.layers.LayerNormalization( name="self_attention_layer_norm", axis=-1, epsilon=self._norm_epsilon, dtype=tf.float32) self._inner_dense = tf_keras.layers.EinsumDense( "abc,cd->abd", output_shape=(None, self._inner_size), bias_axes="d", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), name="inner") self._inner_activation_layer = tf_keras.layers.Activation( self._inner_activation) self._inner_dropout_layer = tf_keras.layers.Dropout( rate=self._inner_dropout) self._output_dense = tf_keras.layers.EinsumDense( "abc,cd->abd", output_shape=(None, hidden_size), bias_axes="d", name="output", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer)) self._output_dropout = tf_keras.layers.Dropout(rate=self._dropout_rate) self._output_layer_norm = tf_keras.layers.LayerNormalization( name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon) super().build(input_shape) def get_config(self): config = { "vocab_size": self._vocab_size, "hidden_size": self._hidden_size, "num_attention_heads": self._num_heads, "head_size": self._head_size, "inner_size": self._inner_size, "dropout_rate": self._dropout_rate, "attention_dropout_rate": self._attention_dropout_rate, "two_stream": self._two_stream, "norm_epsilon": self._norm_epsilon, "inner_activation": self._inner_activation, "kernel_initializer": self._kernel_initializer, "inner_dropout": self._inner_dropout, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) def call(self, content_stream, content_attention_bias, positional_attention_bias, relative_position_encoding=None, segment_matrix=None, segment_encoding=None, segment_attention_bias=None, state=None, content_attention_mask=None, query_stream=None, query_attention_mask=None, target_mapping=None): """Implements `call` for the Layer. Args: content_stream: `Tensor`, the input content stream. This is the standard input to Transformer XL and is commonly referred to as `h` in XLNet. content_attention_bias: Bias `Tensor` for content based attention of shape `[num_heads, dim]`. positional_attention_bias: Bias `Tensor` for position based attention of shape `[num_heads, dim]`. relative_position_encoding: Relative positional encoding `Tensor` of shape `[B, L, dim]`. segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet, but not in Transformer XL. segment_encoding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used in XLNet, but not in Transformer XL. segment_attention_bias: Optional bias `Tensor` for segment based attention of shape `[num_heads, dim]`. state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of the state or memory. If passed, this is also attended over as in Transformer XL. content_attention_mask: Optional `Tensor` representing the mask that is added to content attention logits. If state is not None, the mask source sequence dimension should extend M. query_stream: Optional `Tensor`, the query stream. This is introduced in `TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if `two_stream` is `False`. query_attention_mask: Optional `Tensor` representing the mask that is added to query attention logits. If state is not None, the mask source sequence dimension should extend M. target_mapping: Optional `Tensor` representing the target mapping when calculating query attention. Returns: A `dict` object, containing the key value pairs for `content_attention` and (if `two_stream` is `True`) `query_attention`. """ if not self._two_stream and query_stream is not None: logging.warning("`query_stream` was provided but two stream attention is " "disabled. `query_stream` will be ignored.") if self._two_stream: attention_kwargs = dict( content_stream=content_stream, query_stream=query_stream, query_attention_mask=query_attention_mask, target_mapping=target_mapping, content_attention_mask=content_attention_mask) else: attention_kwargs = dict( query=content_stream, value=content_stream, key=content_stream, attention_mask=content_attention_mask) common_attention_kwargs = dict( content_attention_bias=content_attention_bias, relative_position_encoding=relative_position_encoding, positional_attention_bias=positional_attention_bias, segment_matrix=segment_matrix, segment_encoding=segment_encoding, segment_attention_bias=segment_attention_bias, state=state) attention_kwargs.update(common_attention_kwargs) attention_output = self._attention_layer(**attention_kwargs) if self._two_stream: attention_streams = attention_output input_streams = [content_stream, query_stream] else: attention_streams = [attention_output] input_streams = [content_stream] attention_keys = ["content_attention", "query_attention"] attention_output = {} for attention_stream, input_stream, attention_key in zip( attention_streams, input_streams, attention_keys): attention_stream = self._attention_dropout(attention_stream) attention_stream = self._attention_layer_norm( attention_stream + input_stream) inner_output = self._inner_dense(attention_stream) inner_output = self._inner_activation_layer( inner_output) inner_output = self._inner_dropout_layer( inner_output) layer_output = self._output_dense(inner_output) layer_output = self._output_dropout(layer_output) layer_output = self._output_layer_norm(layer_output + attention_stream) attention_output[attention_key] = layer_output return attention_output class TransformerXL(tf_keras.layers.Layer): """Transformer XL. This layer combines multiple Transformer XL blocks from "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" (https://arxiv.org/abs/1901.02860). This layer handles the attention biases as well as memory caching and reuse as in Transformer XL and XLNet. Attributes: vocab_size: The number of tokens in vocabulary. num_layers: The number of layers. hidden_size: The hidden size. num_attention_heads: The number of attention heads. head_size: The dimension size of each attention head. inner_size: The hidden size in feed-forward layers. dropout_rate: Dropout rate used in each Transformer XL block. attention_dropout_rate: Dropout rate on attention probabilities. two_stream: Whether or not to use `TwoStreamRelativeAttention` used in the XLNet pretrainer. If `False`, then it will use `MultiHeadRelativeAttention` as in Transformer XL. initializer: The initializer to use for attention biases. tie_attention_biases: Whether or not to tie biases together. If `True`, then each Transformer XL block shares the same trainable attention bias. If `False`, then each block has its own attention bias. This is usually set to `True`. memory_length: The number of tokens to cache. reuse_length: The number of tokens in the current batch to be cached and reused in the future. inner_activation: The activation to use in the inner layers for Transformer XL blocks. Typically "relu" or "gelu". """ def __init__(self, vocab_size, num_layers, hidden_size, num_attention_heads, head_size, inner_size, dropout_rate, attention_dropout_rate, initializer, two_stream=False, tie_attention_biases=True, memory_length=None, reuse_length=None, inner_activation="relu", **kwargs): """Initializes TransformerXL.""" super().__init__(**kwargs) self._vocab_size = vocab_size self._initializer = initializer self._num_layers = num_layers self._hidden_size = hidden_size self._num_attention_heads = num_attention_heads self._head_size = head_size self._inner_size = inner_size self._inner_activation = inner_activation self._dropout_rate = dropout_rate self._attention_dropout_rate = attention_dropout_rate self._tie_attention_biases = tie_attention_biases self._two_stream = two_stream self._memory_length = memory_length self._reuse_length = reuse_length if self._tie_attention_biases: attention_bias_shape = [self._num_attention_heads, self._head_size] else: attention_bias_shape = [self._num_layers, self._num_attention_heads, self._head_size] self.content_attention_bias = self.add_weight( "content_attention_bias", shape=attention_bias_shape, dtype=tf.float32, initializer=tf_utils.clone_initializer(self._initializer)) self.positional_attention_bias = self.add_weight( "positional_attention_bias", shape=attention_bias_shape, dtype=tf.float32, initializer=tf_utils.clone_initializer(self._initializer)) self.segment_attention_bias = self.add_weight( "segment_attention_bias", shape=attention_bias_shape, dtype=tf.float32, initializer=tf_utils.clone_initializer(self._initializer)) self.transformer_xl_layers = [] for i in range(self._num_layers): self.transformer_xl_layers.append( TransformerXLBlock( vocab_size=self._vocab_size, hidden_size=self._head_size * self._num_attention_heads, num_attention_heads=self._num_attention_heads, head_size=self._head_size, inner_size=self._inner_size, dropout_rate=self._dropout_rate, attention_dropout_rate=self._attention_dropout_rate, norm_epsilon=1e-12, inner_activation=self._inner_activation, two_stream=self._two_stream, kernel_initializer="variance_scaling", name="layer_%d" % i)) self.output_dropout = tf_keras.layers.Dropout(rate=self._dropout_rate) def get_config(self): config = { "vocab_size": self._vocab_size, "num_layers": self._num_layers, "hidden_size": self._hidden_size, "num_attention_heads": self._num_attention_heads, "head_size": self._head_size, "inner_size": self._inner_size, "dropout_rate": self._dropout_rate, "attention_dropout_rate": self._attention_dropout_rate, "initializer": self._initializer, "two_stream": self._two_stream, "tie_attention_biases": self._tie_attention_biases, "memory_length": self._memory_length, "reuse_length": self._reuse_length, "inner_activation": self._inner_activation, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) def call(self, content_stream, relative_position_encoding, segment_matrix=None, segment_embedding=None, state=None, content_attention_mask=None, query_stream=None, query_attention_mask=None, target_mapping=None): """Implements call() for the layer. Args: content_stream: `Tensor`, the input content stream. This is the standard input to Transformer XL and is commonly referred to as `h` in XLNet. relative_position_encoding: Relative positional encoding `Tensor` of shape `[B, L, dim]`. segment_matrix: Optional `Tensor` of shape `[B, S, S + M]`. Used in XLNet, but not in Transformer XL. segment_embedding: Optional `Tensor` of shape `[2, num_heads, dim]`. Used in XLNet, but not in Transformer XL. state: Optional `Tensor` of shape `[B, M, E]`, where M is the length of the state or memory. If passed, this is also attended over as in Transformer XL. content_attention_mask: Optional `Tensor` representing the mask that is added to content attention logits. If state is not None, the mask source sequence dimension should extend M. query_stream: Optional `Tensor`, the query stream. This is introduced in `TwoStreamRelativeAttention`/XLNet pretrainer. This is ignored if `two_stream` is `False`. query_attention_mask: Optional `Tensor` representing the mask that is added to query attention logits. If state is not None, the mask source sequence dimension should extend M. target_mapping: Optional `Tensor` representing the target mapping when calculating query attention. Returns: A tuple consisting of the attention output and the list of cached memory states. The attention output is `content_attention` if `two_stream` is `False`, otherwise it is `query_attention`. """ new_mems = [] if state is None: state = [None] * self._num_layers for i in range(self._num_layers): # cache new mems new_mems.append( _cache_memory(content_stream, state[i], self._memory_length, self._reuse_length)) # segment bias if segment_matrix is None: segment_attention_bias = None segment_encoding = None else: segment_attention_bias = (self.segment_attention_bias if self._tie_attention_biases else self.segment_attention_bias[i]) segment_encoding = segment_embedding[i] content_attention_bias = (self.content_attention_bias if self._tie_attention_biases else self.content_attention_bias[i]) positional_attention_bias = (self.positional_attention_bias if self._tie_attention_biases else self.positional_attention_bias[i]) transformer_xl_layer = self.transformer_xl_layers[i] transformer_xl_output = transformer_xl_layer( content_stream=content_stream, content_attention_bias=content_attention_bias, positional_attention_bias=positional_attention_bias, relative_position_encoding=relative_position_encoding, segment_matrix=segment_matrix, segment_encoding=segment_encoding, segment_attention_bias=segment_attention_bias, state=state[i], content_attention_mask=content_attention_mask, query_attention_mask=query_attention_mask, query_stream=query_stream, target_mapping=target_mapping) content_stream = transformer_xl_output["content_attention"] if self._two_stream: query_stream = transformer_xl_output["query_attention"] else: query_stream = None if self._two_stream: output_stream = query_stream else: output_stream = content_stream return output_stream, new_mems