Spaces:
Sleeping
Sleeping
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Loads dataset for the BERT pretraining task.""" | |
import dataclasses | |
from typing import Mapping, Optional | |
from absl import logging | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
from official.common import dataset_fn | |
from official.core import config_definitions as cfg | |
from official.core import input_reader | |
from official.nlp.data import data_loader | |
from official.nlp.data import data_loader_factory | |
class BertPretrainDataConfig(cfg.DataConfig): | |
"""Data config for BERT pretraining task (tasks/masked_lm).""" | |
input_path: str = '' | |
global_batch_size: int = 512 | |
is_training: bool = True | |
seq_length: int = 512 | |
max_predictions_per_seq: int = 76 | |
use_next_sentence_label: bool = True | |
use_position_id: bool = False | |
# Historically, BERT implementations take `input_ids` and `segment_ids` as | |
# feature names. Inside the TF Model Garden implementation, the Keras model | |
# inputs are set as `input_word_ids` and `input_type_ids`. When | |
# v2_feature_names is True, the data loader assumes the tf.Examples use | |
# `input_word_ids` and `input_type_ids` as keys. | |
use_v2_feature_names: bool = False | |
file_type: str = 'tfrecord' | |
class BertPretrainDataLoader(data_loader.DataLoader): | |
"""A class to load dataset for bert pretraining task.""" | |
def __init__(self, params): | |
"""Inits `BertPretrainDataLoader` class. | |
Args: | |
params: A `BertPretrainDataConfig` object. | |
""" | |
self._params = params | |
self._seq_length = params.seq_length | |
self._max_predictions_per_seq = params.max_predictions_per_seq | |
self._use_next_sentence_label = params.use_next_sentence_label | |
self._use_position_id = params.use_position_id | |
def _name_to_features(self): | |
name_to_features = { | |
'input_mask': | |
tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
'masked_lm_positions': | |
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64), | |
'masked_lm_ids': | |
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64), | |
'masked_lm_weights': | |
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32), | |
} | |
if self._params.use_v2_feature_names: | |
name_to_features.update({ | |
'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
}) | |
else: | |
name_to_features.update({ | |
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
}) | |
if self._use_next_sentence_label: | |
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1], | |
tf.int64) | |
if self._use_position_id: | |
name_to_features['position_ids'] = tf.io.FixedLenFeature( | |
[self._seq_length], tf.int64) | |
return name_to_features | |
def _decode(self, record: tf.Tensor): | |
"""Decodes a serialized tf.Example.""" | |
name_to_features = self._name_to_features() | |
example = tf.io.parse_single_example(record, name_to_features) | |
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. | |
# So cast all int64 to int32. | |
for name in list(example.keys()): | |
t = example[name] | |
if t.dtype == tf.int64: | |
t = tf.cast(t, tf.int32) | |
example[name] = t | |
return example | |
def _parse(self, record: Mapping[str, tf.Tensor]): | |
"""Parses raw tensors into a dict of tensors to be consumed by the model.""" | |
x = { | |
'input_mask': record['input_mask'], | |
'masked_lm_positions': record['masked_lm_positions'], | |
'masked_lm_ids': record['masked_lm_ids'], | |
'masked_lm_weights': record['masked_lm_weights'], | |
} | |
if self._params.use_v2_feature_names: | |
x['input_word_ids'] = record['input_word_ids'] | |
x['input_type_ids'] = record['input_type_ids'] | |
else: | |
x['input_word_ids'] = record['input_ids'] | |
x['input_type_ids'] = record['segment_ids'] | |
if self._use_next_sentence_label: | |
x['next_sentence_labels'] = record['next_sentence_labels'] | |
if self._use_position_id: | |
x['position_ids'] = record['position_ids'] | |
return x | |
def load(self, input_context: Optional[tf.distribute.InputContext] = None): | |
"""Returns a tf.dataset.Dataset.""" | |
reader = input_reader.InputReader( | |
params=self._params, | |
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type), | |
decoder_fn=self._decode, | |
parser_fn=self._parse) | |
return reader.read(input_context) | |
class XLNetPretrainDataConfig(cfg.DataConfig): | |
"""Data config for XLNet pretraining task. | |
Attributes: | |
input_path: See base class. | |
global_batch_size: See base calss. | |
is_training: See base class. | |
seq_length: The length of each sequence. | |
max_predictions_per_seq: The number of predictions per sequence. | |
reuse_length: The number of tokens in a previous segment to reuse. This | |
should be the same value used during pretrain data creation. | |
sample_strategy: The strategy used to sample factorization permutations. | |
Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'. | |
min_num_tokens: The minimum number of tokens to sample in a span. This is | |
used when `sample_strategy` is 'token_span'. | |
max_num_tokens: The maximum number of tokens to sample in a span. This is | |
used when `sample_strategy` is 'token_span'. | |
min_num_words: The minimum number of words to sample in a span. This is used | |
when `sample_strategy` is 'word_span'. | |
max_num_words: The maximum number of words to sample in a span. This is used | |
when `sample_strategy` is 'word_span'. | |
permutation_size: The length of the longest permutation. This can be set to | |
`reuse_length`. This should NOT be greater than `reuse_length`, otherwise | |
this may introduce data leaks. | |
leak_ratio: The percentage of masked tokens that are leaked. | |
segment_sep_id: The ID of the SEP token used when preprocessing the dataset. | |
segment_cls_id: The ID of the CLS token used when preprocessing the dataset. | |
""" | |
input_path: str = '' | |
global_batch_size: int = 512 | |
is_training: bool = True | |
seq_length: int = 512 | |
max_predictions_per_seq: int = 76 | |
reuse_length: int = 256 | |
sample_strategy: str = 'word_span' | |
min_num_tokens: int = 1 | |
max_num_tokens: int = 5 | |
min_num_words: int = 1 | |
max_num_words: int = 5 | |
permutation_size: int = 256 | |
leak_ratio: float = 0.1 | |
segment_sep_id: int = 4 | |
segment_cls_id: int = 3 | |
class XLNetPretrainDataLoader(data_loader.DataLoader): | |
"""A class to load dataset for xlnet pretraining task.""" | |
def __init__(self, params: XLNetPretrainDataConfig): | |
"""Inits `XLNetPretrainDataLoader` class. | |
Args: | |
params: A `XLNetPretrainDataConfig` object. | |
""" | |
self._params = params | |
self._seq_length = params.seq_length | |
self._max_predictions_per_seq = params.max_predictions_per_seq | |
self._reuse_length = params.reuse_length | |
self._num_replicas_in_sync = None | |
self._permutation_size = params.permutation_size | |
self._sep_id = params.segment_sep_id | |
self._cls_id = params.segment_cls_id | |
self._sample_strategy = params.sample_strategy | |
self._leak_ratio = params.leak_ratio | |
def _decode(self, record: tf.Tensor): | |
"""Decodes a serialized tf.Example.""" | |
name_to_features = { | |
'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
'boundary_indices': tf.io.VarLenFeature(tf.int64), | |
} | |
example = tf.io.parse_single_example(record, name_to_features) | |
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. | |
# So cast all int64 to int32. | |
for name in list(example.keys()): | |
t = example[name] | |
if t.dtype == tf.int64: | |
t = tf.cast(t, tf.int32) | |
example[name] = t | |
return example | |
def _parse(self, record: Mapping[str, tf.Tensor]): | |
"""Parses raw tensors into a dict of tensors to be consumed by the model.""" | |
x = {} | |
inputs = record['input_word_ids'] | |
x['input_type_ids'] = record['input_type_ids'] | |
if self._sample_strategy in ['whole_word', 'word_span']: | |
boundary = tf.sparse.to_dense(record['boundary_indices']) | |
else: | |
boundary = None | |
input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary) | |
if self._reuse_length > 0: | |
if self._permutation_size > self._reuse_length: | |
logging.warning( | |
'`permutation_size` is greater than `reuse_length` (%d > %d).' | |
'This may introduce data leakage.', self._permutation_size, | |
self._reuse_length) | |
# Enable the memory mechanism. | |
# Permute the reuse and non-reuse segments separately. | |
non_reuse_len = self._seq_length - self._reuse_length | |
if not (self._reuse_length % self._permutation_size == 0 and | |
non_reuse_len % self._permutation_size == 0): | |
raise ValueError('`reuse_length` and `seq_length` should both be ' | |
'a multiple of `permutation_size`.') | |
# Creates permutation mask and target mask for the first reuse_len tokens. | |
# The tokens in this part are reused from the last sequence. | |
perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization( | |
inputs=inputs[:self._reuse_length], | |
input_mask=input_mask[:self._reuse_length]) | |
# Creates permutation mask and target mask for the rest of tokens in | |
# current example, which are concatentation of two new segments. | |
perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization( | |
inputs[self._reuse_length:], input_mask[self._reuse_length:]) | |
perm_mask_0 = tf.concat([ | |
perm_mask_0, | |
tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32) | |
], | |
axis=1) | |
perm_mask_1 = tf.concat([ | |
tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32), | |
perm_mask_1 | |
], | |
axis=1) | |
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) | |
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) | |
tokens = tf.concat([tokens_0, tokens_1], axis=0) | |
masked_tokens = tf.concat([masked_0, masked_1], axis=0) | |
else: | |
# Disable the memory mechanism. | |
if self._seq_length % self._permutation_size != 0: | |
raise ValueError('`seq_length` should be a multiple of ' | |
'`permutation_size`.') | |
# Permute the entire sequence together | |
perm_mask, target_mask, tokens, masked_tokens = self._get_factorization( | |
inputs=inputs, input_mask=input_mask) | |
x['permutation_mask'] = tf.reshape(perm_mask, | |
[self._seq_length, self._seq_length]) | |
x['input_word_ids'] = tokens | |
x['masked_tokens'] = masked_tokens | |
target = tokens | |
if self._max_predictions_per_seq is not None: | |
indices = tf.range(self._seq_length, dtype=tf.int32) | |
bool_target_mask = tf.cast(target_mask, tf.bool) | |
indices = tf.boolean_mask(indices, bool_target_mask) | |
# account for extra padding due to CLS/SEP. | |
actual_num_predict = tf.shape(indices)[0] | |
pad_len = self._max_predictions_per_seq - actual_num_predict | |
target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32) | |
paddings = tf.zeros([pad_len, self._seq_length], | |
dtype=target_mapping.dtype) | |
target_mapping = tf.concat([target_mapping, paddings], axis=0) | |
x['target_mapping'] = tf.reshape( | |
target_mapping, [self._max_predictions_per_seq, self._seq_length]) | |
target = tf.boolean_mask(target, bool_target_mask) | |
paddings = tf.zeros([pad_len], dtype=target.dtype) | |
target = tf.concat([target, paddings], axis=0) | |
x['target'] = tf.reshape(target, [self._max_predictions_per_seq]) | |
target_mask = tf.concat([ | |
tf.ones([actual_num_predict], dtype=tf.int32), | |
tf.zeros([pad_len], dtype=tf.int32) | |
], | |
axis=0) | |
x['target_mask'] = tf.reshape(target_mask, | |
[self._max_predictions_per_seq]) | |
else: | |
x['target'] = tf.reshape(target, [self._seq_length]) | |
x['target_mask'] = tf.reshape(target_mask, [self._seq_length]) | |
return x | |
def _index_pair_to_mask(self, begin_indices: tf.Tensor, | |
end_indices: tf.Tensor, | |
inputs: tf.Tensor) -> tf.Tensor: | |
"""Converts beginning and end indices into an actual mask.""" | |
non_func_mask = tf.logical_and( | |
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id)) | |
all_indices = tf.where( | |
non_func_mask, tf.range(self._seq_length, dtype=tf.int32), | |
tf.constant(-1, shape=[self._seq_length], dtype=tf.int32)) | |
candidate_matrix = tf.cast( | |
tf.logical_and(all_indices[None, :] >= begin_indices[:, None], | |
all_indices[None, :] < end_indices[:, None]), tf.float32) | |
cumsum_matrix = tf.reshape( | |
tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length]) | |
masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq, | |
tf.float32) | |
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0) | |
return tf.cast(target_mask, tf.bool) | |
def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor: | |
"""Samples individual tokens as prediction targets.""" | |
all_indices = tf.range(self._seq_length, dtype=tf.int32) | |
non_func_mask = tf.logical_and( | |
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id)) | |
non_func_indices = tf.boolean_mask(all_indices, non_func_mask) | |
masked_pos = tf.random.shuffle(non_func_indices) | |
masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq]) | |
sparse_indices = tf.stack([tf.zeros_like(masked_pos), masked_pos], axis=-1) | |
sparse_indices = tf.cast(sparse_indices, tf.int64) | |
sparse_indices = tf.sparse.SparseTensor( | |
sparse_indices, | |
values=tf.ones_like(masked_pos), | |
dense_shape=(1, self._seq_length)) | |
target_mask = tf.sparse.to_dense(sp_input=sparse_indices, default_value=0) | |
return tf.squeeze(tf.cast(target_mask, tf.bool)) | |
def _whole_word_mask(self, inputs: tf.Tensor, | |
boundary: tf.Tensor) -> tf.Tensor: | |
"""Samples whole words as prediction targets.""" | |
pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1) | |
cand_pair_indices = tf.random.shuffle( | |
pair_indices)[:self._max_predictions_per_seq] | |
begin_indices = cand_pair_indices[:, 0] | |
end_indices = cand_pair_indices[:, 1] | |
return self._index_pair_to_mask( | |
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs) | |
def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor: | |
"""Samples token spans as prediction targets.""" | |
min_num_tokens = self._params.min_num_tokens | |
max_num_tokens = self._params.max_num_tokens | |
mask_alpha = self._seq_length / self._max_predictions_per_seq | |
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32) | |
# Sample span lengths from a zipf distribution | |
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1) | |
probs = np.array([1.0 / (i + 1) for i in span_len_seq]) | |
probs /= np.sum(probs) | |
logits = tf.constant(np.log(probs), dtype=tf.float32) | |
span_lens = tf.random.categorical( | |
logits=logits[None], | |
num_samples=self._max_predictions_per_seq, | |
dtype=tf.int32, | |
)[0] + min_num_tokens | |
# Sample the ratio [0.0, 1.0) of left context lengths | |
span_lens_float = tf.cast(span_lens, tf.float32) | |
left_ratio = tf.random.uniform( | |
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0) | |
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1) | |
left_ctx_len = round_to_int(left_ctx_len) | |
# Compute the offset from left start to the right end | |
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len | |
# Get the actual begin and end indices | |
begin_indices = ( | |
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True)) | |
end_indices = begin_indices + span_lens | |
# Remove out of range indices | |
valid_idx_mask = end_indices < self._seq_length | |
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask) | |
end_indices = tf.boolean_mask(end_indices, valid_idx_mask) | |
# Shuffle valid indices | |
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32) | |
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32)) | |
begin_indices = tf.gather(begin_indices, order) | |
end_indices = tf.gather(end_indices, order) | |
return self._index_pair_to_mask( | |
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs) | |
def _word_span_mask(self, inputs: tf.Tensor, boundary: tf.Tensor): | |
"""Sample whole word spans as prediction targets.""" | |
min_num_words = self._params.min_num_words | |
max_num_words = self._params.max_num_words | |
# Note: 1.2 is the token-to-word ratio | |
mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2 | |
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32) | |
# Sample span lengths from a zipf distribution | |
span_len_seq = np.arange(min_num_words, max_num_words + 1) | |
probs = np.array([1.0 / (i + 1) for i in span_len_seq]) | |
probs /= np.sum(probs) | |
logits = tf.constant(np.log(probs), dtype=tf.float32) | |
# Sample `num_predict` words here: note that this is over sampling | |
span_lens = tf.random.categorical( | |
logits=logits[None], | |
num_samples=self._max_predictions_per_seq, | |
dtype=tf.int32, | |
)[0] + min_num_words | |
# Sample the ratio [0.0, 1.0) of left context lengths | |
span_lens_float = tf.cast(span_lens, tf.float32) | |
left_ratio = tf.random.uniform( | |
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0) | |
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1) | |
left_ctx_len = round_to_int(left_ctx_len) | |
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len | |
begin_indices = ( | |
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True)) | |
end_indices = begin_indices + span_lens | |
# Remove out of range indices | |
max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32) | |
valid_idx_mask = end_indices < max_boundary_index | |
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask) | |
end_indices = tf.boolean_mask(end_indices, valid_idx_mask) | |
begin_indices = tf.gather(boundary, begin_indices) | |
end_indices = tf.gather(boundary, end_indices) | |
# Shuffle valid indices | |
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32) | |
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32)) | |
begin_indices = tf.gather(begin_indices, order) | |
end_indices = tf.gather(end_indices, order) | |
return self._index_pair_to_mask( | |
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs) | |
def _online_sample_mask(self, inputs: tf.Tensor, | |
boundary: tf.Tensor) -> tf.Tensor: | |
"""Samples target positions for predictions. | |
Descriptions of each strategy: | |
- 'single_token': Samples individual tokens as prediction targets. | |
- 'token_span': Samples spans of tokens as prediction targets. | |
- 'whole_word': Samples individual words as prediction targets. | |
- 'word_span': Samples spans of words as prediction targets. | |
Args: | |
inputs: The input tokens. | |
boundary: The `int` Tensor of indices indicating whole word boundaries. | |
This is used in 'whole_word' and 'word_span' | |
Returns: | |
The sampled `bool` input mask. | |
Raises: | |
`ValueError`: if `max_predictions_per_seq` is not set or if boundary is | |
not provided for 'whole_word' and 'word_span' sample strategies. | |
""" | |
if self._max_predictions_per_seq is None: | |
raise ValueError('`max_predictions_per_seq` must be set.') | |
if boundary is None and 'word' in self._sample_strategy: | |
raise ValueError('`boundary` must be provided for {} strategy'.format( | |
self._sample_strategy)) | |
if self._sample_strategy == 'single_token': | |
return self._single_token_mask(inputs) | |
elif self._sample_strategy == 'token_span': | |
return self._token_span_mask(inputs) | |
elif self._sample_strategy == 'whole_word': | |
return self._whole_word_mask(inputs, boundary) | |
elif self._sample_strategy == 'word_span': | |
return self._word_span_mask(inputs, boundary) | |
else: | |
raise NotImplementedError('Invalid sample strategy.') | |
def _get_factorization(self, inputs: tf.Tensor, input_mask: tf.Tensor): | |
"""Samples a permutation of the factorization order. | |
Args: | |
inputs: the input tokens. | |
input_mask: the `bool` Tensor of the same shape as `inputs`. If `True`, | |
then this means select for partial prediction. | |
Returns: | |
perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting | |
of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th | |
token (in original order) cannot attend to the jth attention token. | |
target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s. | |
If target_mask[i] == 1, then the i-th token needs to be predicted and | |
the mask will be used as input. This token will be included in the loss. | |
If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as | |
input. This token will not be included in the loss. | |
tokens: int32 Tensor of shape [seq_length]. | |
masked_tokens: int32 Tensor of shape [seq_length]. | |
""" | |
factorization_length = tf.shape(inputs)[0] | |
# Generate permutation indices | |
index = tf.range(factorization_length, dtype=tf.int32) | |
index = tf.transpose(tf.reshape(index, [-1, self._permutation_size])) | |
index = tf.random.shuffle(index) | |
index = tf.reshape(tf.transpose(index), [-1]) | |
input_mask = tf.cast(input_mask, tf.bool) | |
# non-functional tokens | |
non_func_tokens = tf.logical_not( | |
tf.logical_or( | |
tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id))) | |
masked_tokens = tf.logical_and(input_mask, non_func_tokens) | |
non_masked_or_func_tokens = tf.logical_not(masked_tokens) | |
smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32) | |
# Similar to BERT, randomly leak some masked tokens | |
if self._leak_ratio > 0: | |
leak_tokens = tf.logical_and( | |
masked_tokens, | |
tf.random.uniform([factorization_length], maxval=1.0) < | |
self._leak_ratio) | |
can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens) | |
else: | |
can_attend_self = non_masked_or_func_tokens | |
to_index = tf.where(can_attend_self, smallest_index, index) | |
from_index = tf.where(can_attend_self, to_index + 1, to_index) | |
# For masked tokens, can attend if i > j | |
# For context tokens, always can attend each other | |
can_attend = from_index[:, None] > to_index[None, :] | |
perm_mask = tf.cast(can_attend, tf.int32) | |
# Only masked tokens are included in the loss | |
target_mask = tf.cast(masked_tokens, tf.int32) | |
return perm_mask, target_mask, inputs, masked_tokens | |
def load(self, input_context: Optional[tf.distribute.InputContext] = None): | |
"""Returns a tf.dataset.Dataset.""" | |
if input_context: | |
self._num_replicas_in_sync = input_context.num_replicas_in_sync | |
reader = input_reader.InputReader( | |
params=self._params, decoder_fn=self._decode, parser_fn=self._parse) | |
return reader.read(input_context) | |