# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run masked LM/next sentence pre-training for BERT in TF 2.x.""" # Import libraries from absl import app from absl import flags from absl import logging import gin import tensorflow as tf, tf_keras from official.common import distribute_utils from official.legacy.bert import bert_models from official.legacy.bert import common_flags from official.legacy.bert import configs from official.legacy.bert import input_pipeline from official.legacy.bert import model_training_utils from official.modeling import performance from official.nlp import optimization flags.DEFINE_string('input_files', None, 'File path to retrieve training data for pre-training.') # Model training specific flags. flags.DEFINE_integer( 'max_seq_length', 128, 'The maximum total input sequence length after WordPiece tokenization. ' 'Sequences longer than this will be truncated, and sequences shorter ' 'than this will be padded.') flags.DEFINE_integer('max_predictions_per_seq', 20, 'Maximum predictions per sequence_output.') flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') flags.DEFINE_integer('num_steps_per_epoch', 1000, 'Total number of training steps to run per epoch.') flags.DEFINE_float('warmup_steps', 10000, 'Warmup steps for Adam weight decay optimizer.') flags.DEFINE_bool('use_next_sentence_label', True, 'Whether to use next sentence label to compute final loss.') flags.DEFINE_bool('train_summary_interval', 0, 'Step interval for training ' 'summaries. If the value is a negative number, ' 'then training summaries are not enabled.') common_flags.define_common_bert_flags() FLAGS = flags.FLAGS def get_pretrain_dataset_fn(input_file_pattern, seq_length, max_predictions_per_seq, global_batch_size, use_next_sentence_label=True): """Returns input dataset from input file string.""" def _dataset_fn(ctx=None): """Returns tf.data.Dataset for distributed BERT pretraining.""" input_patterns = input_file_pattern.split(',') batch_size = ctx.get_per_replica_batch_size(global_batch_size) train_dataset = input_pipeline.create_pretrain_dataset( input_patterns, seq_length, max_predictions_per_seq, batch_size, is_training=True, input_pipeline_context=ctx, use_next_sentence_label=use_next_sentence_label) return train_dataset return _dataset_fn def get_loss_fn(): """Returns loss function for BERT pretraining.""" def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args): return tf.reduce_mean(losses) return _bert_pretrain_loss_fn def run_customized_training(strategy, bert_config, init_checkpoint, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, end_lr, optimizer_type, input_files, train_batch_size, use_next_sentence_label=True, train_summary_interval=0, custom_callbacks=None, explicit_allreduce=False, pre_allreduce_callbacks=None, post_allreduce_callbacks=None, allreduce_bytes_per_pack=0): """Run BERT pretrain model training using low-level API.""" train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, max_predictions_per_seq, train_batch_size, use_next_sentence_label) def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq, use_next_sentence_label=use_next_sentence_label) optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps, end_lr, optimizer_type) pretrain_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16()) return pretrain_model, core_model trained_model = model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_pretrain_model, loss_fn=get_loss_fn(), scale_loss=FLAGS.scale_loss, model_dir=model_dir, init_checkpoint=init_checkpoint, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, sub_model_export_name='pretrained/bert_model', explicit_allreduce=explicit_allreduce, pre_allreduce_callbacks=pre_allreduce_callbacks, post_allreduce_callbacks=post_allreduce_callbacks, allreduce_bytes_per_pack=allreduce_bytes_per_pack, train_summary_interval=train_summary_interval, custom_callbacks=custom_callbacks) return trained_model def run_bert_pretrain(strategy, custom_callbacks=None): """Runs BERT pre-training.""" bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) if not strategy: raise ValueError('Distribution strategy is not specified.') # Runs customized training loop. logging.info('Training using customized training loop TF 2.0 with distributed' 'strategy.') performance.set_mixed_precision_policy(common_flags.dtype()) # Only when explicit_allreduce = True, post_allreduce_callbacks and # allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no # longer implicitly allreduce gradients, users manually allreduce gradient and # pass the allreduced grads_and_vars to apply_gradients(). # With explicit_allreduce = True, clip_by_global_norm is moved to after # allreduce. return run_customized_training( strategy, bert_config, FLAGS.init_checkpoint, # Used to initialize only the BERT submodel. FLAGS.max_seq_length, FLAGS.max_predictions_per_seq, FLAGS.model_dir, FLAGS.num_steps_per_epoch, FLAGS.steps_per_loop, FLAGS.num_train_epochs, FLAGS.learning_rate, FLAGS.warmup_steps, FLAGS.end_lr, FLAGS.optimizer_type, FLAGS.input_files, FLAGS.train_batch_size, FLAGS.use_next_sentence_label, FLAGS.train_summary_interval, custom_callbacks=custom_callbacks, explicit_allreduce=FLAGS.explicit_allreduce, pre_allreduce_callbacks=[ model_training_utils.clip_by_global_norm_callback ], allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack) def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) if not FLAGS.model_dir: FLAGS.model_dir = '/tmp/bert20/' # Configures cluster spec for multi-worker distribution strategy. if FLAGS.num_gpus > 0: _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index) strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, all_reduce_alg=FLAGS.all_reduce_alg, tpu_address=FLAGS.tpu) if strategy: print('***** Number of cores used : ', strategy.num_replicas_in_sync) run_bert_pretrain(strategy) if __name__ == '__main__': app.run(main)