|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""XLNet classification finetuning runner in tf2.0.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
import functools |
|
import os |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import tensorflow as tf |
|
|
|
from official.nlp.xlnet import common_flags |
|
from official.nlp.xlnet import data_utils |
|
from official.nlp.xlnet import optimization |
|
from official.nlp.xlnet import training_utils |
|
from official.nlp.xlnet import xlnet_config |
|
from official.nlp.xlnet import xlnet_modeling as modeling |
|
from official.utils.misc import tpu_lib |
|
|
|
flags.DEFINE_integer( |
|
"num_predict", |
|
default=None, |
|
help="Number of tokens to predict in partial prediction.") |
|
|
|
|
|
flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.") |
|
flags.DEFINE_float("leak_ratio", default=0.1, |
|
help="Percent of masked tokens that are leaked.") |
|
|
|
flags.DEFINE_enum("sample_strategy", default="token_span", |
|
enum_values=["single_token", "whole_word", "token_span", |
|
"word_span"], |
|
help="Stragey used to sample prediction targets.") |
|
flags.DEFINE_integer("max_num_tokens", default=5, |
|
help="Maximum number of tokens to sample in a span." |
|
"Effective when token_span strategy is used.") |
|
flags.DEFINE_integer("min_num_tokens", default=1, |
|
help="Minimum number of tokens to sample in a span." |
|
"Effective when token_span strategy is used.") |
|
|
|
flags.DEFINE_integer("max_num_words", default=5, |
|
help="Maximum number of whole words to sample in a span." |
|
"Effective when word_span strategy is used.") |
|
flags.DEFINE_integer("min_num_words", default=1, |
|
help="Minimum number of whole words to sample in a span." |
|
"Effective when word_span strategy is used.") |
|
FLAGS = flags.FLAGS |
|
|
|
|
|
def get_pretrainxlnet_model(model_config, run_config): |
|
return modeling.PretrainingXLNetModel( |
|
use_proj=True, |
|
xlnet_config=model_config, |
|
run_config=run_config, |
|
name="model") |
|
|
|
|
|
def main(unused_argv): |
|
del unused_argv |
|
num_hosts = 1 |
|
if FLAGS.strategy_type == "mirror": |
|
strategy = tf.distribute.MirroredStrategy() |
|
elif FLAGS.strategy_type == "tpu": |
|
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) |
|
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) |
|
topology = FLAGS.tpu_topology.split("x") |
|
total_num_core = 2 * int(topology[0]) * int(topology[1]) |
|
num_hosts = total_num_core // FLAGS.num_core_per_host |
|
else: |
|
raise ValueError("The distribution strategy type is not supported: %s" % |
|
FLAGS.strategy_type) |
|
if strategy: |
|
logging.info("***** Number of cores used : %d", |
|
strategy.num_replicas_in_sync) |
|
logging.info("***** Number of hosts used : %d", num_hosts) |
|
online_masking_config = data_utils.OnlineMaskingConfig( |
|
sample_strategy=FLAGS.sample_strategy, |
|
max_num_tokens=FLAGS.max_num_tokens, |
|
min_num_tokens=FLAGS.min_num_tokens, |
|
max_num_words=FLAGS.max_num_words, |
|
min_num_words=FLAGS.min_num_words) |
|
|
|
train_input_fn = functools.partial( |
|
data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len, |
|
strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size, |
|
FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased, online_masking_config, |
|
num_hosts) |
|
|
|
total_training_steps = FLAGS.train_steps |
|
|
|
steps_per_loop = FLAGS.iterations |
|
|
|
optimizer, learning_rate_fn = optimization.create_optimizer( |
|
init_lr=FLAGS.learning_rate, |
|
num_train_steps=total_training_steps, |
|
num_warmup_steps=FLAGS.warmup_steps, |
|
min_lr_ratio=FLAGS.min_lr_ratio, |
|
adam_epsilon=FLAGS.adam_epsilon, |
|
weight_decay_rate=FLAGS.weight_decay_rate) |
|
|
|
model_config = xlnet_config.XLNetConfig(FLAGS) |
|
run_config = xlnet_config.create_run_config(True, False, FLAGS) |
|
input_meta_data = {} |
|
input_meta_data["d_model"] = FLAGS.d_model |
|
input_meta_data["mem_len"] = FLAGS.mem_len |
|
input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / |
|
strategy.num_replicas_in_sync) |
|
input_meta_data["n_layer"] = FLAGS.n_layer |
|
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate |
|
model_fn = functools.partial(get_pretrainxlnet_model, model_config, |
|
run_config) |
|
|
|
model = training_utils.train( |
|
strategy=strategy, |
|
model_fn=model_fn, |
|
input_meta_data=input_meta_data, |
|
eval_fn=None, |
|
metric_fn=None, |
|
train_input_fn=train_input_fn, |
|
init_checkpoint=FLAGS.init_checkpoint, |
|
init_from_transformerxl=FLAGS.init_from_transformerxl, |
|
total_training_steps=total_training_steps, |
|
steps_per_loop=steps_per_loop, |
|
optimizer=optimizer, |
|
learning_rate_fn=learning_rate_fn, |
|
model_dir=FLAGS.model_dir, |
|
save_steps=FLAGS.save_steps) |
|
|
|
|
|
checkpoint = tf.train.Checkpoint(transformer_xl=model.transformerxl_model) |
|
saved_path = checkpoint.save( |
|
os.path.join(FLAGS.model_dir, "pretrained/transformer_xl.ckpt")) |
|
logging.info("Exporting the transformer-xl model as a new TF checkpoint: %s", |
|
saved_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|