# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Multi-channel Attention.""" # pylint: disable=g-classes-have-attributes import math import tensorflow as tf, tf_keras from official.modeling import tf_utils from official.nlp.modeling.layers import masked_softmax class VotingAttention(tf_keras.layers.Layer): """Voting Attention layer. Args: num_heads: The number of attention heads. head_size: Per-head hidden size. kernel_initializer: Initializer for dense layer kernels. bias_initializer: Initializer for dense layer biases. kernel_regularizer: Regularizer for dense layer kernels. bias_regularizer: Regularizer for dense layer biases. activity_regularizer: Regularizer for dense layer activity. kernel_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels. """ def __init__(self, num_heads, head_size, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): super().__init__(**kwargs) self._num_heads = num_heads self._head_size = head_size self._kernel_initializer = tf_keras.initializers.get(kernel_initializer) self._bias_initializer = tf_keras.initializers.get(bias_initializer) self._kernel_regularizer = tf_keras.regularizers.get(kernel_regularizer) self._bias_regularizer = tf_keras.regularizers.get(bias_regularizer) self._kernel_constraint = tf_keras.constraints.get(kernel_constraint) self._bias_constraint = tf_keras.constraints.get(bias_constraint) def build(self, unused_input_shapes): common_kwargs = dict( kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, activity_regularizer=self._activity_regularizer, kernel_constraint=self._kernel_constraint, bias_constraint=self._bias_constraint) self._query_dense = tf_keras.layers.EinsumDense( "BAE,ENH->BANH", output_shape=(None, self._num_heads, self._head_size), bias_axes="NH", name="query", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), **common_kwargs) self._key_dense = tf_keras.layers.EinsumDense( "BAE,ENH->BANH", output_shape=(None, self._num_heads, self._head_size), bias_axes="NH", name="key", kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer), **common_kwargs) super().build(unused_input_shapes) def call(self, encoder_outputs, doc_attention_mask): num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1] cls_embeddings = encoder_outputs[:, :, 0, :] key = self._key_dense(cls_embeddings) query = self._query_dense(cls_embeddings) doc_attention_mask = tf.cast(doc_attention_mask, tf.float32) key = tf.einsum("BANH,BA->BANH", key, doc_attention_mask) query = tf.einsum("BANH,BA->BANH", query, doc_attention_mask) attention_matrix = tf.einsum("BXNH,BYNH->BNXY", query, key) mask = tf.ones([num_docs, num_docs]) mask = tf.linalg.set_diag(mask, tf.zeros(num_docs)) attention_matrix = tf.einsum("BNXY,XY->BNXY", attention_matrix, mask) doc_attention_probs = tf.einsum("BNAY->BNA", attention_matrix) doc_attention_probs = tf.einsum("BNA->BA", doc_attention_probs) infadder = (1.0 - doc_attention_mask) * -100000.0 return tf.nn.softmax(doc_attention_probs + infadder) class MultiChannelAttention(tf_keras.layers.MultiHeadAttention): """Multi-channel Attention layer. Introduced in, [Generating Representative Headlines for News Stories ](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention target sequences. Call args: query: Query `Tensor` of shape `[B, T, dim]`. value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the context_attention_weights: Context weights of shape `[B, N, T, A]`, where N is the number of attention heads. Combines multi-channel sources context tensors according to the distribution among channels. key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use `value` for both `key` and `value`, which is the most common case. attention_mask: A boolean mask of shape `[B, T, S]`, that prevents attention to certain positions. """ def _build_attention(self, rank): super()._build_attention(rank) # pytype: disable=attribute-error # typed-keras self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2]) def call(self, query, value, key=None, context_attention_weights=None, attention_mask=None): if not self._built_from_signature: self._build_from_signature(query, value, key=key) if key is None: key = value # Scalar dimensions referenced here: # B = batch size (number of stories) # A = num_docs (number of docs) # F = target sequence length # T = source sequence length # N = `num_attention_heads` # H = `size_per_head` # `query_tensor` = [B, F, N ,H] query_tensor = self._query_dense(query) # `key_tensor` = [B, A, T, N, H] key_tensor = self._key_dense(key) # `value_tensor` = [B, A, T, N, H] value_tensor = self._value_dense(value) # Take the dot product between "query" and "key" to get the raw # attention scores. attention_scores = tf.einsum("BATNH,BFNH->BANFT", key_tensor, query_tensor) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(self._key_dim))) # Normalize the attention scores to probabilities. # `attention_probs` = [B, A, N, F, T] attention_probs = self._masked_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self._dropout_layer(attention_probs) # `context_layer` = [B, F, N, H] context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs, value_tensor) attention_output = tf.einsum("BNFA,BAFNH->BFNH", context_attention_weights, context_layer) attention_output = self._output_dense(attention_output) return attention_output