|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import json |
|
import os |
|
import time |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import gin |
|
import tensorflow as tf |
|
|
|
from official.nlp.bert import configs as bert_configs |
|
from official.nlp.bert import run_squad_helper |
|
from official.nlp.bert import tokenization |
|
from official.nlp.data import squad_lib as squad_lib_wp |
|
from official.utils.misc import distribution_utils |
|
from official.utils.misc import keras_utils |
|
|
|
|
|
flags.DEFINE_string('vocab_file', None, |
|
'The vocabulary file that the BERT model was trained on.') |
|
|
|
|
|
run_squad_helper.define_common_squad_flags() |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
def train_squad(strategy, |
|
input_meta_data, |
|
custom_callbacks=None, |
|
run_eagerly=False, |
|
init_checkpoint=None, |
|
sub_model_export_name=None): |
|
"""Run bert squad training.""" |
|
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) |
|
init_checkpoint = init_checkpoint or FLAGS.init_checkpoint |
|
run_squad_helper.train_squad(strategy, input_meta_data, bert_config, |
|
custom_callbacks, run_eagerly, init_checkpoint, |
|
sub_model_export_name=sub_model_export_name) |
|
|
|
|
|
def predict_squad(strategy, input_meta_data): |
|
"""Makes predictions for the squad dataset.""" |
|
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) |
|
tokenizer = tokenization.FullTokenizer( |
|
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) |
|
run_squad_helper.predict_squad( |
|
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp) |
|
|
|
|
|
def eval_squad(strategy, input_meta_data): |
|
"""Evaluate on the squad dataset.""" |
|
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) |
|
tokenizer = tokenization.FullTokenizer( |
|
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) |
|
eval_metrics = run_squad_helper.eval_squad( |
|
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp) |
|
return eval_metrics |
|
|
|
|
|
def export_squad(model_export_path, input_meta_data): |
|
"""Exports a trained model as a `SavedModel` for inference. |
|
|
|
Args: |
|
model_export_path: a string specifying the path to the SavedModel directory. |
|
input_meta_data: dictionary containing meta data about input and model. |
|
|
|
Raises: |
|
Export path is not specified, got an empty string or None. |
|
""" |
|
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) |
|
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config) |
|
|
|
|
|
def main(_): |
|
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) |
|
|
|
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: |
|
input_meta_data = json.loads(reader.read().decode('utf-8')) |
|
|
|
if FLAGS.mode == 'export_only': |
|
export_squad(FLAGS.model_export_path, input_meta_data) |
|
return |
|
|
|
|
|
if FLAGS.num_gpus > 0: |
|
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts, |
|
FLAGS.task_index) |
|
strategy = distribution_utils.get_distribution_strategy( |
|
distribution_strategy=FLAGS.distribution_strategy, |
|
num_gpus=FLAGS.num_gpus, |
|
all_reduce_alg=FLAGS.all_reduce_alg, |
|
tpu_address=FLAGS.tpu) |
|
|
|
if 'train' in FLAGS.mode: |
|
if FLAGS.log_steps: |
|
custom_callbacks = [keras_utils.TimeHistory( |
|
batch_size=FLAGS.train_batch_size, |
|
log_steps=FLAGS.log_steps, |
|
logdir=FLAGS.model_dir, |
|
)] |
|
else: |
|
custom_callbacks = None |
|
|
|
train_squad( |
|
strategy, |
|
input_meta_data, |
|
custom_callbacks=custom_callbacks, |
|
run_eagerly=FLAGS.run_eagerly, |
|
sub_model_export_name=FLAGS.sub_model_export_name, |
|
) |
|
if 'predict' in FLAGS.mode: |
|
predict_squad(strategy, input_meta_data) |
|
if 'eval' in FLAGS.mode: |
|
eval_metrics = eval_squad(strategy, input_meta_data) |
|
f1_score = eval_metrics['final_f1'] |
|
logging.info('SQuAD eval F1-score: %f', f1_score) |
|
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval') |
|
summary_writer = tf.summary.create_file_writer(summary_dir) |
|
with summary_writer.as_default(): |
|
|
|
tf.summary.scalar('F1-score', f1_score, step=0) |
|
summary_writer.flush() |
|
|
|
squad_lib_wp.write_to_json_files( |
|
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json')) |
|
time.sleep(60) |
|
|
|
|
|
if __name__ == '__main__': |
|
flags.mark_flag_as_required('bert_config_file') |
|
flags.mark_flag_as_required('model_dir') |
|
app.run(main) |
|
|