|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Module for extracting segments from sentences in documents.""" |
|
|
|
import tensorflow as tf, tf_keras |
|
|
|
|
|
|
|
def _get_random(positions, random_fn): |
|
flat_random = random_fn( |
|
shape=tf.shape(positions.flat_values), |
|
minval=0, |
|
maxval=1, |
|
dtype=tf.float32) |
|
return positions.with_flat_values(flat_random) |
|
|
|
|
|
|
|
|
|
def _random_int_up_to(maxval, random_fn): |
|
|
|
|
|
|
|
float_maxval = tf.cast(maxval, tf.float32) |
|
return tf.cast( |
|
random_fn( |
|
shape=tf.shape(maxval), |
|
minval=tf.zeros_like(float_maxval), |
|
maxval=float_maxval), |
|
dtype=maxval.dtype) |
|
|
|
|
|
def _random_int_from_range(minval, maxval, random_fn): |
|
|
|
|
|
|
|
float_minval = tf.cast(minval, tf.float32) |
|
float_maxval = tf.cast(maxval, tf.float32) |
|
return tf.cast( |
|
random_fn(tf.shape(maxval), minval=float_minval, maxval=float_maxval), |
|
maxval.dtype) |
|
|
|
|
|
def _sample_from_other_batch(sentences, random_fn): |
|
"""Samples sentences from other batches.""" |
|
|
|
|
|
other_batch = random_fn( |
|
shape=[tf.size(sentences)], |
|
minval=0, |
|
maxval=sentences.nrows() - 1, |
|
dtype=tf.int64) |
|
|
|
other_batch += tf.cast(other_batch >= sentences.value_rowids(), tf.int64) |
|
|
|
|
|
|
|
other_sentence = _random_int_up_to( |
|
tf.gather(sentences.row_lengths(), other_batch), random_fn) |
|
return sentences.with_values(tf.stack([other_batch, other_sentence], axis=1)) |
|
|
|
|
|
def get_sentence_order_labels(sentences, |
|
random_threshold=0.5, |
|
random_next_threshold=0.5, |
|
random_fn=tf.random.uniform): |
|
"""Extract segments and labels for sentence order prediction (SOP) task. |
|
|
|
Extracts the segment and labels for the sentence order prediction task |
|
defined in "ALBERT: A Lite BERT for Self-Supervised Learning of Language |
|
Representations" (https://arxiv.org/pdf/1909.11942.pdf) |
|
|
|
Args: |
|
sentences: a `RaggedTensor` of shape [batch, (num_sentences)] with string |
|
dtype. |
|
random_threshold: (optional) A float threshold between 0 and 1, used to |
|
determine whether to extract a random, out-of-batch sentence or a |
|
suceeding sentence. Higher value favors succeeding sentence. |
|
random_next_threshold: (optional) A float threshold between 0 and 1, used to |
|
determine whether to extract either a random, out-of-batch, or succeeding |
|
sentence or a preceeding sentence. Higher value favors preceeding |
|
sentences. |
|
random_fn: (optional) An op used to generate random float values. |
|
|
|
Returns: |
|
a tuple of (preceeding_or_random_next, is_suceeding_or_random) where: |
|
preceeding_or_random_next: a `RaggedTensor` of strings with the same shape |
|
as `sentences` and contains either a preceeding, suceeding, or random |
|
out-of-batch sentence respective to its counterpart in `sentences` and |
|
dependent on its label in `is_preceeding_or_random_next`. |
|
is_suceeding_or_random: a `RaggedTensor` of bool values with the |
|
same shape as `sentences` and is True if it's corresponding sentence in |
|
`preceeding_or_random_next` is a random or suceeding sentence, False |
|
otherwise. |
|
""" |
|
|
|
|
|
positions = tf.ragged.range(sentences.row_lengths()) |
|
|
|
row_lengths_broadcasted = tf.expand_dims(positions.row_lengths(), |
|
-1) + 0 * positions |
|
row_lengths_broadcasted_flat = row_lengths_broadcasted.flat_values |
|
|
|
|
|
|
|
|
|
all_preceding = tf.ragged.map_flat_values(_random_int_up_to, positions, |
|
random_fn) |
|
|
|
|
|
|
|
all_succeeding = positions.with_flat_values( |
|
tf.ragged.map_flat_values(_random_int_from_range, |
|
positions.flat_values + 1, |
|
row_lengths_broadcasted_flat, random_fn)) |
|
|
|
|
|
rows_broadcasted = tf.expand_dims(tf.range(sentences.nrows()), |
|
-1) + 0 * positions |
|
all_preceding_nd = tf.stack([rows_broadcasted, all_preceding], -1) |
|
all_succeeding_nd = tf.stack([rows_broadcasted, all_succeeding], -1) |
|
all_random_nd = _sample_from_other_batch(positions, random_fn) |
|
|
|
|
|
|
|
|
|
all_zeros = tf.zeros_like(positions) |
|
all_ones = tf.ones_like(positions) |
|
valid_preceding_mask = tf.cast( |
|
tf.concat([all_zeros[:, :1], all_ones[:, 1:]], -1), tf.bool) |
|
valid_succeeding_mask = tf.cast( |
|
tf.concat([all_ones[:, :-1], all_zeros[:, -1:]], -1), tf.bool) |
|
|
|
|
|
|
|
|
|
should_get_random = ((_get_random(positions, random_fn) > random_threshold) |
|
| tf.logical_not(valid_succeeding_mask)) |
|
random_or_succeeding_nd = tf.compat.v1.where(should_get_random, all_random_nd, |
|
all_succeeding_nd) |
|
|
|
|
|
should_get_random_or_succeeding = ( |
|
(_get_random(positions, random_fn) > random_next_threshold) |
|
| tf.logical_not(valid_preceding_mask)) |
|
gather_indices = tf.compat.v1.where(should_get_random_or_succeeding, |
|
random_or_succeeding_nd, all_preceding_nd) |
|
return (tf.gather_nd(sentences, |
|
gather_indices), should_get_random_or_succeeding) |
|
|
|
|
|
def get_next_sentence_labels(sentences, |
|
random_threshold=0.5, |
|
random_fn=tf.random.uniform): |
|
"""Extracts the next sentence label from sentences. |
|
|
|
Args: |
|
sentences: A `RaggedTensor` of strings w/ shape [batch, (num_sentences)]. |
|
random_threshold: (optional) A float threshold between 0 and 1, used to |
|
determine whether to extract a random sentence or the immediate next |
|
sentence. Higher value favors next sentence. |
|
random_fn: (optional) An op used to generate random float values. |
|
|
|
Returns: |
|
A tuple of (next_sentence_or_random, is_next_sentence) where: |
|
|
|
next_sentence_or_random: A `Tensor` with shape [num_sentences] that |
|
contains either the subsequent sentence of `segment_a` or a randomly |
|
injected sentence. |
|
is_next_sentence: A `Tensor` of bool w/ shape [num_sentences] |
|
that contains whether or not `next_sentence_or_random` is truly a |
|
subsequent sentence or not. |
|
""" |
|
|
|
positions = tf.ragged.range(sentences.row_lengths()) |
|
|
|
|
|
next_sentences_pos = (positions + 1) % tf.expand_dims(sentences.row_lengths(), |
|
1) |
|
rows_broadcasted = tf.expand_dims(tf.range(sentences.nrows()), |
|
-1) + 0 * positions |
|
next_sentences_pos_nd = tf.stack([rows_broadcasted, next_sentences_pos], -1) |
|
all_random_nd = _sample_from_other_batch(positions, random_fn) |
|
|
|
|
|
|
|
|
|
valid_next_sentences = tf.cast( |
|
tf.concat([ |
|
tf.ones_like(positions)[:, :-1], |
|
tf.zeros([positions.nrows(), 1], dtype=tf.int64) |
|
], -1), tf.bool) |
|
|
|
is_random = ((_get_random(positions, random_fn) > random_threshold) |
|
| tf.logical_not(valid_next_sentences)) |
|
gather_indices = tf.compat.v1.where(is_random, all_random_nd, |
|
next_sentences_pos_nd) |
|
return tf.gather_nd(sentences, gather_indices), tf.logical_not(is_random) |
|
|