deanna-emery's picture
updates
5672777
raw
history blame
11.7 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XLNet training utils."""
import os
import re
from typing import Any, Callable, Dict, Optional, Text
from absl import logging
import tensorflow as tf, tf_keras
from official.legacy.bert import model_training_utils
from official.legacy.xlnet import data_utils
# pytype: disable=attribute-error
# pylint: disable=g-bare-generic,unused-import
_MIN_SUMMARY_STEPS = 10
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to with provided checkpoint prefix."""
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info("Saving model as TF checkpoint: %s", saved_path)
return
def _float_metric_value(metric):
"""Gets the value of a float-value keras metric."""
return metric.result().numpy().astype(float)
def train(
strategy: tf.distribute.Strategy,
model_fn: Callable,
input_meta_data: Dict,
train_input_fn: Callable,
total_training_steps: int,
steps_per_loop: int,
optimizer: tf_keras.optimizers.Optimizer,
learning_rate_fn: tf_keras.optimizers.schedules.LearningRateSchedule,
eval_fn: Optional[Callable[[tf_keras.Model, int, tf.summary.SummaryWriter],
Any]] = None,
metric_fn: Optional[Callable[[], tf_keras.metrics.Metric]] = None,
init_checkpoint: Optional[Text] = None,
init_from_transformerxl: Optional[bool] = False,
model_dir: Optional[Text] = None,
save_steps: Optional[int] = None,
run_eagerly: Optional[bool] = False):
"""Runs customized training.
Args:
strategy: Distribution strategy on which to run low level training loop.
model_fn: The function returns a keras.Model.
input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`,
`n_layer`, `batch_size_per_core` and `d_model`.
train_input_fn: Function returns a tf.data.Dataset used for training.
total_training_steps: Number of steps to train in total.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every
steps_per_loop.
optimizer: The optimizer for model.
learning_rate_fn: the learning rate schedule.
eval_fn: A callback of evaluation function, that takes a keras.Model,
current step and evaluation summary writer.
metric_fn: A metrics function returns a Keras Metric object to record
evaluation result using evaluation dataset or with training dataset
after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
init_from_transformerxl: Whether to load to `transformerxl_model` of
`model_fn`.
model_dir: The directory of model (checkpoints, summaries).
save_steps: The frequency to save checkpoints. Every save_steps, we save a
model checkpoint. Model checkpoint will be saved and evaluation will be
conducted if evaluation dataset is provided.
run_eagerly: Whether to run training eagerly.
Returns:
Last training step logits if training happens, otherwise returns None.
Raises:
TypeError: if model directory is not specified.
"""
required_arguments = [
train_input_fn, total_training_steps, steps_per_loop, optimizer,
learning_rate_fn, save_steps
]
if [arg for arg in required_arguments if arg is None]:
raise ValueError("`train_input_fn`, `total_training_steps`, "
"`steps_per_loop`, `optimizer`, `save_steps` and "
"`learning_rate_fn` are required parameters.")
if not model_dir:
raise TypeError("Model directory must be specified.")
train_iterator = data_utils.get_input_iterator(train_input_fn, strategy)
if not tf.io.gfile.exists(model_dir):
tf.io.gfile.mkdir(model_dir)
# Create summary writers
summary_dir = os.path.join(model_dir, "summaries")
if not tf.io.gfile.exists(summary_dir):
tf.io.gfile.mkdir(summary_dir)
train_summary_writer = None
eval_summary_writer = None
if eval_fn:
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, "eval"))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, "train"))
with strategy.scope():
model = model_fn()
if init_checkpoint:
logging.info("restore from %s", init_checkpoint)
if init_from_transformerxl:
checkpoint = tf.train.Checkpoint(
transformer_xl=model.transformerxl_model)
else:
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint)
model.optimizer = optimizer
if not hasattr(model, "optimizer"):
raise ValueError("User should set optimizer attribute to model.")
train_loss_metric = tf_keras.metrics.Mean("training_loss", dtype=tf.float32)
train_metric = None
if metric_fn:
train_metric = metric_fn()
def _replicated_step(inputs, mem=None):
"""Replicated training step."""
inputs["mems"] = mem
with tf.GradientTape() as tape:
mem, logits = model(inputs, training=True)
loss = model.losses
train_loss_metric.update_state(loss)
if train_metric:
train_metric.update_state(inputs["label_ids"], logits)
scaled_loss = loss[0] * 1.0 / float(strategy.num_replicas_in_sync)
# Collects training variables.
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
clipped, _ = tf.clip_by_global_norm(grads, clip_norm=1.0)
if input_meta_data["lr_layer_decay_rate"] != 1.0:
n_layer = 0
for i in range(len(clipped)):
m = re.search(r"model/transformer/layer_(\d+?)/", tvars[i].name)
if not m:
continue
n_layer = max(n_layer, int(m.group(1)) + 1)
for i in range(len(clipped)):
for l in range(n_layer):
if "model/transformer/layer_{}/".format(l) in tvars[i].name:
abs_rate = input_meta_data["lr_layer_decay_rate"]**(
n_layer - 1 - l)
clipped[i] *= abs_rate
logging.info("Apply mult {:.4f} to layer-{} grad of {}".format(
abs_rate, l, tvars[i].name))
break
optimizer.apply_gradients(zip(clipped, tvars))
if input_meta_data["mem_len"] > 0:
return mem
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
Returns:
logits: logits computed.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError("steps should be an Tensor. Python object may cause "
"retracing.")
def cache_fn():
"""Initializes memory tensor used in XLNet pretraining."""
mems = []
if input_meta_data["mem_len"] > 0:
for _ in range(input_meta_data["n_layer"]):
zeros = tf.zeros([
input_meta_data["batch_size_per_core"],
input_meta_data["mem_len"],
input_meta_data["d_model"]
],
dtype=tf.float32)
mems.append(zeros)
return mems
if input_meta_data["mem_len"] > 0:
mem = strategy.run(cache_fn)
for _ in tf.range(steps):
mem = strategy.run(
_replicated_step, args=(
next(iterator),
mem,
))
else:
for _ in tf.range(steps):
strategy.run(_replicated_step, args=(next(iterator),))
if not run_eagerly:
train_steps = tf.function(train_steps)
logging.info("Start training...")
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info("Checkpoint file %s found and restoring from checkpoint",
latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
logging.info("Loading from checkpoint file completed")
current_step = optimizer.iterations.numpy()
checkpoint_name = "xlnet_step_{step}.ckpt"
while current_step < total_training_steps:
train_loss_metric.reset_states()
if train_metric:
train_metric.reset_states()
steps = model_training_utils.steps_to_run(current_step, save_steps,
steps_per_loop)
train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % (
current_step, total_training_steps, learning_rate_fn(current_step),
train_loss)
if train_metric:
log_stream += " / %s = %f" % (train_metric.name,
_float_metric_value(train_metric))
logging.info(log_stream)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
"learning_rate",
learning_rate_fn(current_step),
step=current_step)
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
if train_metric:
tf.summary.scalar(
train_metric.name,
_float_metric_value(train_metric),
step=current_step)
train_summary_writer.flush()
if model_dir and current_step % save_steps == 0:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_fn and current_step % save_steps == 0:
logging.info("Running evaluation after step: %s.", current_step)
eval_fn(model, current_step, eval_summary_writer)
if model_dir:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_fn:
logging.info("Running final evaluation after training is complete.")
eval_metric = eval_fn(model, current_step, eval_summary_writer)
training_summary = {
"total_training_steps": total_training_steps,
"train_loss": _float_metric_value(train_loss_metric),
}
if train_metric:
training_summary["last_train_metrics"] = _float_metric_value(train_metric)
if eval_fn:
# eval_metric is supposed to be a float.
training_summary["eval_metrics"] = eval_metric
model_training_utils.write_txt_summary(training_summary, summary_dir)
return model