deanna-emery's picture
updates
93528c6
raw
history blame
11.6 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XLNet SQUAD finetuning runner in tf2.0."""
import functools
import json
import os
import pickle
# Import libraries
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf, tf_keras
# pylint: disable=unused-import
import sentencepiece as spm
from official.common import distribute_utils
from official.legacy.xlnet import common_flags
from official.legacy.xlnet import data_utils
from official.legacy.xlnet import optimization
from official.legacy.xlnet import squad_utils
from official.legacy.xlnet import training_utils
from official.legacy.xlnet import xlnet_config
from official.legacy.xlnet import xlnet_modeling as modeling
flags.DEFINE_string(
"test_feature_path", default=None, help="Path to feature of test set.")
flags.DEFINE_integer("query_len", default=64, help="Max query length.")
flags.DEFINE_integer("start_n_top", default=5, help="Beam size for span start.")
flags.DEFINE_integer("end_n_top", default=5, help="Beam size for span end.")
flags.DEFINE_string(
"predict_dir", default=None, help="Path to write predictions.")
flags.DEFINE_string(
"predict_file", default=None, help="Path to json file of test set.")
flags.DEFINE_integer(
"n_best_size", default=5, help="n best size for predictions.")
flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length.")
# Data preprocessing config
flags.DEFINE_string(
"spiece_model_file", default=None, help="Sentence Piece model path.")
flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length.")
flags.DEFINE_integer("max_query_length", default=64, help="Max query length.")
flags.DEFINE_integer("doc_stride", default=128, help="Doc stride.")
FLAGS = flags.FLAGS
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tok_start_to_orig_index,
tok_end_to_orig_index,
token_is_max_context,
input_ids,
input_mask,
p_mask,
segment_ids,
paragraph_len,
cls_index,
start_position=None,
end_position=None,
is_impossible=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.tok_start_to_orig_index = tok_start_to_orig_index
self.tok_end_to_orig_index = tok_end_to_orig_index
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.p_mask = p_mask
self.segment_ids = segment_ids
self.paragraph_len = paragraph_len
self.cls_index = cls_index
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
# pylint: disable=unused-argument
def run_evaluation(strategy, test_input_fn, eval_examples, eval_features,
original_data, eval_steps, input_meta_data, model,
current_step, eval_summary_writer):
"""Run evaluation for SQUAD task.
Args:
strategy: distribution strategy.
test_input_fn: input function for evaluation data.
eval_examples: tf.Examples of the evaluation set.
eval_features: Feature objects of the evaluation set.
original_data: The original json data for the evaluation set.
eval_steps: total number of evaluation steps.
input_meta_data: input meta data.
model: keras model object.
current_step: current training step.
eval_summary_writer: summary writer used to record evaluation metrics.
Returns:
A float metric, F1 score.
"""
def _test_step_fn(inputs):
"""Replicated validation step."""
inputs["mems"] = None
res = model(inputs, training=False)
return res, inputs["unique_ids"]
@tf.function
def _run_evaluation(test_iterator):
"""Runs validation steps."""
res, unique_ids = strategy.run(
_test_step_fn, args=(next(test_iterator),))
return res, unique_ids
test_iterator = data_utils.get_input_iterator(test_input_fn, strategy)
cur_results = []
for _ in range(eval_steps):
results, unique_ids = _run_evaluation(test_iterator)
unique_ids = strategy.experimental_local_results(unique_ids)
for result_key in results:
results[result_key] = (
strategy.experimental_local_results(results[result_key]))
for core_i in range(strategy.num_replicas_in_sync):
bsz = int(input_meta_data["test_batch_size"] /
strategy.num_replicas_in_sync)
for j in range(bsz):
result = {}
for result_key in results:
result[result_key] = results[result_key][core_i].numpy()[j]
result["unique_ids"] = unique_ids[core_i].numpy()[j]
# We appended a fake example into dev set to make data size can be
# divided by test_batch_size. Ignores this fake example during
# evaluation.
if result["unique_ids"] == 1000012047:
continue
unique_id = int(result["unique_ids"])
start_top_log_probs = ([
float(x) for x in result["start_top_log_probs"].flat
])
start_top_index = [int(x) for x in result["start_top_index"].flat]
end_top_log_probs = ([
float(x) for x in result["end_top_log_probs"].flat
])
end_top_index = [int(x) for x in result["end_top_index"].flat]
cls_logits = float(result["cls_logits"].flat[0])
cur_results.append(
squad_utils.RawResult(
unique_id=unique_id,
start_top_log_probs=start_top_log_probs,
start_top_index=start_top_index,
end_top_log_probs=end_top_log_probs,
end_top_index=end_top_index,
cls_logits=cls_logits))
if len(cur_results) % 1000 == 0:
logging.info("Processing example: %d", len(cur_results))
output_prediction_file = os.path.join(input_meta_data["predict_dir"],
"predictions.json")
output_nbest_file = os.path.join(input_meta_data["predict_dir"],
"nbest_predictions.json")
output_null_log_odds_file = os.path.join(input_meta_data["predict_dir"],
"null_odds.json")
results = squad_utils.write_predictions(
eval_examples, eval_features, cur_results, input_meta_data["n_best_size"],
input_meta_data["max_answer_length"], output_prediction_file,
output_nbest_file, output_null_log_odds_file, original_data,
input_meta_data["start_n_top"], input_meta_data["end_n_top"])
# Log current results.
log_str = "Result | "
for key, val in results.items():
log_str += "{} {} | ".format(key, val)
logging.info(log_str)
with eval_summary_writer.as_default():
tf.summary.scalar("best_f1", results["best_f1"], step=current_step)
tf.summary.scalar("best_exact", results["best_exact"], step=current_step)
eval_summary_writer.flush()
return results["best_f1"]
def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
model = modeling.QAXLNetModel(
model_config,
run_config,
start_n_top=start_n_top,
end_n_top=end_n_top,
name="model")
return model
def main(unused_argv):
del unused_argv
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu)
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
train_input_fn = functools.partial(data_utils.get_squad_input_data,
FLAGS.train_batch_size, FLAGS.seq_len,
FLAGS.query_len, strategy, True,
FLAGS.train_tfrecord_path)
test_input_fn = functools.partial(data_utils.get_squad_input_data,
FLAGS.test_batch_size, FLAGS.seq_len,
FLAGS.query_len, strategy, False,
FLAGS.test_tfrecord_path)
total_training_steps = FLAGS.train_steps
steps_per_loop = FLAGS.iterations
eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
optimizer, learning_rate_fn = optimization.create_optimizer(
FLAGS.learning_rate,
total_training_steps,
FLAGS.warmup_steps,
adam_epsilon=FLAGS.adam_epsilon)
model_config = xlnet_config.XLNetConfig(FLAGS)
run_config = xlnet_config.create_run_config(True, False, FLAGS)
input_meta_data = {}
input_meta_data["start_n_top"] = FLAGS.start_n_top
input_meta_data["end_n_top"] = FLAGS.end_n_top
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["predict_dir"] = FLAGS.predict_dir
input_meta_data["n_best_size"] = FLAGS.n_best_size
input_meta_data["max_answer_length"] = FLAGS.max_answer_length
input_meta_data["test_batch_size"] = FLAGS.test_batch_size
input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
strategy.num_replicas_in_sync)
input_meta_data["mem_len"] = FLAGS.mem_len
model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
FLAGS.start_n_top, FLAGS.end_n_top)
eval_examples = squad_utils.read_squad_examples(
FLAGS.predict_file, is_training=False)
if FLAGS.test_feature_path:
logging.info("start reading pickle file...")
with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f:
eval_features = pickle.load(f)
logging.info("finishing reading pickle file...")
else:
sp_model = spm.SentencePieceProcessor()
sp_model.LoadFromSerializedProto(
tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read())
spm_basename = os.path.basename(FLAGS.spiece_model_file)
eval_features = squad_utils.create_eval_data(
spm_basename, sp_model, eval_examples, FLAGS.max_seq_length,
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased)
with tf.io.gfile.GFile(FLAGS.predict_file) as f:
original_data = json.load(f)["data"]
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
eval_examples, eval_features, original_data,
eval_steps, input_meta_data)
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=eval_fn,
metric_fn=None,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir,
save_steps=FLAGS.save_steps)
if __name__ == "__main__":
app.run(main)