|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import json |
|
import os |
|
import time |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import tensorflow as tf |
|
|
|
from official.nlp.albert import configs as albert_configs |
|
from official.nlp.bert import run_squad_helper |
|
from official.nlp.bert import tokenization |
|
from official.nlp.data import squad_lib_sp |
|
from official.utils.misc import distribution_utils |
|
|
|
flags.DEFINE_string( |
|
'sp_model_file', None, |
|
'The path to the sentence piece model. Used by sentence piece tokenizer ' |
|
'employed by ALBERT.') |
|
|
|
|
|
run_squad_helper.define_common_squad_flags() |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
def train_squad(strategy, |
|
input_meta_data, |
|
custom_callbacks=None, |
|
run_eagerly=False): |
|
"""Runs bert squad training.""" |
|
bert_config = albert_configs.AlbertConfig.from_json_file( |
|
FLAGS.bert_config_file) |
|
run_squad_helper.train_squad(strategy, input_meta_data, bert_config, |
|
custom_callbacks, run_eagerly) |
|
|
|
|
|
def predict_squad(strategy, input_meta_data): |
|
"""Makes predictions for the squad dataset.""" |
|
bert_config = albert_configs.AlbertConfig.from_json_file( |
|
FLAGS.bert_config_file) |
|
tokenizer = tokenization.FullSentencePieceTokenizer( |
|
sp_model_file=FLAGS.sp_model_file) |
|
|
|
run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer, |
|
bert_config, squad_lib_sp) |
|
|
|
|
|
def eval_squad(strategy, input_meta_data): |
|
"""Evaluate on the squad dataset.""" |
|
bert_config = albert_configs.AlbertConfig.from_json_file( |
|
FLAGS.bert_config_file) |
|
tokenizer = tokenization.FullSentencePieceTokenizer( |
|
sp_model_file=FLAGS.sp_model_file) |
|
|
|
eval_metrics = run_squad_helper.eval_squad( |
|
strategy, input_meta_data, tokenizer, bert_config, squad_lib_sp) |
|
return eval_metrics |
|
|
|
|
|
def export_squad(model_export_path, input_meta_data): |
|
"""Exports a trained model as a `SavedModel` for inference. |
|
|
|
Args: |
|
model_export_path: a string specifying the path to the SavedModel directory. |
|
input_meta_data: dictionary containing meta data about input and model. |
|
|
|
Raises: |
|
Export path is not specified, got an empty string or None. |
|
""" |
|
bert_config = albert_configs.AlbertConfig.from_json_file( |
|
FLAGS.bert_config_file) |
|
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config) |
|
|
|
|
|
def main(_): |
|
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: |
|
input_meta_data = json.loads(reader.read().decode('utf-8')) |
|
|
|
if FLAGS.mode == 'export_only': |
|
export_squad(FLAGS.model_export_path, input_meta_data) |
|
return |
|
|
|
|
|
if FLAGS.num_gpus > 0: |
|
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts, |
|
FLAGS.task_index) |
|
strategy = distribution_utils.get_distribution_strategy( |
|
distribution_strategy=FLAGS.distribution_strategy, |
|
num_gpus=FLAGS.num_gpus, |
|
all_reduce_alg=FLAGS.all_reduce_alg, |
|
tpu_address=FLAGS.tpu) |
|
|
|
if 'train' in FLAGS.mode: |
|
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly) |
|
if 'predict' in FLAGS.mode: |
|
predict_squad(strategy, input_meta_data) |
|
if 'eval' in FLAGS.mode: |
|
eval_metrics = eval_squad(strategy, input_meta_data) |
|
f1_score = eval_metrics['final_f1'] |
|
logging.info('SQuAD eval F1-score: %f', f1_score) |
|
summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval') |
|
summary_writer = tf.summary.create_file_writer(summary_dir) |
|
with summary_writer.as_default(): |
|
|
|
tf.summary.scalar('F1-score', f1_score, step=0) |
|
summary_writer.flush() |
|
|
|
squad_lib_sp.write_to_json_files( |
|
eval_metrics, os.path.join(summary_dir, 'eval_metrics.json')) |
|
time.sleep(60) |
|
|
|
|
|
if __name__ == '__main__': |
|
flags.mark_flag_as_required('bert_config_file') |
|
flags.mark_flag_as_required('model_dir') |
|
app.run(main) |
|
|