# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Library for running BERT family models on SQuAD 1.1/2.0 in TF 2.x.""" import collections import json import os from absl import flags from absl import logging import tensorflow as tf, tf_keras from official.legacy.bert import bert_models from official.legacy.bert import common_flags from official.legacy.bert import input_pipeline from official.legacy.bert import model_saving_utils from official.legacy.bert import model_training_utils from official.modeling import performance from official.nlp import optimization from official.nlp.data import squad_lib_sp from official.nlp.tools import squad_evaluate_v1_1 from official.nlp.tools import squad_evaluate_v2_0 from official.utils.misc import keras_utils def define_common_squad_flags(): """Defines common flags used by SQuAD tasks.""" flags.DEFINE_enum( 'mode', 'train_and_eval', [ 'train_and_eval', 'train_and_predict', 'train', 'eval', 'predict', 'export_only' ], 'One of {"train_and_eval", "train_and_predict", ' '"train", "eval", "predict", "export_only"}. ' '`train_and_eval`: train & predict to json files & compute eval metrics. ' '`train_and_predict`: train & predict to json files. ' '`train`: only trains the model. ' '`eval`: predict answers from squad json file & compute eval metrics. ' '`predict`: predict answers from the squad json file. ' '`export_only`: will take the latest checkpoint inside ' 'model_dir and export a `SavedModel`.') flags.DEFINE_string('train_data_path', '', 'Training data path with train tfrecords.') flags.DEFINE_string( 'input_meta_data_path', None, 'Path to file that contains meta data about input ' 'to be used for training and evaluation.') # Model training specific flags. flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') # Predict processing related. flags.DEFINE_string( 'predict_file', None, 'SQuAD prediction json file path. ' '`predict` mode supports multiple files: one can use ' 'wildcard to specify multiple files and it can also be ' 'multiple file patterns separated by comma. Note that ' '`eval` mode only supports a single predict file.') flags.DEFINE_bool( 'do_lower_case', True, 'Whether to lower case the input text. Should be True for uncased ' 'models and False for cased models.') flags.DEFINE_float( 'null_score_diff_threshold', 0.0, 'If null_score - best_non_null is greater than the threshold, ' 'predict null. This is only used for SQuAD v2.') flags.DEFINE_bool( 'verbose_logging', False, 'If true, all of the warnings related to data processing will be ' 'printed. A number of warnings are expected for a normal SQuAD ' 'evaluation.') flags.DEFINE_integer('predict_batch_size', 8, 'Total batch size for prediction.') flags.DEFINE_integer( 'n_best_size', 20, 'The total number of n-best predictions to generate in the ' 'nbest_predictions.json output file.') flags.DEFINE_integer( 'max_answer_length', 30, 'The maximum length of an answer that can be generated. This is needed ' 'because the start and end predictions are not conditioned on one ' 'another.') common_flags.define_common_bert_flags() FLAGS = flags.FLAGS def squad_loss_fn(start_positions, end_positions, start_logits, end_logits): """Returns sparse categorical crossentropy for start/end logits.""" start_loss = tf_keras.losses.sparse_categorical_crossentropy( start_positions, start_logits, from_logits=True) end_loss = tf_keras.losses.sparse_categorical_crossentropy( end_positions, end_logits, from_logits=True) total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2 return total_loss def get_loss_fn(): """Gets a loss function for squad task.""" def _loss_fn(labels, model_outputs): start_positions = labels['start_positions'] end_positions = labels['end_positions'] start_logits, end_logits = model_outputs return squad_loss_fn(start_positions, end_positions, start_logits, end_logits) return _loss_fn RawResult = collections.namedtuple('RawResult', ['unique_id', 'start_logits', 'end_logits']) def get_raw_results(predictions): """Converts multi-replica predictions to RawResult.""" for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'], predictions['start_logits'], predictions['end_logits']): for values in zip(unique_ids.numpy(), start_logits.numpy(), end_logits.numpy()): yield RawResult( unique_id=values[0], start_logits=values[1].tolist(), end_logits=values[2].tolist()) def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size, is_training): """Gets a closure to create a dataset..""" def _dataset_fn(ctx=None): """Returns tf.data.Dataset for distributed BERT pretraining.""" batch_size = ctx.get_per_replica_batch_size( global_batch_size) if ctx else global_batch_size dataset = input_pipeline.create_squad_dataset( input_file_pattern, max_seq_length, batch_size, is_training=is_training, input_pipeline_context=ctx) return dataset return _dataset_fn def get_squad_model_to_predict(strategy, bert_config, checkpoint_path, input_meta_data): """Gets a squad model to make predictions.""" with strategy.scope(): # Prediction always uses float32, even if training uses mixed precision. tf_keras.mixed_precision.set_global_policy('float32') squad_model, _ = bert_models.squad_model( bert_config, input_meta_data['max_seq_length'], hub_module_url=FLAGS.hub_module_url) if checkpoint_path is None: checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) logging.info('Restoring checkpoints from %s', checkpoint_path) checkpoint = tf.train.Checkpoint(model=squad_model) checkpoint.restore(checkpoint_path).expect_partial() return squad_model def predict_squad_customized(strategy, input_meta_data, predict_tfrecord_path, num_steps, squad_model): """Make predictions using a Bert-based squad model.""" predict_dataset_fn = get_dataset_fn( predict_tfrecord_path, input_meta_data['max_seq_length'], FLAGS.predict_batch_size, is_training=False) predict_iterator = iter( strategy.distribute_datasets_from_function(predict_dataset_fn)) @tf.function def predict_step(iterator): """Predicts on distributed devices.""" def _replicated_step(inputs): """Replicated prediction calculation.""" x, _ = inputs unique_ids = x.pop('unique_ids') start_logits, end_logits = squad_model(x, training=False) return dict( unique_ids=unique_ids, start_logits=start_logits, end_logits=end_logits) outputs = strategy.run(_replicated_step, args=(next(iterator),)) return tf.nest.map_structure(strategy.experimental_local_results, outputs) all_results = [] for _ in range(num_steps): predictions = predict_step(predict_iterator) for result in get_raw_results(predictions): all_results.append(result) if len(all_results) % 100 == 0: logging.info('Made predictions for %d records.', len(all_results)) return all_results def train_squad(strategy, input_meta_data, bert_config, custom_callbacks=None, run_eagerly=False, init_checkpoint=None, sub_model_export_name=None): """Run bert squad training.""" if strategy: logging.info('Training using customized training loop with distribution' ' strategy.') # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_session_config(FLAGS.enable_xla) performance.set_mixed_precision_policy(common_flags.dtype()) epochs = FLAGS.num_train_epochs num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size) train_input_fn = get_dataset_fn( FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable) optimizer = optimization.create_optimizer(FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps, FLAGS.end_lr, FLAGS.optimizer_type) squad_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16()) return squad_model, core_model # Only when explicit_allreduce = True, post_allreduce_callbacks and # allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no # longer implicitly allreduce gradients, users manually allreduce gradient and # pass the allreduced grads_and_vars to apply_gradients(). # With explicit_allreduce = True, clip_by_global_norm is moved to after # allreduce. model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_squad_model, loss_fn=get_loss_fn(), model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=FLAGS.steps_per_loop, epochs=epochs, train_input_fn=train_input_fn, init_checkpoint=init_checkpoint or FLAGS.init_checkpoint, sub_model_export_name=sub_model_export_name, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks, explicit_allreduce=FLAGS.explicit_allreduce, pre_allreduce_callbacks=[ model_training_utils.clip_by_global_norm_callback ], allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack) def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib, predict_file, squad_model): """Makes predictions for a squad dataset.""" doc_stride = input_meta_data['doc_stride'] max_query_length = input_meta_data['max_query_length'] # Whether data should be in Ver 2.0 format. version_2_with_negative = input_meta_data.get('version_2_with_negative', False) eval_examples = squad_lib.read_squad_examples( input_file=predict_file, is_training=False, version_2_with_negative=version_2_with_negative) eval_writer = squad_lib.FeatureWriter( filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'), is_training=False) eval_features = [] def _append_feature(feature, is_padding): if not is_padding: eval_features.append(feature) eval_writer.process_feature(feature) # TPU requires a fixed batch size for all batches, therefore the number # of examples must be a multiple of the batch size, or else examples # will get dropped. So we pad with fake examples which are ignored # later on. kwargs = dict( examples=eval_examples, tokenizer=tokenizer, max_seq_length=input_meta_data['max_seq_length'], doc_stride=doc_stride, max_query_length=max_query_length, is_training=False, output_fn=_append_feature, batch_size=FLAGS.predict_batch_size) # squad_lib_sp requires one more argument 'do_lower_case'. if squad_lib == squad_lib_sp: kwargs['do_lower_case'] = FLAGS.do_lower_case dataset_size = squad_lib.convert_examples_to_features(**kwargs) eval_writer.close() logging.info('***** Running predictions *****') logging.info(' Num orig examples = %d', len(eval_examples)) logging.info(' Num split examples = %d', len(eval_features)) logging.info(' Batch size = %d', FLAGS.predict_batch_size) num_steps = int(dataset_size / FLAGS.predict_batch_size) all_results = predict_squad_customized(strategy, input_meta_data, eval_writer.filename, num_steps, squad_model) all_predictions, all_nbest_json, scores_diff_json = ( squad_lib.postprocess_output( eval_examples, eval_features, all_results, FLAGS.n_best_size, FLAGS.max_answer_length, FLAGS.do_lower_case, version_2_with_negative=version_2_with_negative, null_score_diff_threshold=FLAGS.null_score_diff_threshold, verbose=FLAGS.verbose_logging)) return all_predictions, all_nbest_json, scores_diff_json def dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, version_2_with_negative, file_prefix=''): """Save output to json files.""" output_prediction_file = os.path.join(FLAGS.model_dir, '%spredictions.json' % file_prefix) output_nbest_file = os.path.join(FLAGS.model_dir, '%snbest_predictions.json' % file_prefix) output_null_log_odds_file = os.path.join(FLAGS.model_dir, file_prefix, '%snull_odds.json' % file_prefix) logging.info('Writing predictions to: %s', (output_prediction_file)) logging.info('Writing nbest to: %s', (output_nbest_file)) squad_lib.write_to_json_files(all_predictions, output_prediction_file) squad_lib.write_to_json_files(all_nbest_json, output_nbest_file) if version_2_with_negative: squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file) def _get_matched_files(input_path): """Returns all files that matches the input_path.""" input_patterns = input_path.strip().split(',') all_matched_files = [] for input_pattern in input_patterns: input_pattern = input_pattern.strip() if not input_pattern: continue matched_files = tf.io.gfile.glob(input_pattern) if not matched_files: raise ValueError('%s does not match any files.' % input_pattern) else: all_matched_files.extend(matched_files) return sorted(all_matched_files) def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib, init_checkpoint=None): """Get prediction results and evaluate them to hard drive.""" if init_checkpoint is None: init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) all_predict_files = _get_matched_files(FLAGS.predict_file) squad_model = get_squad_model_to_predict(strategy, bert_config, init_checkpoint, input_meta_data) for idx, predict_file in enumerate(all_predict_files): all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad( strategy, input_meta_data, tokenizer, squad_lib, predict_file, squad_model) if len(all_predict_files) == 1: file_prefix = '' else: # if predict_file is /path/xquad.ar.json, the `file_prefix` may be # "xquad.ar-0-" file_prefix = '%s-' % os.path.splitext( os.path.basename(all_predict_files[idx]))[0] dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, input_meta_data.get('version_2_with_negative', False), file_prefix) def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib, init_checkpoint=None): """Get prediction results and evaluate them against ground truth.""" if init_checkpoint is None: init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) all_predict_files = _get_matched_files(FLAGS.predict_file) if len(all_predict_files) != 1: raise ValueError('`eval_squad` only supports one predict file, ' 'but got %s' % all_predict_files) squad_model = get_squad_model_to_predict(strategy, bert_config, init_checkpoint, input_meta_data) all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad( strategy, input_meta_data, tokenizer, squad_lib, all_predict_files[0], squad_model) dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, input_meta_data.get('version_2_with_negative', False)) with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader: dataset_json = json.load(reader) pred_dataset = dataset_json['data'] if input_meta_data.get('version_2_with_negative', False): eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset, all_predictions, scores_diff_json) else: eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions) return eval_metrics def export_squad(model_export_path, input_meta_data, bert_config): """Exports a trained model as a `SavedModel` for inference. Args: model_export_path: a string specifying the path to the SavedModel directory. input_meta_data: dictionary containing meta data about input and model. bert_config: Bert configuration file to define core bert layers. Raises: Export path is not specified, got an empty string or None. """ if not model_export_path: raise ValueError('Export path is not specified: %s' % model_export_path) # Export uses float32 for now, even if training uses mixed precision. tf_keras.mixed_precision.set_global_policy('float32') squad_model, _ = bert_models.squad_model(bert_config, input_meta_data['max_seq_length']) model_saving_utils.export_bert_model( model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)