# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras-based attention layer.""" # pylint: disable=g-classes-have-attributes import collections import math import string import numpy as np import tensorflow as tf, tf_keras from official.modeling import tf_utils _CHR_IDX = string.ascii_lowercase def _build_attention_equation(rank, attn_axes): """Builds einsum equations for the attention computation. Query, key, value inputs after projection are expected to have the shape as: `(bs, , , num_heads, channels)`. `bs` and `` are treated as ``. The attention operations can be generalized: (1) Query-key dot product: `(, , num_heads, channels), (, , num_heads, channels) -> (, num_heads, , )` (2) Combination: `(, num_heads, , ), (, , num_heads, channels) -> (, , num_heads, channels)` Args: rank: Rank of query, key, value tensors. attn_axes: List/tuple of axes, `[-1, rank)`, that attention will be applied to. Returns: Einsum equations. """ target_notation = _CHR_IDX[:rank] # `batch_dims` includes the head dim. batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) letter_offset = rank source_notation = "" for i in range(rank): if i in batch_dims or i == rank - 1: source_notation += target_notation[i] else: source_notation += _CHR_IDX[letter_offset] letter_offset += 1 product_notation = "".join([target_notation[i] for i in batch_dims] + [target_notation[i] for i in attn_axes] + [source_notation[i] for i in attn_axes]) dot_product_equation = "%s,%s->%s" % (source_notation, target_notation, product_notation) attn_scores_rank = len(product_notation) combine_equation = "%s,%s->%s" % (product_notation, source_notation, target_notation) return dot_product_equation, combine_equation, attn_scores_rank def _build_proj_equation(free_dims, bound_dims, output_dims): """Builds an einsum equation for projections inside multi-head attention.""" input_str = "" kernel_str = "" output_str = "" bias_axes = "" letter_offset = 0 for i in range(free_dims): char = _CHR_IDX[i + letter_offset] input_str += char output_str += char letter_offset += free_dims for i in range(bound_dims): char = _CHR_IDX[i + letter_offset] input_str += char kernel_str += char letter_offset += bound_dims for i in range(output_dims): char = _CHR_IDX[i + letter_offset] kernel_str += char output_str += char bias_axes += char equation = "%s,%s->%s" % (input_str, kernel_str, output_str) return equation, bias_axes, len(output_str) def _get_output_shape(output_rank, known_last_dims): return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) class ReuseMultiHeadAttention(tf_keras.layers.Layer): """MultiHeadAttention layer. This is an implementation of multi-headed attention as described in the paper "Attention is all you Need" (Vaswani et al., 2017). If `query`, `key,` `value` are the same, then this is self-attention. Each timestep in `query` attends to the corresponding sequence in `key`, and returns a fixed-width vector. This layer first projects `query`, `key` and `value`. These are (effectively) a list of tensors of length `num_attention_heads`, where the corresponding shapes are `(batch_size, , key_dim)`, `(batch_size, , key_dim)`, `(batch_size, , value_dim)`. Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor. Finally, the result tensor with the last dimension as value_dim can take an linear projection and return. Examples: Performs 1D cross-attention over two sequence inputs with an attention mask. Returns the additional attention weights over heads. >>> layer = MultiHeadAttention(num_heads=2, key_dim=2) >>> target = tf_keras.Input(shape=[8, 16]) >>> source = tf_keras.Input(shape=[4, 16]) >>> output_tensor, weights = layer(target, source, ... return_attention_scores=True) >>> print(output_tensor.shape) (None, 8, 16) >>> print(weights.shape) (None, 2, 8, 4) Performs 2D self-attention over a 5D input tensor on axes 2 and 3. >>> layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3)) >>> input_tensor = tf_keras.Input(shape=[5, 3, 4, 16]) >>> output_tensor = layer(input_tensor, input_tensor) >>> print(output_tensor.shape) (None, 5, 3, 4, 16) Args: num_heads: Number of attention heads. key_dim: Size of each attention head for query and key. value_dim: Size of each attention head for value. dropout: Dropout probability. reuse_attention: An integer specifying number of heads to reuse. -1 for all heads. use_relative_pe: Whether to use relative position bias. max_sequence_length: Used to set the size of the relative positin encodings. use_bias: Boolean, whether the dense layers use bias vectors/matrices. output_shape: The expected shape of an output tensor, besides the batch and sequence dims. If not specified, projects back to the key feature dim. attention_axes: axes over which the attention is applied. `None` means attention over all axes, but batch, heads, and features. kernel_initializer: Initializer for dense layer kernels. bias_initializer: Initializer for dense layer biases. kernel_regularizer: Regularizer for dense layer kernels. bias_regularizer: Regularizer for dense layer biases. activity_regularizer: Regularizer for dense layer activity. kernel_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels. Call arguments: query: Query `Tensor` of shape `(B, T, dim)`. value: Value `Tensor` of shape `(B, S, dim)`. key: Optional key `Tensor` of shape `(B, S, dim)`. If not given, will use `value` for both `key` and `value`, which is the most common case. attention_mask: a boolean mask of shape `(B, T, S)`, that prevents attention to certain positions. The boolean mask specifies which query elements can attend to which key elements, 1 indicates attention and 0 indicates no attention. Broadcasting can happen for the missing batch dimensions and the head dimension. return_attention_scores: A boolean to indicate whether the output should be attention output if True, or (attention_output, attention_scores) if False. Defaults to False. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout). Defaults to either using the training mode of the parent layer/model, or False (inference) if there is no parent layer. Returns: attention_output: The result of the computation, of shape `(B, T, E)`, where `T` is for target sequence shapes and `E` is the query input last dimension if `output_shape` is `None`. Otherwise, the multi-head outputs are project to the shape specified by `output_shape`. attention_scores: [Optional] multi-head attention coeffients over attention axes. """ def __init__(self, num_heads, key_dim, value_dim=None, dropout=0.0, reuse_attention=0, use_relative_pe=False, pe_max_seq_length=512, use_bias=True, output_shape=None, attention_axes=None, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): super().__init__(**kwargs) self._num_heads = num_heads self._key_dim = key_dim self._value_dim = value_dim if value_dim else key_dim self._dropout = dropout if reuse_attention > self._num_heads or reuse_attention < -1: raise ValueError("reuse_attention should be between -1 " "and %d in call to %s." % (self.__class__, self._num_heads)) if reuse_attention == -1: reuse_attention = self._num_heads self._reuse_heads = reuse_attention self._use_relative_pe = use_relative_pe self._pe_max_seq_length = pe_max_seq_length self._use_bias = use_bias self._output_shape = output_shape self._kernel_initializer = tf_keras.initializers.get(kernel_initializer) self._bias_initializer = tf_keras.initializers.get(bias_initializer) self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer) self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer) self._kernel_constraint = tf_keras.constraints.get(kernel_constraint) self._bias_constraint = tf_keras.constraints.get(bias_constraint) if attention_axes is not None and not isinstance(attention_axes, collections.abc.Sized): self._attention_axes = (attention_axes,) else: self._attention_axes = attention_axes self._built_from_signature = False self._query_shape, self._key_shape, self._value_shape = None, None, None # Use relative PE only if reuse_heads < num_heads. if self._use_relative_pe and self._reuse_heads < self._num_heads: # Determine the dtype from global policy. policy = tf_keras.mixed_precision.global_policy() if policy.name == "mixed_bfloat16": policy = tf.bfloat16 elif policy.name == "mixed_float16": policy = tf.float16 else: policy = tf.float32 self._position_embeddings = tf.Variable( name="relative_position_embeddings", initial_value=lambda: tf.random.truncated_normal( # pylint: disable=g-long-lambda [ 1, self._num_heads - self._reuse_heads, 2 * self. _pe_max_seq_length - 1 ], mean=0.0, stddev=0.2, dtype=policy), trainable=True, dtype=policy) def get_config(self): config = { "num_heads": self._num_heads, "key_dim": self._key_dim, "value_dim": self._value_dim, "dropout": self._dropout, "use_bias": self._use_bias, "output_shape": self._output_shape, "attention_axes": self._attention_axes, "reuse_attention": self._reuse_heads, "use_relative_pe": self._use_relative_pe, "pe_max_seq_length": self._pe_max_seq_length, "kernel_initializer": tf_keras.initializers.serialize(self._kernel_initializer), "bias_initializer": tf_keras.initializers.serialize(self._bias_initializer), "kernel_regularizer": tf_keras.regularizers.serialize(self._kernel_regularizer), "bias_regularizer": tf_keras.regularizers.serialize(self._bias_regularizer), "activity_regularizer": tf_keras.regularizers.serialize(self._activity_regularizer), "kernel_constraint": tf_keras.constraints.serialize(self._kernel_constraint), "bias_constraint": tf_keras.constraints.serialize(self._bias_constraint), "query_shape": self._query_shape, "key_shape": self._key_shape, "value_shape": self._value_shape, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config): # If the layer has a different build() function from the Keras default, # we need to trigger the customized build to create weights. query_shape = config.pop("query_shape") key_shape = config.pop("key_shape") value_shape = config.pop("value_shape") layer = cls(**config) if None in [query_shape, key_shape, value_shape]: tf.get_logger().warning( "One of dimensions of the input shape is missing. It should have been" " memorized when the layer was serialized. " "%s is created without weights.", str(cls)) else: layer._build_from_signature(query_shape, value_shape, key_shape) # pylint: disable=protected-access return layer def _build_from_signature(self, query, value, key=None): """Builds layers and variables. Once the method is called, self._built_from_signature will be set to True. Args: query: Query tensor or TensorShape. value: Value tensor or TensorShape. key: Key tensor or TensorShape. """ self._built_from_signature = True if hasattr(query, "shape"): self._query_shape = tf.TensorShape(query.shape) else: self._query_shape = tf.TensorShape(query) if hasattr(value, "shape"): self._value_shape = tf.TensorShape(value.shape) else: self._value_shape = tf.TensorShape(value) if key is None: self._key_shape = self._value_shape elif hasattr(key, "shape"): self._key_shape = tf.TensorShape(key.shape) else: self._key_shape = tf.TensorShape(key) common_kwargs = dict( kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, activity_regularizer=self._activity_regularizer, kernel_constraint=self._kernel_constraint, bias_constraint=self._bias_constraint) # Any setup work performed only once should happen in an `init_scope` # to avoid creating symbolic Tensors that will later pollute any eager # operations. with tf.init_scope(): free_dims = self._query_shape.rank - 1 if self._reuse_heads < self._num_heads: einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=1, output_dims=2) self._query_dense = tf_keras.layers.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads - self._reuse_heads, self._key_dim]), bias_axes=bias_axes if self._use_bias else None, name="query", kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), **common_kwargs) einsum_equation, bias_axes, output_rank = _build_proj_equation( self._key_shape.rank - 1, bound_dims=1, output_dims=2) self._key_dense = tf_keras.layers.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads - self._reuse_heads, self._key_dim]), bias_axes=bias_axes if self._use_bias else None, name="key", kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), **common_kwargs) einsum_equation, bias_axes, output_rank = _build_proj_equation( self._value_shape.rank - 1, bound_dims=1, output_dims=2) self._value_dense = [] if self._reuse_heads > 0: self._value_dense.append( tf_keras.layers.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._reuse_heads, self._value_dim]), bias_axes=bias_axes if self._use_bias else None, name="value_reuse", kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer), bias_initializer=tf_utils.clone_initializer( self._bias_initializer), **common_kwargs)) if self._reuse_heads < self._num_heads: self._value_dense.append( tf_keras.layers.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads - self._reuse_heads, self._value_dim]), bias_axes=bias_axes if self._use_bias else None, name="value_new", kernel_initializer=tf_utils.clone_initializer( self._kernel_initializer), bias_initializer=tf_utils.clone_initializer( self._bias_initializer), **common_kwargs)) # Builds the attention computations for multi-head dot product attention. # These computations could be wrapped into the keras attention layer once # it support mult-head einsum computations. self._build_attention(output_rank) self._output_dense = [] if self._reuse_heads > 0: self._output_dense.append(self._make_output_dense( free_dims, common_kwargs, "attention_output_reuse")) if self._reuse_heads < self._num_heads: self._output_dense.append(self._make_output_dense( free_dims, common_kwargs, "attention_output_new", self._reuse_heads == 0)) def _make_output_dense(self, free_dims, common_kwargs, name=None, use_bias=True): """Builds the output projection matrix. Args: free_dims: Number of free dimensions for einsum equation building. common_kwargs: Common keyword arguments for einsum layer. name: Name for the projection layer. use_bias: Use bias if self._use_bias is true Returns: Projection layer. """ if self._output_shape: if not isinstance(self._output_shape, collections.abc.Sized): output_shape = [self._output_shape] else: output_shape = self._output_shape else: output_shape = [self._query_shape[-1]] einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=2, output_dims=len(output_shape)) return tf_keras.layers.EinsumDense( einsum_equation, output_shape=_get_output_shape(output_rank - 1, output_shape), bias_axes=bias_axes if (use_bias and self._use_bias) else None, name=name, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), **common_kwargs) def _build_attention(self, rank): """Builds multi-head dot-product attention computations. This function builds attributes necessary for `_compute_attention` to customize attention computation to replace the default dot-product attention. Args: rank: the rank of query, key, value tensors. """ if self._attention_axes is None: self._attention_axes = tuple(range(1, rank - 2)) else: self._attention_axes = tuple(self._attention_axes) self._dot_product_equation, self._combine_equation, attn_scores_rank = ( _build_attention_equation(rank, attn_axes=self._attention_axes)) norm_axes = tuple( range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) self._softmax = tf_keras.layers.Softmax(axis=norm_axes) self._dropout_layer = tf_keras.layers.Dropout(rate=self._dropout) def _masked_softmax(self, attention_scores, attention_mask=None): # Normalize the attention scores to probabilities. # `attention_scores` = [B, N, T, S] if attention_mask is not None: # The expand dim happens starting from the `num_heads` dimension, # (, num_heads, ) mask_expansion_axes = [-len(self._attention_axes) * 2 - 1] for _ in range(len(attention_scores.shape) - len(attention_mask.shape)): attention_mask = tf.expand_dims( attention_mask, axis=mask_expansion_axes) return self._softmax(attention_scores, attention_mask) def _compute_relative_position(self, query_seq_length, key_seq_length): position_zero = self._pe_max_seq_length - 1 # We take the vector position variable and concatenate to form a matrix of # relative position encodings. i=0 indicates reltaive position is 0. indices = tf.expand_dims(tf.range(0, -query_seq_length, -1), -1) + tf.range(key_seq_length) + position_zero indices = tf.maximum(indices, 0) indices = tf.minimum(indices, 2*self._pe_max_seq_length-2) attention_biases = tf.gather(self._position_embeddings, indices, axis=2) return attention_biases def _compute_attention(self, query, key, value, reuse_scores=None, attention_mask=None, training=None): """Applies Dot-product attention with query, key, value tensors. This function defines the computation inside `call` with projected multi-head Q, K, V inputs. Users can override this function for customized attention implementation. Args: query: Projected query `Tensor` of shape `(B, T, N, key_dim)`. key: Projected key `Tensor` of shape `(B, T, N, key_dim)`. value: Projected value `Tensor` of shape `(B, T, N, value_dim)`. reuse_scores: Attention scores from a previous layer if needed. attention_mask: a boolean mask of shape `(B, T, S)`, that prevents attention to certain positions. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (doing nothing). Returns: attention_output: Multi-headed outputs of attention computation. attention_scores: Multi-headed attention weights. """ # Partial or no reuse if self._reuse_heads < self._num_heads: query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) new_scores = tf.einsum(self._dot_product_equation, key, query) # Add relative position embeddings if required. if self._use_relative_pe: new_scores = new_scores + self._compute_relative_position( tf.shape(query)[1], tf.shape(key)[1]) new_scores = self._masked_softmax(new_scores, attention_mask) if self._reuse_heads > 0: # Partial reuse reuse_scores = reuse_scores[:, :self._reuse_heads, :, :] attention_scores = tf.concat([new_scores, reuse_scores], 1) else: # No reuse attention_scores = new_scores else: # Full reuse attention_scores = reuse_scores new_scores = None # `context_layer` = [B, T, N, H] attention_output = [] # Partial or full reuse if self._reuse_heads > 0: attention_output.append( tf.einsum(self._combine_equation, self._dropout_layer( reuse_scores, training=training), value[0])) # Partial or no reuse if self._reuse_heads < self._num_heads: attention_output.append( tf.einsum(self._combine_equation, self._dropout_layer( new_scores, training=training), value[-1])) return attention_output, attention_scores def call(self, query, value, key=None, attention_mask=None, return_attention_scores=False, training=None, reuse_attention_scores=None): if self._reuse_heads > 0 and reuse_attention_scores is None: raise ValueError("reuse_attention_scores cannot be None when " "reuse_attention is True or > 0.") if not self._built_from_signature: self._build_from_signature(query=query, value=value, key=key) if key is None: key = value # N = `num_attention_heads` # H = `size_per_head` # `value` = [B, S, N, H] value = [vd(value) for vd in self._value_dense] if self._reuse_heads < self._num_heads: # `query` = [B, T, N ,H] query = self._query_dense(query) # `key` = [B, S, N, H] key = self._key_dense(key) else: query, key = None, None attention_output, attention_scores = self._compute_attention( query, key, value, reuse_attention_scores, attention_mask, training) attention_output = [od(attention_output[i]) for i, od in enumerate( self._output_dense)] if len(attention_output) == 1: attention_output = attention_output[0] else: attention_output = attention_output[0] + attention_output[1] if return_attention_scores: return attention_output, attention_scores return attention_output