File size: 8,616 Bytes
97b6013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run NHNet model training and eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from absl import logging
from six.moves import zip
import tensorflow as tf
from official.modeling.hyperparams import params_dict
from official.nlp.nhnet import evaluation
from official.nlp.nhnet import input_pipeline
from official.nlp.nhnet import models
from official.nlp.nhnet import optimizer
from official.nlp.transformer import metrics as transformer_metrics
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
def define_flags():
"""Defines command line flags used by NHNet trainer."""
## Required parameters
flags.DEFINE_enum("mode", "train", ["train", "eval", "train_and_eval"],
"Execution mode.")
flags.DEFINE_string("train_file_pattern", "", "Train file pattern.")
flags.DEFINE_string("eval_file_pattern", "", "Eval file pattern.")
flags.DEFINE_string(
"model_dir", None,
"The output directory where the model checkpoints will be written.")
# Model training specific flags.
flags.DEFINE_enum(
"distribution_strategy", "mirrored", ["tpu", "mirrored"],
"Distribution Strategy type to use for training. `tpu` uses TPUStrategy "
"for running on TPUs, `mirrored` uses GPUs with single host.")
flags.DEFINE_string("tpu", "", "TPU address to connect to.")
flags.DEFINE_string(
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained BERT model).")
flags.DEFINE_integer("train_steps", 100000, "Max train steps")
flags.DEFINE_integer("eval_steps", 32, "Number of eval steps per run.")
flags.DEFINE_integer("eval_timeout", 3000, "Timeout waiting for checkpoints.")
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
flags.DEFINE_integer("eval_batch_size", 4, "Total batch size for evaluation.")
flags.DEFINE_integer(
"steps_per_loop", 1000,
"Number of steps per graph-mode loop. Only training step "
"happens inside the loop.")
flags.DEFINE_integer("checkpoint_interval", 2000, "Checkpointing interval.")
flags.DEFINE_integer("len_title", 15, "Title length.")
flags.DEFINE_integer("len_passage", 200, "Passage length.")
flags.DEFINE_integer("num_encoder_layers", 12,
"Number of hidden layers of encoder.")
flags.DEFINE_integer("num_decoder_layers", 12,
"Number of hidden layers of decoder.")
flags.DEFINE_string("model_type", "nhnet",
"Model type to choose a model configuration.")
flags.DEFINE_integer(
"num_nhnet_articles", 5,
"Maximum number of articles in NHNet, only used when model_type=nhnet")
flags.DEFINE_string(
"params_override",
default=None,
help=("a YAML/JSON string or a YAML file which specifies additional "
"overrides over the default parameters"))
# pylint: disable=protected-access
class Trainer(tf.keras.Model):
"""A training only model."""
def __init__(self, model, params):
super(Trainer, self).__init__()
self.model = model
self.params = params
self._num_replicas_in_sync = tf.distribute.get_strategy(
).num_replicas_in_sync
def call(self, inputs, mode="train"):
return self.model(inputs, mode)
def train_step(self, inputs):
"""The logic for one training step."""
with tf.GradientTape() as tape:
logits, _, _ = self(inputs, mode="train", training=True)
targets = models.remove_sos_from_seq(inputs["target_ids"],
self.params.pad_token_id)
loss = transformer_metrics.transformer_loss(logits, targets,
self.params.label_smoothing,
self.params.vocab_size)
# Scales the loss, which results in using the average loss across all
# of the replicas for backprop.
scaled_loss = loss / self._num_replicas_in_sync
tvars = self.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
self.optimizer.apply_gradients(list(zip(grads, tvars)))
return {
"training_loss": loss,
"learning_rate": self.optimizer._decayed_lr(var_dtype=tf.float32)
}
def train(params, strategy, dataset=None):
"""Runs training."""
if not dataset:
dataset = input_pipeline.get_input_dataset(
FLAGS.train_file_pattern,
FLAGS.train_batch_size,
params,
is_training=True,
strategy=strategy)
with strategy.scope():
model = models.create_model(
FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint)
opt = optimizer.create_optimizer(params)
trainer = Trainer(model, params)
model.global_step = opt.iterations
trainer.compile(
optimizer=opt,
experimental_steps_per_execution=FLAGS.steps_per_loop)
summary_dir = os.path.join(FLAGS.model_dir, "summaries")
summary_callback = tf.keras.callbacks.TensorBoard(
summary_dir, update_freq=max(100, FLAGS.steps_per_loop))
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=FLAGS.model_dir,
max_to_keep=10,
step_counter=model.global_step,
checkpoint_interval=FLAGS.checkpoint_interval)
if checkpoint_manager.restore_or_initialize():
logging.info("Training restored from the checkpoints in: %s",
FLAGS.model_dir)
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
# Trains the model.
steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval)
epochs = FLAGS.train_steps // steps_per_epoch
history = trainer.fit(
x=dataset,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
callbacks=[summary_callback, checkpoint_callback],
verbose=2)
train_hist = history.history
# Gets final loss from training.
stats = dict(training_loss=float(train_hist["training_loss"][-1]))
return stats
def run():
"""Runs NHNet using Keras APIs."""
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu)
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
params = models.get_model_params(FLAGS.model_type)
params = params_dict.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.override(
{
"len_title":
FLAGS.len_title,
"len_passage":
FLAGS.len_passage,
"num_hidden_layers":
FLAGS.num_encoder_layers,
"num_decoder_layers":
FLAGS.num_decoder_layers,
"passage_list":
[chr(ord("b") + i) for i in range(FLAGS.num_nhnet_articles)],
},
is_strict=False)
stats = {}
if "train" in FLAGS.mode:
stats = train(params, strategy)
if "eval" in FLAGS.mode:
timeout = 0 if FLAGS.mode == "train_and_eval" else FLAGS.eval_timeout
# Uses padded decoding for TPU. Always uses cache.
padded_decode = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
params.override({
"padded_decode": padded_decode,
}, is_strict=False)
stats = evaluation.continuous_eval(
strategy,
params,
model_type=FLAGS.model_type,
eval_file_pattern=FLAGS.eval_file_pattern,
batch_size=FLAGS.eval_batch_size,
eval_steps=FLAGS.eval_steps,
model_dir=FLAGS.model_dir,
timeout=timeout)
return stats
def main(_):
stats = run()
if stats:
logging.info("Stats:\n%s", stats)
if __name__ == "__main__":
define_flags()
app.run(main)
|