# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras-based attention layer.""" # pylint: disable=g-classes-have-attributes import math import tensorflow as tf, tf_keras EinsumDense = tf_keras.layers.EinsumDense MultiHeadAttention = tf_keras.layers.MultiHeadAttention @tf_keras.utils.register_keras_serializable(package="Text") class CachedAttention(tf_keras.layers.MultiHeadAttention): """Attention layer with cache used for autoregressive decoding. Arguments are the same as `tf_keras.layers.MultiHeadAttention` layer. """ def _update_cache(self, key, value, cache, decode_loop_step): """Updates cache states and gets full-length key/value tensors.""" # Combines cached keys and values with new keys and values. if decode_loop_step is not None: # TPU special case. key_seq_dim = cache["key"].shape.as_list()[1] indices = tf.reshape( tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype), [1, key_seq_dim, 1, 1]) key = cache["key"] + key * indices value_seq_dim = cache["value"].shape.as_list()[1] indices = tf.reshape( tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype), [1, value_seq_dim, 1, 1]) value = cache["value"] + value * indices else: key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1) value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1) # Update cache cache["key"] = key cache["value"] = value return key, value def call(self, query, value, key=None, attention_mask=None, cache=None, decode_loop_step=None, return_attention_scores=False): if not self._built_from_signature: self._build_from_signature(query=query, value=value, key=key) if key is None: key = value # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` # `query` = [B, F, N ,H] query = self._query_dense(query) # `key` = [B, T, N, H] key = self._key_dense(key) # `value` = [B, T, N, H] value = self._value_dense(value) if cache: key, value = self._update_cache(key, value, cache, decode_loop_step) query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) # Take the dot product between "query" and "key" to get the raw # attention scores. attention_scores = tf.einsum(self._dot_product_equation, key, query) # Normalize the attention scores to probabilities. # `attention_scores` = [B, N, F, T] attention_scores = self._masked_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_scores = self._dropout_layer(attention_scores) # `context_layer` = [B, F, N, H] attention_output = tf.einsum(self._combine_equation, attention_scores, value) attention_output = self._output_dense(attention_output) if return_attention_scores: return attention_output, attention_scores, cache return attention_output, cache