# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras-based bigbird attention layer.""" import numpy as np import tensorflow as tf, tf_keras MAX_SEQ_LEN = 4096 def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask): """Create 3D attention mask from a 2D tensor mask. Args: from_blocked_mask: 2D Tensor of shape [batch_size, from_seq_length//from_block_size, from_block_size]. to_blocked_mask: int32 Tensor of shape [batch_size, to_seq_length//to_block_size, to_block_size]. Returns: float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, 3*to_block_size]. """ exp_blocked_to_pad = tf.concat([ to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1] ], 2) band_mask = tf.einsum("BLQ,BLK->BLQK", from_blocked_mask[:, 2:-2], exp_blocked_to_pad) band_mask = tf.expand_dims(band_mask, 1) return band_mask def bigbird_block_rand_mask(from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1): """Create adjacency list of random attention. Args: from_seq_length: int. length of from sequence. to_seq_length: int. length of to sequence. from_block_size: int. size of block in from sequence. to_block_size: int. size of block in to sequence. num_rand_blocks: int. Number of random chunks per row. last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, if positive then num_rand_blocks blocks choosen only upto last_idx. Returns: adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks """ assert from_seq_length//from_block_size == to_seq_length//to_block_size, \ "Error the number of blocks needs to be same!" rand_attn = np.zeros( (from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32) middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32) last = to_seq_length // to_block_size - 1 if last_idx > (2 * to_block_size): last = (last_idx // to_block_size) - 1 r = num_rand_blocks # shorthand for i in range(1, from_seq_length // from_block_size - 1): start = i - 2 end = i if i == 1: rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r] elif i == 2: rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r] elif i == from_seq_length // from_block_size - 3: rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] # Missing -3: should have been sliced till last-3 elif i == from_seq_length // from_block_size - 2: rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] # Missing -4: should have been sliced till last-4 else: if start > last: start = last rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] elif (end + 1) == last: rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] else: rand_attn[i - 1, :] = np.random.permutation( np.concatenate((middle_seq[:start], middle_seq[end + 1:last])))[:r] return rand_attn def create_rand_mask_from_inputs(from_blocked_mask, to_blocked_mask, rand_attn, num_attention_heads, num_rand_blocks, batch_size, from_seq_length, from_block_size): """Create 3D attention mask from a 2D tensor mask. Args: from_blocked_mask: 2D Tensor of shape [batch_size, from_seq_length//from_block_size, from_block_size]. to_blocked_mask: int32 Tensor of shape [batch_size, to_seq_length//to_block_size, to_block_size]. rand_attn: [batch_size, num_attention_heads, from_seq_length//from_block_size-2, num_rand_blocks] num_attention_heads: int. Number of attention heads. num_rand_blocks: int. Number of random chunks per row. batch_size: int. Batch size for computation. from_seq_length: int. length of from sequence. from_block_size: int. size of block in from sequence. Returns: float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2, from_block_size, num_rand_blocks*to_block_size]. """ num_windows = from_seq_length // from_block_size - 2 rand_mask = tf.reshape( tf.gather(to_blocked_mask, rand_attn, batch_dims=1), [ batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size ]) rand_mask = tf.einsum("BLQ,BHLK->BHLQK", from_blocked_mask[:, 1:-1], rand_mask) return rand_mask def bigbird_block_sparse_attention( query_layer, key_layer, value_layer, band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask, rand_attn, num_attention_heads, num_rand_blocks, size_per_head, batch_size, from_seq_length, to_seq_length, from_block_size, to_block_size): """BigBird attention sparse calculation using blocks in linear time. Assumes from_seq_length//from_block_size == to_seq_length//to_block_size. Args: query_layer: float Tensor of shape [batch_size, num_attention_heads, from_seq_length, size_per_head] key_layer: float Tensor of shape [batch_size, num_attention_heads, to_seq_length, size_per_head] value_layer: float Tensor of shape [batch_size, num_attention_heads, to_seq_length, size_per_head] band_mask: (optional) int32 Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, 3*to_block_size]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. from_mask: (optional) int32 Tensor of shape [batch_size, 1, from_seq_length, 1]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. to_mask: (optional) int32 Tensor of shape [batch_size, 1, 1, to_seq_length]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. from_blocked_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length//from_block_size, from_block_size]. Same as from_mask, just reshaped. to_blocked_mask: (optional) int32 Tensor of shape [batch_size, to_seq_length//to_block_size, to_block_size]. Same as to_mask, just reshaped. rand_attn: [batch_size, num_attention_heads, from_seq_length//from_block_size-2, num_rand_blocks] num_attention_heads: int. Number of attention heads. num_rand_blocks: int. Number of random chunks per row. size_per_head: int. Size of each attention head. batch_size: int. Batch size for computation. from_seq_length: int. length of from sequence. to_seq_length: int. length of to sequence. from_block_size: int. size of block in from sequence. to_block_size: int. size of block in to sequence. Returns: float Tensor of shape [batch_size, from_seq_length, num_attention_heads, size_per_head]. """ rand_attn = tf.expand_dims(rand_attn, 0) rand_attn = tf.repeat(rand_attn, batch_size, 0) rand_mask = create_rand_mask_from_inputs( from_blocked_mask, to_blocked_mask, rand_attn, num_attention_heads, num_rand_blocks, batch_size, from_seq_length, from_block_size, ) # Define shorthands h = num_attention_heads r = num_rand_blocks d = size_per_head b = batch_size m = from_seq_length n = to_seq_length wm = from_block_size wn = to_block_size dtype = query_layer.dtype query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3]) key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3]) value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3]) blocked_query_matrix = tf.reshape(query_layer, (b, h, m // wm, wm, -1)) blocked_key_matrix = tf.reshape(key_layer, (b, h, n // wn, wn, -1)) blocked_value_matrix = tf.reshape(value_layer, (b, h, n // wn, wn, -1)) gathered_key = tf.reshape( tf.gather(blocked_key_matrix, rand_attn, batch_dims=2, name="gather_key"), (b, h, m // wm - 2, r * wn, -1)) # [b, h, n//wn-2, r, wn, -1] gathered_value = tf.reshape( tf.gather( blocked_value_matrix, rand_attn, batch_dims=2, name="gather_value"), (b, h, m // wm - 2, r * wn, -1)) # [b, h, n//wn-2, r, wn, -1] first_product = tf.einsum( "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 0], key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n] first_product = tf.multiply(first_product, 1.0 / np.sqrt(d)) first_product += (1.0 - tf.cast(to_mask, dtype=dtype)) * -10000.0 first_attn_weights = tf.nn.softmax(first_product) # [b, h, wm, n] first_context_layer = tf.einsum( "BHQK,BHKD->BHQD", first_attn_weights, value_layer) # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1] first_context_layer = tf.expand_dims(first_context_layer, 2) second_key_mat = tf.concat([ blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, 1], blocked_key_matrix[:, :, 2], blocked_key_matrix[:, :, -1], gathered_key[:, :, 0] ], 2) # [b, h, (4+r)*wn, -1] second_value_mat = tf.concat([ blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, 1], blocked_value_matrix[:, :, 2], blocked_value_matrix[:, :, -1], gathered_value[:, :, 0] ], 2) # [b, h, (4+r)*wn, -1] second_product = tf.einsum( "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 1], second_key_mat ) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn] second_seq_pad = tf.concat([ to_mask[:, :, :, :3 * wn], to_mask[:, :, :, -wn:], tf.ones([b, 1, 1, r * wn], dtype=dtype) ], 3) second_rand_pad = tf.concat([ tf.ones([b, h, wm, 4 * wn], dtype=dtype), rand_mask[:, :, 0] ], 3) second_product = tf.multiply(second_product, 1.0 / np.sqrt(d)) second_product += (1.0 - tf.minimum(second_seq_pad, second_rand_pad)) * -10000.0 second_attn_weights = tf.nn.softmax(second_product) # [b , h, wm, (4+r)*wn] second_context_layer = tf.einsum( "BHQK,BHKD->BHQD", second_attn_weights, second_value_mat ) # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1] second_context_layer = tf.expand_dims(second_context_layer, 2) exp_blocked_key_matrix = tf.concat([ blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2], blocked_key_matrix[:, :, 3:-1] ], 3) # [b, h, m//wm-4, 3*wn, -1] exp_blocked_value_matrix = tf.concat([ blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2], blocked_value_matrix[:, :, 3:-1] ], 3) # [b, h, m//wm-4, 3*wn, -1] middle_query_matrix = blocked_query_matrix[:, :, 2:-2] inner_band_product = tf.einsum( "BHLQD,BHLKD->BHLQK", middle_query_matrix, exp_blocked_key_matrix ) # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, 3*wn, -1] # ==> [b, h, m//wm-4, wm, 3*wn] inner_band_product = tf.multiply(inner_band_product, 1.0 / np.sqrt(d)) rand_band_product = tf.einsum( "BHLQD,BHLKD->BHLQK", middle_query_matrix, gathered_key[:, :, 1:-1]) # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, r*wn, -1] # ==> [b, h, m//wm-4, wm, r*wn] rand_band_product = tf.multiply(rand_band_product, 1.0 / np.sqrt(d)) first_band_product = tf.einsum( "BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, 0] ) # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn] first_band_product = tf.multiply(first_band_product, 1.0 / np.sqrt(d)) last_band_product = tf.einsum( "BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, -1] ) # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn] last_band_product = tf.multiply(last_band_product, 1.0 / np.sqrt(d)) inner_band_product += (1.0 - band_mask) * -10000.0 first_band_product += (1.0 - tf.expand_dims(to_mask[:, :, :, :wn], 3)) * -10000.0 last_band_product += (1.0 - tf.expand_dims(to_mask[:, :, :, -wn:], 3)) * -10000.0 rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * -10000.0 band_product = tf.concat([ first_band_product, inner_band_product, rand_band_product, last_band_product ], -1) # [b, h, m//wm-4, wm, (5+r)*wn] attn_weights = tf.nn.softmax(band_product) # [b, h, m//wm-4, wm, (5+r)*wn] context_layer = tf.einsum( "BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :, wn:4 * wn], exp_blocked_value_matrix ) # [b, h, m//wm-4, wm, 3*wn] x [b, h, m//wm-4, 3*wn, -1] # ==> [b, h, m//wm-4, wm, -1] context_layer += tf.einsum( "BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :, 4 * wn:-wn], gathered_value[:, :, 1:-1] ) # [b, h, m//wm-4, wm, r*wn] x [b, h, m//wm-4, r*wn, -1] # ==> [b, h, m//wm-4, wm, -1] context_layer += tf.einsum( "BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :, :wn], blocked_value_matrix[:, :, 0] ) # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1] context_layer += tf.einsum( "BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :, -wn:], blocked_value_matrix[:, :, -1] ) # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1] second_last_key_mat = tf.concat([ blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, -3], blocked_key_matrix[:, :, -2], blocked_key_matrix[:, :, -1], gathered_key[:, :, -1] ], 2) # [b, h, (4+r)*wn, -1] second_last_value_mat = tf.concat([ blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, -3], blocked_value_matrix[:, :, -2], blocked_value_matrix[:, :, -1], gathered_value[:, :, -1] ], 2) # [b, h, (4+r)*wn, -1] second_last_product = tf.einsum( "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -2], second_last_key_mat ) # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn] second_last_seq_pad = tf.concat([ to_mask[:, :, :, :wn], to_mask[:, :, :, -3 * wn:], tf.ones([b, 1, 1, r * wn], dtype=dtype) ], 3) second_last_rand_pad = tf.concat( [tf.ones([b, h, wm, 4 * wn], dtype=dtype), rand_mask[:, :, -1]], 3) second_last_product = tf.multiply(second_last_product, 1.0 / np.sqrt(d)) second_last_product += ( 1.0 - tf.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0 second_last_attn_weights = tf.nn.softmax( second_last_product) # [b, h, wm, (4+r)*wn] second_last_context_layer = tf.einsum( "BHQK,BHKD->BHQD", second_last_attn_weights, second_last_value_mat ) # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1] second_last_context_layer = tf.expand_dims(second_last_context_layer, 2) last_product = tf.einsum( "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -1], key_layer) # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n] last_product = tf.multiply(last_product, 1.0 / np.sqrt(d)) last_product += (1.0 - to_mask) * -10000.0 last_attn_weights = tf.nn.softmax(last_product) # [b, h, wm, n] last_context_layer = tf.einsum( "BHQK,BHKD->BHQD", last_attn_weights, value_layer) # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1] last_context_layer = tf.expand_dims(last_context_layer, 2) context_layer = tf.concat([ first_context_layer, second_context_layer, context_layer, second_last_context_layer, last_context_layer ], 2) context_layer = tf.reshape(context_layer, (b, h, m, -1)) * from_mask context_layer = tf.transpose(context_layer, (0, 2, 1, 3)) return context_layer class BigBirdMasks(tf_keras.layers.Layer): """Creates bigbird attention masks.""" def __init__(self, block_size, **kwargs): super().__init__(**kwargs) self._block_size = block_size def call(self, inputs, mask): encoder_shape = tf.shape(mask) mask = tf.cast(mask, inputs.dtype) batch_size, seq_length = encoder_shape[0], encoder_shape[1] # reshape for blocking blocked_encoder_mask = tf.reshape( mask, (batch_size, seq_length // self._block_size, self._block_size)) encoder_from_mask = tf.reshape(mask, (batch_size, 1, seq_length, 1)) encoder_to_mask = tf.reshape(mask, (batch_size, 1, 1, seq_length)) band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask) return [band_mask, encoder_from_mask, encoder_to_mask, blocked_encoder_mask] @tf_keras.utils.register_keras_serializable(package="Text") class BigBirdAttention(tf_keras.layers.MultiHeadAttention): """BigBird, a sparse attention mechanism. This layer follows the paper "Big Bird: Transformers for Longer Sequences" (https://arxiv.org/abs/2007.14062). It reduces this quadratic dependency of attention computation to linear. Arguments are the same as `MultiHeadAttention` layer. """ def __init__(self, num_rand_blocks=3, from_block_size=64, to_block_size=64, max_rand_mask_length=MAX_SEQ_LEN, seed=None, **kwargs): super().__init__(**kwargs) self._num_rand_blocks = num_rand_blocks self._from_block_size = from_block_size self._to_block_size = to_block_size self._seed = seed # Generates random attention. np.random.seed(self._seed) # pylint: disable=g-complex-comprehension rand_attn = [ bigbird_block_rand_mask( max_rand_mask_length, max_rand_mask_length, from_block_size, to_block_size, num_rand_blocks, last_idx=1024) for _ in range(self._num_heads) ] # pylint: enable=g-complex-comprehension rand_attn = np.stack(rand_attn, axis=0) self.rand_attn = tf.constant(rand_attn, dtype=tf.int32) def _compute_attention(self, query, key, value, attention_mask=None): (band_mask, encoder_from_mask, encoder_to_mask, blocked_encoder_mask) = attention_mask query_shape = tf.shape(query) from_seq_length = query_shape[1] to_seq_length = tf.shape(key)[1] rand_attn = self.rand_attn[:, :(from_seq_length // self._from_block_size - 2)] return bigbird_block_sparse_attention( query, key, value, band_mask, encoder_from_mask, encoder_to_mask, blocked_encoder_mask, blocked_encoder_mask, num_attention_heads=self._num_heads, num_rand_blocks=self._num_rand_blocks, size_per_head=self._key_dim, batch_size=query_shape[0], from_seq_length=from_seq_length, to_seq_length=to_seq_length, from_block_size=self._from_block_size, to_block_size=self._to_block_size, rand_attn=rand_attn) def call(self, query, value, key=None, attention_mask=None, **kwargs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks if not self._built_from_signature: self._build_from_signature(query=query, value=value, key=key) if key is None: key = value # N = `num_attention_heads` # H = `size_per_head` # `query` = [B, T, N ,H] query = self._query_dense(query) # `key` = [B, S, N, H] key = self._key_dense(key) # `value` = [B, S, N, H] value = self._value_dense(value) attention_output = self._compute_attention(query, key, value, attention_mask) attention_output.set_shape([None, None, self._num_heads, self._value_dim]) attention_output = self._output_dense(attention_output) return attention_output def get_config(self): config = { "num_rand_blocks": self._num_rand_blocks, "from_block_size": self._from_block_size, "to_block_size": self._to_block_size, "seed": self._seed } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))