deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test beam search helper methods."""
from absl.testing import parameterized
import tensorflow as tf, tf_keras
from official.nlp.modeling.ops import beam_search
class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
def test_expand_to_beam_size(self):
x = tf.ones([7, 4, 2, 5])
x = beam_search.expand_to_beam_size(x, 3)
shape = tf.shape(x)
self.assertAllEqual([7, 3, 4, 2, 5], shape)
def test_get_shape_keep_last_dim(self):
y = tf.constant(4.0)
x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5])
shape = beam_search._get_shape_keep_last_dim(x)
self.assertAllEqual([None, None, None, 5], shape.as_list())
def test_flatten_beam_dim(self):
x = tf.ones([7, 4, 2, 5])
x = beam_search.flatten_beam_dim(x)
self.assertAllEqual([28, 2, 5], tf.shape(x))
def test_unflatten_beam_dim(self):
x = tf.ones([28, 2, 5])
x = beam_search._unflatten_beam_dim(x, 7, 4)
self.assertAllEqual([7, 4, 2, 5], tf.shape(x))
def test_gather_beams(self):
x = tf.reshape(tf.range(24), [2, 3, 4])
# x looks like: [[[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]
#
# [[12 13 14 15]
# [16 17 18 19]
# [20 21 22 23]]]
y = beam_search.SequenceBeamSearch._gather_beams(x, [[1, 2], [0, 2]], 2, 2)
self.assertAllEqual(
[[[4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [20, 21, 22, 23]]],
y)
@parameterized.named_parameters([
('padded_decode_true_with_name', True, 0.0, 'decoding'),
('padded_decode_false_with_name', False, 0.0, 'decoding'),
('padded_decode_true_without_name', True, 0.0, None),
('padded_decode_false_without_name', False, 0.0, None),
('padded_decode_false_with_noise', False, 0.5, 'decoding'),
])
def test_sequence_beam_search(self, padded_decode, noise_multiplier, name):
# batch_size*beam_size, max_decode_length, vocab_size
probabilities = tf.constant([[[0.2, 0.7, 0.1], [0.5, 0.3, 0.2],
[0.1, 0.8, 0.1]],
[[0.1, 0.8, 0.1], [0.3, 0.4, 0.3],
[0.2, 0.1, 0.7]]])
# batch_size, max_decode_length, num_heads, embed_size per head
x = tf.zeros([1, 3, 2, 32], dtype=tf.float32)
cache = {'layer_%d' % layer: {'k': x, 'v': x} for layer in range(2)}
def _get_test_symbols_to_logits_fn():
"""Test function that returns logits for next token."""
def symbols_to_logits_fn(_, i, cache):
logits = tf.cast(probabilities[:, i, :], tf.float32)
return logits, cache
return symbols_to_logits_fn
predictions, _ = beam_search.sequence_beam_search(
symbols_to_logits_fn=_get_test_symbols_to_logits_fn(),
initial_ids=tf.zeros([1], dtype=tf.int32),
initial_cache=cache,
vocab_size=3,
beam_size=2,
alpha=0.6,
max_decode_length=3,
eos_id=9,
padded_decode=padded_decode,
dtype=tf.float32,
noise_multiplier=noise_multiplier,
decoding_name=name,
)
if noise_multiplier > 0:
self.assertAllEqual([[[0, 1, 0, 1], [0, 0, 2, 2]]], predictions)
else:
self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)
@parameterized.named_parameters([
('padded_decode_true_with_name', True, 0.0, 'decoding'),
('padded_decode_false_with_name', False, 0.0, 'decoding'),
('padded_decode_true_without_name', True, 0.0, None),
('padded_decode_false_without_name', False, 0.0, None),
('padded_decode_false_with_noise', False, 0.5, 'decoding'),
])
def test_sequence_beam_search_multi_eos(
self, padded_decode, noise_multiplier, name
):
# batch_size*beam_size, max_decode_length, vocab_size
probabilities = tf.constant([
[[0.2, 0.7, 0.1], [0.5, 0.3, 0.2], [0.1, 0.8, 0.1]],
[[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.2, 0.1, 0.7]],
])
# batch_size, max_decode_length, num_heads, embed_size per head
x = tf.zeros([1, 3, 2, 32], dtype=tf.float32)
cache = {'layer_%d' % layer: {'k': x, 'v': x} for layer in range(2)}
def _get_test_symbols_to_logits_fn():
"""Test function that returns logits for next token."""
def symbols_to_logits_fn(_, i, cache):
logits = tf.cast(probabilities[:, i, :], tf.float32)
return logits, cache
return symbols_to_logits_fn
predictions, _ = beam_search.sequence_beam_search(
symbols_to_logits_fn=_get_test_symbols_to_logits_fn(),
initial_ids=tf.zeros([1], dtype=tf.int32),
initial_cache=cache,
vocab_size=3,
beam_size=2,
alpha=0.6,
max_decode_length=3,
eos_id=[9, 10],
padded_decode=padded_decode,
dtype=tf.float32,
noise_multiplier=noise_multiplier,
decoding_name=name,
)
if noise_multiplier > 0:
self.assertAllEqual([[[0, 1, 0, 1], [0, 0, 2, 2]]], predictions)
else:
self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)
if __name__ == '__main__':
tf.test.main()