deanna-emery's picture
updates
93528c6
raw
history blame
7.13 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of multiheaded attention and self-attention layers."""
import math
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
class Attention(tf_keras.layers.Layer):
"""Multi-headed attention layer."""
def __init__(self, hidden_size, num_heads, attention_dropout):
"""Initialize Attention.
Args:
hidden_size: int, output dim of hidden layer.
num_heads: int, number of heads to repeat the same attention structure.
attention_dropout: float, dropout rate inside attention for training.
"""
if hidden_size % num_heads:
raise ValueError(
"Hidden size ({}) must be divisible by the number of heads ({})."
.format(hidden_size, num_heads))
super(Attention, self).__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
def build(self, input_shape):
"""Builds the layer."""
# Layers for linearly projecting the queries, keys, and values.
size_per_head = self.hidden_size // self.num_heads
def _glorot_initializer(fan_in, fan_out):
limit = math.sqrt(6.0 / (fan_in + fan_out))
return tf_keras.initializers.RandomUniform(minval=-limit, maxval=limit)
attention_initializer = _glorot_initializer(input_shape.as_list()[-1],
self.hidden_size)
self.query_dense_layer = tf_keras.layers.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="query")
self.key_dense_layer = tf_keras.layers.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="key")
self.value_dense_layer = tf_keras.layers.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="value")
output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
self.output_dense_layer = tf_keras.layers.EinsumDense(
"BTNH,NHE->BTE",
output_shape=(None, self.hidden_size),
kernel_initializer=output_initializer,
bias_axes=None,
name="output_transform")
super(Attention, self).build(input_shape)
def get_config(self):
return {
"hidden_size": self.hidden_size,
"num_heads": self.num_heads,
"attention_dropout": self.attention_dropout,
}
def call(self,
query_input,
source_input,
bias,
training,
cache=None,
decode_loop_step=None):
"""Apply attention mechanism to query_input and source_input.
Args:
query_input: A tensor with shape [batch_size, length_query, hidden_size].
source_input: A tensor with shape [batch_size, length_source,
hidden_size].
bias: A tensor with shape [batch_size, 1, length_query, length_source],
the attention bias that will be added to the result of the dot product.
training: A bool, whether in training mode or not.
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, heads, dim_per_head],
"v": tensor with shape [batch_size, i, heads, dim_per_head]} where
i is the current decoded length for non-padded decode, or max
sequence length for padded decode.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns:
Attention layer output with shape [batch_size, length_query, hidden_size]
"""
# Linearly project the query, key and value using different learned
# projections. Splitting heads is automatically done during the linear
# projections --> [batch_size, length, num_heads, dim_per_head].
query = self.query_dense_layer(query_input)
key = self.key_dense_layer(source_input)
value = self.value_dense_layer(source_input)
if cache is not None:
# Combine cached keys and values with new keys and values.
if decode_loop_step is not None:
cache_k_shape = cache["k"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
[1, cache_k_shape[1], 1, 1])
key = cache["k"] + key * indices
cache_v_shape = cache["v"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
[1, cache_v_shape[1], 1, 1])
value = cache["v"] + value * indices
else:
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
# Update cache
cache["k"] = key
cache["v"] = value
# Scale query to prevent the dot product between query and key from growing
# too large.
depth = (self.hidden_size // self.num_heads)
query *= depth**-0.5
# Calculate dot product attention
logits = tf.einsum("BTNH,BFNH->BNFT", key, query)
logits += bias
# Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
# and output in float16 for better performance.
weights = tf.nn.softmax(logits, name="attention_weights")
if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.einsum("BNFT,BTNH->BFNH", weights, value)
# Run the outputs through another linear projection layer. Recombining heads
# is automatically done --> [batch_size, length, hidden_size]
attention_output = self.output_dense_layer(attention_output)
return attention_output
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""
def call(self,
query_input,
bias,
training,
cache=None,
decode_loop_step=None):
return super(SelfAttention, self).call(query_input, query_input, bias,
training, cache, decode_loop_step)