|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""BERT classification or regression finetuning runner in TF 2.x.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import functools |
|
import json |
|
import math |
|
import os |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import gin |
|
import tensorflow as tf |
|
from official.modeling import performance |
|
from official.nlp import optimization |
|
from official.nlp.bert import bert_models |
|
from official.nlp.bert import common_flags |
|
from official.nlp.bert import configs as bert_configs |
|
from official.nlp.bert import input_pipeline |
|
from official.nlp.bert import model_saving_utils |
|
from official.utils.misc import distribution_utils |
|
from official.utils.misc import keras_utils |
|
|
|
flags.DEFINE_enum( |
|
'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'], |
|
'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: ' |
|
'trains the model and evaluates in the meantime. ' |
|
'`export_only`: will take the latest checkpoint inside ' |
|
'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and ' |
|
'restores the model to output predictions on the test set.') |
|
flags.DEFINE_string('train_data_path', None, |
|
'Path to training data for BERT classifier.') |
|
flags.DEFINE_string('eval_data_path', None, |
|
'Path to evaluation data for BERT classifier.') |
|
flags.DEFINE_string( |
|
'input_meta_data_path', None, |
|
'Path to file that contains meta data about input ' |
|
'to be used for training and evaluation.') |
|
flags.DEFINE_string('predict_checkpoint_path', None, |
|
'Path to the checkpoint for predictions.') |
|
flags.DEFINE_integer( |
|
'num_eval_per_epoch', 1, |
|
'Number of evaluations per epoch. The purpose of this flag is to provide ' |
|
'more granular evaluation scores and checkpoints. For example, if original ' |
|
'data has N samples and num_eval_per_epoch is n, then each epoch will be ' |
|
'evaluated every N/n samples.') |
|
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.') |
|
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.') |
|
|
|
common_flags.define_common_bert_flags() |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32} |
|
|
|
|
|
def get_loss_fn(num_classes): |
|
"""Gets the classification loss function.""" |
|
|
|
def classification_loss_fn(labels, logits): |
|
"""Classification loss.""" |
|
labels = tf.squeeze(labels) |
|
log_probs = tf.nn.log_softmax(logits, axis=-1) |
|
one_hot_labels = tf.one_hot( |
|
tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32) |
|
per_example_loss = -tf.reduce_sum( |
|
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1) |
|
return tf.reduce_mean(per_example_loss) |
|
|
|
return classification_loss_fn |
|
|
|
|
|
def get_dataset_fn(input_file_pattern, |
|
max_seq_length, |
|
global_batch_size, |
|
is_training, |
|
label_type=tf.int64, |
|
include_sample_weights=False): |
|
"""Gets a closure to create a dataset.""" |
|
|
|
def _dataset_fn(ctx=None): |
|
"""Returns tf.data.Dataset for distributed BERT pretraining.""" |
|
batch_size = ctx.get_per_replica_batch_size( |
|
global_batch_size) if ctx else global_batch_size |
|
dataset = input_pipeline.create_classifier_dataset( |
|
tf.io.gfile.glob(input_file_pattern), |
|
max_seq_length, |
|
batch_size, |
|
is_training=is_training, |
|
input_pipeline_context=ctx, |
|
label_type=label_type, |
|
include_sample_weights=include_sample_weights) |
|
return dataset |
|
|
|
return _dataset_fn |
|
|
|
|
|
def run_bert_classifier(strategy, |
|
bert_config, |
|
input_meta_data, |
|
model_dir, |
|
epochs, |
|
steps_per_epoch, |
|
steps_per_loop, |
|
eval_steps, |
|
warmup_steps, |
|
initial_lr, |
|
init_checkpoint, |
|
train_input_fn, |
|
eval_input_fn, |
|
training_callbacks=True, |
|
custom_callbacks=None, |
|
custom_metrics=None): |
|
"""Run BERT classifier training using low-level API.""" |
|
max_seq_length = input_meta_data['max_seq_length'] |
|
num_classes = input_meta_data.get('num_labels', 1) |
|
is_regression = num_classes == 1 |
|
|
|
def _get_classifier_model(): |
|
"""Gets a classifier model.""" |
|
classifier_model, core_model = ( |
|
bert_models.classifier_model( |
|
bert_config, |
|
num_classes, |
|
max_seq_length, |
|
hub_module_url=FLAGS.hub_module_url, |
|
hub_module_trainable=FLAGS.hub_module_trainable)) |
|
optimizer = optimization.create_optimizer(initial_lr, |
|
steps_per_epoch * epochs, |
|
warmup_steps, FLAGS.end_lr, |
|
FLAGS.optimizer_type) |
|
classifier_model.optimizer = performance.configure_optimizer( |
|
optimizer, |
|
use_float16=common_flags.use_float16(), |
|
use_graph_rewrite=common_flags.use_graph_rewrite()) |
|
return classifier_model, core_model |
|
|
|
|
|
|
|
|
|
|
|
loss_fn = (tf.keras.losses.MeanSquaredError() if is_regression |
|
else get_loss_fn(num_classes)) |
|
|
|
|
|
|
|
if custom_metrics: |
|
metric_fn = custom_metrics |
|
elif is_regression: |
|
metric_fn = functools.partial( |
|
tf.keras.metrics.MeanSquaredError, |
|
'mean_squared_error', |
|
dtype=tf.float32) |
|
else: |
|
metric_fn = functools.partial( |
|
tf.keras.metrics.SparseCategoricalAccuracy, |
|
'accuracy', |
|
dtype=tf.float32) |
|
|
|
|
|
logging.info('Training using TF 2.x Keras compile/fit API with ' |
|
'distribution strategy.') |
|
return run_keras_compile_fit( |
|
model_dir, |
|
strategy, |
|
_get_classifier_model, |
|
train_input_fn, |
|
eval_input_fn, |
|
loss_fn, |
|
metric_fn, |
|
init_checkpoint, |
|
epochs, |
|
steps_per_epoch, |
|
steps_per_loop, |
|
eval_steps, |
|
training_callbacks=training_callbacks, |
|
custom_callbacks=custom_callbacks) |
|
|
|
|
|
def run_keras_compile_fit(model_dir, |
|
strategy, |
|
model_fn, |
|
train_input_fn, |
|
eval_input_fn, |
|
loss_fn, |
|
metric_fn, |
|
init_checkpoint, |
|
epochs, |
|
steps_per_epoch, |
|
steps_per_loop, |
|
eval_steps, |
|
training_callbacks=True, |
|
custom_callbacks=None): |
|
"""Runs BERT classifier model using Keras compile/fit API.""" |
|
|
|
with strategy.scope(): |
|
training_dataset = train_input_fn() |
|
evaluation_dataset = eval_input_fn() if eval_input_fn else None |
|
bert_model, sub_model = model_fn() |
|
optimizer = bert_model.optimizer |
|
|
|
if init_checkpoint: |
|
checkpoint = tf.train.Checkpoint(model=sub_model) |
|
checkpoint.restore(init_checkpoint).assert_existing_objects_matched() |
|
|
|
if not isinstance(metric_fn, (list, tuple)): |
|
metric_fn = [metric_fn] |
|
bert_model.compile( |
|
optimizer=optimizer, |
|
loss=loss_fn, |
|
metrics=[fn() for fn in metric_fn], |
|
experimental_steps_per_execution=steps_per_loop) |
|
|
|
summary_dir = os.path.join(model_dir, 'summaries') |
|
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) |
|
checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer) |
|
checkpoint_manager = tf.train.CheckpointManager( |
|
checkpoint, |
|
directory=model_dir, |
|
max_to_keep=None, |
|
step_counter=optimizer.iterations, |
|
checkpoint_interval=0) |
|
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager) |
|
|
|
if training_callbacks: |
|
if custom_callbacks is not None: |
|
custom_callbacks += [summary_callback, checkpoint_callback] |
|
else: |
|
custom_callbacks = [summary_callback, checkpoint_callback] |
|
|
|
history = bert_model.fit( |
|
x=training_dataset, |
|
validation_data=evaluation_dataset, |
|
steps_per_epoch=steps_per_epoch, |
|
epochs=epochs, |
|
validation_steps=eval_steps, |
|
callbacks=custom_callbacks) |
|
stats = {'total_training_steps': steps_per_epoch * epochs} |
|
if 'loss' in history.history: |
|
stats['train_loss'] = history.history['loss'][-1] |
|
if 'val_accuracy' in history.history: |
|
stats['eval_metrics'] = history.history['val_accuracy'][-1] |
|
return bert_model, stats |
|
|
|
|
|
def get_predictions_and_labels(strategy, |
|
trained_model, |
|
eval_input_fn, |
|
return_probs=False): |
|
"""Obtains predictions of trained model on evaluation data. |
|
|
|
Note that list of labels is returned along with the predictions because the |
|
order changes on distributing dataset over TPU pods. |
|
|
|
Args: |
|
strategy: Distribution strategy. |
|
trained_model: Trained model with preloaded weights. |
|
eval_input_fn: Input function for evaluation data. |
|
return_probs: Whether to return probabilities of classes. |
|
|
|
Returns: |
|
predictions: List of predictions. |
|
labels: List of gold labels corresponding to predictions. |
|
""" |
|
|
|
@tf.function |
|
def test_step(iterator): |
|
"""Computes predictions on distributed devices.""" |
|
|
|
def _test_step_fn(inputs): |
|
"""Replicated predictions.""" |
|
inputs, labels = inputs |
|
logits = trained_model(inputs, training=False) |
|
probabilities = tf.nn.softmax(logits) |
|
return probabilities, labels |
|
|
|
outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),)) |
|
|
|
outputs = tf.nest.map_structure(strategy.experimental_local_results, |
|
outputs) |
|
labels = tf.nest.map_structure(strategy.experimental_local_results, labels) |
|
return outputs, labels |
|
|
|
def _run_evaluation(test_iterator): |
|
"""Runs evaluation steps.""" |
|
preds, golds = list(), list() |
|
try: |
|
with tf.experimental.async_scope(): |
|
while True: |
|
probabilities, labels = test_step(test_iterator) |
|
for cur_probs, cur_labels in zip(probabilities, labels): |
|
if return_probs: |
|
preds.extend(cur_probs.numpy().tolist()) |
|
else: |
|
preds.extend(tf.math.argmax(cur_probs, axis=1).numpy()) |
|
golds.extend(cur_labels.numpy().tolist()) |
|
except (StopIteration, tf.errors.OutOfRangeError): |
|
tf.experimental.async_clear_error() |
|
return preds, golds |
|
|
|
test_iter = iter( |
|
strategy.experimental_distribute_datasets_from_function(eval_input_fn)) |
|
predictions, labels = _run_evaluation(test_iter) |
|
|
|
return predictions, labels |
|
|
|
|
|
def export_classifier(model_export_path, input_meta_data, bert_config, |
|
model_dir): |
|
"""Exports a trained model as a `SavedModel` for inference. |
|
|
|
Args: |
|
model_export_path: a string specifying the path to the SavedModel directory. |
|
input_meta_data: dictionary containing meta data about input and model. |
|
bert_config: Bert configuration file to define core bert layers. |
|
model_dir: The directory where the model weights and training/evaluation |
|
summaries are stored. |
|
|
|
Raises: |
|
Export path is not specified, got an empty string or None. |
|
""" |
|
if not model_export_path: |
|
raise ValueError('Export path is not specified: %s' % model_export_path) |
|
if not model_dir: |
|
raise ValueError('Export path is not specified: %s' % model_dir) |
|
|
|
|
|
tf.keras.mixed_precision.experimental.set_policy('float32') |
|
classifier_model = bert_models.classifier_model( |
|
bert_config, input_meta_data.get('num_labels', 1))[0] |
|
|
|
model_saving_utils.export_bert_model( |
|
model_export_path, model=classifier_model, checkpoint_dir=model_dir) |
|
|
|
|
|
def run_bert(strategy, |
|
input_meta_data, |
|
model_config, |
|
train_input_fn=None, |
|
eval_input_fn=None, |
|
init_checkpoint=None, |
|
custom_callbacks=None, |
|
custom_metrics=None): |
|
"""Run BERT training.""" |
|
|
|
keras_utils.set_session_config(FLAGS.enable_xla) |
|
performance.set_mixed_precision_policy(common_flags.dtype()) |
|
|
|
epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch |
|
train_data_size = ( |
|
input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch) |
|
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) |
|
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size) |
|
eval_steps = int( |
|
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) |
|
|
|
if not strategy: |
|
raise ValueError('Distribution strategy has not been specified.') |
|
|
|
if not custom_callbacks: |
|
custom_callbacks = [] |
|
|
|
if FLAGS.log_steps: |
|
custom_callbacks.append( |
|
keras_utils.TimeHistory( |
|
batch_size=FLAGS.train_batch_size, |
|
log_steps=FLAGS.log_steps, |
|
logdir=FLAGS.model_dir)) |
|
|
|
trained_model, _ = run_bert_classifier( |
|
strategy, |
|
model_config, |
|
input_meta_data, |
|
FLAGS.model_dir, |
|
epochs, |
|
steps_per_epoch, |
|
FLAGS.steps_per_loop, |
|
eval_steps, |
|
warmup_steps, |
|
FLAGS.learning_rate, |
|
init_checkpoint or FLAGS.init_checkpoint, |
|
train_input_fn, |
|
eval_input_fn, |
|
custom_callbacks=custom_callbacks, |
|
custom_metrics=custom_metrics) |
|
|
|
if FLAGS.model_export_path: |
|
model_saving_utils.export_bert_model( |
|
FLAGS.model_export_path, model=trained_model) |
|
return trained_model |
|
|
|
|
|
def custom_main(custom_callbacks=None, custom_metrics=None): |
|
"""Run classification or regression. |
|
|
|
Args: |
|
custom_callbacks: list of tf.keras.Callbacks passed to training loop. |
|
custom_metrics: list of metrics passed to the training loop. |
|
""" |
|
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) |
|
|
|
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: |
|
input_meta_data = json.loads(reader.read().decode('utf-8')) |
|
label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')] |
|
include_sample_weights = input_meta_data.get('has_sample_weights', False) |
|
|
|
if not FLAGS.model_dir: |
|
FLAGS.model_dir = '/tmp/bert20/' |
|
|
|
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) |
|
|
|
if FLAGS.mode == 'export_only': |
|
export_classifier(FLAGS.model_export_path, input_meta_data, bert_config, |
|
FLAGS.model_dir) |
|
return |
|
|
|
strategy = distribution_utils.get_distribution_strategy( |
|
distribution_strategy=FLAGS.distribution_strategy, |
|
num_gpus=FLAGS.num_gpus, |
|
tpu_address=FLAGS.tpu) |
|
eval_input_fn = get_dataset_fn( |
|
FLAGS.eval_data_path, |
|
input_meta_data['max_seq_length'], |
|
FLAGS.eval_batch_size, |
|
is_training=False, |
|
label_type=label_type, |
|
include_sample_weights=include_sample_weights) |
|
|
|
if FLAGS.mode == 'predict': |
|
with strategy.scope(): |
|
classifier_model = bert_models.classifier_model( |
|
bert_config, input_meta_data['num_labels'])[0] |
|
checkpoint = tf.train.Checkpoint(model=classifier_model) |
|
latest_checkpoint_file = ( |
|
FLAGS.predict_checkpoint_path or |
|
tf.train.latest_checkpoint(FLAGS.model_dir)) |
|
assert latest_checkpoint_file |
|
logging.info('Checkpoint file %s found and restoring from ' |
|
'checkpoint', latest_checkpoint_file) |
|
checkpoint.restore( |
|
latest_checkpoint_file).assert_existing_objects_matched() |
|
preds, _ = get_predictions_and_labels( |
|
strategy, classifier_model, eval_input_fn, return_probs=True) |
|
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv') |
|
with tf.io.gfile.GFile(output_predict_file, 'w') as writer: |
|
logging.info('***** Predict results *****') |
|
for probabilities in preds: |
|
output_line = '\t'.join( |
|
str(class_probability) |
|
for class_probability in probabilities) + '\n' |
|
writer.write(output_line) |
|
return |
|
|
|
if FLAGS.mode != 'train_and_eval': |
|
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode) |
|
train_input_fn = get_dataset_fn( |
|
FLAGS.train_data_path, |
|
input_meta_data['max_seq_length'], |
|
FLAGS.train_batch_size, |
|
is_training=True, |
|
label_type=label_type, |
|
include_sample_weights=include_sample_weights) |
|
run_bert( |
|
strategy, |
|
input_meta_data, |
|
bert_config, |
|
train_input_fn, |
|
eval_input_fn, |
|
custom_callbacks=custom_callbacks, |
|
custom_metrics=custom_metrics) |
|
|
|
|
|
def main(_): |
|
custom_main(custom_callbacks=None, custom_metrics=None) |
|
|
|
|
|
if __name__ == '__main__': |
|
flags.mark_flag_as_required('bert_config_file') |
|
flags.mark_flag_as_required('input_meta_data_path') |
|
flags.mark_flag_as_required('model_dir') |
|
app.run(main) |
|
|