# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module for extracting segments from sentences in documents.""" import tensorflow as tf, tf_keras # Get a random tensor like `positions` and make some decisions def _get_random(positions, random_fn): flat_random = random_fn( shape=tf.shape(positions.flat_values), minval=0, maxval=1, dtype=tf.float32) return positions.with_flat_values(flat_random) # For every position j in a row, sample a position preceeding j or # a position which is [0, j-1] def _random_int_up_to(maxval, random_fn): # Need to cast because the int kernel for uniform doesn't support bcast. # We add one because maxval is exclusive, and this will get rounded down # when we cast back to int. float_maxval = tf.cast(maxval, tf.float32) return tf.cast( random_fn( shape=tf.shape(maxval), minval=tf.zeros_like(float_maxval), maxval=float_maxval), dtype=maxval.dtype) def _random_int_from_range(minval, maxval, random_fn): # Need to cast because the int kernel for uniform doesn't support bcast. # We add one because maxval is exclusive, and this will get rounded down # when we cast back to int. float_minval = tf.cast(minval, tf.float32) float_maxval = tf.cast(maxval, tf.float32) return tf.cast( random_fn(tf.shape(maxval), minval=float_minval, maxval=float_maxval), maxval.dtype) def _sample_from_other_batch(sentences, random_fn): """Samples sentences from other batches.""" # other_batch: [num_sentences]: The batch to sample from for each # sentence. other_batch = random_fn( shape=[tf.size(sentences)], minval=0, maxval=sentences.nrows() - 1, dtype=tf.int64) other_batch += tf.cast(other_batch >= sentences.value_rowids(), tf.int64) # other_sentence: [num_sentences]: The sentence within each batch # that we sampled. other_sentence = _random_int_up_to( tf.gather(sentences.row_lengths(), other_batch), random_fn) return sentences.with_values(tf.stack([other_batch, other_sentence], axis=1)) def get_sentence_order_labels(sentences, random_threshold=0.5, random_next_threshold=0.5, random_fn=tf.random.uniform): """Extract segments and labels for sentence order prediction (SOP) task. Extracts the segment and labels for the sentence order prediction task defined in "ALBERT: A Lite BERT for Self-Supervised Learning of Language Representations" (https://arxiv.org/pdf/1909.11942.pdf) Args: sentences: a `RaggedTensor` of shape [batch, (num_sentences)] with string dtype. random_threshold: (optional) A float threshold between 0 and 1, used to determine whether to extract a random, out-of-batch sentence or a suceeding sentence. Higher value favors succeeding sentence. random_next_threshold: (optional) A float threshold between 0 and 1, used to determine whether to extract either a random, out-of-batch, or succeeding sentence or a preceeding sentence. Higher value favors preceeding sentences. random_fn: (optional) An op used to generate random float values. Returns: a tuple of (preceeding_or_random_next, is_suceeding_or_random) where: preceeding_or_random_next: a `RaggedTensor` of strings with the same shape as `sentences` and contains either a preceeding, suceeding, or random out-of-batch sentence respective to its counterpart in `sentences` and dependent on its label in `is_preceeding_or_random_next`. is_suceeding_or_random: a `RaggedTensor` of bool values with the same shape as `sentences` and is True if it's corresponding sentence in `preceeding_or_random_next` is a random or suceeding sentence, False otherwise. """ # Create a RaggedTensor in the same shape as sentences ([doc, (sentences)]) # whose values are index positions. positions = tf.ragged.range(sentences.row_lengths()) row_lengths_broadcasted = tf.expand_dims(positions.row_lengths(), -1) + 0 * positions row_lengths_broadcasted_flat = row_lengths_broadcasted.flat_values # Generate indices for all preceeding, succeeding and random. # For every position j in a row, sample a position preceeding j or # a position which is [0, j-1] all_preceding = tf.ragged.map_flat_values(_random_int_up_to, positions, random_fn) # For every position j, sample a position following j, or a position # which is [j, row_max] all_succeeding = positions.with_flat_values( tf.ragged.map_flat_values(_random_int_from_range, positions.flat_values + 1, row_lengths_broadcasted_flat, random_fn)) # Convert to format that is convenient for `gather_nd` rows_broadcasted = tf.expand_dims(tf.range(sentences.nrows()), -1) + 0 * positions all_preceding_nd = tf.stack([rows_broadcasted, all_preceding], -1) all_succeeding_nd = tf.stack([rows_broadcasted, all_succeeding], -1) all_random_nd = _sample_from_other_batch(positions, random_fn) # There's a few spots where there is no "preceding" or "succeeding" item (e.g. # first and last sentences in a document). Mark where these are and we will # patch them up to grab a random sentence from another document later. all_zeros = tf.zeros_like(positions) all_ones = tf.ones_like(positions) valid_preceding_mask = tf.cast( tf.concat([all_zeros[:, :1], all_ones[:, 1:]], -1), tf.bool) valid_succeeding_mask = tf.cast( tf.concat([all_ones[:, :-1], all_zeros[:, -1:]], -1), tf.bool) # Decide what to use for the segment: (1) random, out-of-batch, (2) preceeding # item, or (3) succeeding. # Should get out-of-batch instead of succeeding item should_get_random = ((_get_random(positions, random_fn) > random_threshold) | tf.logical_not(valid_succeeding_mask)) random_or_succeeding_nd = tf.compat.v1.where(should_get_random, all_random_nd, all_succeeding_nd) # Choose which items should get a random succeeding item. Force positions that # don't have a valid preceeding items to get a random succeeding item. should_get_random_or_succeeding = ( (_get_random(positions, random_fn) > random_next_threshold) | tf.logical_not(valid_preceding_mask)) gather_indices = tf.compat.v1.where(should_get_random_or_succeeding, random_or_succeeding_nd, all_preceding_nd) return (tf.gather_nd(sentences, gather_indices), should_get_random_or_succeeding) def get_next_sentence_labels(sentences, random_threshold=0.5, random_fn=tf.random.uniform): """Extracts the next sentence label from sentences. Args: sentences: A `RaggedTensor` of strings w/ shape [batch, (num_sentences)]. random_threshold: (optional) A float threshold between 0 and 1, used to determine whether to extract a random sentence or the immediate next sentence. Higher value favors next sentence. random_fn: (optional) An op used to generate random float values. Returns: A tuple of (next_sentence_or_random, is_next_sentence) where: next_sentence_or_random: A `Tensor` with shape [num_sentences] that contains either the subsequent sentence of `segment_a` or a randomly injected sentence. is_next_sentence: A `Tensor` of bool w/ shape [num_sentences] that contains whether or not `next_sentence_or_random` is truly a subsequent sentence or not. """ # shift everyone to get the next sentence predictions positions positions = tf.ragged.range(sentences.row_lengths()) # Shift every position down to the right. next_sentences_pos = (positions + 1) % tf.expand_dims(sentences.row_lengths(), 1) rows_broadcasted = tf.expand_dims(tf.range(sentences.nrows()), -1) + 0 * positions next_sentences_pos_nd = tf.stack([rows_broadcasted, next_sentences_pos], -1) all_random_nd = _sample_from_other_batch(positions, random_fn) # Mark the items that don't have a next sentence (e.g. the last # sentences in the document). We will patch these up and force them to grab a # random sentence from a random document. valid_next_sentences = tf.cast( tf.concat([ tf.ones_like(positions)[:, :-1], tf.zeros([positions.nrows(), 1], dtype=tf.int64) ], -1), tf.bool) is_random = ((_get_random(positions, random_fn) > random_threshold) | tf.logical_not(valid_next_sentences)) gather_indices = tf.compat.v1.where(is_random, all_random_nd, next_sentences_pos_nd) return tf.gather_nd(sentences, gather_indices), tf.logical_not(is_random)