deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the attention layer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras
from tensorflow.python.distribute import combinations
from official.nlp.modeling.layers import relative_attention
def _create_mock_attention_data(
num_heads,
key_dim,
value_dim,
seq_length,
batch_size,
memory_length=0,
num_predictions=2,
two_stream=False,
include_state=False,
include_mask=False,
include_segment=False):
"""Creates mock testing data.
Args:
num_heads: `int`, Number of attention heads.
key_dim: `int`, Size of query head.
value_dim: `int`, Size of key, value dim.
seq_length: `int`, Sequence length of the input.
batch_size: `int`, the batch size.
memory_length: optional `int`, the length of the state. Defaults to 0.
num_predictions: `int`, the number of predictions used in two stream
attention.
two_stream: `bool`, whether or not to generate two stream data.
include_state: optional `bool`, whether or not to include state data.
include_mask: optional `bool`, whether or not to include mask data.
include_segment: optional `bool`, whether or not to include segment data.
Returns:
A dictionary with `str` as keys and `Tensor` as values.
"""
query_shape = (batch_size, seq_length, key_dim)
value_shape = (batch_size, seq_length, value_dim)
encoding_shape = (batch_size, seq_length * 2, key_dim)
attention_bias_shape = (num_heads, key_dim)
data = dict(
relative_position_encoding=tf.random.normal(shape=encoding_shape),
content_attention_bias=tf.random.normal(shape=attention_bias_shape),
positional_attention_bias=tf.random.normal(shape=attention_bias_shape))
if two_stream:
query_stream_shape = (batch_size, num_predictions, key_dim)
target_mapping_shape = (batch_size, num_predictions, seq_length)
stream_data = dict(
content_stream=tf.random.normal(shape=query_shape),
query_stream=tf.random.normal(shape=query_stream_shape),
target_mapping=tf.random.normal(shape=target_mapping_shape))
else:
stream_data = dict(
query=tf.random.normal(shape=query_shape),
value=tf.random.normal(shape=value_shape),
key=tf.random.normal(shape=value_shape))
data.update(stream_data)
if include_state:
total_seq_length = seq_length + memory_length
state_data = dict(
state=tf.random.normal(shape=(batch_size, memory_length, value_dim)))
data.update(state_data)
else:
total_seq_length = seq_length
if include_mask:
mask_shape = (batch_size, num_heads, seq_length, total_seq_length)
mask_data = np.random.randint(2, size=mask_shape).astype("float32")
if two_stream:
mask_data = dict(
content_attention_mask=mask_data,
query_attention_mask=mask_data)
else:
mask_data = dict(attention_mask=mask_data)
data.update(mask_data)
if include_segment:
segment_encoding_shape = (2, num_heads, key_dim)
segment_matrix = np.random.randint(
2, size=(batch_size, seq_length, total_seq_length))
segment_matrix = tf.math.equal(segment_matrix, 1)
segment_data = dict(
segment_attention_bias=tf.random.normal(shape=attention_bias_shape),
segment_encoding=tf.random.normal(shape=segment_encoding_shape),
segment_matrix=segment_matrix)
data.update(segment_data)
return data
class MultiHeadRelativeAttentionTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(combinations.combine(
value_dim=[32, 64],
memory_length=[0, 4],
state=[True, False],
mask=[True, False],
segment=[True, False]))
def test_attention_scores(self,
value_dim,
memory_length,
state,
mask,
segment):
"""Tests combinations of attention score calculations."""
batch_size, num_heads, key_dim, seq_length = 2, 12, 64, 8
test_layer = relative_attention.MultiHeadRelativeAttention(
num_heads=num_heads,
key_dim=key_dim,
value_dim=value_dim)
data = _create_mock_attention_data(
num_heads=num_heads,
key_dim=key_dim,
value_dim=value_dim,
seq_length=seq_length,
memory_length=memory_length,
two_stream=False,
batch_size=batch_size,
include_state=state,
include_mask=mask,
include_segment=segment)
output = test_layer(**data)
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])
class TwoStreamRelativeAttentionTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(combinations.combine(
num_predictions=[2, 10],
memory_length=[0, 4],
state=[True, False],
mask=[True, False],
segment=[True, False]))
def test_attention_scores(self,
num_predictions,
memory_length,
state,
mask,
segment):
"""Tests combinations of attention score calculations."""
batch_size, num_heads, key_dim, seq_length = 2, 12, 64, 8
test_layer = relative_attention.TwoStreamRelativeAttention(
num_heads=num_heads,
key_dim=key_dim,
value_dim=key_dim)
data = _create_mock_attention_data(
num_heads=num_heads,
key_dim=key_dim,
value_dim=key_dim,
seq_length=seq_length,
memory_length=memory_length,
num_predictions=num_predictions,
two_stream=True,
batch_size=batch_size,
include_state=state,
include_mask=mask,
include_segment=segment)
content_output, query_output, = test_layer(**data)
self.assertEqual(content_output.shape, [batch_size, seq_length, key_dim])
self.assertEqual(query_output.shape, [batch_size, num_predictions, key_dim])
if __name__ == "__main__":
np.random.seed(0)
tf.random.set_seed(0)
tf.test.main()