Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Implementation of multiheaded attention and self-attention layers.""" | |
import math | |
import tensorflow as tf, tf_keras | |
from official.modeling import tf_utils | |
class Attention(tf_keras.layers.Layer): | |
"""Multi-headed attention layer.""" | |
def __init__(self, hidden_size, num_heads, attention_dropout): | |
"""Initialize Attention. | |
Args: | |
hidden_size: int, output dim of hidden layer. | |
num_heads: int, number of heads to repeat the same attention structure. | |
attention_dropout: float, dropout rate inside attention for training. | |
""" | |
if hidden_size % num_heads: | |
raise ValueError( | |
"Hidden size ({}) must be divisible by the number of heads ({})." | |
.format(hidden_size, num_heads)) | |
super(Attention, self).__init__() | |
self.hidden_size = hidden_size | |
self.num_heads = num_heads | |
self.attention_dropout = attention_dropout | |
def build(self, input_shape): | |
"""Builds the layer.""" | |
# Layers for linearly projecting the queries, keys, and values. | |
size_per_head = self.hidden_size // self.num_heads | |
def _glorot_initializer(fan_in, fan_out): | |
limit = math.sqrt(6.0 / (fan_in + fan_out)) | |
return tf_keras.initializers.RandomUniform(minval=-limit, maxval=limit) | |
attention_initializer = _glorot_initializer(input_shape.as_list()[-1], | |
self.hidden_size) | |
self.query_dense_layer = tf_keras.layers.EinsumDense( | |
"BTE,ENH->BTNH", | |
output_shape=(None, self.num_heads, size_per_head), | |
kernel_initializer=tf_utils.clone_initializer(attention_initializer), | |
bias_axes=None, | |
name="query") | |
self.key_dense_layer = tf_keras.layers.EinsumDense( | |
"BTE,ENH->BTNH", | |
output_shape=(None, self.num_heads, size_per_head), | |
kernel_initializer=tf_utils.clone_initializer(attention_initializer), | |
bias_axes=None, | |
name="key") | |
self.value_dense_layer = tf_keras.layers.EinsumDense( | |
"BTE,ENH->BTNH", | |
output_shape=(None, self.num_heads, size_per_head), | |
kernel_initializer=tf_utils.clone_initializer(attention_initializer), | |
bias_axes=None, | |
name="value") | |
output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size) | |
self.output_dense_layer = tf_keras.layers.EinsumDense( | |
"BTNH,NHE->BTE", | |
output_shape=(None, self.hidden_size), | |
kernel_initializer=output_initializer, | |
bias_axes=None, | |
name="output_transform") | |
super(Attention, self).build(input_shape) | |
def get_config(self): | |
return { | |
"hidden_size": self.hidden_size, | |
"num_heads": self.num_heads, | |
"attention_dropout": self.attention_dropout, | |
} | |
def call(self, | |
query_input, | |
source_input, | |
bias, | |
training, | |
cache=None, | |
decode_loop_step=None): | |
"""Apply attention mechanism to query_input and source_input. | |
Args: | |
query_input: A tensor with shape [batch_size, length_query, hidden_size]. | |
source_input: A tensor with shape [batch_size, length_source, | |
hidden_size]. | |
bias: A tensor with shape [batch_size, 1, length_query, length_source], | |
the attention bias that will be added to the result of the dot product. | |
training: A bool, whether in training mode or not. | |
cache: (Used during prediction) A dictionary with tensors containing | |
results of previous attentions. The dictionary must have the items: | |
{"k": tensor with shape [batch_size, i, heads, dim_per_head], | |
"v": tensor with shape [batch_size, i, heads, dim_per_head]} where | |
i is the current decoded length for non-padded decode, or max | |
sequence length for padded decode. | |
decode_loop_step: An integer, step number of the decoding loop. Used only | |
for autoregressive inference on TPU. | |
Returns: | |
Attention layer output with shape [batch_size, length_query, hidden_size] | |
""" | |
# Linearly project the query, key and value using different learned | |
# projections. Splitting heads is automatically done during the linear | |
# projections --> [batch_size, length, num_heads, dim_per_head]. | |
query = self.query_dense_layer(query_input) | |
key = self.key_dense_layer(source_input) | |
value = self.value_dense_layer(source_input) | |
if cache is not None: | |
# Combine cached keys and values with new keys and values. | |
if decode_loop_step is not None: | |
cache_k_shape = cache["k"].shape.as_list() | |
indices = tf.reshape( | |
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype), | |
[1, cache_k_shape[1], 1, 1]) | |
key = cache["k"] + key * indices | |
cache_v_shape = cache["v"].shape.as_list() | |
indices = tf.reshape( | |
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype), | |
[1, cache_v_shape[1], 1, 1]) | |
value = cache["v"] + value * indices | |
else: | |
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1) | |
value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1) | |
# Update cache | |
cache["k"] = key | |
cache["v"] = value | |
# Scale query to prevent the dot product between query and key from growing | |
# too large. | |
depth = (self.hidden_size // self.num_heads) | |
query *= depth**-0.5 | |
# Calculate dot product attention | |
logits = tf.einsum("BTNH,BFNH->BNFT", key, query) | |
logits += bias | |
# Note that softmax internally performs math operations using float32 | |
# for numeric stability. When training with float16, we keep the input | |
# and output in float16 for better performance. | |
weights = tf.nn.softmax(logits, name="attention_weights") | |
if training: | |
weights = tf.nn.dropout(weights, rate=self.attention_dropout) | |
attention_output = tf.einsum("BNFT,BTNH->BFNH", weights, value) | |
# Run the outputs through another linear projection layer. Recombining heads | |
# is automatically done --> [batch_size, length, hidden_size] | |
attention_output = self.output_dense_layer(attention_output) | |
return attention_output | |
class SelfAttention(Attention): | |
"""Multiheaded self-attention layer.""" | |
def call(self, | |
query_input, | |
bias, | |
training, | |
cache=None, | |
decode_loop_step=None): | |
return super(SelfAttention, self).call(query_input, query_input, bias, | |
training, cache, decode_loop_step) | |