Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for official.nlp.projects.kernel.attention.""" | |
import itertools | |
from absl.testing import parameterized | |
import tensorflow as tf, tf_keras | |
from official.nlp.modeling.layers import kernel_attention as attention | |
_FEATURE_TRANSFORM = ["relu", "elu", "exp", "expplus"] | |
_REDRAW = [True, False] | |
_TRAINING = [True, False] | |
_IS_SHORT_SEQ = [True, False] | |
_BEGIN_KERNEL = [0, 512] | |
class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase): | |
# expplus is only designed for bi-directional use case. | |
# exp can be numeric unstable. | |
def test_causal_windowed_attention_projection_streaming( | |
self, feature_transform, causal_chunk_length, causal_weight_decay): | |
num_heads = 12 | |
key_dim = 64 | |
seq_length = 16 | |
num_chunks = seq_length // causal_chunk_length | |
causal_window_length = num_chunks | |
batch_size = 2 | |
training = False | |
num_random_features = 0 | |
test_layer = attention.KernelAttention( | |
num_heads=num_heads, | |
key_dim=key_dim, | |
feature_transform=feature_transform, | |
num_random_features=num_random_features, | |
redraw=False, | |
is_short_seq=False, | |
begin_kernel=False, | |
use_causal_windowed=True, | |
causal_chunk_length=causal_chunk_length, | |
causal_window_length=causal_window_length, | |
causal_window_decay=causal_weight_decay, | |
causal_padding=None, | |
) | |
query = tf.random.normal( | |
shape=(batch_size, seq_length, key_dim), seed=2) | |
value = query | |
encoder_inputs_mask = tf.ones((batch_size, seq_length), dtype=tf.int32) | |
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32) | |
output = test_layer( | |
query=query, | |
value=value, | |
attention_mask=masks, | |
training=training) | |
dim = num_random_features if num_random_features > 0 else key_dim | |
kv_cache = tf.zeros( | |
(batch_size, num_heads, dim, dim)) | |
k_sum_cache = tf.zeros((batch_size, num_heads, dim)) | |
stream_output = [] | |
cache = {"kv": kv_cache, "k_sum": k_sum_cache} | |
for i in range(num_chunks): | |
stream_output.append( | |
test_layer( | |
query=query[:, i * causal_chunk_length:(i + 1) * | |
causal_chunk_length, :], | |
value=value[:, i * causal_chunk_length:(i + 1) * | |
causal_chunk_length, :], | |
attention_mask=masks[:, i * causal_chunk_length:(i + 1) * | |
causal_chunk_length], | |
cache=cache, | |
training=training)) | |
stream_output = tf.concat(stream_output, axis=1) | |
self.assertAllClose(output, stream_output) | |
def test_attention_projection( | |
self, feature_transform, num_random_features, training, redraw, is_short, | |
begin_kernel): | |
num_heads = 12 | |
key_dim = 64 | |
seq_length = 1024 | |
batch_size = 2 | |
test_layer = attention.KernelAttention( | |
num_heads=num_heads, | |
key_dim=key_dim, | |
feature_transform=feature_transform, | |
num_random_features=num_random_features, | |
redraw=redraw, | |
is_short_seq=is_short, | |
begin_kernel=begin_kernel) | |
query = tf.random.normal( | |
shape=(batch_size, seq_length, key_dim)) | |
value = query | |
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32) | |
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32) | |
output = test_layer( | |
query=query, | |
value=value, | |
attention_mask=masks, | |
training=training) | |
self.assertEqual(output.shape, [batch_size, seq_length, key_dim]) | |
def test_causal_windowed_attention_projection( | |
self, feature_transform, num_random_features, training, redraw, | |
begin_kernel, causal_window_decay, causal_padding): | |
num_heads = 12 | |
key_dim = 64 | |
seq_length = 1024 | |
batch_size = 2 | |
test_layer = attention.KernelAttention( | |
num_heads=num_heads, | |
key_dim=key_dim, | |
feature_transform=feature_transform, | |
num_random_features=num_random_features, | |
redraw=redraw, | |
is_short_seq=False, | |
begin_kernel=begin_kernel, | |
use_causal_windowed=True, | |
causal_chunk_length=8, | |
causal_window_length=3, | |
causal_window_decay=causal_window_decay, | |
causal_padding=causal_padding) | |
query = tf.random.normal( | |
shape=(batch_size, seq_length, key_dim)) | |
value = query | |
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32) | |
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32) | |
output = test_layer( | |
query=query, | |
value=value, | |
attention_mask=masks, | |
training=training) | |
self.assertEqual(output.shape, [batch_size, seq_length, key_dim]) | |
def test_attention_no_projection( | |
self, feature_transform, num_random_features, training, redraw, is_short, | |
begin_kernel): | |
num_heads = 12 | |
key_dim = 64 | |
seq_length = 1024 | |
batch_size = 2 | |
test_layer = attention.KernelAttention( | |
num_heads=num_heads, | |
key_dim=key_dim, | |
feature_transform=feature_transform, | |
num_random_features=num_random_features, | |
redraw=redraw, | |
is_short_seq=is_short, | |
begin_kernel=begin_kernel) | |
query = tf.random.normal( | |
shape=(batch_size, seq_length, key_dim)) | |
value = query | |
encoder_inputs_mask = tf.zeros((batch_size, seq_length), dtype=tf.int32) | |
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32) | |
output = test_layer( | |
query=query, | |
value=value, | |
attention_mask=masks, | |
training=training) | |
self.assertEqual(output.shape, [batch_size, seq_length, key_dim]) | |
def test_attention_scale_by_length(self, seq_length): | |
num_heads = 12 | |
key_dim = 64 | |
batch_size = 2 | |
test_layer = attention.KernelAttention( | |
num_heads=num_heads, | |
key_dim=key_dim, | |
num_random_features=0, | |
scale_by_length=True) | |
query = tf.random.normal( | |
shape=(batch_size, seq_length, key_dim)) | |
value = query | |
encoder_inputs_mask = tf.ones((batch_size, seq_length), dtype=tf.int32) | |
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32) | |
output_scale_by_length = test_layer( | |
query=query, value=value, attention_mask=masks) | |
test_layer._scale_by_length = False | |
output_no_scale_by_length = test_layer( | |
query=query, value=value, attention_mask=masks) | |
if seq_length == 512: # Equals because log(seq_length, base=512) = 1.0 | |
self.assertAllClose(output_scale_by_length, output_no_scale_by_length) | |
else: | |
self.assertNotAllClose(output_scale_by_length, output_no_scale_by_length) | |
def test_unsupported_feature_transform(self): | |
with self.assertRaisesRegex(ValueError, "Unsupported feature_transform.*"): | |
_ = attention.KernelAttention(feature_transform="test") | |
def test_redraw_true_no_projection(self): | |
with self.assertRaisesRegex( | |
ValueError, "There is nothing to redraw when num_random_features.*"): | |
_ = attention.KernelAttention( | |
num_heads=2, key_dim=64, feature_transform="elu", | |
num_random_features=0, redraw=True) | |
def test_config(self): | |
num_heads = 12 | |
key_dim = 64 | |
test_layer = attention.KernelAttention( | |
num_heads=num_heads, | |
key_dim=key_dim, | |
feature_transform="exp", | |
num_random_features=128, | |
is_short_seq=True) | |
new_layer = attention.KernelAttention.from_config( | |
test_layer.get_config()) | |
# If the serialization was successful, the new config should match the old. | |
self.assertAllEqual(test_layer.get_config(), new_layer.get_config()) | |
def test_rectangular_window_sum(self): | |
x = tf.ones([2, 5, 2, 2, 2]) | |
winsum = attention.rectangular_window_sum(x, 3) | |
self.assertEqual(winsum.shape, x.shape) | |
self.assertAllClose( | |
tf.tile( | |
tf.reshape([1., 2., 3., 3., 3.], [1, -1, 1, 1, 1]), | |
[2, 1, 2, 2, 2]), | |
winsum) | |
def test_weighted_window_sum(self): | |
x = tf.ones([2, 5, 2, 2, 2]) | |
winsum = attention.weighted_window_sum(x, 3, [0.01, 0.1, 1.]) | |
self.assertEqual(winsum.shape, x.shape) | |
self.assertAllClose( | |
tf.tile( | |
tf.reshape([1., 1.1, 1.11, 1.11, 1.11], [1, -1, 1, 1, 1]), | |
[2, 1, 2, 2, 2]), | |
winsum) | |
if __name__ == "__main__": | |
tf.test.main() | |