|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities used for data preparation.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
import collections |
|
import json |
|
import os |
|
from absl import logging |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
special_symbols = { |
|
"<unk>": 0, |
|
"<s>": 1, |
|
"</s>": 2, |
|
"<cls>": 3, |
|
"<sep>": 4, |
|
"<pad>": 5, |
|
"<mask>": 6, |
|
"<eod>": 7, |
|
"<eop>": 8, |
|
} |
|
|
|
VOCAB_SIZE = 32000 |
|
UNK_ID = special_symbols["<unk>"] |
|
CLS_ID = special_symbols["<cls>"] |
|
SEP_ID = special_symbols["<sep>"] |
|
MASK_ID = special_symbols["<mask>"] |
|
EOD_ID = special_symbols["<eod>"] |
|
SEG_ID_P = 0 |
|
SEG_ID_Q = 1 |
|
SEG_ID_CLS = 2 |
|
SEG_ID_PAD = 3 |
|
|
|
|
|
OnlineMaskingConfig = collections.namedtuple("OnlineMaskingConfig", [ |
|
"sample_strategy", "max_num_tokens", "min_num_tokens", "max_num_words", |
|
"min_num_words"]) |
|
|
|
|
|
def file_based_input_fn_builder(input_file, name_to_features, batch_size, |
|
is_training): |
|
"""Creates an `input_fn` closure.""" |
|
|
|
logging.info("Input tfrecord file %s", input_file) |
|
|
|
def _decode_record(record, name_to_features): |
|
"""Decodes a record to a TensorFlow example.""" |
|
example = tf.io.parse_single_example(record, name_to_features) |
|
|
|
|
|
|
|
for name in list(example.keys()): |
|
t = example[name] |
|
if t.dtype == tf.int64: |
|
t = tf.cast(t, tf.int32) |
|
example[name] = t |
|
|
|
return example |
|
|
|
def input_fn(): |
|
"""Returns dataset for training/evaluation.""" |
|
num_threads = 8 |
|
if isinstance(input_file, str): |
|
d = tf.data.TFRecordDataset(input_file) |
|
|
|
|
|
if is_training: |
|
d = d.shuffle(2048) |
|
d = d.repeat() |
|
else: |
|
cycle_length = min(num_threads, len(input_file)) |
|
d = tf.data.Dataset.from_tensor_slices(input_file) |
|
|
|
d = d.shuffle(len(input_file)).repeat() |
|
|
|
d = d.interleave( |
|
tf.data.TFRecordDataset, |
|
sloppy=is_training, |
|
cycle_length=cycle_length) |
|
|
|
if is_training: |
|
|
|
d = d.shuffle(buffer_size=2048) |
|
d = d.map( |
|
lambda record: _decode_record(record, name_to_features), |
|
num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
d = d.batch(batch_size, drop_remainder=is_training) |
|
|
|
|
|
|
|
|
|
if isinstance(input_file, str) or len(input_file) == 1: |
|
options = tf.data.Options() |
|
options.experimental_distribute.auto_shard_policy = ( |
|
tf.data.experimental.AutoShardPolicy.OFF) |
|
d = d.with_options(options) |
|
|
|
d = d.prefetch(tf.data.experimental.AUTOTUNE) |
|
return d |
|
|
|
return input_fn |
|
|
|
|
|
def create_classification_dataset(file_path, seq_length, batch_size, |
|
is_training): |
|
"""Creates input dataset from (tf)records files for pretraining.""" |
|
name_to_features = { |
|
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), |
|
"input_mask": tf.io.FixedLenFeature([seq_length], tf.float32), |
|
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), |
|
"label_ids": tf.io.FixedLenFeature([], tf.int64), |
|
"is_real_example": tf.io.FixedLenFeature([], tf.int64), |
|
} |
|
|
|
input_fn = file_based_input_fn_builder(file_path, name_to_features, |
|
batch_size, is_training) |
|
dataset = input_fn() |
|
return dataset |
|
|
|
|
|
def create_squad_dataset(file_path, seq_length, batch_size, is_training): |
|
"""Creates input dataset from (tf)records files for pretraining.""" |
|
name_to_features = { |
|
"unique_ids": tf.io.FixedLenFeature([], tf.int64), |
|
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), |
|
"input_mask": tf.io.FixedLenFeature([seq_length], tf.float32), |
|
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), |
|
"cls_index": tf.io.FixedLenFeature([], tf.int64), |
|
"p_mask": tf.io.FixedLenFeature([seq_length], tf.float32) |
|
} |
|
|
|
if is_training: |
|
name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64) |
|
name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64) |
|
name_to_features["is_impossible"] = tf.io.FixedLenFeature([], tf.float32) |
|
|
|
input_fn = file_based_input_fn_builder(file_path, name_to_features, |
|
batch_size, is_training) |
|
dataset = input_fn() |
|
return dataset |
|
|
|
|
|
def get_input_iterator(input_fn, strategy): |
|
"""Returns distributed dataset iterator.""" |
|
|
|
|
|
|
|
|
|
input_data = input_fn() |
|
if callable(input_data): |
|
iterator = iter( |
|
strategy.experimental_distribute_datasets_from_function(input_data)) |
|
else: |
|
iterator = iter(strategy.experimental_distribute_dataset(input_data)) |
|
return iterator |
|
|
|
|
|
def get_classification_input_data(batch_size, seq_len, strategy, is_training, |
|
file_path): |
|
"""Returns input dataset from input file string.""" |
|
|
|
|
|
|
|
|
|
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy) |
|
if use_dataset_fn: |
|
if batch_size % strategy.num_replicas_in_sync != 0: |
|
raise ValueError( |
|
"Batch size must be divisible by number of replicas : {}".format( |
|
strategy.num_replicas_in_sync)) |
|
|
|
|
|
|
|
|
|
|
|
batch_size = int(batch_size / strategy.num_replicas_in_sync) |
|
|
|
def _dataset_fn(ctx=None): |
|
del ctx |
|
|
|
train_dataset = create_classification_dataset( |
|
file_path=file_path, |
|
seq_length=seq_len, |
|
batch_size=batch_size, |
|
is_training=is_training) |
|
return train_dataset |
|
|
|
return _dataset_fn if use_dataset_fn else _dataset_fn() |
|
|
|
|
|
def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training, |
|
file_path): |
|
"""Returns input dataset from input file string.""" |
|
|
|
|
|
|
|
|
|
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy) |
|
if use_dataset_fn: |
|
if batch_size % strategy.num_replicas_in_sync != 0: |
|
raise ValueError( |
|
"Batch size must be divisible by number of replicas : {}".format( |
|
strategy.num_replicas_in_sync)) |
|
|
|
|
|
|
|
|
|
|
|
batch_size = int(batch_size / strategy.num_replicas_in_sync) |
|
|
|
if is_training: |
|
input_glob = os.path.join( |
|
file_path, |
|
"spiece.model.*.slen-{}.qlen-{}.train.tf_record".format(seq_len, q_len)) |
|
|
|
global_input_paths = tf.io.gfile.glob(input_glob) |
|
else: |
|
global_input_paths = file_path |
|
|
|
def _dataset_fn(ctx=None): |
|
del ctx |
|
|
|
train_dataset = create_squad_dataset( |
|
file_path=global_input_paths, |
|
seq_length=seq_len, |
|
batch_size=batch_size, |
|
is_training=is_training) |
|
return train_dataset |
|
|
|
return _dataset_fn if use_dataset_fn else _dataset_fn() |
|
|
|
|
|
def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict): |
|
"""Turn beg and end indices into actual mask.""" |
|
non_func_mask = tf.logical_and( |
|
tf.not_equal(inputs, SEP_ID), |
|
tf.not_equal(inputs, CLS_ID)) |
|
all_indices = tf.where( |
|
non_func_mask, |
|
tf.range(tgt_len, dtype=tf.int64), |
|
tf.constant(-1, shape=[tgt_len], dtype=tf.int64)) |
|
candidate_matrix = tf.cast( |
|
tf.logical_and( |
|
all_indices[None, :] >= beg_indices[:, None], |
|
all_indices[None, :] < end_indices[:, None]), |
|
tf.float32) |
|
cumsum_matrix = tf.reshape( |
|
tf.cumsum(tf.reshape(candidate_matrix, [-1])), |
|
[-1, tgt_len]) |
|
masked_matrix = tf.cast(cumsum_matrix <= num_predict, tf.float32) |
|
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0) |
|
is_masked = tf.cast(target_mask, tf.bool) |
|
|
|
return is_masked, target_mask |
|
|
|
|
|
def _word_span_mask(inputs, tgt_len, num_predict, min_num_words, |
|
max_num_words, boundary): |
|
"""Sample whole word spans as prediction targets.""" |
|
|
|
mask_alpha = tgt_len / num_predict / 1.2 |
|
round_to_int = lambda x: tf.cast(tf.round(x), tf.int64) |
|
|
|
|
|
span_len_seq = np.arange(min_num_words, max_num_words + 1) |
|
probs = np.array([1.0 / (i + 1) for i in span_len_seq]) |
|
probs /= np.sum(probs) |
|
logits = tf.constant(np.log(probs), dtype=tf.float32) |
|
|
|
|
|
span_lens = tf.random.categorical( |
|
logits=logits[None], |
|
num_samples=num_predict, |
|
dtype=tf.int64, |
|
)[0] + min_num_words |
|
|
|
|
|
span_lens_float = tf.cast(span_lens, tf.float32) |
|
left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0) |
|
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1) |
|
|
|
left_ctx_len = round_to_int(left_ctx_len) |
|
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len |
|
|
|
beg_indices = (tf.cumsum(left_ctx_len) + |
|
tf.cumsum(right_offset, exclusive=True)) |
|
end_indices = beg_indices + span_lens |
|
|
|
|
|
max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int64) |
|
valid_idx_mask = end_indices < max_boundary_index |
|
beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask) |
|
end_indices = tf.boolean_mask(end_indices, valid_idx_mask) |
|
|
|
beg_indices = tf.gather(boundary, beg_indices) |
|
end_indices = tf.gather(boundary, end_indices) |
|
|
|
|
|
num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64) |
|
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64)) |
|
beg_indices = tf.gather(beg_indices, order) |
|
end_indices = tf.gather(end_indices, order) |
|
|
|
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, |
|
num_predict) |
|
|
|
|
|
def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens, |
|
max_num_tokens): |
|
"""Sample token spans as prediction targets.""" |
|
mask_alpha = tgt_len / num_predict |
|
round_to_int = lambda x: tf.cast(tf.round(x), tf.int64) |
|
|
|
|
|
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1) |
|
probs = np.array([1.0 / (i + 1) for i in span_len_seq]) |
|
|
|
probs /= np.sum(probs) |
|
logits = tf.constant(np.log(probs), dtype=tf.float32) |
|
span_lens = tf.random.categorical( |
|
logits=logits[None], |
|
num_samples=num_predict, |
|
dtype=tf.int64, |
|
)[0] + min_num_tokens |
|
|
|
|
|
span_lens_float = tf.cast(span_lens, tf.float32) |
|
left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0) |
|
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1) |
|
left_ctx_len = round_to_int(left_ctx_len) |
|
|
|
|
|
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len |
|
|
|
|
|
beg_indices = (tf.cumsum(left_ctx_len) + |
|
tf.cumsum(right_offset, exclusive=True)) |
|
end_indices = beg_indices + span_lens |
|
|
|
|
|
valid_idx_mask = end_indices < tgt_len |
|
beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask) |
|
end_indices = tf.boolean_mask(end_indices, valid_idx_mask) |
|
|
|
|
|
num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64) |
|
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64)) |
|
beg_indices = tf.gather(beg_indices, order) |
|
end_indices = tf.gather(end_indices, order) |
|
|
|
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, |
|
num_predict) |
|
|
|
|
|
def _whole_word_mask(inputs, tgt_len, num_predict, boundary): |
|
"""Sample whole words as prediction targets.""" |
|
pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1) |
|
cand_pair_indices = tf.random.shuffle(pair_indices)[:num_predict] |
|
beg_indices = cand_pair_indices[:, 0] |
|
end_indices = cand_pair_indices[:, 1] |
|
|
|
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, |
|
num_predict) |
|
|
|
|
|
def _single_token_mask(inputs, tgt_len, num_predict): |
|
"""Sample individual tokens as prediction targets.""" |
|
all_indices = tf.range(tgt_len, dtype=tf.int64) |
|
non_func_mask = tf.logical_and( |
|
tf.not_equal(inputs, SEP_ID), |
|
tf.not_equal(inputs, CLS_ID)) |
|
non_func_indices = tf.boolean_mask(all_indices, non_func_mask) |
|
|
|
masked_pos = tf.random.shuffle(non_func_indices) |
|
masked_pos = tf.sort(masked_pos[:num_predict]) |
|
target_mask = tf.sparse_to_dense( |
|
sparse_indices=masked_pos, |
|
output_shape=[tgt_len], |
|
sparse_values=1.0, |
|
default_value=0.0) |
|
|
|
is_masked = tf.cast(target_mask, tf.bool) |
|
|
|
return is_masked, target_mask |
|
|
|
|
|
def _online_sample_masks(inputs, tgt_len, num_predict, online_masking_config, |
|
boundary=None): |
|
"""Sample target positions to predict.""" |
|
logging.info("Online sample with strategy: `%s`.", |
|
online_masking_config.sample_strategy) |
|
if online_masking_config.sample_strategy == "single_token": |
|
return _single_token_mask(inputs, tgt_len, num_predict) |
|
elif online_masking_config.sample_strategy == "whole_word": |
|
assert boundary is not None, "whole word sampling requires `boundary`" |
|
return _whole_word_mask(inputs, tgt_len, num_predict, boundary) |
|
elif online_masking_config.sample_strategy == "token_span": |
|
return _token_span_mask(inputs, tgt_len, num_predict, |
|
online_masking_config.min_num_tokens, |
|
online_masking_config.max_num_tokens) |
|
elif online_masking_config.sample_strategy == "word_span": |
|
assert boundary is not None, "word span sampling requires `boundary`" |
|
return _word_span_mask(inputs, tgt_len, num_predict, |
|
online_masking_config.min_num_words, |
|
online_masking_config.max_num_words, |
|
boundary) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def create_pretrain_dataset(file_names, |
|
bsz_per_core, |
|
seq_len, |
|
reuse_len, |
|
perm_size, |
|
leak_ratio, |
|
online_masking_config, |
|
num_predict=None, |
|
input_pipeline_context=None): |
|
"""Creates pretrain dataset.""" |
|
|
|
def parser(record): |
|
"""Function used to parse tfrecord.""" |
|
|
|
record_spec = { |
|
"input": tf.io.FixedLenFeature([seq_len], tf.int64), |
|
"seg_id": tf.io.FixedLenFeature([seq_len], tf.int64), |
|
"label": tf.io.FixedLenFeature([1], tf.int64), |
|
} |
|
|
|
if online_masking_config.sample_strategy in ["whole_word", "word_span"]: |
|
logging.info("Add `boundary` spec for %s", |
|
online_masking_config.sample_strategy) |
|
record_spec["boundary"] = tf.io.VarLenFeature(tf.int64) |
|
|
|
|
|
example = tf.io.parse_single_example( |
|
serialized=record, features=record_spec) |
|
|
|
inputs = example.pop("input") |
|
if online_masking_config.sample_strategy in ["whole_word", "word_span"]: |
|
boundary = tf.sparse.to_dense(example.pop("boundary")) |
|
else: |
|
boundary = None |
|
is_masked, _ = _online_sample_masks( |
|
inputs, seq_len, num_predict, online_masking_config, boundary=boundary) |
|
|
|
if reuse_len > 0: |
|
|
|
|
|
non_reuse_len = seq_len - reuse_len |
|
assert reuse_len % perm_size == 0 and non_reuse_len % perm_size == 0 |
|
|
|
|
|
|
|
perm_mask_0, target_mask_0, input_k_0, input_q_0 = _local_perm( |
|
inputs[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len, |
|
leak_ratio) |
|
|
|
|
|
|
|
perm_mask_1, target_mask_1, input_k_1, input_q_1 = _local_perm( |
|
inputs[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len, |
|
leak_ratio) |
|
|
|
perm_mask_0 = tf.concat( |
|
[perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) |
|
perm_mask_1 = tf.concat( |
|
[tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) |
|
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) |
|
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) |
|
input_k = tf.concat([input_k_0, input_k_1], axis=0) |
|
input_q = tf.concat([input_q_0, input_q_1], axis=0) |
|
else: |
|
|
|
assert seq_len % perm_size == 0 |
|
|
|
perm_mask, target_mask, input_k, input_q = _local_perm( |
|
inputs, is_masked, perm_size, seq_len, leak_ratio) |
|
|
|
|
|
example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) |
|
example["input_k"] = tf.reshape(input_k, [seq_len]) |
|
example["input_q"] = tf.reshape(input_q, [seq_len]) |
|
|
|
|
|
target = inputs |
|
|
|
if num_predict is not None: |
|
indices = tf.range(seq_len, dtype=tf.int64) |
|
bool_target_mask = tf.cast(target_mask, tf.bool) |
|
indices = tf.boolean_mask(indices, bool_target_mask) |
|
|
|
|
|
actual_num_predict = tf.shape(indices)[0] |
|
pad_len = num_predict - actual_num_predict |
|
|
|
|
|
target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) |
|
paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) |
|
target_mapping = tf.concat([target_mapping, paddings], axis=0) |
|
example["target_mapping"] = tf.reshape(target_mapping, |
|
[num_predict, seq_len]) |
|
|
|
|
|
target = tf.boolean_mask(target, bool_target_mask) |
|
paddings = tf.zeros([pad_len], dtype=target.dtype) |
|
target = tf.concat([target, paddings], axis=0) |
|
example["target"] = tf.reshape(target, [num_predict]) |
|
|
|
|
|
target_mask = tf.concat( |
|
[tf.ones([actual_num_predict], dtype=tf.float32), |
|
tf.zeros([pad_len], dtype=tf.float32)], |
|
axis=0) |
|
example["target_mask"] = tf.reshape(target_mask, [num_predict]) |
|
else: |
|
example["target"] = tf.reshape(target, [seq_len]) |
|
example["target_mask"] = tf.reshape(target_mask, [seq_len]) |
|
|
|
for key in list(example.keys()): |
|
val = example[key] |
|
if tf.keras.backend.is_sparse(val): |
|
val = tf.sparse.to_dense(val) |
|
if val.dtype == tf.int64: |
|
val = tf.cast(val, tf.int32) |
|
|
|
example[key] = val |
|
|
|
for k, v in example.items(): |
|
logging.info("%s: %s", k, v) |
|
|
|
return example |
|
|
|
dataset = parse_files_to_dataset( |
|
parser=parser, |
|
file_paths=file_names, |
|
bsz_per_core=bsz_per_core, |
|
sequential=reuse_len > 0, |
|
input_pipeline_context=input_pipeline_context) |
|
|
|
return dataset |
|
|
|
|
|
def format_filename(prefix, suffix, bsz_per_host, seq_len, reuse_len=None, |
|
uncased=False): |
|
"""Generates input file name pattern.""" |
|
if reuse_len is not None and reuse_len > 0: |
|
reuse_str = "reuse-{}.".format(reuse_len) |
|
bsz_str = "hostbsz-{}.".format(bsz_per_host) |
|
else: |
|
reuse_str = "" |
|
bsz_str = "" |
|
|
|
if not uncased: |
|
case_str = "" |
|
else: |
|
case_str = "uncased." |
|
|
|
file_name = "{}.seq-{}.{}{}{}{}".format( |
|
prefix, seq_len, reuse_str, bsz_str, case_str, suffix) |
|
|
|
return file_name |
|
|
|
|
|
def get_pretrain_input_data(batch_size, |
|
seq_len, |
|
strategy, |
|
file_path, |
|
reuse_len, |
|
perm_size, |
|
leak_ratio, |
|
num_predict, |
|
uncased, |
|
online_masking_config, |
|
num_hosts=1): |
|
"""Returns input dataset from input file string.""" |
|
|
|
|
|
|
|
|
|
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy) |
|
split = "train" |
|
bsz_per_host = int(batch_size / num_hosts) |
|
record_glob_base = format_filename( |
|
prefix="meta.{}.pass-*".format(split), |
|
suffix="json*", |
|
bsz_per_host=bsz_per_host, |
|
seq_len=seq_len, |
|
reuse_len=reuse_len, |
|
uncased=uncased) |
|
|
|
def _get_num_batch(info): |
|
if "num_batch" in info: |
|
return info["num_batch"] |
|
elif "num_example" in info: |
|
return info["num_example"] / bsz_per_host |
|
else: |
|
raise ValueError("Do not have sample info.") |
|
|
|
if use_dataset_fn: |
|
if batch_size % strategy.num_replicas_in_sync != 0: |
|
raise ValueError( |
|
"Batch size must be divisible by number of replicas : {}".format( |
|
strategy.num_replicas_in_sync)) |
|
|
|
|
|
|
|
|
|
|
|
batch_size = int(batch_size / strategy.num_replicas_in_sync) |
|
|
|
record_info = {"num_batch": 0, "filenames": []} |
|
|
|
tfrecord_dirs = file_path.split(",") |
|
logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs) |
|
|
|
for idx, record_dir in enumerate(tfrecord_dirs): |
|
record_glob = os.path.join(record_dir, record_glob_base) |
|
logging.info("[%d] Record glob: %s", idx, record_glob) |
|
|
|
record_paths = sorted(tf.io.gfile.glob(record_glob)) |
|
logging.info("[%d] Num of record info path: %d", idx, len(record_paths)) |
|
|
|
cur_record_info = {"num_batch": 0, "filenames": []} |
|
|
|
for record_info_path in record_paths: |
|
with tf.io.gfile.GFile(record_info_path, "r") as fp: |
|
info = json.load(fp) |
|
cur_record_info["num_batch"] += int(_get_num_batch(info)) |
|
cur_record_info["filenames"] += info["filenames"] |
|
|
|
|
|
new_filenames = [] |
|
for filename in cur_record_info["filenames"]: |
|
basename = os.path.basename(filename) |
|
new_filename = os.path.join(record_dir, basename) |
|
new_filenames.append(new_filename) |
|
cur_record_info["filenames"] = new_filenames |
|
|
|
logging.info("[Dir %d] Number of chosen batches: %s", idx, |
|
cur_record_info["num_batch"]) |
|
logging.info("[Dir %d] Number of chosen files: %s", idx, |
|
len(cur_record_info["filenames"])) |
|
logging.info(cur_record_info["filenames"]) |
|
|
|
|
|
record_info["num_batch"] += cur_record_info["num_batch"] |
|
record_info["filenames"] += cur_record_info["filenames"] |
|
|
|
logging.info("Total number of batches: %d", record_info["num_batch"]) |
|
logging.info("Total number of files: %d", len(record_info["filenames"])) |
|
logging.info(record_info["filenames"]) |
|
|
|
def _dataset_fn(ctx=None): |
|
"""Function that can create a pretrain dataset.""" |
|
|
|
train_dataset = create_pretrain_dataset( |
|
file_names=record_info["filenames"], |
|
bsz_per_core=batch_size, |
|
seq_len=seq_len, |
|
reuse_len=reuse_len, |
|
perm_size=perm_size, |
|
leak_ratio=leak_ratio, |
|
online_masking_config=online_masking_config, |
|
num_predict=num_predict, |
|
input_pipeline_context=ctx) |
|
return train_dataset |
|
|
|
return _dataset_fn if use_dataset_fn else _dataset_fn() |
|
|
|
|
|
def parse_files_to_dataset(parser, |
|
file_paths, |
|
bsz_per_core, |
|
sequential, |
|
input_pipeline_context=None): |
|
"""Creates the dataset given file paths.""" |
|
|
|
dataset = tf.data.Dataset.from_tensor_slices(file_paths) |
|
|
|
|
|
|
|
|
|
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1: |
|
dataset = dataset.shard(input_pipeline_context.num_input_pipelines, |
|
input_pipeline_context.input_pipeline_id) |
|
|
|
if len(file_paths) > 1: |
|
dataset = dataset.shuffle(len(file_paths)) |
|
|
|
if sequential: |
|
|
|
|
|
dataset = tf.data.TFRecordDataset(dataset) |
|
else: |
|
|
|
cycle_length = min(8, len(file_paths)) |
|
logging.info("Interleave %d files", cycle_length) |
|
|
|
|
|
|
|
dataset = dataset.apply( |
|
tf.data.experimental.parallel_interleave( |
|
tf.data.TFRecordDataset, |
|
sloppy=True, |
|
cycle_length=cycle_length)) |
|
buffer_size = 2048 |
|
logging.info("Perform sample-level shuffle with size %d", buffer_size) |
|
dataset = dataset.shuffle(buffer_size=buffer_size) |
|
|
|
dataset = dataset.cache().repeat().map(parser) |
|
dataset = dataset.batch(bsz_per_core, drop_remainder=True) |
|
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) |
|
|
|
return dataset |
|
|
|
|
|
def _local_perm(inputs, is_masked, perm_size, seq_len, leak_ratio): |
|
"""Samples a permutation of the factorization order. |
|
|
|
Creates perm_mask and target_mask accordingly. |
|
|
|
Args: |
|
inputs: int64 Tensor in shape [seq_len], input ids. |
|
is_masked: bool Tensor in shape [seq_len]. True means being selected for |
|
partial prediction. |
|
perm_size: the length of longest permutation. Could be set to be reuse_len. |
|
Should not be larger than reuse_len or there will be data leaks. |
|
seq_len: int, sequence length. |
|
leak_ratio: float, percent of masked tokens that are leaked. |
|
|
|
Returns: |
|
perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1. |
|
If perm_mask[i][j] == 1, it means the ith token (in original order) cannot |
|
attend to the jth token |
|
(in original order). This case will happen only when the ith token's |
|
permutated position <= the jth token's permutated position, |
|
and the jth token is masked or is func token. If perm_mask[i][j] == 0, it |
|
means the ith token (in original order) can attend to the jth token |
|
(in original order). Note that non-masked tokens can be attended by all |
|
other tokens, which is different from the description in original paper. |
|
target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If |
|
target_mask[i] == 1, |
|
the ith token needs to be predicted and mask will be used as input. This |
|
token will count for loss. |
|
If target_mask[i] == 0, token (or [SEP], [CLS]) will be used as input. This |
|
token will not count for loss. |
|
inputs_k: int64 Tensor in shape [seq_len], input ids. |
|
inputs_q: float32 Tensor in shape [seq_len], the same as target_mask. |
|
|
|
""" |
|
|
|
|
|
index = tf.range(seq_len, dtype=tf.int64) |
|
index = tf.transpose(tf.reshape(index, [-1, perm_size])) |
|
index = tf.random.shuffle(index) |
|
index = tf.reshape(tf.transpose(index), [-1]) |
|
|
|
|
|
non_func_tokens = tf.logical_not(tf.logical_or( |
|
tf.equal(inputs, SEP_ID), |
|
tf.equal(inputs, CLS_ID))) |
|
masked_tokens = tf.logical_and(is_masked, non_func_tokens) |
|
non_masked_or_func_tokens = tf.logical_not(masked_tokens) |
|
|
|
smallest_index = -2 * tf.ones([seq_len], dtype=tf.int64) |
|
|
|
|
|
if leak_ratio > 0: |
|
leak_tokens = tf.logical_and( |
|
masked_tokens, |
|
tf.random.uniform([seq_len], maxval=1.0) < leak_ratio) |
|
can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens) |
|
else: |
|
can_attend_self = non_masked_or_func_tokens |
|
to_index = tf.where(can_attend_self, smallest_index, index) |
|
from_index = tf.where(can_attend_self, to_index + 1, to_index) |
|
|
|
|
|
|
|
can_attend = from_index[:, None] > to_index[None, :] |
|
|
|
|
|
perm_mask = 1.0 - tf.cast(can_attend, tf.float32) |
|
|
|
|
|
target_mask = tf.cast(masked_tokens, tf.float32) |
|
|
|
|
|
inputs_k = inputs |
|
|
|
|
|
inputs_q = masked_tokens |
|
|
|
return perm_mask, target_mask, inputs_k, inputs_q |
|
|