# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -*- coding: utf-8 -*- """Script to pre-process pre-training data into tfrecords.""" import json import os import random # Import libraries from absl import app from absl import flags from absl import logging import numpy as np import tensorflow.compat.v1 as tf import sentencepiece as spm from official.legacy.xlnet import preprocess_utils FLAGS = flags.FLAGS special_symbols = { "": 0, "": 1, "": 2, "": 3, "": 4, "": 5, "": 6, "": 7, "": 8, } VOCAB_SIZE = 32000 UNK_ID = special_symbols[""] CLS_ID = special_symbols[""] SEP_ID = special_symbols[""] MASK_ID = special_symbols[""] EOD_ID = special_symbols[""] def _int64_feature(values): return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def _float_feature(values): return tf.train.Feature(float_list=tf.train.FloatList(value=values)) def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix, mask_alpha=5, mask_beta=1, reuse_len=None, uncased=False, fixed_num_predict=None): """docs.""" if reuse_len is None: reuse_len_str = "" else: reuse_len_str = "reuse-{}.".format(reuse_len) if not uncased: uncased_str = "" else: uncased_str = "uncased." if bi_data: bi_data_str = "bi" else: bi_data_str = "uni" if fixed_num_predict is not None: fnp_str = "fnp-{}.".format(fixed_num_predict) else: fnp_str = "" file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format( prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str, mask_alpha, mask_beta, fnp_str, suffix) return file_name def _create_data(idx, input_paths): """Creates data.""" # Load sentence-piece model sp = spm.SentencePieceProcessor() sp.Load(FLAGS.sp_path) input_shards = [] total_line_cnt = 0 for input_path in input_paths: input_data, sent_ids = [], [] sent_id, line_cnt = True, 0 logging.info("Processing %s", input_path) for line in tf.gfile.Open(input_path): if line_cnt % 100000 == 0: logging.info("Loading line %d", line_cnt) line_cnt += 1 if not line.strip(): if FLAGS.use_eod: sent_id = not sent_id cur_sent = [EOD_ID] else: continue else: if FLAGS.from_raw_text: cur_sent = preprocess_utils.preprocess_text( line.strip(), lower=FLAGS.uncased) cur_sent = preprocess_utils.encode_ids(sp, cur_sent) else: cur_sent = list(map(int, line.strip().split())) input_data.extend(cur_sent) sent_ids.extend([sent_id] * len(cur_sent)) sent_id = not sent_id logging.info("Finish with line %d", line_cnt) if line_cnt == 0: continue input_data = np.array(input_data, dtype=np.int64) sent_ids = np.array(sent_ids, dtype=bool) total_line_cnt += line_cnt input_shards.append((input_data, sent_ids)) logging.info("[Task %d] Total number line: %d", idx, total_line_cnt) tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords") filenames, num_batch = [], 0 # Randomly shuffle input shards (with a fixed but distinct random seed) np.random.seed(100 * FLAGS.task + FLAGS.pass_id) perm_indices = np.random.permutation(len(input_shards)) logging.info("Using perm indices %s for pass %d", perm_indices.tolist(), FLAGS.pass_id) input_data_list, sent_ids_list = [], [] prev_sent_id = None for perm_idx in perm_indices: input_data, sent_ids = input_shards[perm_idx] # make sure the `send_ids[0] == not prev_sent_id` if prev_sent_id is not None and sent_ids[0] == prev_sent_id: sent_ids = np.logical_not(sent_ids) # append to temporary list input_data_list.append(input_data) sent_ids_list.append(sent_ids) # update `prev_sent_id` prev_sent_id = sent_ids[-1] input_data = np.concatenate(input_data_list) sent_ids = np.concatenate(sent_ids_list) file_name, cur_num_batch = create_tfrecords( save_dir=tfrecord_dir, basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id), data=[input_data, sent_ids], bsz_per_host=FLAGS.bsz_per_host, seq_len=FLAGS.seq_len, bi_data=FLAGS.bi_data, sp=sp, ) filenames.append(file_name) num_batch += cur_num_batch record_info = { "filenames": filenames, "num_batch": num_batch } return record_info def create_data(_): """Creates pretrain data.""" # Validate FLAGS assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0 if not FLAGS.use_tpu: FLAGS.num_core_per_host = 1 # forced to be one # Make workdirs if not tf.gfile.Exists(FLAGS.save_dir): tf.gfile.MakeDirs(FLAGS.save_dir) tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords") if not tf.gfile.Exists(tfrecord_dir): tf.gfile.MakeDirs(tfrecord_dir) # Create and dump corpus_info from task 0 if FLAGS.task == 0 and FLAGS.pass_id == 0: corpus_info = { "vocab_size": VOCAB_SIZE, "bsz_per_host": FLAGS.bsz_per_host, "num_core_per_host": FLAGS.num_core_per_host, "seq_len": FLAGS.seq_len, "reuse_len": FLAGS.reuse_len, "uncased": FLAGS.uncased, "bi_data": FLAGS.bi_data, "mask_alpha": FLAGS.mask_alpha, "mask_beta": FLAGS.mask_beta, "num_predict": FLAGS.num_predict, "use_eod": FLAGS.use_eod, "sp_path": FLAGS.sp_path, "input_glob": FLAGS.input_glob, } corpus_info_path = os.path.join(FLAGS.save_dir, "corpus_info.json") with tf.gfile.Open(corpus_info_path, "w") as fp: json.dump(corpus_info, fp) # Interleavely split the work into FLAGS.num_task splits file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob)) logging.info("Use glob: %s", FLAGS.input_glob) logging.info("Find %d files: %s", len(file_paths), file_paths) task_file_paths = file_paths[FLAGS.task::FLAGS.num_task] if not task_file_paths: logging.info("Exit: task %d has no file to process.", FLAGS.task) return logging.info("Task %d process %d files: %s", FLAGS.task, len(task_file_paths), task_file_paths) record_info = _create_data(FLAGS.task, task_file_paths) record_prefix = "record_info-{}-{}-{}".format( FLAGS.split, FLAGS.task, FLAGS.pass_id) record_name = format_filename( prefix=record_prefix, bsz_per_host=FLAGS.bsz_per_host, seq_len=FLAGS.seq_len, mask_alpha=FLAGS.mask_alpha, mask_beta=FLAGS.mask_beta, reuse_len=FLAGS.reuse_len, bi_data=FLAGS.bi_data, suffix="json", uncased=FLAGS.uncased, fixed_num_predict=FLAGS.num_predict) record_info_path = os.path.join(tfrecord_dir, record_name) with tf.gfile.Open(record_info_path, "w") as fp: json.dump(record_info, fp) def batchify(data, bsz_per_host, sent_ids=None): """Creates batches.""" num_step = len(data) // bsz_per_host data = data[:bsz_per_host * num_step] data = data.reshape(bsz_per_host, num_step) if sent_ids is not None: sent_ids = sent_ids[:bsz_per_host * num_step] sent_ids = sent_ids.reshape(bsz_per_host, num_step) if sent_ids is not None: return data, sent_ids return data def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False): """Split two segments from `data` starting from the index `begin_idx`.""" data_len = data.shape[0] if begin_idx + tot_len >= data_len: logging.info("[_split_a_and_b] returns None: " "begin_idx %d + tot_len %d >= data_len %d", begin_idx, tot_len, data_len) return None end_idx = begin_idx + 1 cut_points = [] while end_idx < data_len: if sent_ids[end_idx] != sent_ids[end_idx - 1]: if end_idx - begin_idx >= tot_len: break cut_points.append(end_idx) end_idx += 1 a_begin = begin_idx if len(cut_points) == 0 or random.random() < 0.5: # pylint:disable=g-explicit-length-test label = 0 if len(cut_points) == 0: # pylint:disable=g-explicit-length-test a_end = end_idx else: a_end = random.choice(cut_points) b_len = max(1, tot_len - (a_end - a_begin)) # (zihangd): `data_len - 1` to account for extend_target b_begin = random.randint(0, data_len - 1 - b_len) b_end = b_begin + b_len while b_begin > 0 and sent_ids[b_begin - 1] == sent_ids[b_begin]: b_begin -= 1 # (zihangd): `data_len - 1` to account for extend_target while b_end < data_len - 1 and sent_ids[b_end - 1] == sent_ids[b_end]: b_end += 1 new_begin = a_end else: label = 1 a_end = random.choice(cut_points) b_begin = a_end b_end = end_idx new_begin = b_end while a_end - a_begin + b_end - b_begin > tot_len: if a_end - a_begin > b_end - b_begin: # delete the right side only for the LM objective a_end -= 1 else: b_end -= 1 ret = [data[a_begin: a_end], data[b_begin: b_end], label, new_begin] if extend_target: if a_end >= data_len or b_end >= data_len: logging.info("[_split_a_and_b] returns None: " "a_end %d or b_end %d >= data_len %d", a_end, b_end, data_len) return None a_target = data[a_begin + 1: a_end + 1] b_target = data[b_begin: b_end + 1] ret.extend([a_target, b_target]) return ret def _is_start_piece(piece): special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~')) if (piece.startswith("▁") or piece.startswith("<") or piece in special_pieces): return True else: return False def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None): """Samples `goal_num_predict` tokens for partial prediction.""" seg_len = len(seg) mask = np.array([False] * seg_len, dtype=bool) num_predict = 0 ngrams = np.arange(1, max_gram + 1, dtype=np.int64) pvals = 1. / np.arange(1, max_gram + 1) pvals /= pvals.sum(keepdims=True) if reverse: seg = np.flip(seg, 0) cur_len = 0 while cur_len < seg_len: if goal_num_predict is not None and num_predict >= goal_num_predict: break n = np.random.choice(ngrams, p=pvals) if goal_num_predict is not None: n = min(n, goal_num_predict - num_predict) ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta l_ctx = np.random.choice(ctx_size) r_ctx = ctx_size - l_ctx # Find the start position of a complete token beg = cur_len + l_ctx while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())): beg += 1 if beg >= seg_len: break # Find the end position of the n-gram (start pos of the n+1-th gram) end = beg + 1 cnt_ngram = 1 while end < seg_len: cnt_ngram += 1 if cnt_ngram > n: break end += 1 if end >= seg_len: break # Update mask[beg:end] = True num_predict += end - beg cur_len = end + r_ctx while goal_num_predict is not None and num_predict < goal_num_predict: i = np.random.randint(seg_len) if not mask[i]: mask[i] = True num_predict += 1 if reverse: mask = np.flip(mask, 0) return mask def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5, goal_num_predict=None): """Sample `goal_num_predict` tokens for partial prediction.""" seg_len = len(seg) mask = np.array([False] * seg_len, dtype=bool) num_predict = 0 ngrams = np.arange(1, max_gram + 1, dtype=np.int64) pvals = 1. / np.arange(1, max_gram + 1) pvals /= pvals.sum(keepdims=True) if reverse: seg = np.flip(seg, 0) cur_len = 0 while cur_len < seg_len: if goal_num_predict is not None and num_predict >= goal_num_predict: break n = np.random.choice(ngrams, p=pvals) if goal_num_predict is not None: n = min(n, goal_num_predict - num_predict) ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta l_ctx = np.random.choice(ctx_size) r_ctx = ctx_size - l_ctx # Find the start position of a complete token beg = cur_len + l_ctx while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())): beg += 1 if beg >= seg_len: break # Find the end position of the n-gram (start pos of the n+1-th gram) end = beg cnt_ngram = 0 while end < seg_len: if _is_start_piece(sp.IdToPiece(seg[end].item())): cnt_ngram += 1 if cnt_ngram > n: break # select current piece mask[end] = True # update the end pointer and increment num_predict end += 1 num_predict += 1 if goal_num_predict is not None and num_predict >= goal_num_predict: break cur_len = end + r_ctx while goal_num_predict is not None and num_predict < goal_num_predict: i = np.random.randint(seg_len) if not mask[i]: mask[i] = True num_predict += 1 if reverse: mask = np.flip(mask, 0) return mask def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len, bi_data, sp): """Creates TFRecords.""" data, sent_ids = data[0], data[1] num_core = FLAGS.num_core_per_host bsz_per_core = bsz_per_host // num_core if bi_data: assert bsz_per_host % (2 * FLAGS.num_core_per_host) == 0 fwd_data, fwd_sent_ids = batchify(data, bsz_per_host // 2, sent_ids) fwd_data = fwd_data.reshape(num_core, 1, bsz_per_core // 2, -1) fwd_sent_ids = fwd_sent_ids.reshape(num_core, 1, bsz_per_core // 2, -1) bwd_data = fwd_data[:, :, :, ::-1] bwd_sent_ids = fwd_sent_ids[:, :, :, ::-1] data = np.concatenate( [fwd_data, bwd_data], 1).reshape(bsz_per_host, -1) sent_ids = np.concatenate( [fwd_sent_ids, bwd_sent_ids], 1).reshape(bsz_per_host, -1) else: data, sent_ids = batchify(data, bsz_per_host, sent_ids) logging.info("Raw data shape %s.", data.shape) file_name = format_filename( prefix=basename, bsz_per_host=bsz_per_host, seq_len=seq_len, bi_data=bi_data, suffix="tfrecords", mask_alpha=FLAGS.mask_alpha, mask_beta=FLAGS.mask_beta, reuse_len=FLAGS.reuse_len, uncased=FLAGS.uncased, fixed_num_predict=FLAGS.num_predict ) save_path = os.path.join(save_dir, file_name) record_writer = tf.python_io.TFRecordWriter(save_path) logging.info("Start writing %s.", save_path) num_batch = 0 reuse_len = FLAGS.reuse_len # [sep] x 2 + [cls] assert reuse_len < seq_len - 3 data_len = data.shape[1] sep_array = np.array([SEP_ID], dtype=np.int64) cls_array = np.array([CLS_ID], dtype=np.int64) i = 0 while i + seq_len <= data_len: if num_batch % 500 == 0: logging.info("Processing batch %d", num_batch) all_ok = True features = [] for idx in range(bsz_per_host): inp = data[idx, i: i + reuse_len] tgt = data[idx, i + 1: i + reuse_len + 1] results = _split_a_and_b( data[idx], sent_ids[idx], begin_idx=i + reuse_len, tot_len=seq_len - reuse_len - 3, extend_target=True) if results is None: logging.info("Break out with seq idx %d", i) all_ok = False break # unpack the results (a_data, b_data, label, _, a_target, b_target) = tuple(results) # sample ngram spans to predict reverse = bi_data and (idx // (bsz_per_core // 2)) % 2 == 1 if FLAGS.num_predict is None: num_predict_0 = num_predict_1 = None else: num_predict_1 = FLAGS.num_predict // 2 num_predict_0 = FLAGS.num_predict - num_predict_1 mask_0 = _sample_mask(sp, inp, reverse=reverse, goal_num_predict=num_predict_0) mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data, sep_array, cls_array]), reverse=reverse, goal_num_predict=num_predict_1) # concatenate data cat_data = np.concatenate([inp, a_data, sep_array, b_data, sep_array, cls_array]) seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] + [1] * b_data.shape[0] + [1] + [2]) assert cat_data.shape[0] == seq_len assert mask_0.shape[0] == seq_len // 2 assert mask_1.shape[0] == seq_len // 2 # the last two CLS's are not used, just for padding purposes tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array]) assert tgt.shape[0] == seq_len is_masked = np.concatenate([mask_0, mask_1], 0) if FLAGS.num_predict is not None: assert np.sum(is_masked) == FLAGS.num_predict feature = { "input": _int64_feature(cat_data), "is_masked": _int64_feature(is_masked), "target": _int64_feature(tgt), "seg_id": _int64_feature(seg_id), "label": _int64_feature([label]), } features.append(feature) if all_ok: assert len(features) == bsz_per_host for feature in features: example = tf.train.Example(features=tf.train.Features(feature=feature)) record_writer.write(example.SerializeToString()) num_batch += 1 else: break i += reuse_len record_writer.close() logging.info("Done writing %s. Num of batches: %d", save_path, num_batch) return save_path, num_batch ################ # get_input_fn # ################ def _convert_example(example, use_bfloat16): """Cast int64 into int32 and float32 to bfloat16 if use_bfloat16.""" for key in list(example.keys()): val = example[key] if tf_keras.backend.is_sparse(val): val = tf.sparse.to_dense(val) if val.dtype == tf.int64: val = tf.cast(val, tf.int32) if use_bfloat16 and val.dtype == tf.float32: val = tf.cast(val, tf.bfloat16) example[key] = val def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts, host_id, num_core_per_host, bsz_per_core): """Parses files to a dataset.""" del num_batch # list of file pathes num_files = len(file_names) num_files_per_host = num_files // num_hosts my_start_file_id = host_id * num_files_per_host my_end_file_id = (host_id + 1) * num_files_per_host if host_id == num_hosts - 1: my_end_file_id = num_files file_paths = file_names[my_start_file_id: my_end_file_id] logging.info("Host %d handles %d files", host_id, len(file_paths)) assert split == "train" dataset = tf.data.Dataset.from_tensor_slices(file_paths) # file-level shuffle if len(file_paths) > 1: dataset = dataset.shuffle(len(file_paths)) # Note: we cannot perform sample-level shuffle here because this will violate # the consecutive requirement of data stream. dataset = tf.data.TFRecordDataset(dataset) # Note: since we are doing online preprocessing, the parsed result of # the same input at each time will be different. Thus, cache processed data # is not helpful. It will use a lot of memory and lead to contrainer OOM. # So, change to cache non-parsed raw data instead. dataset = dataset.cache().map(parser).repeat() dataset = dataset.batch(bsz_per_core, drop_remainder=True) dataset = dataset.prefetch(num_core_per_host * bsz_per_core) return dataset def _local_perm(inputs, targets, is_masked, perm_size, seq_len): """Samples a permutation of the factorization order, and create a mask. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. Returns: The permutation mask, new targets, target mask, and new inputs. """ # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.transpose(tf.reshape(index, [-1, perm_size])) index = tf.random_shuffle(index) index = tf.reshape(tf.transpose(index), [-1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not(tf.logical_or( tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won"t be information leak smallest_index = -tf.ones([seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and( self_rev_index[:, None] <= rev_index[None, :], masked_or_func_tokens) perm_mask = tf.cast(perm_mask, tf.float32) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[0: 1], targets[: -1]], axis=0) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q def get_dataset(params, num_hosts, num_core_per_host, split, file_names, num_batch, seq_len, reuse_len, perm_size, mask_alpha, mask_beta, use_bfloat16=False, num_predict=None): """Gets the dataset.""" del mask_alpha del mask_beta bsz_per_core = params["batch_size"] if num_hosts > 1: host_id = params["context"].current_host else: host_id = 0 #### Function used to parse tfrecord def parser(record): """function used to parse tfrecord.""" record_spec = { "input": tf.FixedLenFeature([seq_len], tf.int64), "target": tf.FixedLenFeature([seq_len], tf.int64), "seg_id": tf.FixedLenFeature([seq_len], tf.int64), "label": tf.FixedLenFeature([1], tf.int64), "is_masked": tf.FixedLenFeature([seq_len], tf.int64), } # retrieve serialized example example = tf.parse_single_example( serialized=record, features=record_spec) inputs = example.pop("input") target = example.pop("target") is_masked = tf.cast(example.pop("is_masked"), tf.bool) non_reuse_len = seq_len - reuse_len assert perm_size <= reuse_len and perm_size <= non_reuse_len perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len) perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len) perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) target = tf.concat([target_0, target_1], axis=0) target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) input_k = tf.concat([input_k_0, input_k_1], axis=0) input_q = tf.concat([input_q_0, input_q_1], axis=0) if num_predict is not None: indices = tf.range(seq_len, dtype=tf.int64) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[0] pad_len = num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) example["target_mapping"] = tf.reshape(target_mapping, [num_predict, seq_len]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) example["target"] = tf.reshape(target, [num_predict]) ##### target mask target_mask = tf.concat( [tf.ones([actual_num_predict], dtype=tf.float32), tf.zeros([pad_len], dtype=tf.float32)], axis=0) example["target_mask"] = tf.reshape(target_mask, [num_predict]) else: example["target"] = tf.reshape(target, [seq_len]) example["target_mask"] = tf.reshape(target_mask, [seq_len]) # reshape back to fixed shape example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) example["input_k"] = tf.reshape(input_k, [seq_len]) example["input_q"] = tf.reshape(input_q, [seq_len]) _convert_example(example, use_bfloat16) for k, v in example.items(): logging.info("%s: %s", k, v) return example # Get dataset dataset = parse_files_to_dataset( parser=parser, file_names=file_names, split=split, num_batch=num_batch, num_hosts=num_hosts, host_id=host_id, num_core_per_host=num_core_per_host, bsz_per_core=bsz_per_core) return dataset def get_input_fn( tfrecord_dir, split, bsz_per_host, seq_len, reuse_len, bi_data, num_hosts=1, num_core_per_host=1, perm_size=None, mask_alpha=None, mask_beta=None, uncased=False, num_passes=None, use_bfloat16=False, num_predict=None): """Gets the input function.""" # Merge all record infos into a single one record_glob_base = format_filename( prefix="record_info-{}-*".format(split), bsz_per_host=bsz_per_host, seq_len=seq_len, bi_data=bi_data, suffix="json", mask_alpha=mask_alpha, mask_beta=mask_beta, reuse_len=reuse_len, uncased=uncased, fixed_num_predict=num_predict) record_info = {"num_batch": 0, "filenames": []} tfrecord_dirs = tfrecord_dir.split(",") logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs) for idx, record_dir in enumerate(tfrecord_dirs): record_glob = os.path.join(record_dir, record_glob_base) logging.info("[%d] Record glob: %s", idx, record_glob) record_paths = sorted(tf.gfile.Glob(record_glob)) logging.info("[%d] Num of record info path: %d", idx, len(record_paths)) cur_record_info = {"num_batch": 0, "filenames": []} for record_info_path in record_paths: if num_passes is not None: record_info_name = os.path.basename(record_info_path) fields = record_info_name.split(".")[0].split("-") pass_id = int(fields[-1]) if len(fields) == 5 and pass_id >= num_passes: logging.info("Skip pass %d: %s", pass_id, record_info_name) continue with tf.gfile.Open(record_info_path, "r") as fp: info = json.load(fp) if num_passes is not None: eff_num_passes = min(num_passes, len(info["filenames"])) ratio = eff_num_passes / len(info["filenames"]) cur_record_info["num_batch"] += int(info["num_batch"] * ratio) cur_record_info["filenames"] += info["filenames"][:eff_num_passes] else: cur_record_info["num_batch"] += info["num_batch"] cur_record_info["filenames"] += info["filenames"] # overwrite directory for `cur_record_info` new_filenames = [] for filename in cur_record_info["filenames"]: basename = os.path.basename(filename) new_filename = os.path.join(record_dir, basename) new_filenames.append(new_filename) cur_record_info["filenames"] = new_filenames logging.info("[Dir %d] Number of chosen batches: %s", idx, cur_record_info["num_batch"]) logging.info("[Dir %d] Number of chosen files: %s", idx, len(cur_record_info["filenames"])) logging.info(cur_record_info["filenames"]) # add `cur_record_info` to global `record_info` record_info["num_batch"] += cur_record_info["num_batch"] record_info["filenames"] += cur_record_info["filenames"] logging.info("Total number of batches: %d", record_info["num_batch"]) logging.info("Total number of files: %d", len(record_info["filenames"])) logging.info(record_info["filenames"]) def input_fn(params): """docs.""" assert params["batch_size"] * num_core_per_host == bsz_per_host dataset = get_dataset( params=params, num_hosts=num_hosts, num_core_per_host=num_core_per_host, split=split, file_names=record_info["filenames"], num_batch=record_info["num_batch"], seq_len=seq_len, reuse_len=reuse_len, perm_size=perm_size, mask_alpha=mask_alpha, mask_beta=mask_beta, use_bfloat16=use_bfloat16, num_predict=num_predict) return dataset return input_fn, record_info def define_flags(): """Defines relevant flags.""" flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs") flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.") flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.") flags.DEFINE_integer("seq_len", 512, help="Sequence length.") flags.DEFINE_integer("reuse_len", 256, help="Number of token that can be reused as memory. " "Could be half of `seq_len`.") flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.") flags.DEFINE_bool("bi_data", True, help="whether to create bidirectional data") flags.DEFINE_integer("mask_alpha", default=6, help="How many tokens to form a group.") flags.DEFINE_integer("mask_beta", default=1, help="How many tokens to mask within each group.") flags.DEFINE_bool("use_eod", True, help="whether to append EOD at the end of a doc.") flags.DEFINE_bool("from_raw_text", True, help="Whether the input is raw text or encoded ids.") flags.DEFINE_integer("num_predict", default=85, help="Num of tokens to predict.") flags.DEFINE_string("input_glob", "data/example/*.txt", help="Input file glob.") flags.DEFINE_string("sp_path", "", help="Path to the sentence piece model.") flags.DEFINE_string("save_dir", "proc_data/example", help="Directory for saving the processed data.") flags.DEFINE_enum("split", "train", ["train", "dev", "test"], help="Save the data as which split.") flags.DEFINE_integer("pass_id", 0, help="ID of the current pass." "Different passes sample different negative segment.") flags.DEFINE_integer("num_task", 1, help="Number of total tasks.") flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when " "using multiple workers to identify each worker.") if __name__ == "__main__": define_flags() logging.set_verbosity(logging.INFO) app.run(create_data)