# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """XLNet SQUAD finetuning runner in tf2.0.""" import functools import json import os import pickle # Import libraries from absl import app from absl import flags from absl import logging import tensorflow as tf, tf_keras # pylint: disable=unused-import import sentencepiece as spm from official.common import distribute_utils from official.legacy.xlnet import common_flags from official.legacy.xlnet import data_utils from official.legacy.xlnet import optimization from official.legacy.xlnet import squad_utils from official.legacy.xlnet import training_utils from official.legacy.xlnet import xlnet_config from official.legacy.xlnet import xlnet_modeling as modeling flags.DEFINE_string( "test_feature_path", default=None, help="Path to feature of test set.") flags.DEFINE_integer("query_len", default=64, help="Max query length.") flags.DEFINE_integer("start_n_top", default=5, help="Beam size for span start.") flags.DEFINE_integer("end_n_top", default=5, help="Beam size for span end.") flags.DEFINE_string( "predict_dir", default=None, help="Path to write predictions.") flags.DEFINE_string( "predict_file", default=None, help="Path to json file of test set.") flags.DEFINE_integer( "n_best_size", default=5, help="n best size for predictions.") flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length.") # Data preprocessing config flags.DEFINE_string( "spiece_model_file", default=None, help="Sentence Piece model path.") flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length.") flags.DEFINE_integer("max_query_length", default=64, help="Max query length.") flags.DEFINE_integer("doc_stride", default=128, help="Doc stride.") FLAGS = flags.FLAGS class InputFeatures(object): """A single set of features of data.""" def __init__(self, unique_id, example_index, doc_span_index, tok_start_to_orig_index, tok_end_to_orig_index, token_is_max_context, input_ids, input_mask, p_mask, segment_ids, paragraph_len, cls_index, start_position=None, end_position=None, is_impossible=None): self.unique_id = unique_id self.example_index = example_index self.doc_span_index = doc_span_index self.tok_start_to_orig_index = tok_start_to_orig_index self.tok_end_to_orig_index = tok_end_to_orig_index self.token_is_max_context = token_is_max_context self.input_ids = input_ids self.input_mask = input_mask self.p_mask = p_mask self.segment_ids = segment_ids self.paragraph_len = paragraph_len self.cls_index = cls_index self.start_position = start_position self.end_position = end_position self.is_impossible = is_impossible # pylint: disable=unused-argument def run_evaluation(strategy, test_input_fn, eval_examples, eval_features, original_data, eval_steps, input_meta_data, model, current_step, eval_summary_writer): """Run evaluation for SQUAD task. Args: strategy: distribution strategy. test_input_fn: input function for evaluation data. eval_examples: tf.Examples of the evaluation set. eval_features: Feature objects of the evaluation set. original_data: The original json data for the evaluation set. eval_steps: total number of evaluation steps. input_meta_data: input meta data. model: keras model object. current_step: current training step. eval_summary_writer: summary writer used to record evaluation metrics. Returns: A float metric, F1 score. """ def _test_step_fn(inputs): """Replicated validation step.""" inputs["mems"] = None res = model(inputs, training=False) return res, inputs["unique_ids"] @tf.function def _run_evaluation(test_iterator): """Runs validation steps.""" res, unique_ids = strategy.run( _test_step_fn, args=(next(test_iterator),)) return res, unique_ids test_iterator = data_utils.get_input_iterator(test_input_fn, strategy) cur_results = [] for _ in range(eval_steps): results, unique_ids = _run_evaluation(test_iterator) unique_ids = strategy.experimental_local_results(unique_ids) for result_key in results: results[result_key] = ( strategy.experimental_local_results(results[result_key])) for core_i in range(strategy.num_replicas_in_sync): bsz = int(input_meta_data["test_batch_size"] / strategy.num_replicas_in_sync) for j in range(bsz): result = {} for result_key in results: result[result_key] = results[result_key][core_i].numpy()[j] result["unique_ids"] = unique_ids[core_i].numpy()[j] # We appended a fake example into dev set to make data size can be # divided by test_batch_size. Ignores this fake example during # evaluation. if result["unique_ids"] == 1000012047: continue unique_id = int(result["unique_ids"]) start_top_log_probs = ([ float(x) for x in result["start_top_log_probs"].flat ]) start_top_index = [int(x) for x in result["start_top_index"].flat] end_top_log_probs = ([ float(x) for x in result["end_top_log_probs"].flat ]) end_top_index = [int(x) for x in result["end_top_index"].flat] cls_logits = float(result["cls_logits"].flat[0]) cur_results.append( squad_utils.RawResult( unique_id=unique_id, start_top_log_probs=start_top_log_probs, start_top_index=start_top_index, end_top_log_probs=end_top_log_probs, end_top_index=end_top_index, cls_logits=cls_logits)) if len(cur_results) % 1000 == 0: logging.info("Processing example: %d", len(cur_results)) output_prediction_file = os.path.join(input_meta_data["predict_dir"], "predictions.json") output_nbest_file = os.path.join(input_meta_data["predict_dir"], "nbest_predictions.json") output_null_log_odds_file = os.path.join(input_meta_data["predict_dir"], "null_odds.json") results = squad_utils.write_predictions( eval_examples, eval_features, cur_results, input_meta_data["n_best_size"], input_meta_data["max_answer_length"], output_prediction_file, output_nbest_file, output_null_log_odds_file, original_data, input_meta_data["start_n_top"], input_meta_data["end_n_top"]) # Log current results. log_str = "Result | " for key, val in results.items(): log_str += "{} {} | ".format(key, val) logging.info(log_str) with eval_summary_writer.as_default(): tf.summary.scalar("best_f1", results["best_f1"], step=current_step) tf.summary.scalar("best_exact", results["best_exact"], step=current_step) eval_summary_writer.flush() return results["best_f1"] def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top): model = modeling.QAXLNetModel( model_config, run_config, start_n_top=start_n_top, end_n_top=end_n_top, name="model") return model def main(unused_argv): del unused_argv strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.strategy_type, tpu_address=FLAGS.tpu) if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) train_input_fn = functools.partial(data_utils.get_squad_input_data, FLAGS.train_batch_size, FLAGS.seq_len, FLAGS.query_len, strategy, True, FLAGS.train_tfrecord_path) test_input_fn = functools.partial(data_utils.get_squad_input_data, FLAGS.test_batch_size, FLAGS.seq_len, FLAGS.query_len, strategy, False, FLAGS.test_tfrecord_path) total_training_steps = FLAGS.train_steps steps_per_loop = FLAGS.iterations eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size) optimizer, learning_rate_fn = optimization.create_optimizer( FLAGS.learning_rate, total_training_steps, FLAGS.warmup_steps, adam_epsilon=FLAGS.adam_epsilon) model_config = xlnet_config.XLNetConfig(FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS) input_meta_data = {} input_meta_data["start_n_top"] = FLAGS.start_n_top input_meta_data["end_n_top"] = FLAGS.end_n_top input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["predict_dir"] = FLAGS.predict_dir input_meta_data["n_best_size"] = FLAGS.n_best_size input_meta_data["max_answer_length"] = FLAGS.max_answer_length input_meta_data["test_batch_size"] = FLAGS.test_batch_size input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["mem_len"] = FLAGS.mem_len model_fn = functools.partial(get_qaxlnet_model, model_config, run_config, FLAGS.start_n_top, FLAGS.end_n_top) eval_examples = squad_utils.read_squad_examples( FLAGS.predict_file, is_training=False) if FLAGS.test_feature_path: logging.info("start reading pickle file...") with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f: eval_features = pickle.load(f) logging.info("finishing reading pickle file...") else: sp_model = spm.SentencePieceProcessor() sp_model.LoadFromSerializedProto( tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read()) spm_basename = os.path.basename(FLAGS.spiece_model_file) eval_features = squad_utils.create_eval_data( spm_basename, sp_model, eval_examples, FLAGS.max_seq_length, FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased) with tf.io.gfile.GFile(FLAGS.predict_file) as f: original_data = json.load(f)["data"] eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_examples, eval_features, original_data, eval_steps, input_meta_data) training_utils.train( strategy=strategy, model_fn=model_fn, input_meta_data=input_meta_data, eval_fn=eval_fn, metric_fn=None, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, init_from_transformerxl=FLAGS.init_from_transformerxl, total_training_steps=total_training_steps, steps_per_loop=steps_per_loop, optimizer=optimizer, learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir, save_steps=FLAGS.save_steps) if __name__ == "__main__": app.run(main)