deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run masked LM/next sentence pre-training for BERT in TF 2.x."""
# Import libraries
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf, tf_keras
from official.common import distribute_utils
from official.legacy.bert import bert_models
from official.legacy.bert import common_flags
from official.legacy.bert import configs
from official.legacy.bert import input_pipeline
from official.legacy.bert import model_training_utils
from official.modeling import performance
from official.nlp import optimization
flags.DEFINE_string('input_files', None,
'File path to retrieve training data for pre-training.')
# Model training specific flags.
flags.DEFINE_integer(
'max_seq_length', 128,
'The maximum total input sequence length after WordPiece tokenization. '
'Sequences longer than this will be truncated, and sequences shorter '
'than this will be padded.')
flags.DEFINE_integer('max_predictions_per_seq', 20,
'Maximum predictions per sequence_output.')
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000,
'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.')
flags.DEFINE_bool('train_summary_interval', 0, 'Step interval for training '
'summaries. If the value is a negative number, '
'then training summaries are not enabled.')
common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq, global_batch_size,
use_next_sentence_label=True):
"""Returns input dataset from input file string."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
input_patterns = input_file_pattern.split(',')
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
train_dataset = input_pipeline.create_pretrain_dataset(
input_patterns,
seq_length,
max_predictions_per_seq,
batch_size,
is_training=True,
input_pipeline_context=ctx,
use_next_sentence_label=use_next_sentence_label)
return train_dataset
return _dataset_fn
def get_loss_fn():
"""Returns loss function for BERT pretraining."""
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
return tf.reduce_mean(losses)
return _bert_pretrain_loss_fn
def run_customized_training(strategy,
bert_config,
init_checkpoint,
max_seq_length,
max_predictions_per_seq,
model_dir,
steps_per_epoch,
steps_per_loop,
epochs,
initial_lr,
warmup_steps,
end_lr,
optimizer_type,
input_files,
train_batch_size,
use_next_sentence_label=True,
train_summary_interval=0,
custom_callbacks=None,
explicit_allreduce=False,
pre_allreduce_callbacks=None,
post_allreduce_callbacks=None,
allreduce_bytes_per_pack=0):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
max_predictions_per_seq,
train_batch_size,
use_next_sentence_label)
def _get_pretrain_model():
"""Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq,
use_next_sentence_label=use_next_sentence_label)
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps,
end_lr, optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16())
return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_pretrain_model,
loss_fn=get_loss_fn(),
scale_loss=FLAGS.scale_loss,
model_dir=model_dir,
init_checkpoint=init_checkpoint,
train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
epochs=epochs,
sub_model_export_name='pretrained/bert_model',
explicit_allreduce=explicit_allreduce,
pre_allreduce_callbacks=pre_allreduce_callbacks,
post_allreduce_callbacks=post_allreduce_callbacks,
allreduce_bytes_per_pack=allreduce_bytes_per_pack,
train_summary_interval=train_summary_interval,
custom_callbacks=custom_callbacks)
return trained_model
def run_bert_pretrain(strategy, custom_callbacks=None):
"""Runs BERT pre-training."""
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if not strategy:
raise ValueError('Distribution strategy is not specified.')
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distributed'
'strategy.')
performance.set_mixed_precision_policy(common_flags.dtype())
# Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
# longer implicitly allreduce gradients, users manually allreduce gradient and
# pass the allreduced grads_and_vars to apply_gradients().
# With explicit_allreduce = True, clip_by_global_norm is moved to after
# allreduce.
return run_customized_training(
strategy,
bert_config,
FLAGS.init_checkpoint, # Used to initialize only the BERT submodel.
FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq,
FLAGS.model_dir,
FLAGS.num_steps_per_epoch,
FLAGS.steps_per_loop,
FLAGS.num_train_epochs,
FLAGS.learning_rate,
FLAGS.warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type,
FLAGS.input_files,
FLAGS.train_batch_size,
FLAGS.use_next_sentence_label,
FLAGS.train_summary_interval,
custom_callbacks=custom_callbacks,
explicit_allreduce=FLAGS.explicit_allreduce,
pre_allreduce_callbacks=[
model_training_utils.clip_by_global_norm_callback
],
allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack)
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if strategy:
print('***** Number of cores used : ', strategy.num_replicas_in_sync)
run_bert_pretrain(strategy)
if __name__ == '__main__':
app.run(main)