deanna-emery's picture
updates
93528c6
raw
history blame
5.41 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# encoding=utf-8
"""Tests for sentence prediction labels."""
import functools
from absl.testing import parameterized
import tensorflow as tf, tf_keras
from official.nlp.modeling.ops import segment_extractor
class NextSentencePredictionTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters([
dict(
test_description="all random",
sentences=[[b"Hello there.", b"La la la.", b"Such is life."],
[b"Who let the dogs out?", b"Who?."]],
expected_segment=[[
b"Who let the dogs out?", b"Who?.", b"Who let the dogs out?"
], [b"Hello there.", b"Hello there."]],
expected_labels=[
[False, False, False],
[False, False],
],
random_threshold=0.0,
),
dict(
test_description="all next",
sentences=[[b"Hello there.", b"La la la.", b"Such is life."],
[b"Who let the dogs out?", b"Who?."]],
expected_segment=[
[b"La la la.", b"Such is life.", b"Who let the dogs out?"],
[b"Who?.", b"Hello there."],
],
expected_labels=[
[True, True, False],
[True, False],
],
random_threshold=1.0,
),
])
def testNextSentencePrediction(self,
sentences,
expected_segment,
expected_labels,
random_threshold=0.5,
test_description=""):
sentences = tf.ragged.constant(sentences)
# Set seed and rig the shuffle function to a deterministic reverse function
# instead. This is so that we have consistent and deterministic results.
extracted_segment, actual_labels = (
segment_extractor.get_next_sentence_labels(
sentences,
random_threshold,
random_fn=functools.partial(
tf.random.stateless_uniform, seed=(2, 3))))
self.assertAllEqual(expected_segment, extracted_segment)
self.assertAllEqual(expected_labels, actual_labels)
class SentenceOrderLabelsTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters([
dict(
test_description="all random",
sentences=[[b"Hello there.", b"La la la.", b"Such is life."],
[b"Who let the dogs out?", b"Who?."]],
expected_segment=[[
b"Who let the dogs out?", b"Who?.", b"Who let the dogs out?"
], [b"Hello there.", b"Hello there."]],
expected_labels=[[True, True, True], [True, True]],
random_threshold=0.0,
random_next_threshold=0.0,
),
dict(
test_description="all next",
sentences=[[b"Hello there.", b"La la la.", b"Such is life."],
[b"Who let the dogs out?", b"Who?."]],
expected_segment=[[
b"La la la.", b"Such is life.", b"Who let the dogs out?"
], [b"Who?.", b"Hello there."]],
expected_labels=[[True, True, True], [True, True]],
random_threshold=1.0,
random_next_threshold=0.0,
),
dict(
test_description="all preceeding",
sentences=[[b"Hello there.", b"La la la.", b"Such is life."],
[b"Who let the dogs out?", b"Who?."]],
expected_segment=[
[b"La la la.", b"Hello there.", b"Hello there."],
[b"Who?.", b"Who let the dogs out?"],
],
expected_labels=[
[True, False, False],
[True, False],
],
random_threshold=1.0,
random_next_threshold=1.0,
),
])
def testSentenceOrderPrediction(self,
sentences,
expected_segment,
expected_labels,
random_threshold=0.5,
random_next_threshold=0.5,
test_description=""):
sentences = tf.ragged.constant(sentences)
# Set seed and rig the shuffle function to a deterministic reverse function
# instead. This is so that we have consistent and deterministic results.
extracted_segment, actual_labels = (
segment_extractor.get_sentence_order_labels(
sentences,
random_threshold=random_threshold,
random_next_threshold=random_next_threshold,
random_fn=functools.partial(
tf.random.stateless_uniform, seed=(2, 3))))
self.assertAllEqual(expected_segment, extracted_segment)
self.assertAllEqual(expected_labels, actual_labels)
if __name__ == "__main__":
tf.test.main()