# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras-based kernel attention layer.""" import functools import math import tensorflow as tf, tf_keras from official.modeling import tf_utils _NUMERIC_STABLER = 1e-6 class KernelMask(tf_keras.layers.Layer): """Creates kernel attention mask. inputs: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. mask: a Tensor of shape [batch_size, from_seq_length] which indicates which part of the inputs we should not attend. Returns: float Tensor of shape [batch_size, from_seq_length] that KernelAttention takes as mask. """ def call(self, inputs, mask): mask = tf.cast(mask, inputs.dtype) return mask def pad_to_chunk_length(tensor, axis, chunk_length, padding=None): """Pads a tensor so that shape[axis] is divisible by chunk_length. Args: tensor: Input tensor to pad. axis: Axis to pad along. chunk_length: The output tensor will have shape[axis] divisible by chunk_length. padding: Pad the input tensor across the axis from either left or right if padding is set to "left" or "right"; applies no padding if padding is set to None. In the latter case, the axis dimension of the input tensor must be divisible by the chunk_length. Returns: Padded tensor with shape[axis] divisible by chunk_length. """ if padding is None: return tensor shape = tf.shape(tensor) rank = tf.rank(tensor) if axis < 0: axis += rank axis_length = shape[axis] pad_length = -axis_length % chunk_length if padding == "right": axis_paddings = [[0, pad_length]] elif padding == "left": axis_paddings = [[pad_length, 0]] else: raise ValueError( "Illegal padding value; must be one of \"left\", \"right\" or None.") paddings = tf.concat([ tf.zeros([axis, 2], dtype=tf.int32), axis_paddings, tf.zeros([rank - axis - 1, 2], dtype=tf.int32) ], axis=0) return tf.pad(tensor, paddings) def split_tensor_into_chunks(tensor, axis, chunk_length): """Reshape tensor along given axis using chunk_length. Args: tensor: Input tensor. axis: Reshape tensor along this axis. chunk_length: Split the axis into [axis/chunk_length, chunk_length] Returns: Reshaped tensor. """ shape = tf.shape(tensor) num_chunks = shape[axis] // chunk_length new_shape = tf.concat( [shape[:axis], [num_chunks, chunk_length], shape[(axis + 1):]], axis=0) return tf.reshape(tensor, new_shape) def rectangular_window_sum(tensor, window_length): """Summarizes tensor elements over a sliding rectangular window. Sums elements of the input tensor of shape [B, T', C', H, dim] across a rectangular window sliding along the dimension T'. Args: tensor: Tensor of shape `[B, T', C', H, dim]`. window_length: The length of the rectangular window. Returns: A tensor of shape [B, T', C', H, dim] containing sums over the window. """ tensor_cumsum = tf.cumsum(tensor, axis=-4) tensor_winsum = tensor_cumsum - tf.pad( tensor_cumsum, [[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]])[:, :-window_length] return tensor_winsum def weighted_window_sum(tensor, window_length, window_weights): """Summarizes tensor elements over a sliding weighted window. Computes a weighted sum of elements of the input tensor of shape [B, T', C', H, dim] across a window sliding along the dimension T'. Args: tensor: Tensor of shape `[B, T', C', H, dim]`. window_length: The length of the window. window_weights: Tensor of shape [window_length] containing window weights. Returns: A tensor of shape [B, T', C', H, dim] containing sums over the window. """ # Flatten the last three dimensions of the [B, T', C', H, dim] shape # into a single channels dimension. tensor_shape = tf.shape(tensor) tensor_2d = tf.reshape(tensor, [tensor_shape[0], tensor_shape[1], 1, -1]) # Apply the same weights to all channels. conv_filter = tf.tile( tf.reshape(window_weights, [-1, 1, 1, 1]), multiples=[1, 1, tf.shape(tensor_2d)[-1], 1]) tensor_winsum_2d = tf.nn.depthwise_conv2d( tensor_2d, conv_filter, strides=[1, 1, 1, 1], padding=[[0, 0], [window_length - 1, 0], [0, 0], [0, 0]]) # Unflatten the channels dimension into the original shape. tensor_winsum = tf.reshape(tensor_winsum_2d, tensor_shape) return tensor_winsum def causal_windowed_performer_attention(query_matrix, key_matrix, value_matrix, chunk_length, window_length, window_decay=None, padding=None, cache=None): """Applies windowed causal kernel attention with query, key, value tensors. We partition the T-length input sequence into N chunks, each of chunk_length tokens (thus: T = N * chunk_length). Within each chunk, we apply bidirectional (non-causal) Performers’ implicit attention and we model relationships between different chunks using Performers’ causal attention. We consider windowed causal variant of performer, where the current chunk attends only to the window of window_length of the most recent chunks. Below is an example with T=9, chunk_length=3, window_length=2. In this example 1 indicates attention is computed between the pair while 0 indicates attention is not computed between the pairs: 111000000 111000000 111000000 111111000 111111000 111111000 000111111 000111111 000111111 User can ensure sequence_length is divisible by chunk_length or use padding="left"/"right" to pad the sequence length either at the left or right respectively and make it divisible by chunk_length. Args: query_matrix: Kernel query `Tensor` of shape `[B, T, H, dim]`. key_matrix: Kernel key `Tensor` of shape `[B, T, H, dim]`. value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`. chunk_length: Length of each chunk in tokens. window_length: Length of attention window in chunks. window_decay: Float window decay factor or `None`. If set, exponentially decay past attention window values by this factor before summation. padding: Pad the query, value and key input tensors across the axis from either left or right if padding is set to "left" or "right"; apply no padding if padding is set to None. In the latter case, the axis dimension of the query, value and key input tensors must be divisible by the chunk_length. cache: Cache to accumulate history in memory. Used at inferecne time (streaming, decoding) for causal attention. Returns: Window causal performer attention of shape `[B, T, H, out_dim]`. """ if cache is None: # Training old_shape = tf.shape(value_matrix) query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, padding) key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, padding) value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, padding) new_shape = tf.shape(value_matrix) chunked_query_matrix = split_tensor_into_chunks( query_matrix, -3, chunk_length) # [-1, T//chunk_length, chunk_length, N, dim] chunked_key_matrix = split_tensor_into_chunks( key_matrix, -3, chunk_length) # [-1, T//chunk_length, chunk_length, N, dim] chunked_value_matrix = split_tensor_into_chunks( value_matrix, -3, chunk_length) # [-1, T//chunk_length, chunk_length, N, out_dim] kp_v = tf.einsum("BTCHD,BTCHO->BTHDO", chunked_key_matrix, chunked_value_matrix) k_sum = tf.math.reduce_sum(chunked_key_matrix, axis=-3, keepdims=True) if window_decay is None: kp_v_winsum = rectangular_window_sum(kp_v, window_length) k_winsum = rectangular_window_sum(k_sum, window_length) else: # Compute exponentially decaying weights. decaying_weights = tf.math.pow( tf.convert_to_tensor(window_decay, dtype=value_matrix.dtype), tf.range(window_length - 1, -1, delta=-1, dtype=value_matrix.dtype)) kp_v_winsum = weighted_window_sum(kp_v, window_length, decaying_weights) k_winsum = weighted_window_sum(k_sum, window_length, decaying_weights) numerator = tf.einsum( "BTCHD,BTHDO->BTCHO", chunked_query_matrix, kp_v_winsum) k_winsum = tf.squeeze(k_winsum, -3) denominator = tf.einsum("BTCHD,BTHD->BTCH", chunked_query_matrix, k_winsum) denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER attention = numerator / denominator attention = tf.reshape(attention, new_shape) start = tf.zeros([old_shape.shape[0]], dtype=old_shape.dtype) attention = tf.slice(attention, start, old_shape) # Queued window cache (drop instead of decay) not yet supported. else: # Streaming if window_decay is None or window_decay > 1.0 or window_decay < 0.0: raise ValueError("window_decay should be in (0.0, 1.0) and not None.") kv = window_decay * cache["kv"] + tf.einsum( "BTHD,BTHO->BHOD", key_matrix, value_matrix) cache["kv"] = kv k_sum = window_decay * cache["k_sum"] + tf.reduce_sum(key_matrix, axis=1) cache["k_sum"] = k_sum denominator = tf.einsum("BTHD,BHD->BTH", query_matrix, k_sum) # The below is equivalent to but converts to TF Lite better than: # tf.einsum("BTHD,BTH->BTHD", # query_matrix, 1.0 / (denominator + _NUMERIC_STABLER)) inverse_denominator = 1.0 / (denominator + _NUMERIC_STABLER) # Add another dimension to align for the broadcast multiplication. fused_query_denominator = query_matrix * tf.expand_dims(inverse_denominator, -1) attention = tf.einsum("BTHD,BHOD->BTHO", fused_query_denominator, kv) return attention def create_projection_matrix(m, d, seed=None): r"""Constructs the matrix of random projections. Constructs a matrix of random orthogonal projections. Each projection vector has direction chosen uniformly at random length taken from the \chi(d) distribution.). Args: m: number of random projections. d: dimensionality of each random projection. seed: random seed used to construct projections. If not, we use the stateful api. Returns: The matrix of random projections of the shape [m, d]. """ nb_full_blocks = math.ceil(m / d) block_list = tf.TensorArray( tf.float32, size=tf.cast(nb_full_blocks, dtype=tf.int32)) stateful = False if seed is None: stateful = True # dummy seed to make sure the graph compiles though the path is not taken. seed = tf.constant([0, 1]) current_seed = seed for i in range(nb_full_blocks): if stateful: unstructured_block = tf.random.normal((d, d)) else: unstructured_block = tf.random.stateless_normal((d, d), seed=current_seed) current_seed = tf.random.stateless_uniform([2], seed=current_seed, minval=None, dtype=tf.int32) q, _ = tf.linalg.qr(unstructured_block) q = tf.transpose(q) block_list = block_list.write(i, q) final_matrix = block_list.concat()[:m] if stateful is None: multiplier = tf.norm(tf.random.normal((m, d)), axis=1) else: multiplier = tf.norm( tf.random.stateless_normal((m, d), seed=current_seed), axis=1) return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix) def _generalized_kernel(x, y, is_query, projection_matrix, f, h): """Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS. Args: x: The feature being transformed with shape [B, T, N ,H]. y: The extra stats-tensor of shape [B, T, N ,H]. is_query: True if x is a query-tensor. projection_matrix: The matrix with shape [M, H] that we projecct x to, where M is the number of projections. f: A non-linear function applied on x or projected x. h: A muliplier which is a function of x applied after projected and transformed. Only applied if projection_matrix is not None. Returns: Transformed feature. """ del y del is_query if projection_matrix is None: return h(x) * f(x) else: x_projected = tf.einsum("BTNH,MH->BTNM", x, projection_matrix) return h(x) * f(x_projected) / tf.math.sqrt( tf.cast(tf.shape(projection_matrix)[0], tf.float32)) def expplus(data_orig, other_data, is_query, projection_matrix=None, numerical_stabilizer=0.000001, normalize_data=True, numerical_renormalizer=True, extra_renormalize_exp_fun=False): """FAVOR++ mechanism from the CRT paper: https://arxiv.org/abs/2205.15317 . Args: data_orig: data tensor of shape [B,T,H,D] for which random features aree to be computed other_data: additional tensor of the shape [B,F,H,D] used to collect stats to determine the exact instantiation of the random feature mechanism is_query: boolean indicating whether tensor is a query tensor projection_matrix: tensor of the shape [M,D] encoding random projections for random features (M stands for the number of random features) numerical_stabilizer: numerical stabilizer for the kernel features normalize_data: whether to sqrt-d-normalize queries/keys as in the regular attention numerical_renormalizer: whether to apply additional renormalization for numerical stability extra_renormalize_exp_fun: extra renormalizer for the exponential mapping applied to construct random features Returns: Random feature map tensor for the unbiased softmax-kernel estimation. """ data = data_orig if projection_matrix is None: return data_orig projection_matrix = tf.cast(projection_matrix, data.dtype) if normalize_data: data_normalizer = 1.0 / tf.math.sqrt( (tf.math.sqrt(tf.dtypes.cast(data.shape[-1], data.dtype)))) else: data_normalizer = 1.0 lengths = tf.math.square(data) lengths = tf.reduce_sum(lengths, axis=tf_keras.backend.ndim(data) - 1) lengths = tf.expand_dims(lengths, axis=tf_keras.backend.ndim(data) - 1) lengths = tf.math.sqrt(lengths) data /= lengths ratio = 1.0 / tf.math.sqrt( tf.dtypes.cast(projection_matrix.shape[0], data.dtype)) data_dash = tf.einsum("blhd,md->blhm", data_normalizer * data, projection_matrix) diag_data = tf.math.square(data) diag_data = tf.math.reduce_sum( diag_data, axis=tf_keras.backend.ndim(data) - 1) diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer diag_data = tf.expand_dims(diag_data, axis=tf_keras.backend.ndim(data) - 1) # Calculating coefficients A, B of the FAVOR++ mechanism: _, l, _, _ = tf_utils.get_shape_list(data_orig) l = tf.cast(l, dtype=tf.float32) first_sum_of_squares = tf.math.square(data) first_sum_of_squares = tf.math.reduce_sum( first_sum_of_squares, axis=(1, -1), keepdims=True) first_sum_of_squares *= (data_normalizer * data_normalizer) first_sum_of_squares /= l # data.shape[1] second_sum_of_squares = tf.math.square(other_data) second_sum_of_squares = tf.math.reduce_sum( second_sum_of_squares, axis=(1, -1), keepdims=True) second_sum_of_squares *= (data_normalizer * data_normalizer) second_sum_of_squares /= l # other_data.shape[1] data_sum = tf.math.reduce_sum(data, axis=(1,), keepdims=True) other_data_sum = tf.math.reduce_sum(other_data, axis=(1,), keepdims=True) d_prod = tf.einsum("blhd,blhd->blh", data_sum, other_data_sum) d_prod = tf.expand_dims(d_prod, axis=-1) d_prod *= (data_normalizer * data_normalizer) d_prod *= (2.0 / (l * l)) ave = first_sum_of_squares + second_sum_of_squares + d_prod dim = projection_matrix.shape[-1] a_coeff = (1.0 / (4.0 * ave)) * ( tf.math.sqrt((2.0 * ave + dim) * (2.0 * ave + dim) + 8.0 * dim * ave) - 2.0 * ave - dim) a_coeff = (1.0 - 1.0 / a_coeff) / 8.0 b_coeff = tf.math.sqrt(1.0 - 4.0 * a_coeff) d_coeff = tf.math.pow(1.0 - 4.0 * a_coeff, dim / 4.0) a_coeff = tf.stop_gradient(a_coeff) b_coeff = tf.stop_gradient(b_coeff) d_coeff = tf.stop_gradient(d_coeff) # Calculating diag_omega for the FAVOR++ mechanism: diag_omega = tf.math.square(projection_matrix) diag_omega = tf.math.reduce_sum( diag_omega, axis=tf_keras.backend.ndim(projection_matrix) - 1) diag_omega = tf.expand_dims(diag_omega, axis=0) diag_omega = tf.expand_dims(diag_omega, axis=0) diag_omega = tf.expand_dims(diag_omega, axis=0) diag_omega = a_coeff * diag_omega if numerical_renormalizer: if is_query: last_dims_t = (len(data_dash.shape) - 1,) stab = b_coeff * tf.math.reduce_max( data_dash, axis=last_dims_t, keepdims=True) else: stab = b_coeff * tf.math.reduce_max(data_dash, keepdims=True) if extra_renormalize_exp_fun: extra_stab = tf.reduce_max(diag_data, axis=1, keepdims=True) stab = tf.math.maximum(stab, extra_stab) data_dash = ratio * d_coeff * ( tf.math.exp(b_coeff * data_dash - stab - diag_data + diag_omega) + numerical_stabilizer) else: data_dash = ratio * d_coeff * ( tf.math.exp(b_coeff * data_dash - diag_data + diag_omega) + numerical_stabilizer) return data_dash # pylint: disable=g-long-lambda _CAUSAL_SUPPORT_TRANSFORM_MAP = { "elu": functools.partial( _generalized_kernel, f=lambda x: tf_keras.activations.elu(x) + 1, h=lambda x: 1), "relu": functools.partial( _generalized_kernel, # Improve numerical stability and avoid NaNs in some cases by adding # a tiny epsilon. f=lambda x: tf_keras.activations.relu(x) + 1e-3, h=lambda x: 1), "square": functools.partial(_generalized_kernel, f=tf.math.square, h=lambda x: 1), "exp": functools.partial( _generalized_kernel, # Avoid exp explosion by shifting. f=lambda x: tf.math.exp(x - tf.math.reduce_max( x, axis=[1, 2, 3], keepdims=True)), h=lambda x: tf.math.exp(-0.5 * tf.math.reduce_sum( tf.math.square(x), axis=-1, keepdims=True)), ), "expmod": functools.partial( _generalized_kernel, # Avoid exp explosion by shifting. f=lambda x: tf.math.exp(x - tf.math.reduce_max( x, axis=[1, 2, 3], keepdims=True)), h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt( tf.cast(tf.shape(x)[-1], tf.float32))), ), "identity": functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1) } _NON_CAUSAL_SUPPORT_TRANSFORM_MAP = { "expplus": expplus, } _TRANSFORM_MAP = { **_CAUSAL_SUPPORT_TRANSFORM_MAP, **_NON_CAUSAL_SUPPORT_TRANSFORM_MAP } # pylint: enable=g-long-lambda class KernelAttention(tf_keras.layers.MultiHeadAttention): """A variant of efficient transformers which replaces softmax with kernels. This module combines ideas from the two following papers: Rethinking Attention with Performers (https://arxiv.org/abs/2009.14794) - exp (Lemma 1, positive), relu - random/deterministic projection Chefs' Random Tables: Non-Trigonometric Random Features (https://arxiv.org/abs/2205.15317) - expplus (OPRF mechanism) Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention (https://arxiv.org/abs/2006.16236) - elu with the theory of approximating angular Performer kernels from go/performer. The module enables computing efficient attention in both: long sequence and shorter sequence regimes. In the former setting, the attention matrix is never explicitly computed and instead its low-rank decomposition obtained with given kernel feature maps is leveraged to conduct attention module calculations (see: https://arxiv.org/abs/2006.16236). In the latter setting, attention matrix is constructed, but kernel features providing dimensionality reduction are applied, resulting in more efficient computation of the attention matrix. """ def __init__(self, feature_transform="exp", num_random_features=256, seed=0, redraw=False, is_short_seq=False, begin_kernel=0, scale=None, scale_by_length=False, use_causal_windowed=False, causal_chunk_length=1, causal_window_length=3, causal_window_decay=None, causal_padding=None, **kwargs): r"""Constructor of KernelAttention. Args: feature_transform: A non-linear transform of the keys and queries. Possible transforms are "elu", "relu", "square", "exp", "expplus", "expmod", "identity". num_random_features: Number of random features to be used for projection. if num_random_features <= 0, no production is used before transform. seed: The seed to begin drawing random features. Once the seed is set, the psedo number generation is determinisitc. Users should pass different seed for different layers. For multi-worker, each layer will use the same projection at each step. redraw: Whether to redraw projection every forward pass during training. The argument is only effective when num_random_features > 0. is_short_seq: boolean predicate indicating whether input data consists of very short sequences or not; in most cases this should be False (default option). begin_kernel: Apply kernel_attention after this sequence id and apply softmax attention before this. scale: The value to scale the dot product as described in `Attention Is All You Need`. If None, we use 1/sqrt(dk) as described in the paper. scale_by_length: boolean predicate indicating whether additionally scale the dot product based on key length. Set as log_512^(n) to stablize attention entropy against length. Refer to https://kexue.fm/archives/8823 for details. use_causal_windowed: If true perform windowed causal attention. See causal_windowed_performer_attention function docstring for more details. causal_chunk_length: Length of each chunk in tokens. causal_window_length: Length of attention window in chunks. causal_window_decay: Float window decay factor or `None`. If set, exponentially decay past attention window values by this factor before summation. causal_padding: Pad the query, value and key input tensors across the axis from either left or right if padding is set to "left" or "right"; apply no padding if padding is set to None. In the latter case, the axis dimension of the query, value and key input tensors must be divisible by the chunk_length. **kwargs: The same arguments `MultiHeadAttention` layer. """ if feature_transform not in _TRANSFORM_MAP: raise ValueError("Unsupported feature_transform. The supported " "feature_transform are %s. " "Got '%s'." % (_TRANSFORM_MAP.keys(), feature_transform)) if num_random_features <= 0 and redraw: raise ValueError( "There is nothing to redraw when num_random_features <= 0.") self._feature_transform = feature_transform self._num_random_features = num_random_features self._redraw = redraw self._is_short_seq = is_short_seq self._begin_kernel = begin_kernel self._scale_by_length = scale_by_length # We use the seed for two scenarios: # 1. inference # 2. no redraw self._seed = seed super().__init__(**kwargs) if scale is None: self._scale = 1.0 / math.sqrt(float(self._key_dim)) else: self._scale = scale self._projection_matrix = None if num_random_features > 0: self._projection_matrix = create_projection_matrix( self._num_random_features, self._key_dim, tf.constant([self._seed, self._seed + 1])) self.use_causal_windowed = use_causal_windowed self.causal_chunk_length = causal_chunk_length self.causal_window_length = causal_window_length self.causal_window_decay = causal_window_decay self.causal_padding = causal_padding if self.use_causal_windowed and self._is_short_seq: raise ValueError( "use_causal_windowed and short_seq methods are mutually exclusive") def _compute_attention(self, query, key, value, feature_transform, is_short_seq, attention_mask=None, cache=None, training=False, numeric_stabler=_NUMERIC_STABLER): """Applies kernel attention with query, key, value tensors. This function defines the computation inside `call` with projected multi-head Q, K, V inputs. Users can override this function for customized attention implementation. Args: query: Projected query `Tensor` of shape `[B, T, N, key_dim]`. key: Projected key `Tensor` of shape `[B, S, N, key_dim]`. value: Projected value `Tensor` of shape `[B, S, N, value_dim]`. feature_transform: A non-linear transform of the keys and quries. is_short_seq: boolean predicate indicating whether input data consists of short or long sequences; usually short sequence is defined as having length L <= 1024. attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting to masked positions. Note that the mask is only appied to the keys. User may want to mask the output if query contains pads. cache: Cache to accumulate history in memory. Used at inferecne time (streaming, decoding) for causal attention. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (doing nothing). numeric_stabler: A scalar value added to avoid divide by 0. Returns: attention_output: Multi-headed outputs of attention computation. """ projection_matrix = None if self._num_random_features > 0: if self._redraw and training: projection_matrix = create_projection_matrix(self._num_random_features, self._key_dim) else: projection_matrix = self._projection_matrix if self._scale_by_length: scale = tf.math.log(tf.reduce_sum(attention_mask, axis=-1)) * self._scale / math.log(512) scale = tf.reshape(scale, [-1, 1, 1, 1]) else: scale = self._scale if is_short_seq: # Note: Applying scalar multiply at the smaller end of einsum improves # XLA performance, but may introduce slight numeric differences in # the Transformer attention head. query = query * scale else: # Note: we suspect spliting the scale to key, query yields smaller # approximation variance when random projection is used. # For simplicity, we also split when there's no random projection. key *= tf.math.sqrt(scale) query *= tf.math.sqrt(scale) key_prime = _TRANSFORM_MAP[feature_transform](key, query, False, projection_matrix) query_prime = _TRANSFORM_MAP[feature_transform](query, key, True, projection_matrix) if attention_mask is not None: key_prime = tf.einsum("BSNH,BS->BSNH", key_prime, attention_mask) if is_short_seq: attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime) attention_scores = tf.nn.softmax(attention_scores, axis=2) attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value) elif self.use_causal_windowed: attention_output = causal_windowed_performer_attention( query_prime, key_prime, value, chunk_length=self.causal_chunk_length, window_length=self.causal_window_length, window_decay=self.causal_window_decay, padding=self.causal_padding, cache=cache) else: kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value) denominator = 1.0 / ( tf.einsum("BTNH,BNH->BTN", query_prime, tf.reduce_sum(key_prime, axis=1)) + _NUMERIC_STABLER) attention_output = tf.einsum("BTNH,BNDH,BTN->BTND", query_prime, kv, denominator) return attention_output def _build_from_signature(self, query, value, key=None): super()._build_from_signature(query=query, value=value, key=key) # pytype: disable=attribute-error # typed-keras if self._begin_kernel > 0: common_kwargs = dict( kernel_initializer=self._kernel_initializer, bias_initializer=self._bias_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, activity_regularizer=self._activity_regularizer, kernel_constraint=self._kernel_constraint, bias_constraint=self._bias_constraint) self._output_dense_softmax = self._make_output_dense( self._query_shape.rank - 1, common_kwargs, name="attention_output_softmax") self._dropout_softmax = tf_keras.layers.Dropout(rate=self._dropout) def call(self, query, value, key=None, attention_mask=None, cache=None, training=False): """Compute attention with kernel mechanism. Args: query: Query `Tensor` of shape `[B, T, dim]`. value: Value `Tensor` of shape `[B, S, dim]`. key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use `value` for both `key` and `value`, which is the most common case. attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting to masked positions. Note that the mask is only appied to the keys. User may want to mask the output if query contains pads. cache: Cache to accumulate history in memory. Used at inferecne time (streaming, decoding) for causal attention. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (doing nothing). Returns: Multi-headed outputs of attention computation. """ if cache is not None: if training: raise ValueError( "Cache is not supported when training is True.") if not self.use_causal_windowed: raise ValueError( "Cache is not supported for non use_causal_windowed case.") if self._begin_kernel: raise ValueError( "Cache is not supported when begin_kernel is set since the bahvior " "is too complicated.") if self._feature_transform in _NON_CAUSAL_SUPPORT_TRANSFORM_MAP: raise ValueError("Cache is not supported for feature_transform %s" % (self._feature_transform)) if not self._built_from_signature: self._build_from_signature(query=query, value=value, key=key) if key is None: key = value # N = `num_attention_heads` # H = `size_per_head` # `query` = [B, T, N ,H] query = self._query_dense(query) # `key` = [B, S, N, H] key = self._key_dense(key) # `value` = [B, S, N, D] value = self._value_dense(value) if self._begin_kernel > 0: attention_output_softmax = self._compute_attention( query[:, :self._begin_kernel], key, value, "identity", True, attention_mask, training) attention_output_softmax = self._dropout_softmax(attention_output_softmax) attention_output_softmax = self._output_dense_softmax( attention_output_softmax) attention_output_kernel = self._compute_attention( query[:, self._begin_kernel:], key, value, self._feature_transform, self._is_short_seq, attention_mask, training) attention_output_kernel = self._dropout_layer(attention_output_kernel) attention_output_kernel = self._output_dense(attention_output_kernel) attention_output = tf.concat( [attention_output_softmax, attention_output_kernel], axis=1) else: attention_output = self._compute_attention(query, key, value, self._feature_transform, self._is_short_seq, attention_mask, cache, training) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_output = self._dropout_layer(attention_output) attention_output = self._output_dense(attention_output) return attention_output def get_config(self): config = { "feature_transform": self._feature_transform, "num_random_features": self._num_random_features, "seed": self._seed, "redraw": self._redraw, "is_short_seq": self._is_short_seq, "begin_kernel": self._begin_kernel, "scale": self._scale, "scale_by_length": self._scale_by_length, "use_causal_windowed": self.use_causal_windowed, "causal_chunk_length": self.causal_chunk_length, "causal_window_length": self.causal_window_length, "causal_window_decay": self.causal_window_decay, "causal_padding": self.causal_padding, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))