# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities used for data preparation.""" import collections import json import os from absl import logging import numpy as np import tensorflow as tf, tf_keras special_symbols = { "": 0, "": 1, "": 2, "": 3, "": 4, "": 5, "": 6, "": 7, "": 8, } VOCAB_SIZE = 32000 UNK_ID = special_symbols[""] CLS_ID = special_symbols[""] SEP_ID = special_symbols[""] MASK_ID = special_symbols[""] EOD_ID = special_symbols[""] SEG_ID_P = 0 SEG_ID_Q = 1 SEG_ID_CLS = 2 SEG_ID_PAD = 3 OnlineMaskingConfig = collections.namedtuple("OnlineMaskingConfig", [ "sample_strategy", "max_num_tokens", "min_num_tokens", "max_num_words", "min_num_words" ]) def file_based_input_fn_builder(input_file, name_to_features, batch_size, is_training): """Creates an `input_fn` closure.""" logging.info("Input tfrecord file %s", input_file) def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example.""" example = tf.io.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.cast(t, tf.int32) example[name] = t return example def input_fn(): """Returns dataset for training/evaluation.""" num_threads = 8 if isinstance(input_file, str): d = tf.data.TFRecordDataset(input_file) # For training, we want a lot of parallel reading and shuffling. # For eval, we want no shuffling and parallel reading doesn't matter. if is_training: d = d.shuffle(2048) d = d.repeat() else: cycle_length = min(num_threads, len(input_file)) d = tf.data.Dataset.from_tensor_slices(input_file) # file level shuffle d = d.shuffle(len(input_file)).repeat() d = d.interleave( tf.data.TFRecordDataset, cycle_length=cycle_length) if is_training: # sample level shuffle d = d.shuffle(buffer_size=2048) d = d.map( lambda record: _decode_record(record, name_to_features), num_parallel_calls=tf.data.experimental.AUTOTUNE) d = d.batch(batch_size, drop_remainder=is_training) # When `input_file` is a path to a single file or a list # containing a single path, disable auto sharding so that # same input file is sent to all workers. if isinstance(input_file, str) or len(input_file) == 1: options = tf.data.Options() options.experimental_distribute.auto_shard_policy = ( tf.data.experimental.AutoShardPolicy.OFF) d = d.with_options(options) d = d.prefetch(tf.data.experimental.AUTOTUNE) return d return input_fn def create_classification_dataset(file_path, seq_length, batch_size, is_training): """Creates input dataset from (tf)records files for pretraining.""" name_to_features = { "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), "input_mask": tf.io.FixedLenFeature([seq_length], tf.float32), "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), "label_ids": tf.io.FixedLenFeature([], tf.int64), "is_real_example": tf.io.FixedLenFeature([], tf.int64), } input_fn = file_based_input_fn_builder(file_path, name_to_features, batch_size, is_training) dataset = input_fn() return dataset def create_squad_dataset(file_path, seq_length, batch_size, is_training): """Creates input dataset from (tf)records files for pretraining.""" name_to_features = { "unique_ids": tf.io.FixedLenFeature([], tf.int64), "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), "input_mask": tf.io.FixedLenFeature([seq_length], tf.float32), "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), "cls_index": tf.io.FixedLenFeature([], tf.int64), "p_mask": tf.io.FixedLenFeature([seq_length], tf.float32) } if is_training: name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64) name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64) name_to_features["is_impossible"] = tf.io.FixedLenFeature([], tf.float32) input_fn = file_based_input_fn_builder(file_path, name_to_features, batch_size, is_training) dataset = input_fn() return dataset def get_input_iterator(input_fn, strategy): """Returns distributed dataset iterator.""" # When training with TPU pods, datasets needs to be cloned across # workers. Since Dataset instance cannot be cloned in eager mode, we instead # pass callable that returns a dataset. input_data = input_fn() if callable(input_data): iterator = iter(strategy.distribute_datasets_from_function(input_data)) else: iterator = iter(strategy.experimental_distribute_dataset(input_data)) return iterator def get_classification_input_data(batch_size, seq_len, strategy, is_training, file_path): """Returns input dataset from input file string.""" # When using TPU pods, we need to clone dataset across # workers and need to pass in function that returns the dataset rather # than passing dataset instance itself. use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy) if use_dataset_fn: if batch_size % strategy.num_replicas_in_sync != 0: raise ValueError( "Batch size must be divisible by number of replicas : {}".format( strategy.num_replicas_in_sync)) # As auto rebatching is not supported in # `distribute_datasets_from_function()` API, which is # required when cloning dataset to multiple workers in eager mode, # we use per-replica batch size. batch_size = int(batch_size / strategy.num_replicas_in_sync) def _dataset_fn(ctx=None): del ctx train_dataset = create_classification_dataset( file_path=file_path, seq_length=seq_len, batch_size=batch_size, is_training=is_training) return train_dataset return _dataset_fn if use_dataset_fn else _dataset_fn() def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training, file_path): """Returns input dataset from input file string.""" # When using TPU pods, we need to clone dataset across # workers and need to pass in function that returns the dataset rather # than passing dataset instance itself. use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy) if use_dataset_fn: if batch_size % strategy.num_replicas_in_sync != 0: raise ValueError( "Batch size must be divisible by number of replicas : {}".format( strategy.num_replicas_in_sync)) # As auto rebatching is not supported in # `distribute_datasets_from_function()` API, which is # required when cloning dataset to multiple workers in eager mode, # we use per-replica batch size. batch_size = int(batch_size / strategy.num_replicas_in_sync) if is_training: input_glob = os.path.join( file_path, "spiece.model.*.slen-{}.qlen-{}.train.tf_record".format(seq_len, q_len)) global_input_paths = tf.io.gfile.glob(input_glob) else: global_input_paths = file_path def _dataset_fn(ctx=None): del ctx train_dataset = create_squad_dataset( file_path=global_input_paths, seq_length=seq_len, batch_size=batch_size, is_training=is_training) return train_dataset return _dataset_fn if use_dataset_fn else _dataset_fn() def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict): """Turn beg and end indices into actual mask.""" non_func_mask = tf.logical_and( tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, CLS_ID)) all_indices = tf.where(non_func_mask, tf.range(tgt_len, dtype=tf.int64), tf.constant(-1, shape=[tgt_len], dtype=tf.int64)) candidate_matrix = tf.cast( tf.logical_and(all_indices[None, :] >= beg_indices[:, None], all_indices[None, :] < end_indices[:, None]), tf.float32) cumsum_matrix = tf.reshape( tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, tgt_len]) masked_matrix = tf.cast(cumsum_matrix <= num_predict, tf.float32) target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0) is_masked = tf.cast(target_mask, tf.bool) return is_masked, target_mask def _word_span_mask(inputs, tgt_len, num_predict, min_num_words, max_num_words, boundary): """Sample whole word spans as prediction targets.""" # Note: 1.2 is the token-to-word ratio mask_alpha = tgt_len / num_predict / 1.2 round_to_int = lambda x: tf.cast(tf.round(x), tf.int64) # Sample span lengths from a zipf distribution span_len_seq = np.arange(min_num_words, max_num_words + 1) probs = np.array([1.0 / (i + 1) for i in span_len_seq]) probs /= np.sum(probs) logits = tf.constant(np.log(probs), dtype=tf.float32) # Sample `num_predict` words here: note that this is over sampling span_lens = tf.random.categorical( logits=logits[None], num_samples=num_predict, dtype=tf.int64, )[0] + min_num_words # Sample the ratio [0.0, 1.0) of left context lengths span_lens_float = tf.cast(span_lens, tf.float32) left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0) left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1) left_ctx_len = round_to_int(left_ctx_len) right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len beg_indices = ( tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True)) end_indices = beg_indices + span_lens # Remove out of range indices max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int64) valid_idx_mask = end_indices < max_boundary_index beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask) end_indices = tf.boolean_mask(end_indices, valid_idx_mask) beg_indices = tf.gather(boundary, beg_indices) end_indices = tf.gather(boundary, end_indices) # Shuffle valid indices num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64) order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64)) beg_indices = tf.gather(beg_indices, order) end_indices = tf.gather(end_indices, order) return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict) def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens, max_num_tokens): """Sample token spans as prediction targets.""" mask_alpha = tgt_len / num_predict round_to_int = lambda x: tf.cast(tf.round(x), tf.int64) # Sample span lengths from a zipf distribution span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1) probs = np.array([1.0 / (i + 1) for i in span_len_seq]) probs /= np.sum(probs) logits = tf.constant(np.log(probs), dtype=tf.float32) span_lens = tf.random.categorical( logits=logits[None], num_samples=num_predict, dtype=tf.int64, )[0] + min_num_tokens # Sample the ratio [0.0, 1.0) of left context lengths span_lens_float = tf.cast(span_lens, tf.float32) left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0) left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1) left_ctx_len = round_to_int(left_ctx_len) # Compute the offset from left start to the right end right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len # Get the actual begin and end indices beg_indices = ( tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True)) end_indices = beg_indices + span_lens # Remove out of range indices valid_idx_mask = end_indices < tgt_len beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask) end_indices = tf.boolean_mask(end_indices, valid_idx_mask) # Shuffle valid indices num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64) order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64)) beg_indices = tf.gather(beg_indices, order) end_indices = tf.gather(end_indices, order) return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict) def _whole_word_mask(inputs, tgt_len, num_predict, boundary): """Sample whole words as prediction targets.""" pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1) cand_pair_indices = tf.random.shuffle(pair_indices)[:num_predict] beg_indices = cand_pair_indices[:, 0] end_indices = cand_pair_indices[:, 1] return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict) def _single_token_mask(inputs, tgt_len, num_predict): """Sample individual tokens as prediction targets.""" all_indices = tf.range(tgt_len, dtype=tf.int64) non_func_mask = tf.logical_and( tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, CLS_ID)) non_func_indices = tf.boolean_mask(all_indices, non_func_mask) masked_pos = tf.random.shuffle(non_func_indices) masked_pos = tf.sort(masked_pos[:num_predict]) target_mask = tf.sparse_to_dense( sparse_indices=masked_pos, output_shape=[tgt_len], sparse_values=1.0, default_value=0.0) is_masked = tf.cast(target_mask, tf.bool) return is_masked, target_mask def _online_sample_masks(inputs, tgt_len, num_predict, online_masking_config, boundary=None): """Sample target positions to predict.""" logging.info("Online sample with strategy: `%s`.", online_masking_config.sample_strategy) if online_masking_config.sample_strategy == "single_token": return _single_token_mask(inputs, tgt_len, num_predict) elif online_masking_config.sample_strategy == "whole_word": assert boundary is not None, "whole word sampling requires `boundary`" return _whole_word_mask(inputs, tgt_len, num_predict, boundary) elif online_masking_config.sample_strategy == "token_span": return _token_span_mask(inputs, tgt_len, num_predict, online_masking_config.min_num_tokens, online_masking_config.max_num_tokens) elif online_masking_config.sample_strategy == "word_span": assert boundary is not None, "word span sampling requires `boundary`" return _word_span_mask(inputs, tgt_len, num_predict, online_masking_config.min_num_words, online_masking_config.max_num_words, boundary) else: raise NotImplementedError def create_pretrain_dataset(file_names, bsz_per_core, seq_len, reuse_len, perm_size, leak_ratio, online_masking_config, num_predict=None, input_pipeline_context=None): """Creates pretrain dataset.""" def parser(record): """Function used to parse tfrecord.""" record_spec = { "input": tf.io.FixedLenFeature([seq_len], tf.int64), "seg_id": tf.io.FixedLenFeature([seq_len], tf.int64), "label": tf.io.FixedLenFeature([1], tf.int64), } if online_masking_config.sample_strategy in ["whole_word", "word_span"]: logging.info("Add `boundary` spec for %s", online_masking_config.sample_strategy) record_spec["boundary"] = tf.io.VarLenFeature(tf.int64) # retrieve serialized example example = tf.io.parse_single_example( serialized=record, features=record_spec) inputs = example.pop("input") if online_masking_config.sample_strategy in ["whole_word", "word_span"]: boundary = tf.sparse.to_dense(example.pop("boundary")) else: boundary = None is_masked, _ = _online_sample_masks( inputs, seq_len, num_predict, online_masking_config, boundary=boundary) if reuse_len > 0: ##### Use memory # permutate the reuse and non-reuse parts separately non_reuse_len = seq_len - reuse_len assert reuse_len % perm_size == 0 and non_reuse_len % perm_size == 0 # Creates permutation mask and target mask for the first reuse_len tokens. # The tokens in this part are reused from the last sequence. perm_mask_0, target_mask_0, input_k_0, input_q_0 = _local_perm( inputs[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len, leak_ratio) # Creates permutation mask and target mask for the rest of tokens in # current example, which are concatentation of two new segments. perm_mask_1, target_mask_1, input_k_1, input_q_1 = _local_perm( inputs[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len, leak_ratio) perm_mask_0 = tf.concat( [perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) perm_mask_1 = tf.concat( [tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) input_k = tf.concat([input_k_0, input_k_1], axis=0) input_q = tf.concat([input_q_0, input_q_1], axis=0) else: ##### Do not use memory assert seq_len % perm_size == 0 # permutate the entire sequence together perm_mask, target_mask, input_k, input_q = _local_perm( inputs, is_masked, perm_size, seq_len, leak_ratio) # reshape back to fixed shape example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) example["input_ids"] = tf.reshape(input_k, [seq_len]) example["input_q"] = tf.reshape(input_q, [seq_len]) # Directly use raw inputs as the target target = inputs if num_predict is not None: indices = tf.range(seq_len, dtype=tf.int64) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[0] pad_len = num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) example["target_mapping"] = tf.reshape(target_mapping, [num_predict, seq_len]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) example["target"] = tf.reshape(target, [num_predict]) ##### target mask target_mask = tf.concat([ tf.ones([actual_num_predict], dtype=tf.float32), tf.zeros([pad_len], dtype=tf.float32) ], axis=0) example["target_mask"] = tf.reshape(target_mask, [num_predict]) else: example["target"] = tf.reshape(target, [seq_len]) example["target_mask"] = tf.reshape(target_mask, [seq_len]) for key in list(example.keys()): val = example[key] if tf_keras.backend.is_sparse(val): val = tf.sparse.to_dense(val) if val.dtype == tf.int64: val = tf.cast(val, tf.int32) example[key] = val for k, v in example.items(): logging.info("%s: %s", k, v) return example dataset = parse_files_to_dataset( parser=parser, file_paths=file_names, bsz_per_core=bsz_per_core, sequential=reuse_len > 0, input_pipeline_context=input_pipeline_context) return dataset def format_filename(prefix, suffix, bsz_per_host, seq_len, reuse_len=None, uncased=False): """Generates input file name pattern.""" if reuse_len is not None and reuse_len > 0: reuse_str = "reuse-{}.".format(reuse_len) bsz_str = "hostbsz-{}.".format(bsz_per_host) else: reuse_str = "" bsz_str = "" if not uncased: case_str = "" else: case_str = "uncased." file_name = "{}.seq-{}.{}{}{}{}".format(prefix, seq_len, reuse_str, bsz_str, case_str, suffix) return file_name def get_pretrain_input_data(batch_size, seq_len, strategy, file_path, reuse_len, perm_size, leak_ratio, num_predict, uncased, online_masking_config, num_hosts=1): """Returns input dataset from input file string.""" # When using TPU pods, we need to clone dataset across # workers and need to pass in function that returns the dataset rather # than passing dataset instance itself. use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy) split = "train" bsz_per_host = int(batch_size / num_hosts) record_glob_base = format_filename( prefix="meta.{}.pass-*".format(split), suffix="json*", bsz_per_host=bsz_per_host, seq_len=seq_len, reuse_len=reuse_len, uncased=uncased) def _get_num_batch(info): if "num_batch" in info: return info["num_batch"] elif "num_example" in info: return info["num_example"] / bsz_per_host else: raise ValueError("Do not have sample info.") if use_dataset_fn: if batch_size % strategy.num_replicas_in_sync != 0: raise ValueError( "Batch size must be divisible by number of replicas : {}".format( strategy.num_replicas_in_sync)) # As auto rebatching is not supported in # `distribute_datasets_from_function()` API, which is # required when cloning dataset to multiple workers in eager mode, # we use per-replica batch size. batch_size = int(batch_size / strategy.num_replicas_in_sync) record_info = {"num_batch": 0, "filenames": []} tfrecord_dirs = file_path.split(",") logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs) for idx, record_dir in enumerate(tfrecord_dirs): record_glob = os.path.join(record_dir, record_glob_base) logging.info("[%d] Record glob: %s", idx, record_glob) record_paths = sorted(tf.io.gfile.glob(record_glob)) logging.info("[%d] Num of record info path: %d", idx, len(record_paths)) cur_record_info = {"num_batch": 0, "filenames": []} for record_info_path in record_paths: with tf.io.gfile.GFile(record_info_path, "r") as fp: info = json.load(fp) cur_record_info["num_batch"] += int(_get_num_batch(info)) cur_record_info["filenames"] += info["filenames"] # overwrite directory for `cur_record_info` new_filenames = [] for filename in cur_record_info["filenames"]: basename = os.path.basename(filename) new_filename = os.path.join(record_dir, basename) new_filenames.append(new_filename) cur_record_info["filenames"] = new_filenames logging.info("[Dir %d] Number of chosen batches: %s", idx, cur_record_info["num_batch"]) logging.info("[Dir %d] Number of chosen files: %s", idx, len(cur_record_info["filenames"])) logging.info(cur_record_info["filenames"]) # add `cur_record_info` to global `record_info` record_info["num_batch"] += cur_record_info["num_batch"] record_info["filenames"] += cur_record_info["filenames"] logging.info("Total number of batches: %d", record_info["num_batch"]) logging.info("Total number of files: %d", len(record_info["filenames"])) logging.info(record_info["filenames"]) def _dataset_fn(ctx=None): """Function that can create a pretrain dataset.""" train_dataset = create_pretrain_dataset( file_names=record_info["filenames"], bsz_per_core=batch_size, seq_len=seq_len, reuse_len=reuse_len, perm_size=perm_size, leak_ratio=leak_ratio, online_masking_config=online_masking_config, num_predict=num_predict, input_pipeline_context=ctx) return train_dataset return _dataset_fn if use_dataset_fn else _dataset_fn() def parse_files_to_dataset(parser, file_paths, bsz_per_core, sequential, input_pipeline_context=None): """Creates the dataset given file paths.""" dataset = tf.data.Dataset.from_tensor_slices(file_paths) # Note: we cannot perform sample-level shuffle here because this will violate # the consecutive requirement of data stream. if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1: dataset = dataset.shard(input_pipeline_context.num_input_pipelines, input_pipeline_context.input_pipeline_id) # file-level shuffle if len(file_paths) > 1: dataset = dataset.shuffle(len(file_paths)) if sequential: # Note: cannot perform sample-level shuffle here because this will violate # the consecutive requirement of data stream. dataset = tf.data.TFRecordDataset(dataset) else: # `cycle_length` is the number of parallel files that get read. cycle_length = min(8, len(file_paths)) logging.info("Interleave %d files", cycle_length) dataset = dataset.apply( tf.data.experimental.parallel_interleave( tf.data.TFRecordDataset, cycle_length=cycle_length)) buffer_size = 2048 logging.info("Perform sample-level shuffle with size %d", buffer_size) dataset = dataset.shuffle(buffer_size=buffer_size) dataset = dataset.cache().repeat().map(parser) dataset = dataset.batch(bsz_per_core, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset def _local_perm(inputs, is_masked, perm_size, seq_len, leak_ratio): """Samples a permutation of the factorization order. Creates perm_mask and target_mask accordingly. Args: inputs: int64 Tensor in shape [seq_len], input ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. leak_ratio: float, percent of masked tokens that are leaked. Returns: perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1. If perm_mask[i][j] == 1, it means the ith token (in original order) cannot attend to the jth token (in original order). This case will happen only when the ith token's permutated position <= the jth token's permutated position, and the jth token is masked or is func token. If perm_mask[i][j] == 0, it means the ith token (in original order) can attend to the jth token (in original order). Note that non-masked tokens can be attended by all other tokens, which is different from the description in original paper. target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If target_mask[i] == 1, the ith token needs to be predicted and mask will be used as input. This token will count for loss. If target_mask[i] == 0, token (or [SEP], [CLS]) will be used as input. This token will not count for loss. inputs_k: int64 Tensor in shape [seq_len], input ids. inputs_q: float32 Tensor in shape [seq_len], the same as target_mask. """ # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.transpose(tf.reshape(index, [-1, perm_size])) index = tf.random.shuffle(index) index = tf.reshape(tf.transpose(index), [-1]) # non-functional tokens non_func_tokens = tf.logical_not( tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) masked_tokens = tf.logical_and(is_masked, non_func_tokens) non_masked_or_func_tokens = tf.logical_not(masked_tokens) smallest_index = -2 * tf.ones([seq_len], dtype=tf.int64) # Similar to BERT, randomly leak some masked tokens if leak_ratio > 0: leak_tokens = tf.logical_and( masked_tokens, tf.random.uniform([seq_len], maxval=1.0) < leak_ratio) can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens) else: can_attend_self = non_masked_or_func_tokens to_index = tf.where(can_attend_self, smallest_index, index) from_index = tf.where(can_attend_self, to_index + 1, to_index) # For masked tokens, can attend if i > j # For context tokens, always can attend each other can_attend = from_index[:, None] > to_index[None, :] # In modeling, 1 indicates cannot attend. Hence, reverse the value here. perm_mask = 1.0 - tf.cast(can_attend, tf.float32) # Only masked tokens are included in the loss target_mask = tf.cast(masked_tokens, tf.float32) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = masked_tokens return perm_mask, target_mask, inputs_k, inputs_q