Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Utilities used for data preparation.""" | |
import collections | |
import json | |
import os | |
from absl import logging | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
special_symbols = { | |
"<unk>": 0, | |
"<s>": 1, | |
"</s>": 2, | |
"<cls>": 3, | |
"<sep>": 4, | |
"<pad>": 5, | |
"<mask>": 6, | |
"<eod>": 7, | |
"<eop>": 8, | |
} | |
VOCAB_SIZE = 32000 | |
UNK_ID = special_symbols["<unk>"] | |
CLS_ID = special_symbols["<cls>"] | |
SEP_ID = special_symbols["<sep>"] | |
MASK_ID = special_symbols["<mask>"] | |
EOD_ID = special_symbols["<eod>"] | |
SEG_ID_P = 0 | |
SEG_ID_Q = 1 | |
SEG_ID_CLS = 2 | |
SEG_ID_PAD = 3 | |
OnlineMaskingConfig = collections.namedtuple("OnlineMaskingConfig", [ | |
"sample_strategy", "max_num_tokens", "min_num_tokens", "max_num_words", | |
"min_num_words" | |
]) | |
def file_based_input_fn_builder(input_file, name_to_features, batch_size, | |
is_training): | |
"""Creates an `input_fn` closure.""" | |
logging.info("Input tfrecord file %s", input_file) | |
def _decode_record(record, name_to_features): | |
"""Decodes a record to a TensorFlow example.""" | |
example = tf.io.parse_single_example(record, name_to_features) | |
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. | |
# So cast all int64 to int32. | |
for name in list(example.keys()): | |
t = example[name] | |
if t.dtype == tf.int64: | |
t = tf.cast(t, tf.int32) | |
example[name] = t | |
return example | |
def input_fn(): | |
"""Returns dataset for training/evaluation.""" | |
num_threads = 8 | |
if isinstance(input_file, str): | |
d = tf.data.TFRecordDataset(input_file) | |
# For training, we want a lot of parallel reading and shuffling. | |
# For eval, we want no shuffling and parallel reading doesn't matter. | |
if is_training: | |
d = d.shuffle(2048) | |
d = d.repeat() | |
else: | |
cycle_length = min(num_threads, len(input_file)) | |
d = tf.data.Dataset.from_tensor_slices(input_file) | |
# file level shuffle | |
d = d.shuffle(len(input_file)).repeat() | |
d = d.interleave( | |
tf.data.TFRecordDataset, | |
cycle_length=cycle_length) | |
if is_training: | |
# sample level shuffle | |
d = d.shuffle(buffer_size=2048) | |
d = d.map( | |
lambda record: _decode_record(record, name_to_features), | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
d = d.batch(batch_size, drop_remainder=is_training) | |
# When `input_file` is a path to a single file or a list | |
# containing a single path, disable auto sharding so that | |
# same input file is sent to all workers. | |
if isinstance(input_file, str) or len(input_file) == 1: | |
options = tf.data.Options() | |
options.experimental_distribute.auto_shard_policy = ( | |
tf.data.experimental.AutoShardPolicy.OFF) | |
d = d.with_options(options) | |
d = d.prefetch(tf.data.experimental.AUTOTUNE) | |
return d | |
return input_fn | |
def create_classification_dataset(file_path, seq_length, batch_size, | |
is_training): | |
"""Creates input dataset from (tf)records files for pretraining.""" | |
name_to_features = { | |
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), | |
"input_mask": tf.io.FixedLenFeature([seq_length], tf.float32), | |
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), | |
"label_ids": tf.io.FixedLenFeature([], tf.int64), | |
"is_real_example": tf.io.FixedLenFeature([], tf.int64), | |
} | |
input_fn = file_based_input_fn_builder(file_path, name_to_features, | |
batch_size, is_training) | |
dataset = input_fn() | |
return dataset | |
def create_squad_dataset(file_path, seq_length, batch_size, is_training): | |
"""Creates input dataset from (tf)records files for pretraining.""" | |
name_to_features = { | |
"unique_ids": tf.io.FixedLenFeature([], tf.int64), | |
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), | |
"input_mask": tf.io.FixedLenFeature([seq_length], tf.float32), | |
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), | |
"cls_index": tf.io.FixedLenFeature([], tf.int64), | |
"p_mask": tf.io.FixedLenFeature([seq_length], tf.float32) | |
} | |
if is_training: | |
name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64) | |
name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64) | |
name_to_features["is_impossible"] = tf.io.FixedLenFeature([], tf.float32) | |
input_fn = file_based_input_fn_builder(file_path, name_to_features, | |
batch_size, is_training) | |
dataset = input_fn() | |
return dataset | |
def get_input_iterator(input_fn, strategy): | |
"""Returns distributed dataset iterator.""" | |
# When training with TPU pods, datasets needs to be cloned across | |
# workers. Since Dataset instance cannot be cloned in eager mode, we instead | |
# pass callable that returns a dataset. | |
input_data = input_fn() | |
if callable(input_data): | |
iterator = iter(strategy.distribute_datasets_from_function(input_data)) | |
else: | |
iterator = iter(strategy.experimental_distribute_dataset(input_data)) | |
return iterator | |
def get_classification_input_data(batch_size, seq_len, strategy, is_training, | |
file_path): | |
"""Returns input dataset from input file string.""" | |
# When using TPU pods, we need to clone dataset across | |
# workers and need to pass in function that returns the dataset rather | |
# than passing dataset instance itself. | |
use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy) | |
if use_dataset_fn: | |
if batch_size % strategy.num_replicas_in_sync != 0: | |
raise ValueError( | |
"Batch size must be divisible by number of replicas : {}".format( | |
strategy.num_replicas_in_sync)) | |
# As auto rebatching is not supported in | |
# `distribute_datasets_from_function()` API, which is | |
# required when cloning dataset to multiple workers in eager mode, | |
# we use per-replica batch size. | |
batch_size = int(batch_size / strategy.num_replicas_in_sync) | |
def _dataset_fn(ctx=None): | |
del ctx | |
train_dataset = create_classification_dataset( | |
file_path=file_path, | |
seq_length=seq_len, | |
batch_size=batch_size, | |
is_training=is_training) | |
return train_dataset | |
return _dataset_fn if use_dataset_fn else _dataset_fn() | |
def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training, | |
file_path): | |
"""Returns input dataset from input file string.""" | |
# When using TPU pods, we need to clone dataset across | |
# workers and need to pass in function that returns the dataset rather | |
# than passing dataset instance itself. | |
use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy) | |
if use_dataset_fn: | |
if batch_size % strategy.num_replicas_in_sync != 0: | |
raise ValueError( | |
"Batch size must be divisible by number of replicas : {}".format( | |
strategy.num_replicas_in_sync)) | |
# As auto rebatching is not supported in | |
# `distribute_datasets_from_function()` API, which is | |
# required when cloning dataset to multiple workers in eager mode, | |
# we use per-replica batch size. | |
batch_size = int(batch_size / strategy.num_replicas_in_sync) | |
if is_training: | |
input_glob = os.path.join( | |
file_path, | |
"spiece.model.*.slen-{}.qlen-{}.train.tf_record".format(seq_len, q_len)) | |
global_input_paths = tf.io.gfile.glob(input_glob) | |
else: | |
global_input_paths = file_path | |
def _dataset_fn(ctx=None): | |
del ctx | |
train_dataset = create_squad_dataset( | |
file_path=global_input_paths, | |
seq_length=seq_len, | |
batch_size=batch_size, | |
is_training=is_training) | |
return train_dataset | |
return _dataset_fn if use_dataset_fn else _dataset_fn() | |
def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict): | |
"""Turn beg and end indices into actual mask.""" | |
non_func_mask = tf.logical_and( | |
tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, CLS_ID)) | |
all_indices = tf.where(non_func_mask, tf.range(tgt_len, dtype=tf.int64), | |
tf.constant(-1, shape=[tgt_len], dtype=tf.int64)) | |
candidate_matrix = tf.cast( | |
tf.logical_and(all_indices[None, :] >= beg_indices[:, None], | |
all_indices[None, :] < end_indices[:, None]), tf.float32) | |
cumsum_matrix = tf.reshape( | |
tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, tgt_len]) | |
masked_matrix = tf.cast(cumsum_matrix <= num_predict, tf.float32) | |
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0) | |
is_masked = tf.cast(target_mask, tf.bool) | |
return is_masked, target_mask | |
def _word_span_mask(inputs, tgt_len, num_predict, min_num_words, max_num_words, | |
boundary): | |
"""Sample whole word spans as prediction targets.""" | |
# Note: 1.2 is the token-to-word ratio | |
mask_alpha = tgt_len / num_predict / 1.2 | |
round_to_int = lambda x: tf.cast(tf.round(x), tf.int64) | |
# Sample span lengths from a zipf distribution | |
span_len_seq = np.arange(min_num_words, max_num_words + 1) | |
probs = np.array([1.0 / (i + 1) for i in span_len_seq]) | |
probs /= np.sum(probs) | |
logits = tf.constant(np.log(probs), dtype=tf.float32) | |
# Sample `num_predict` words here: note that this is over sampling | |
span_lens = tf.random.categorical( | |
logits=logits[None], | |
num_samples=num_predict, | |
dtype=tf.int64, | |
)[0] + min_num_words | |
# Sample the ratio [0.0, 1.0) of left context lengths | |
span_lens_float = tf.cast(span_lens, tf.float32) | |
left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0) | |
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1) | |
left_ctx_len = round_to_int(left_ctx_len) | |
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len | |
beg_indices = ( | |
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True)) | |
end_indices = beg_indices + span_lens | |
# Remove out of range indices | |
max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int64) | |
valid_idx_mask = end_indices < max_boundary_index | |
beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask) | |
end_indices = tf.boolean_mask(end_indices, valid_idx_mask) | |
beg_indices = tf.gather(boundary, beg_indices) | |
end_indices = tf.gather(boundary, end_indices) | |
# Shuffle valid indices | |
num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64) | |
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64)) | |
beg_indices = tf.gather(beg_indices, order) | |
end_indices = tf.gather(end_indices, order) | |
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, | |
num_predict) | |
def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens, | |
max_num_tokens): | |
"""Sample token spans as prediction targets.""" | |
mask_alpha = tgt_len / num_predict | |
round_to_int = lambda x: tf.cast(tf.round(x), tf.int64) | |
# Sample span lengths from a zipf distribution | |
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1) | |
probs = np.array([1.0 / (i + 1) for i in span_len_seq]) | |
probs /= np.sum(probs) | |
logits = tf.constant(np.log(probs), dtype=tf.float32) | |
span_lens = tf.random.categorical( | |
logits=logits[None], | |
num_samples=num_predict, | |
dtype=tf.int64, | |
)[0] + min_num_tokens | |
# Sample the ratio [0.0, 1.0) of left context lengths | |
span_lens_float = tf.cast(span_lens, tf.float32) | |
left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0) | |
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1) | |
left_ctx_len = round_to_int(left_ctx_len) | |
# Compute the offset from left start to the right end | |
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len | |
# Get the actual begin and end indices | |
beg_indices = ( | |
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True)) | |
end_indices = beg_indices + span_lens | |
# Remove out of range indices | |
valid_idx_mask = end_indices < tgt_len | |
beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask) | |
end_indices = tf.boolean_mask(end_indices, valid_idx_mask) | |
# Shuffle valid indices | |
num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64) | |
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64)) | |
beg_indices = tf.gather(beg_indices, order) | |
end_indices = tf.gather(end_indices, order) | |
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, | |
num_predict) | |
def _whole_word_mask(inputs, tgt_len, num_predict, boundary): | |
"""Sample whole words as prediction targets.""" | |
pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1) | |
cand_pair_indices = tf.random.shuffle(pair_indices)[:num_predict] | |
beg_indices = cand_pair_indices[:, 0] | |
end_indices = cand_pair_indices[:, 1] | |
return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, | |
num_predict) | |
def _single_token_mask(inputs, tgt_len, num_predict): | |
"""Sample individual tokens as prediction targets.""" | |
all_indices = tf.range(tgt_len, dtype=tf.int64) | |
non_func_mask = tf.logical_and( | |
tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, CLS_ID)) | |
non_func_indices = tf.boolean_mask(all_indices, non_func_mask) | |
masked_pos = tf.random.shuffle(non_func_indices) | |
masked_pos = tf.sort(masked_pos[:num_predict]) | |
target_mask = tf.sparse_to_dense( | |
sparse_indices=masked_pos, | |
output_shape=[tgt_len], | |
sparse_values=1.0, | |
default_value=0.0) | |
is_masked = tf.cast(target_mask, tf.bool) | |
return is_masked, target_mask | |
def _online_sample_masks(inputs, | |
tgt_len, | |
num_predict, | |
online_masking_config, | |
boundary=None): | |
"""Sample target positions to predict.""" | |
logging.info("Online sample with strategy: `%s`.", | |
online_masking_config.sample_strategy) | |
if online_masking_config.sample_strategy == "single_token": | |
return _single_token_mask(inputs, tgt_len, num_predict) | |
elif online_masking_config.sample_strategy == "whole_word": | |
assert boundary is not None, "whole word sampling requires `boundary`" | |
return _whole_word_mask(inputs, tgt_len, num_predict, boundary) | |
elif online_masking_config.sample_strategy == "token_span": | |
return _token_span_mask(inputs, tgt_len, num_predict, | |
online_masking_config.min_num_tokens, | |
online_masking_config.max_num_tokens) | |
elif online_masking_config.sample_strategy == "word_span": | |
assert boundary is not None, "word span sampling requires `boundary`" | |
return _word_span_mask(inputs, tgt_len, num_predict, | |
online_masking_config.min_num_words, | |
online_masking_config.max_num_words, boundary) | |
else: | |
raise NotImplementedError | |
def create_pretrain_dataset(file_names, | |
bsz_per_core, | |
seq_len, | |
reuse_len, | |
perm_size, | |
leak_ratio, | |
online_masking_config, | |
num_predict=None, | |
input_pipeline_context=None): | |
"""Creates pretrain dataset.""" | |
def parser(record): | |
"""Function used to parse tfrecord.""" | |
record_spec = { | |
"input": tf.io.FixedLenFeature([seq_len], tf.int64), | |
"seg_id": tf.io.FixedLenFeature([seq_len], tf.int64), | |
"label": tf.io.FixedLenFeature([1], tf.int64), | |
} | |
if online_masking_config.sample_strategy in ["whole_word", "word_span"]: | |
logging.info("Add `boundary` spec for %s", | |
online_masking_config.sample_strategy) | |
record_spec["boundary"] = tf.io.VarLenFeature(tf.int64) | |
# retrieve serialized example | |
example = tf.io.parse_single_example( | |
serialized=record, features=record_spec) | |
inputs = example.pop("input") | |
if online_masking_config.sample_strategy in ["whole_word", "word_span"]: | |
boundary = tf.sparse.to_dense(example.pop("boundary")) | |
else: | |
boundary = None | |
is_masked, _ = _online_sample_masks( | |
inputs, seq_len, num_predict, online_masking_config, boundary=boundary) | |
if reuse_len > 0: | |
##### Use memory | |
# permutate the reuse and non-reuse parts separately | |
non_reuse_len = seq_len - reuse_len | |
assert reuse_len % perm_size == 0 and non_reuse_len % perm_size == 0 | |
# Creates permutation mask and target mask for the first reuse_len tokens. | |
# The tokens in this part are reused from the last sequence. | |
perm_mask_0, target_mask_0, input_k_0, input_q_0 = _local_perm( | |
inputs[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len, | |
leak_ratio) | |
# Creates permutation mask and target mask for the rest of tokens in | |
# current example, which are concatentation of two new segments. | |
perm_mask_1, target_mask_1, input_k_1, input_q_1 = _local_perm( | |
inputs[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len, | |
leak_ratio) | |
perm_mask_0 = tf.concat( | |
[perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) | |
perm_mask_1 = tf.concat( | |
[tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) | |
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) | |
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) | |
input_k = tf.concat([input_k_0, input_k_1], axis=0) | |
input_q = tf.concat([input_q_0, input_q_1], axis=0) | |
else: | |
##### Do not use memory | |
assert seq_len % perm_size == 0 | |
# permutate the entire sequence together | |
perm_mask, target_mask, input_k, input_q = _local_perm( | |
inputs, is_masked, perm_size, seq_len, leak_ratio) | |
# reshape back to fixed shape | |
example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) | |
example["input_ids"] = tf.reshape(input_k, [seq_len]) | |
example["input_q"] = tf.reshape(input_q, [seq_len]) | |
# Directly use raw inputs as the target | |
target = inputs | |
if num_predict is not None: | |
indices = tf.range(seq_len, dtype=tf.int64) | |
bool_target_mask = tf.cast(target_mask, tf.bool) | |
indices = tf.boolean_mask(indices, bool_target_mask) | |
##### extra padding due to CLS/SEP introduced after prepro | |
actual_num_predict = tf.shape(indices)[0] | |
pad_len = num_predict - actual_num_predict | |
##### target_mapping | |
target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) | |
paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) | |
target_mapping = tf.concat([target_mapping, paddings], axis=0) | |
example["target_mapping"] = tf.reshape(target_mapping, | |
[num_predict, seq_len]) | |
##### target | |
target = tf.boolean_mask(target, bool_target_mask) | |
paddings = tf.zeros([pad_len], dtype=target.dtype) | |
target = tf.concat([target, paddings], axis=0) | |
example["target"] = tf.reshape(target, [num_predict]) | |
##### target mask | |
target_mask = tf.concat([ | |
tf.ones([actual_num_predict], dtype=tf.float32), | |
tf.zeros([pad_len], dtype=tf.float32) | |
], | |
axis=0) | |
example["target_mask"] = tf.reshape(target_mask, [num_predict]) | |
else: | |
example["target"] = tf.reshape(target, [seq_len]) | |
example["target_mask"] = tf.reshape(target_mask, [seq_len]) | |
for key in list(example.keys()): | |
val = example[key] | |
if tf_keras.backend.is_sparse(val): | |
val = tf.sparse.to_dense(val) | |
if val.dtype == tf.int64: | |
val = tf.cast(val, tf.int32) | |
example[key] = val | |
for k, v in example.items(): | |
logging.info("%s: %s", k, v) | |
return example | |
dataset = parse_files_to_dataset( | |
parser=parser, | |
file_paths=file_names, | |
bsz_per_core=bsz_per_core, | |
sequential=reuse_len > 0, | |
input_pipeline_context=input_pipeline_context) | |
return dataset | |
def format_filename(prefix, | |
suffix, | |
bsz_per_host, | |
seq_len, | |
reuse_len=None, | |
uncased=False): | |
"""Generates input file name pattern.""" | |
if reuse_len is not None and reuse_len > 0: | |
reuse_str = "reuse-{}.".format(reuse_len) | |
bsz_str = "hostbsz-{}.".format(bsz_per_host) | |
else: | |
reuse_str = "" | |
bsz_str = "" | |
if not uncased: | |
case_str = "" | |
else: | |
case_str = "uncased." | |
file_name = "{}.seq-{}.{}{}{}{}".format(prefix, seq_len, reuse_str, bsz_str, | |
case_str, suffix) | |
return file_name | |
def get_pretrain_input_data(batch_size, | |
seq_len, | |
strategy, | |
file_path, | |
reuse_len, | |
perm_size, | |
leak_ratio, | |
num_predict, | |
uncased, | |
online_masking_config, | |
num_hosts=1): | |
"""Returns input dataset from input file string.""" | |
# When using TPU pods, we need to clone dataset across | |
# workers and need to pass in function that returns the dataset rather | |
# than passing dataset instance itself. | |
use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy) | |
split = "train" | |
bsz_per_host = int(batch_size / num_hosts) | |
record_glob_base = format_filename( | |
prefix="meta.{}.pass-*".format(split), | |
suffix="json*", | |
bsz_per_host=bsz_per_host, | |
seq_len=seq_len, | |
reuse_len=reuse_len, | |
uncased=uncased) | |
def _get_num_batch(info): | |
if "num_batch" in info: | |
return info["num_batch"] | |
elif "num_example" in info: | |
return info["num_example"] / bsz_per_host | |
else: | |
raise ValueError("Do not have sample info.") | |
if use_dataset_fn: | |
if batch_size % strategy.num_replicas_in_sync != 0: | |
raise ValueError( | |
"Batch size must be divisible by number of replicas : {}".format( | |
strategy.num_replicas_in_sync)) | |
# As auto rebatching is not supported in | |
# `distribute_datasets_from_function()` API, which is | |
# required when cloning dataset to multiple workers in eager mode, | |
# we use per-replica batch size. | |
batch_size = int(batch_size / strategy.num_replicas_in_sync) | |
record_info = {"num_batch": 0, "filenames": []} | |
tfrecord_dirs = file_path.split(",") | |
logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs) | |
for idx, record_dir in enumerate(tfrecord_dirs): | |
record_glob = os.path.join(record_dir, record_glob_base) | |
logging.info("[%d] Record glob: %s", idx, record_glob) | |
record_paths = sorted(tf.io.gfile.glob(record_glob)) | |
logging.info("[%d] Num of record info path: %d", idx, len(record_paths)) | |
cur_record_info = {"num_batch": 0, "filenames": []} | |
for record_info_path in record_paths: | |
with tf.io.gfile.GFile(record_info_path, "r") as fp: | |
info = json.load(fp) | |
cur_record_info["num_batch"] += int(_get_num_batch(info)) | |
cur_record_info["filenames"] += info["filenames"] | |
# overwrite directory for `cur_record_info` | |
new_filenames = [] | |
for filename in cur_record_info["filenames"]: | |
basename = os.path.basename(filename) | |
new_filename = os.path.join(record_dir, basename) | |
new_filenames.append(new_filename) | |
cur_record_info["filenames"] = new_filenames | |
logging.info("[Dir %d] Number of chosen batches: %s", idx, | |
cur_record_info["num_batch"]) | |
logging.info("[Dir %d] Number of chosen files: %s", idx, | |
len(cur_record_info["filenames"])) | |
logging.info(cur_record_info["filenames"]) | |
# add `cur_record_info` to global `record_info` | |
record_info["num_batch"] += cur_record_info["num_batch"] | |
record_info["filenames"] += cur_record_info["filenames"] | |
logging.info("Total number of batches: %d", record_info["num_batch"]) | |
logging.info("Total number of files: %d", len(record_info["filenames"])) | |
logging.info(record_info["filenames"]) | |
def _dataset_fn(ctx=None): | |
"""Function that can create a pretrain dataset.""" | |
train_dataset = create_pretrain_dataset( | |
file_names=record_info["filenames"], | |
bsz_per_core=batch_size, | |
seq_len=seq_len, | |
reuse_len=reuse_len, | |
perm_size=perm_size, | |
leak_ratio=leak_ratio, | |
online_masking_config=online_masking_config, | |
num_predict=num_predict, | |
input_pipeline_context=ctx) | |
return train_dataset | |
return _dataset_fn if use_dataset_fn else _dataset_fn() | |
def parse_files_to_dataset(parser, | |
file_paths, | |
bsz_per_core, | |
sequential, | |
input_pipeline_context=None): | |
"""Creates the dataset given file paths.""" | |
dataset = tf.data.Dataset.from_tensor_slices(file_paths) | |
# Note: we cannot perform sample-level shuffle here because this will violate | |
# the consecutive requirement of data stream. | |
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1: | |
dataset = dataset.shard(input_pipeline_context.num_input_pipelines, | |
input_pipeline_context.input_pipeline_id) | |
# file-level shuffle | |
if len(file_paths) > 1: | |
dataset = dataset.shuffle(len(file_paths)) | |
if sequential: | |
# Note: cannot perform sample-level shuffle here because this will violate | |
# the consecutive requirement of data stream. | |
dataset = tf.data.TFRecordDataset(dataset) | |
else: | |
# `cycle_length` is the number of parallel files that get read. | |
cycle_length = min(8, len(file_paths)) | |
logging.info("Interleave %d files", cycle_length) | |
dataset = dataset.apply( | |
tf.data.experimental.parallel_interleave( | |
tf.data.TFRecordDataset, cycle_length=cycle_length)) | |
buffer_size = 2048 | |
logging.info("Perform sample-level shuffle with size %d", buffer_size) | |
dataset = dataset.shuffle(buffer_size=buffer_size) | |
dataset = dataset.cache().repeat().map(parser) | |
dataset = dataset.batch(bsz_per_core, drop_remainder=True) | |
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) | |
return dataset | |
def _local_perm(inputs, is_masked, perm_size, seq_len, leak_ratio): | |
"""Samples a permutation of the factorization order. | |
Creates perm_mask and target_mask accordingly. | |
Args: | |
inputs: int64 Tensor in shape [seq_len], input ids. | |
is_masked: bool Tensor in shape [seq_len]. True means being selected for | |
partial prediction. | |
perm_size: the length of longest permutation. Could be set to be reuse_len. | |
Should not be larger than reuse_len or there will be data leaks. | |
seq_len: int, sequence length. | |
leak_ratio: float, percent of masked tokens that are leaked. | |
Returns: | |
perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1. | |
If perm_mask[i][j] == 1, it means the ith token (in original order) cannot | |
attend to the jth token | |
(in original order). This case will happen only when the ith token's | |
permutated position <= the jth token's permutated position, | |
and the jth token is masked or is func token. If perm_mask[i][j] == 0, it | |
means the ith token (in original order) can attend to the jth token | |
(in original order). Note that non-masked tokens can be attended by all | |
other tokens, which is different from the description in original paper. | |
target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If | |
target_mask[i] == 1, | |
the ith token needs to be predicted and mask will be used as input. This | |
token will count for loss. | |
If target_mask[i] == 0, token (or [SEP], [CLS]) will be used as input. This | |
token will not count for loss. | |
inputs_k: int64 Tensor in shape [seq_len], input ids. | |
inputs_q: float32 Tensor in shape [seq_len], the same as target_mask. | |
""" | |
# Generate permutation indices | |
index = tf.range(seq_len, dtype=tf.int64) | |
index = tf.transpose(tf.reshape(index, [-1, perm_size])) | |
index = tf.random.shuffle(index) | |
index = tf.reshape(tf.transpose(index), [-1]) | |
# non-functional tokens | |
non_func_tokens = tf.logical_not( | |
tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) | |
masked_tokens = tf.logical_and(is_masked, non_func_tokens) | |
non_masked_or_func_tokens = tf.logical_not(masked_tokens) | |
smallest_index = -2 * tf.ones([seq_len], dtype=tf.int64) | |
# Similar to BERT, randomly leak some masked tokens | |
if leak_ratio > 0: | |
leak_tokens = tf.logical_and( | |
masked_tokens, | |
tf.random.uniform([seq_len], maxval=1.0) < leak_ratio) | |
can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens) | |
else: | |
can_attend_self = non_masked_or_func_tokens | |
to_index = tf.where(can_attend_self, smallest_index, index) | |
from_index = tf.where(can_attend_self, to_index + 1, to_index) | |
# For masked tokens, can attend if i > j | |
# For context tokens, always can attend each other | |
can_attend = from_index[:, None] > to_index[None, :] | |
# In modeling, 1 indicates cannot attend. Hence, reverse the value here. | |
perm_mask = 1.0 - tf.cast(can_attend, tf.float32) | |
# Only masked tokens are included in the loss | |
target_mask = tf.cast(masked_tokens, tf.float32) | |
# construct inputs_k | |
inputs_k = inputs | |
# construct inputs_q | |
inputs_q = masked_tokens | |
return perm_mask, target_mask, inputs_k, inputs_q | |