|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Keras-based attention layer.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
import collections |
|
import math |
|
import string |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
from official.nlp.modeling.layers import masked_softmax |
|
|
|
EinsumDense = tf.keras.layers.experimental.EinsumDense |
|
_CHR_IDX = string.ascii_lowercase |
|
|
|
|
|
def _build_attention_equation(qkv_rank, attn_axes): |
|
"""Builds einsum equations for the attention computation. |
|
|
|
Query, key, value inputs after projection are expected to have the shape as: |
|
(bs, <non-attention dims>, <attention dims>, num_heads, channels). |
|
bs and <non-attention dims> are treated as <batch dims>. |
|
The attention operations can be generalized: |
|
(1) Query-key dot product: |
|
(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>, |
|
<key attention dims>, num_heads, channels) -> (<batch dims>, |
|
num_heads, <query attention dims>, <key attention dims>) |
|
(2) Combination: |
|
(<batch dims>, num_heads, <query attention dims>, <key attention dims>), |
|
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>, |
|
<query attention dims>, num_heads, channels) |
|
|
|
Args: |
|
qkv_rank: the rank of query, key, value tensors. |
|
attn_axes: a list/tuple of axes, [1, rank), that will do attention. |
|
|
|
Returns: |
|
Einsum equations. |
|
""" |
|
target_notation = _CHR_IDX[:qkv_rank] |
|
|
|
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,))) |
|
letter_offset = qkv_rank |
|
source_notation = "" |
|
for i in range(qkv_rank): |
|
if i in batch_dims or i == qkv_rank - 1: |
|
source_notation += target_notation[i] |
|
else: |
|
source_notation += _CHR_IDX[letter_offset] |
|
letter_offset += 1 |
|
|
|
product_notation = "".join([target_notation[i] for i in batch_dims] + |
|
[target_notation[i] for i in attn_axes] + |
|
[source_notation[i] for i in attn_axes]) |
|
dot_product_equation = "%s,%s->%s" % (source_notation, target_notation, |
|
product_notation) |
|
attn_scores_rank = len(product_notation) |
|
combine_equation = "%s,%s->%s" % (product_notation, source_notation, |
|
target_notation) |
|
return dot_product_equation, combine_equation, attn_scores_rank |
|
|
|
|
|
def _build_proj_equation(free_dims, bound_dims, output_dims): |
|
"""Builds an einsum equation for projections inside multi-head attention.""" |
|
input_str = "" |
|
kernel_str = "" |
|
output_str = "" |
|
bias_axes = "" |
|
letter_offset = 0 |
|
for i in range(free_dims): |
|
char = _CHR_IDX[i + letter_offset] |
|
input_str += char |
|
output_str += char |
|
|
|
letter_offset += free_dims |
|
for i in range(bound_dims): |
|
char = _CHR_IDX[i + letter_offset] |
|
input_str += char |
|
kernel_str += char |
|
|
|
letter_offset += bound_dims |
|
for i in range(output_dims): |
|
char = _CHR_IDX[i + letter_offset] |
|
kernel_str += char |
|
output_str += char |
|
bias_axes += char |
|
equation = "%s,%s->%s" % (input_str, kernel_str, output_str) |
|
|
|
return equation, bias_axes, len(output_str) |
|
|
|
|
|
def _get_output_shape(output_rank, known_last_dims): |
|
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable(package="Text") |
|
class MultiHeadAttention(tf.keras.layers.Layer): |
|
"""MultiHeadAttention layer. |
|
|
|
This is an implementation of multi-headed attention based on "Attention |
|
is all you Need". If `query`, `key,` `value` are the same, then |
|
this is self-attention. Each timestep in `query` attends to the |
|
corresponding sequence in `key`, and returns a fixed-width vector. |
|
|
|
This layer first projects `query`, `key` and `value`. These are |
|
(effectively) a list of tensors of length `num_attention_heads`, where the |
|
corresponding shapes are [batch_size, <query dimensions>, key_size], |
|
[batch_size, <key/value dimensions>, key_size], |
|
[batch_size, <key/value dimensions>, value_size]. |
|
|
|
Then, the query and key tensors are dot-producted and scaled. These are |
|
softmaxed to obtain attention probabilities. The value tensors are then |
|
interpolated by these probabilities, then concatenated back to a single |
|
tensor. |
|
|
|
Finally, the result tensor with the last dimension as value_size can take an |
|
linear projection and return. |
|
|
|
Examples: |
|
|
|
Performs 1D cross-attention over two sequence inputs with an attention mask. |
|
Returns the additional attention weights over heads. |
|
|
|
>>> layer = MultiHeadAttention(num_heads=2, key_size=2, |
|
... return_attention_scores=True) |
|
>>> target = tf.keras.Input(shape=[8, 16]) |
|
>>> source = tf.keras.Input(shape=[4, 16]) |
|
>>> mask_tensor = tf.keras.Input(shape=[8, 4]) |
|
>>> output_tensor, weights = layer([target, source]) |
|
>>> print(output_tensor.shape), print(weights.shape) |
|
(None, 8, 16) (None, 2, 8, 4) |
|
|
|
Performs 2D self-attention over a 5D input tensor on axes 2 and 3. |
|
|
|
>>> layer = MultiHeadAttention(num_heads=2, key_size=2, attention_axes=(2, 3)) |
|
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16]) |
|
>>> output_tensor = layer([input_tensor, input_tensor]) |
|
>>> print(output_tensor.shape) |
|
(None, 5, 3, 4, 16) |
|
|
|
Arguments: |
|
num_heads: Number of attention heads. |
|
key_size: Size of each attention head for query and key. |
|
value_size: Size of each attention head for value. |
|
dropout: Dropout probability. |
|
use_bias: Boolean, whether the dense layers use bias vectors/matrices. |
|
output_shape: The expected shape of an output tensor, besides the batch and |
|
sequence dims. If not specified, projects back to the key feature dim. |
|
attention_axes: axes over which the attention is applied. `None` means |
|
attention over all axes, but batch, heads, and features. |
|
return_attention_scores: bool, if `True`, returns the multi-head |
|
attention scores as an additional output argument. |
|
kernel_initializer: Initializer for dense layer kernels. |
|
bias_initializer: Initializer for dense layer biases. |
|
kernel_regularizer: Regularizer for dense layer kernels. |
|
bias_regularizer: Regularizer for dense layer biases. |
|
activity_regularizer: Regularizer for dense layer activity. |
|
kernel_constraint: Constraint for dense layer kernels. |
|
bias_constraint: Constraint for dense layer kernels. |
|
""" |
|
|
|
def __init__(self, |
|
num_heads, |
|
key_size, |
|
value_size=None, |
|
dropout=0.0, |
|
use_bias=True, |
|
output_shape=None, |
|
attention_axes=None, |
|
return_attention_scores=False, |
|
kernel_initializer="glorot_uniform", |
|
bias_initializer="zeros", |
|
kernel_regularizer=None, |
|
bias_regularizer=None, |
|
activity_regularizer=None, |
|
kernel_constraint=None, |
|
bias_constraint=None, |
|
**kwargs): |
|
super(MultiHeadAttention, self).__init__(**kwargs) |
|
self._num_heads = num_heads |
|
self._key_size = key_size |
|
self._value_size = value_size if value_size else key_size |
|
self._dropout = dropout |
|
self._use_bias = use_bias |
|
self._output_shape = output_shape |
|
self._return_attention_scores = return_attention_scores |
|
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) |
|
self._bias_initializer = tf.keras.initializers.get(bias_initializer) |
|
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) |
|
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) |
|
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) |
|
self._bias_constraint = tf.keras.constraints.get(bias_constraint) |
|
if attention_axes is not None and not isinstance(attention_axes, |
|
collections.abc.Sized): |
|
self._attention_axes = (attention_axes,) |
|
else: |
|
self._attention_axes = attention_axes |
|
|
|
def get_config(self): |
|
config = { |
|
"num_heads": |
|
self._num_heads, |
|
"key_size": |
|
self._key_size, |
|
"value_size": |
|
self._value_size, |
|
"dropout": |
|
self._dropout, |
|
"use_bias": |
|
self._use_bias, |
|
"output_shape": |
|
self._output_shape, |
|
"attention_axes": |
|
self._attention_axes, |
|
"return_attention_scores": |
|
self._return_attention_scores, |
|
"kernel_initializer": |
|
tf.keras.initializers.serialize(self._kernel_initializer), |
|
"bias_initializer": |
|
tf.keras.initializers.serialize(self._bias_initializer), |
|
"kernel_regularizer": |
|
tf.keras.regularizers.serialize(self._kernel_regularizer), |
|
"bias_regularizer": |
|
tf.keras.regularizers.serialize(self._bias_regularizer), |
|
"activity_regularizer": |
|
tf.keras.regularizers.serialize(self._activity_regularizer), |
|
"kernel_constraint": |
|
tf.keras.constraints.serialize(self._kernel_constraint), |
|
"bias_constraint": |
|
tf.keras.constraints.serialize(self._bias_constraint) |
|
} |
|
base_config = super(MultiHeadAttention, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
def build(self, input_shape): |
|
inputs_len = len(input_shape) |
|
if inputs_len > 3 or inputs_len < 2: |
|
raise ValueError( |
|
"Expects inputs list of length 2 or 3, namely [query, value] or " |
|
"[query, value, key]. " |
|
"Given length: %d" % inputs_len) |
|
tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape) |
|
query_shape = tensor_shapes[0] |
|
value_shape = tensor_shapes[1] |
|
key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape |
|
|
|
common_kwargs = dict( |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint) |
|
|
|
free_dims = query_shape.rank - 1 |
|
einsum_equation, bias_axes, output_rank = _build_proj_equation( |
|
free_dims, bound_dims=1, output_dims=2) |
|
self._query_dense = EinsumDense( |
|
einsum_equation, |
|
output_shape=_get_output_shape(output_rank - 1, |
|
[self._num_heads, self._key_size]), |
|
bias_axes=bias_axes if self._use_bias else None, |
|
name="query", |
|
**common_kwargs) |
|
einsum_equation, bias_axes, output_rank = _build_proj_equation( |
|
key_shape.rank - 1, bound_dims=1, output_dims=2) |
|
self._key_dense = EinsumDense( |
|
einsum_equation, |
|
output_shape=_get_output_shape(output_rank - 1, |
|
[self._num_heads, self._key_size]), |
|
bias_axes=bias_axes if self._use_bias else None, |
|
name="key", |
|
**common_kwargs) |
|
einsum_equation, bias_axes, output_rank = _build_proj_equation( |
|
value_shape.rank - 1, bound_dims=1, output_dims=2) |
|
self._value_dense = EinsumDense( |
|
einsum_equation, |
|
output_shape=_get_output_shape(output_rank - 1, |
|
[self._num_heads, self._value_size]), |
|
bias_axes=bias_axes if self._use_bias else None, |
|
name="value", |
|
**common_kwargs) |
|
|
|
|
|
|
|
|
|
self._build_attention(output_rank) |
|
if self._output_shape: |
|
if not isinstance(self._output_shape, collections.abc.Sized): |
|
output_shape = [self._output_shape] |
|
else: |
|
output_shape = self._output_shape |
|
else: |
|
output_shape = [query_shape[-1]] |
|
einsum_equation, bias_axes, output_rank = _build_proj_equation( |
|
free_dims, bound_dims=2, output_dims=len(output_shape)) |
|
self._output_dense = EinsumDense( |
|
einsum_equation, |
|
output_shape=_get_output_shape(output_rank - 1, output_shape), |
|
bias_axes=bias_axes if self._use_bias else None, |
|
name="attention_output", |
|
**common_kwargs) |
|
super(MultiHeadAttention, self).build(input_shape) |
|
|
|
def _build_attention(self, qkv_rank): |
|
"""Builds multi-head dot-product attention computations. |
|
|
|
This function builds attributes necessary for `_compute_attention` to |
|
costomize attention computation to replace the default dot-product |
|
attention. |
|
|
|
Args: |
|
qkv_rank: the rank of query, key, value tensors. |
|
""" |
|
if self._attention_axes is None: |
|
self._attention_axes = tuple(range(1, qkv_rank - 2)) |
|
else: |
|
self._attention_axes = tuple(self._attention_axes) |
|
self._dot_product_equation, self._combine_equation, attn_scores_rank = ( |
|
_build_attention_equation(qkv_rank, attn_axes=self._attention_axes)) |
|
norm_axes = tuple( |
|
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) |
|
self._masked_softmax = masked_softmax.MaskedSoftmax( |
|
mask_expansion_axes=[1], normalization_axes=norm_axes) |
|
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) |
|
|
|
def _compute_attention(self, |
|
query_tensor, |
|
key_tensor, |
|
value_tensor, |
|
attention_mask=None): |
|
"""Applies Dot-product attention with query, key, value tensors. |
|
|
|
This function defines the computation inside `call` with projected |
|
multi-head Q, K, V inputs. Users can override this function for customized |
|
attention implementation. |
|
|
|
Args: |
|
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`. |
|
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`. |
|
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`. |
|
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents |
|
attention to certain positions. |
|
|
|
Returns: |
|
attention_output: Multi-headed outputs of attention computation. |
|
attention_scores: Multi-headed attention weights. |
|
""" |
|
|
|
|
|
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, |
|
query_tensor) |
|
attention_scores = tf.multiply(attention_scores, |
|
1.0 / math.sqrt(float(self._key_size))) |
|
|
|
|
|
|
|
attention_scores = self._masked_softmax(attention_scores, attention_mask) |
|
|
|
|
|
|
|
attention_scores_dropout = self._dropout_layer(attention_scores) |
|
|
|
|
|
attention_output = tf.einsum(self._combine_equation, |
|
attention_scores_dropout, value_tensor) |
|
return attention_output, attention_scores |
|
|
|
def call(self, inputs, attention_mask=None): |
|
"""Implements the forward pass. |
|
|
|
Size glossary: |
|
* Number of heads (H): the number of attention heads. |
|
* Value size (V): the size of each value embedding per head. |
|
* Key size (K): the size of each key embedding per head. Equally, the size |
|
of each query embedding per head. Typically K <= V. |
|
* Batch dimensions (B). |
|
* Query (target) attention axes shape (T). |
|
* Value (source) attention axes shape (S), the rank must match the target. |
|
|
|
Args: |
|
inputs: List of the following tensors: |
|
* query: Query `Tensor` of shape `[B, T, dim]`. |
|
* value: Value `Tensor` of shape `[B, S, dim]`. |
|
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will |
|
use `value` for both `key` and `value`, which is the most common case. |
|
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents |
|
attention to certain positions. |
|
|
|
Returns: |
|
attention_output: The result of the computation, of shape [B, T, E], |
|
where `T` is for target sequence shapes and `E` is the query input last |
|
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs |
|
are project to the shape specified by `output_shape`. |
|
attention_scores: [Optional] multi-head attention coeffients over |
|
attention |
|
axes. |
|
""" |
|
inputs_len = len(inputs) |
|
if inputs_len > 3 or inputs_len < 2: |
|
raise ValueError( |
|
"Expects inputs list of length 2 or 3, namely [query, value] or " |
|
"[query, value, key]. " |
|
"Given length: %d" % inputs_len) |
|
query = inputs[0] |
|
value = inputs[1] |
|
key = inputs[2] if inputs_len == 3 else value |
|
|
|
|
|
|
|
|
|
query_tensor = self._query_dense(query) |
|
|
|
|
|
key_tensor = self._key_dense(key) |
|
|
|
|
|
value_tensor = self._value_dense(value) |
|
|
|
attention_output, attention_scores = self._compute_attention( |
|
query_tensor, key_tensor, value_tensor, attention_mask) |
|
attention_output = self._output_dense(attention_output) |
|
|
|
if self._return_attention_scores: |
|
return attention_output, attention_scores |
|
return attention_output |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable(package="Text") |
|
class CachedAttention(MultiHeadAttention): |
|
"""Attention layer with cache used for auto-agressive decoding. |
|
|
|
Arguments are the same as `MultiHeadAttention` layer. |
|
""" |
|
|
|
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step): |
|
"""Updates cache states and gets full-length key/value tensors.""" |
|
|
|
if decode_loop_step is not None: |
|
|
|
key_seq_dim = cache["key"].shape.as_list()[1] |
|
indices = tf.reshape( |
|
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype), |
|
[1, key_seq_dim, 1, 1]) |
|
key_tensor = cache["key"] + key_tensor * indices |
|
value_seq_dim = cache["value"].shape.as_list()[1] |
|
indices = tf.reshape( |
|
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype), |
|
[1, value_seq_dim, 1, 1]) |
|
value_tensor = cache["value"] + value_tensor * indices |
|
else: |
|
key_tensor = tf.concat( |
|
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1) |
|
value_tensor = tf.concat( |
|
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1) |
|
|
|
|
|
cache["key"] = key_tensor |
|
cache["value"] = value_tensor |
|
|
|
return key_tensor, value_tensor |
|
|
|
def call(self, |
|
inputs, |
|
attention_mask=None, |
|
cache=None, |
|
decode_loop_step=None): |
|
from_tensor = inputs[0] |
|
to_tensor = inputs[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_tensor = self._query_dense(from_tensor) |
|
|
|
|
|
key_tensor = self._key_dense(to_tensor) |
|
|
|
|
|
value_tensor = self._value_dense(to_tensor) |
|
|
|
if cache: |
|
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor, |
|
cache, decode_loop_step) |
|
|
|
|
|
|
|
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, |
|
query_tensor) |
|
attention_scores = tf.multiply(attention_scores, |
|
1.0 / math.sqrt(float(self._key_size))) |
|
|
|
|
|
|
|
attention_scores = self._masked_softmax(attention_scores, attention_mask) |
|
|
|
|
|
|
|
attention_scores = self._dropout_layer(attention_scores) |
|
|
|
attention_output = tf.einsum(self._combine_equation, attention_scores, |
|
value_tensor) |
|
attention_output = self._output_dense(attention_output) |
|
if self._return_attention_scores: |
|
return attention_output, attention_scores, cache |
|
return attention_output, cache |
|
|