|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Talking Head Attention layer.""" |
|
|
|
import math |
|
import string |
|
|
|
import gin |
|
import tensorflow as tf |
|
|
|
from official.nlp.modeling.layers import attention |
|
|
|
_CHR_IDX = string.ascii_lowercase |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable(package="Text") |
|
@gin.configurable |
|
class TalkingHeadsAttention(attention.MultiHeadAttention): |
|
"""Implements Talking-Heads Attention. |
|
|
|
This is an implementation of Talking-Heads Attention based on the paper |
|
Talking-Heads Attention (https://arxiv.org/abs/2003.02436): it enhanced |
|
multi-head attention by including linearprojections across the attention-heads |
|
dimension, immediately before and after the softmax operation. |
|
|
|
See the base class `MultiHeadAttention` for more details. |
|
|
|
Arguments: |
|
num_heads: Number of attention heads. |
|
key_size: Size of each attention head for query and key. |
|
value_size: Size of each attention head for value. |
|
dropout: Dropout probability. |
|
use_bias: Boolean, whether the dense layers use bias vectors/matrices. |
|
output_shape: The expected shape of an output tensor, besides the batch and |
|
sequence dims. If not specified, projects back to the key feature dim. |
|
attention_axes: axes over which the attention is applied. `None` means |
|
attention over all axes, but batch, heads, and features. |
|
return_attention_scores: bool, if `True`, returns the multi-head attention |
|
scores as an additional output argument. |
|
kernel_initializer: Initializer for dense layer kernels. |
|
bias_initializer: Initializer for dense layer biases. |
|
kernel_regularizer: Regularizer for dense layer kernels. |
|
bias_regularizer: Regularizer for dense layer biases. |
|
activity_regularizer: Regularizer for dense layer activity. |
|
kernel_constraint: Constraint for dense layer kernels. |
|
bias_constraint: Constraint for dense layer kernels. |
|
""" |
|
|
|
def _build_attention(self, qkv_rank): |
|
"""Builds multi-head dot-product attention computations. |
|
|
|
This function overrides base class to create additional linear projection |
|
that will be applied on attention scores before and after softmax. |
|
|
|
Args: |
|
qkv_rank: the rank of query, key, value tensors after projection. |
|
""" |
|
super(TalkingHeadsAttention, self)._build_attention(qkv_rank) |
|
|
|
|
|
|
|
|
|
|
|
num_batch_dims = qkv_rank - len(self._attention_axes) - 2 |
|
|
|
|
|
|
|
attn_scores_rank = num_batch_dims + 1 + len(self._attention_axes) * 2 |
|
scores_notation = _CHR_IDX[:attn_scores_rank] |
|
projection_notation = scores_notation[num_batch_dims] + ( |
|
_CHR_IDX[attn_scores_rank]) |
|
projected_scores_notation = scores_notation[:num_batch_dims] + ( |
|
_CHR_IDX[attn_scores_rank] + scores_notation[num_batch_dims + 1:]) |
|
self._talking_heads_equation = "%s,%s->%s" % ( |
|
scores_notation, projection_notation, projected_scores_notation) |
|
|
|
self._pre_softmax_weight = self.add_weight( |
|
"pre_softmax_weight", |
|
shape=(self._num_heads, self._num_heads), |
|
initializer=self._kernel_initializer, |
|
regularizer=self._kernel_regularizer, |
|
constraint=self._kernel_constraint, |
|
dtype=self.dtype, |
|
trainable=True) |
|
self._post_softmax_weight = self.add_weight( |
|
"post_softmax_weight", |
|
shape=(self._num_heads, self._num_heads), |
|
initializer=self._kernel_initializer, |
|
regularizer=self._kernel_regularizer, |
|
constraint=self._kernel_constraint, |
|
dtype=self.dtype, |
|
trainable=True) |
|
|
|
def _compute_attention(self, |
|
query_tensor, |
|
key_tensor, |
|
value_tensor, |
|
attention_mask=None): |
|
"""Applies Dot-product attention with query, key, value tensors. |
|
|
|
This function overrides base class to apply additional linear projection |
|
on attention scores before and after softmax. |
|
|
|
Args: |
|
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`. |
|
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`. |
|
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`. |
|
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents |
|
attention to certain positions. |
|
|
|
Returns: |
|
attention_output: Multi-headed outputs of attention computation. |
|
attention_scores: Multi-headed attention weights. |
|
""" |
|
|
|
|
|
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, |
|
query_tensor) |
|
attention_scores = tf.multiply(attention_scores, |
|
1.0 / math.sqrt(float(self._key_size))) |
|
|
|
|
|
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores, |
|
self._pre_softmax_weight) |
|
|
|
|
|
|
|
attention_scores = self._masked_softmax(attention_scores, attention_mask) |
|
|
|
|
|
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores, |
|
self._post_softmax_weight) |
|
|
|
|
|
|
|
attention_scores_dropout = self._dropout_layer(attention_scores) |
|
|
|
|
|
attention_output = tf.einsum(self._combine_equation, |
|
attention_scores_dropout, value_tensor) |
|
return attention_output, attention_scores |
|
|