Spaces:
Sleeping
Sleeping
File size: 6,913 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Talking Head Attention layer."""
# pylint: disable=g-classes-have-attributes
import math
import string
import gin
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase
@tf_keras.utils.register_keras_serializable(package="Text")
@gin.configurable
class TalkingHeadsAttention(tf_keras.layers.MultiHeadAttention):
"""Implements Talking-Heads Attention.
This is an implementation of Talking-Heads Attention based on the paper
Talking-Heads Attention (https://arxiv.org/abs/2003.02436): it enhanced
multi-head attention by including linearprojections across the attention-heads
dimension, immediately before and after the softmax operation.
See the base class `tf_keras.layers.MultiHeadAttention` for more details.
Args:
num_heads: Number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head attention
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def _build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
that will be applied on attention scores before and after softmax.
Args:
qkv_rank: The rank of query, key, value tensors after projection.
"""
super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
# (<batch_dims>, num_heads_b, ...)
# qkv_ranks has `batch_dims`, `attention_dims`, `num_heads` and `channels`.
num_batch_dims = qkv_rank - len(self._attention_axes) - 2
# The shape of attn_scores is:
# (<batch_dims>, num_heads, <query_attn_dims>, <key_attn_dims>)
attn_scores_rank = num_batch_dims + 1 + len(self._attention_axes) * 2
scores_notation = _CHR_IDX[:attn_scores_rank]
projection_notation = scores_notation[num_batch_dims] + (
_CHR_IDX[attn_scores_rank])
projected_scores_notation = scores_notation[:num_batch_dims] + (
_CHR_IDX[attn_scores_rank] + scores_notation[num_batch_dims + 1:])
self._talking_heads_equation = "%s,%s->%s" % (
scores_notation, projection_notation, projected_scores_notation)
self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight",
shape=(self._num_heads, self._num_heads),
initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
trainable=True)
self._post_softmax_weight = self.add_weight(
"post_softmax_weight",
shape=(self._num_heads, self._num_heads),
initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
trainable=True)
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None,
training=None):
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
on attention scores before and after softmax.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_dim]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_dim)))
# Apply linear projection before softmax
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
self._pre_softmax_weight)
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
attention_scores = self._masked_softmax(attention_scores, attention_mask)
# Apply linear projection after softmax
attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
self._post_softmax_weight)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training)
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor)
return attention_output, attention_scores
|