|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Run masked LM/next sentence pre-training for BERT in TF 2.x.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import gin |
|
import tensorflow as tf |
|
from official.modeling import performance |
|
from official.nlp import optimization |
|
from official.nlp.bert import bert_models |
|
from official.nlp.bert import common_flags |
|
from official.nlp.bert import configs |
|
from official.nlp.bert import input_pipeline |
|
from official.nlp.bert import model_training_utils |
|
from official.utils.misc import distribution_utils |
|
|
|
|
|
flags.DEFINE_string('input_files', None, |
|
'File path to retrieve training data for pre-training.') |
|
|
|
flags.DEFINE_integer( |
|
'max_seq_length', 128, |
|
'The maximum total input sequence length after WordPiece tokenization. ' |
|
'Sequences longer than this will be truncated, and sequences shorter ' |
|
'than this will be padded.') |
|
flags.DEFINE_integer('max_predictions_per_seq', 20, |
|
'Maximum predictions per sequence_output.') |
|
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') |
|
flags.DEFINE_integer('num_steps_per_epoch', 1000, |
|
'Total number of training steps to run per epoch.') |
|
flags.DEFINE_float('warmup_steps', 10000, |
|
'Warmup steps for Adam weight decay optimizer.') |
|
flags.DEFINE_bool('use_next_sentence_label', True, |
|
'Whether to use next sentence label to compute final loss.') |
|
flags.DEFINE_bool('train_summary_interval', 0, 'Step interval for training ' |
|
'summaries. If the value is a negative number, ' |
|
'then training summaries are not enabled.') |
|
|
|
common_flags.define_common_bert_flags() |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
def get_pretrain_dataset_fn(input_file_pattern, seq_length, |
|
max_predictions_per_seq, global_batch_size, |
|
use_next_sentence_label=True): |
|
"""Returns input dataset from input file string.""" |
|
def _dataset_fn(ctx=None): |
|
"""Returns tf.data.Dataset for distributed BERT pretraining.""" |
|
input_patterns = input_file_pattern.split(',') |
|
batch_size = ctx.get_per_replica_batch_size(global_batch_size) |
|
train_dataset = input_pipeline.create_pretrain_dataset( |
|
input_patterns, |
|
seq_length, |
|
max_predictions_per_seq, |
|
batch_size, |
|
is_training=True, |
|
input_pipeline_context=ctx, |
|
use_next_sentence_label=use_next_sentence_label) |
|
return train_dataset |
|
|
|
return _dataset_fn |
|
|
|
|
|
def get_loss_fn(): |
|
"""Returns loss function for BERT pretraining.""" |
|
|
|
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args): |
|
return tf.reduce_mean(losses) |
|
|
|
return _bert_pretrain_loss_fn |
|
|
|
|
|
def run_customized_training(strategy, |
|
bert_config, |
|
init_checkpoint, |
|
max_seq_length, |
|
max_predictions_per_seq, |
|
model_dir, |
|
steps_per_epoch, |
|
steps_per_loop, |
|
epochs, |
|
initial_lr, |
|
warmup_steps, |
|
end_lr, |
|
optimizer_type, |
|
input_files, |
|
train_batch_size, |
|
use_next_sentence_label=True, |
|
train_summary_interval=0, |
|
custom_callbacks=None): |
|
"""Run BERT pretrain model training using low-level API.""" |
|
|
|
train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length, |
|
max_predictions_per_seq, |
|
train_batch_size, |
|
use_next_sentence_label) |
|
|
|
def _get_pretrain_model(): |
|
"""Gets a pretraining model.""" |
|
pretrain_model, core_model = bert_models.pretrain_model( |
|
bert_config, max_seq_length, max_predictions_per_seq, |
|
use_next_sentence_label=use_next_sentence_label) |
|
optimizer = optimization.create_optimizer( |
|
initial_lr, steps_per_epoch * epochs, warmup_steps, |
|
end_lr, optimizer_type) |
|
pretrain_model.optimizer = performance.configure_optimizer( |
|
optimizer, |
|
use_float16=common_flags.use_float16(), |
|
use_graph_rewrite=common_flags.use_graph_rewrite()) |
|
return pretrain_model, core_model |
|
|
|
trained_model = model_training_utils.run_customized_training_loop( |
|
strategy=strategy, |
|
model_fn=_get_pretrain_model, |
|
loss_fn=get_loss_fn(), |
|
scale_loss=FLAGS.scale_loss, |
|
model_dir=model_dir, |
|
init_checkpoint=init_checkpoint, |
|
train_input_fn=train_input_fn, |
|
steps_per_epoch=steps_per_epoch, |
|
steps_per_loop=steps_per_loop, |
|
epochs=epochs, |
|
sub_model_export_name='pretrained/bert_model', |
|
train_summary_interval=train_summary_interval, |
|
custom_callbacks=custom_callbacks) |
|
|
|
return trained_model |
|
|
|
|
|
def run_bert_pretrain(strategy, custom_callbacks=None): |
|
"""Runs BERT pre-training.""" |
|
|
|
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) |
|
if not strategy: |
|
raise ValueError('Distribution strategy is not specified.') |
|
|
|
|
|
logging.info('Training using customized training loop TF 2.0 with distributed' |
|
'strategy.') |
|
|
|
performance.set_mixed_precision_policy(common_flags.dtype()) |
|
|
|
return run_customized_training( |
|
strategy, |
|
bert_config, |
|
FLAGS.init_checkpoint, |
|
FLAGS.max_seq_length, |
|
FLAGS.max_predictions_per_seq, |
|
FLAGS.model_dir, |
|
FLAGS.num_steps_per_epoch, |
|
FLAGS.steps_per_loop, |
|
FLAGS.num_train_epochs, |
|
FLAGS.learning_rate, |
|
FLAGS.warmup_steps, |
|
FLAGS.end_lr, |
|
FLAGS.optimizer_type, |
|
FLAGS.input_files, |
|
FLAGS.train_batch_size, |
|
FLAGS.use_next_sentence_label, |
|
FLAGS.train_summary_interval, |
|
custom_callbacks=custom_callbacks) |
|
|
|
|
|
def main(_): |
|
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) |
|
if not FLAGS.model_dir: |
|
FLAGS.model_dir = '/tmp/bert20/' |
|
strategy = distribution_utils.get_distribution_strategy( |
|
distribution_strategy=FLAGS.distribution_strategy, |
|
num_gpus=FLAGS.num_gpus, |
|
tpu_address=FLAGS.tpu) |
|
if strategy: |
|
print('***** Number of cores used : ', strategy.num_replicas_in_sync) |
|
|
|
run_bert_pretrain(strategy) |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run(main) |
|
|