Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""BERT classification or regression finetuning runner in TF 2.x.""" | |
import functools | |
import json | |
import math | |
import os | |
# Import libraries | |
from absl import app | |
from absl import flags | |
from absl import logging | |
import gin | |
import tensorflow as tf, tf_keras | |
from official.common import distribute_utils | |
from official.legacy.bert import bert_models | |
from official.legacy.bert import common_flags | |
from official.legacy.bert import configs as bert_configs | |
from official.legacy.bert import input_pipeline | |
from official.legacy.bert import model_saving_utils | |
from official.modeling import performance | |
from official.nlp import optimization | |
from official.utils.misc import keras_utils | |
flags.DEFINE_enum( | |
'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'], | |
'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: ' | |
'trains the model and evaluates in the meantime. ' | |
'`export_only`: will take the latest checkpoint inside ' | |
'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and ' | |
'restores the model to output predictions on the test set.') | |
flags.DEFINE_string('train_data_path', None, | |
'Path to training data for BERT classifier.') | |
flags.DEFINE_string('eval_data_path', None, | |
'Path to evaluation data for BERT classifier.') | |
flags.DEFINE_string( | |
'input_meta_data_path', None, | |
'Path to file that contains meta data about input ' | |
'to be used for training and evaluation.') | |
flags.DEFINE_integer('train_data_size', None, 'Number of training samples ' | |
'to use. If None, uses the full train data. ' | |
'(default: None).') | |
flags.DEFINE_string('predict_checkpoint_path', None, | |
'Path to the checkpoint for predictions.') | |
flags.DEFINE_integer( | |
'num_eval_per_epoch', 1, | |
'Number of evaluations per epoch. The purpose of this flag is to provide ' | |
'more granular evaluation scores and checkpoints. For example, if original ' | |
'data has N samples and num_eval_per_epoch is n, then each epoch will be ' | |
'evaluated every N/n samples.') | |
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.') | |
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.') | |
common_flags.define_common_bert_flags() | |
FLAGS = flags.FLAGS | |
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32} | |
def get_loss_fn(num_classes): | |
"""Gets the classification loss function.""" | |
def classification_loss_fn(labels, logits): | |
"""Classification loss.""" | |
labels = tf.reshape(labels, [-1]) | |
log_probs = tf.nn.log_softmax(logits, axis=-1) | |
one_hot_labels = tf.one_hot( | |
tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32) | |
per_example_loss = -tf.reduce_sum( | |
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1) | |
return tf.reduce_mean(per_example_loss) | |
return classification_loss_fn | |
def get_dataset_fn(input_file_pattern, | |
max_seq_length, | |
global_batch_size, | |
is_training, | |
label_type=tf.int64, | |
include_sample_weights=False, | |
num_samples=None): | |
"""Gets a closure to create a dataset.""" | |
def _dataset_fn(ctx=None): | |
"""Returns tf.data.Dataset for distributed BERT pretraining.""" | |
batch_size = ctx.get_per_replica_batch_size( | |
global_batch_size) if ctx else global_batch_size | |
dataset = input_pipeline.create_classifier_dataset( | |
tf.io.gfile.glob(input_file_pattern), | |
max_seq_length, | |
batch_size, | |
is_training=is_training, | |
input_pipeline_context=ctx, | |
label_type=label_type, | |
include_sample_weights=include_sample_weights, | |
num_samples=num_samples) | |
return dataset | |
return _dataset_fn | |
def run_bert_classifier(strategy, | |
bert_config, | |
input_meta_data, | |
model_dir, | |
epochs, | |
steps_per_epoch, | |
steps_per_loop, | |
eval_steps, | |
warmup_steps, | |
initial_lr, | |
init_checkpoint, | |
train_input_fn, | |
eval_input_fn, | |
training_callbacks=True, | |
custom_callbacks=None, | |
custom_metrics=None): | |
"""Run BERT classifier training using low-level API.""" | |
max_seq_length = input_meta_data['max_seq_length'] | |
num_classes = input_meta_data.get('num_labels', 1) | |
is_regression = num_classes == 1 | |
def _get_classifier_model(): | |
"""Gets a classifier model.""" | |
classifier_model, core_model = ( | |
bert_models.classifier_model( | |
bert_config, | |
num_classes, | |
max_seq_length, | |
hub_module_url=FLAGS.hub_module_url, | |
hub_module_trainable=FLAGS.hub_module_trainable)) | |
optimizer = optimization.create_optimizer(initial_lr, | |
steps_per_epoch * epochs, | |
warmup_steps, FLAGS.end_lr, | |
FLAGS.optimizer_type) | |
classifier_model.optimizer = performance.configure_optimizer( | |
optimizer, | |
use_float16=common_flags.use_float16()) | |
return classifier_model, core_model | |
# tf_keras.losses objects accept optional sample_weight arguments (eg. coming | |
# from the dataset) to compute weighted loss, as used for the regression | |
# tasks. The classification tasks, using the custom get_loss_fn don't accept | |
# sample weights though. | |
loss_fn = (tf_keras.losses.MeanSquaredError() if is_regression | |
else get_loss_fn(num_classes)) | |
# Defines evaluation metrics function, which will create metrics in the | |
# correct device and strategy scope. | |
if custom_metrics: | |
metric_fn = custom_metrics | |
elif is_regression: | |
metric_fn = functools.partial( | |
tf_keras.metrics.MeanSquaredError, | |
'mean_squared_error', | |
dtype=tf.float32) | |
else: | |
metric_fn = functools.partial( | |
tf_keras.metrics.SparseCategoricalAccuracy, | |
'accuracy', | |
dtype=tf.float32) | |
# Start training using Keras compile/fit API. | |
logging.info('Training using TF 2.x Keras compile/fit API with ' | |
'distribution strategy.') | |
return run_keras_compile_fit( | |
model_dir, | |
strategy, | |
_get_classifier_model, | |
train_input_fn, | |
eval_input_fn, | |
loss_fn, | |
metric_fn, | |
init_checkpoint, | |
epochs, | |
steps_per_epoch, | |
steps_per_loop, | |
eval_steps, | |
training_callbacks=training_callbacks, | |
custom_callbacks=custom_callbacks) | |
def run_keras_compile_fit(model_dir, | |
strategy, | |
model_fn, | |
train_input_fn, | |
eval_input_fn, | |
loss_fn, | |
metric_fn, | |
init_checkpoint, | |
epochs, | |
steps_per_epoch, | |
steps_per_loop, | |
eval_steps, | |
training_callbacks=True, | |
custom_callbacks=None): | |
"""Runs BERT classifier model using Keras compile/fit API.""" | |
with strategy.scope(): | |
training_dataset = train_input_fn() | |
evaluation_dataset = eval_input_fn() if eval_input_fn else None | |
bert_model, sub_model = model_fn() | |
optimizer = bert_model.optimizer | |
if init_checkpoint: | |
checkpoint = tf.train.Checkpoint(model=sub_model, encoder=sub_model) | |
checkpoint.read(init_checkpoint).assert_existing_objects_matched() | |
if not isinstance(metric_fn, (list, tuple)): | |
metric_fn = [metric_fn] | |
bert_model.compile( | |
optimizer=optimizer, | |
loss=loss_fn, | |
metrics=[fn() for fn in metric_fn], | |
steps_per_execution=steps_per_loop) | |
summary_dir = os.path.join(model_dir, 'summaries') | |
summary_callback = tf_keras.callbacks.TensorBoard(summary_dir) | |
checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
directory=model_dir, | |
max_to_keep=None, | |
step_counter=optimizer.iterations, | |
checkpoint_interval=0) | |
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager) | |
if training_callbacks: | |
if custom_callbacks is not None: | |
custom_callbacks += [summary_callback, checkpoint_callback] | |
else: | |
custom_callbacks = [summary_callback, checkpoint_callback] | |
history = bert_model.fit( | |
x=training_dataset, | |
validation_data=evaluation_dataset, | |
steps_per_epoch=steps_per_epoch, | |
epochs=epochs, | |
validation_steps=eval_steps, | |
callbacks=custom_callbacks) | |
stats = {'total_training_steps': steps_per_epoch * epochs} | |
if 'loss' in history.history: | |
stats['train_loss'] = history.history['loss'][-1] | |
if 'val_accuracy' in history.history: | |
stats['eval_metrics'] = history.history['val_accuracy'][-1] | |
return bert_model, stats | |
def get_predictions_and_labels(strategy, | |
trained_model, | |
eval_input_fn, | |
is_regression=False, | |
return_probs=False): | |
"""Obtains predictions of trained model on evaluation data. | |
Note that list of labels is returned along with the predictions because the | |
order changes on distributing dataset over TPU pods. | |
Args: | |
strategy: Distribution strategy. | |
trained_model: Trained model with preloaded weights. | |
eval_input_fn: Input function for evaluation data. | |
is_regression: Whether it is a regression task. | |
return_probs: Whether to return probabilities of classes. | |
Returns: | |
predictions: List of predictions. | |
labels: List of gold labels corresponding to predictions. | |
""" | |
def test_step(iterator): | |
"""Computes predictions on distributed devices.""" | |
def _test_step_fn(inputs): | |
"""Replicated predictions.""" | |
inputs, labels = inputs | |
logits = trained_model(inputs, training=False) | |
if not is_regression: | |
probabilities = tf.nn.softmax(logits) | |
return probabilities, labels | |
else: | |
return logits, labels | |
outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),)) | |
# outputs: current batch logits as a tuple of shard logits | |
outputs = tf.nest.map_structure(strategy.experimental_local_results, | |
outputs) | |
labels = tf.nest.map_structure(strategy.experimental_local_results, labels) | |
return outputs, labels | |
def _run_evaluation(test_iterator): | |
"""Runs evaluation steps.""" | |
preds, golds = list(), list() | |
try: | |
with tf.experimental.async_scope(): | |
while True: | |
probabilities, labels = test_step(test_iterator) | |
for cur_probs, cur_labels in zip(probabilities, labels): | |
if return_probs: | |
preds.extend(cur_probs.numpy().tolist()) | |
else: | |
preds.extend(tf.math.argmax(cur_probs, axis=1).numpy()) | |
golds.extend(cur_labels.numpy().tolist()) | |
except (StopIteration, tf.errors.OutOfRangeError): | |
tf.experimental.async_clear_error() | |
return preds, golds | |
test_iter = iter(strategy.distribute_datasets_from_function(eval_input_fn)) | |
predictions, labels = _run_evaluation(test_iter) | |
return predictions, labels | |
def export_classifier(model_export_path, input_meta_data, bert_config, | |
model_dir): | |
"""Exports a trained model as a `SavedModel` for inference. | |
Args: | |
model_export_path: a string specifying the path to the SavedModel directory. | |
input_meta_data: dictionary containing meta data about input and model. | |
bert_config: Bert configuration file to define core bert layers. | |
model_dir: The directory where the model weights and training/evaluation | |
summaries are stored. | |
Raises: | |
Export path is not specified, got an empty string or None. | |
""" | |
if not model_export_path: | |
raise ValueError('Export path is not specified: %s' % model_export_path) | |
if not model_dir: | |
raise ValueError('Export path is not specified: %s' % model_dir) | |
# Export uses float32 for now, even if training uses mixed precision. | |
tf_keras.mixed_precision.set_global_policy('float32') | |
classifier_model = bert_models.classifier_model( | |
bert_config, | |
input_meta_data.get('num_labels', 1), | |
hub_module_url=FLAGS.hub_module_url, | |
hub_module_trainable=False)[0] | |
model_saving_utils.export_bert_model( | |
model_export_path, model=classifier_model, checkpoint_dir=model_dir) | |
def run_bert(strategy, | |
input_meta_data, | |
model_config, | |
train_input_fn=None, | |
eval_input_fn=None, | |
init_checkpoint=None, | |
custom_callbacks=None, | |
custom_metrics=None): | |
"""Run BERT training.""" | |
# Enables XLA in Session Config. Should not be set for TPU. | |
keras_utils.set_session_config(FLAGS.enable_xla) | |
performance.set_mixed_precision_policy(common_flags.dtype()) | |
epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch | |
train_data_size = ( | |
input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch) | |
if FLAGS.train_data_size: | |
train_data_size = min(train_data_size, FLAGS.train_data_size) | |
logging.info('Updated train_data_size: %s', train_data_size) | |
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) | |
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size) | |
eval_steps = int( | |
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) | |
if not strategy: | |
raise ValueError('Distribution strategy has not been specified.') | |
if not custom_callbacks: | |
custom_callbacks = [] | |
if FLAGS.log_steps: | |
custom_callbacks.append( | |
keras_utils.TimeHistory( | |
batch_size=FLAGS.train_batch_size, | |
log_steps=FLAGS.log_steps, | |
logdir=FLAGS.model_dir)) | |
trained_model, _ = run_bert_classifier( | |
strategy, | |
model_config, | |
input_meta_data, | |
FLAGS.model_dir, | |
epochs, | |
steps_per_epoch, | |
FLAGS.steps_per_loop, | |
eval_steps, | |
warmup_steps, | |
FLAGS.learning_rate, | |
init_checkpoint or FLAGS.init_checkpoint, | |
train_input_fn, | |
eval_input_fn, | |
custom_callbacks=custom_callbacks, | |
custom_metrics=custom_metrics) | |
if FLAGS.model_export_path: | |
model_saving_utils.export_bert_model( | |
FLAGS.model_export_path, model=trained_model) | |
return trained_model | |
def custom_main(custom_callbacks=None, custom_metrics=None): | |
"""Run classification or regression. | |
Args: | |
custom_callbacks: list of tf_keras.Callbacks passed to training loop. | |
custom_metrics: list of metrics passed to the training loop. | |
""" | |
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) | |
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: | |
input_meta_data = json.loads(reader.read().decode('utf-8')) | |
label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')] | |
include_sample_weights = input_meta_data.get('has_sample_weights', False) | |
if not FLAGS.model_dir: | |
FLAGS.model_dir = '/tmp/bert20/' | |
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) | |
if FLAGS.mode == 'export_only': | |
export_classifier(FLAGS.model_export_path, input_meta_data, bert_config, | |
FLAGS.model_dir) | |
return | |
strategy = distribute_utils.get_distribution_strategy( | |
distribution_strategy=FLAGS.distribution_strategy, | |
num_gpus=FLAGS.num_gpus, | |
tpu_address=FLAGS.tpu) | |
eval_input_fn = get_dataset_fn( | |
FLAGS.eval_data_path, | |
input_meta_data['max_seq_length'], | |
FLAGS.eval_batch_size, | |
is_training=False, | |
label_type=label_type, | |
include_sample_weights=include_sample_weights) | |
if FLAGS.mode == 'predict': | |
num_labels = input_meta_data.get('num_labels', 1) | |
with strategy.scope(): | |
classifier_model = bert_models.classifier_model( | |
bert_config, num_labels)[0] | |
checkpoint = tf.train.Checkpoint(model=classifier_model) | |
latest_checkpoint_file = ( | |
FLAGS.predict_checkpoint_path or | |
tf.train.latest_checkpoint(FLAGS.model_dir)) | |
assert latest_checkpoint_file | |
logging.info('Checkpoint file %s found and restoring from ' | |
'checkpoint', latest_checkpoint_file) | |
checkpoint.restore( | |
latest_checkpoint_file).assert_existing_objects_matched() | |
preds, _ = get_predictions_and_labels( | |
strategy, | |
classifier_model, | |
eval_input_fn, | |
is_regression=(num_labels == 1), | |
return_probs=True) | |
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv') | |
with tf.io.gfile.GFile(output_predict_file, 'w') as writer: | |
logging.info('***** Predict results *****') | |
for probabilities in preds: | |
output_line = '\t'.join( | |
str(class_probability) | |
for class_probability in probabilities) + '\n' | |
writer.write(output_line) | |
return | |
if FLAGS.mode != 'train_and_eval': | |
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode) | |
train_input_fn = get_dataset_fn( | |
FLAGS.train_data_path, | |
input_meta_data['max_seq_length'], | |
FLAGS.train_batch_size, | |
is_training=True, | |
label_type=label_type, | |
include_sample_weights=include_sample_weights, | |
num_samples=FLAGS.train_data_size) | |
run_bert( | |
strategy, | |
input_meta_data, | |
bert_config, | |
train_input_fn, | |
eval_input_fn, | |
custom_callbacks=custom_callbacks, | |
custom_metrics=custom_metrics) | |
def main(_): | |
custom_main(custom_callbacks=None, custom_metrics=None) | |
if __name__ == '__main__': | |
flags.mark_flag_as_required('bert_config_file') | |
flags.mark_flag_as_required('input_meta_data_path') | |
flags.mark_flag_as_required('model_dir') | |
app.run(main) | |