# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """XLNet training utils.""" import os import re from typing import Any, Callable, Dict, Optional, Text from absl import logging import tensorflow as tf, tf_keras from official.legacy.bert import model_training_utils from official.legacy.xlnet import data_utils # pytype: disable=attribute-error # pylint: disable=g-bare-generic,unused-import _MIN_SUMMARY_STEPS = 10 def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix): """Saves model to with provided checkpoint prefix.""" checkpoint_path = os.path.join(model_dir, checkpoint_prefix) saved_path = checkpoint.save(checkpoint_path) logging.info("Saving model as TF checkpoint: %s", saved_path) return def _float_metric_value(metric): """Gets the value of a float-value keras metric.""" return metric.result().numpy().astype(float) def train( strategy: tf.distribute.Strategy, model_fn: Callable, input_meta_data: Dict, train_input_fn: Callable, total_training_steps: int, steps_per_loop: int, optimizer: tf_keras.optimizers.Optimizer, learning_rate_fn: tf_keras.optimizers.schedules.LearningRateSchedule, eval_fn: Optional[Callable[[tf_keras.Model, int, tf.summary.SummaryWriter], Any]] = None, metric_fn: Optional[Callable[[], tf_keras.metrics.Metric]] = None, init_checkpoint: Optional[Text] = None, init_from_transformerxl: Optional[bool] = False, model_dir: Optional[Text] = None, save_steps: Optional[int] = None, run_eagerly: Optional[bool] = False): """Runs customized training. Args: strategy: Distribution strategy on which to run low level training loop. model_fn: The function returns a keras.Model. input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`, `n_layer`, `batch_size_per_core` and `d_model`. train_input_fn: Function returns a tf.data.Dataset used for training. total_training_steps: Number of steps to train in total. steps_per_loop: Number of steps per graph-mode loop. In order to reduce communication in eager context, training logs are printed every steps_per_loop. optimizer: The optimizer for model. learning_rate_fn: the learning rate schedule. eval_fn: A callback of evaluation function, that takes a keras.Model, current step and evaluation summary writer. metric_fn: A metrics function returns a Keras Metric object to record evaluation result using evaluation dataset or with training dataset after every epoch. init_checkpoint: Optional checkpoint to load to `sub_model` returned by `model_fn`. init_from_transformerxl: Whether to load to `transformerxl_model` of `model_fn`. model_dir: The directory of model (checkpoints, summaries). save_steps: The frequency to save checkpoints. Every save_steps, we save a model checkpoint. Model checkpoint will be saved and evaluation will be conducted if evaluation dataset is provided. run_eagerly: Whether to run training eagerly. Returns: Last training step logits if training happens, otherwise returns None. Raises: TypeError: if model directory is not specified. """ required_arguments = [ train_input_fn, total_training_steps, steps_per_loop, optimizer, learning_rate_fn, save_steps ] if [arg for arg in required_arguments if arg is None]: raise ValueError("`train_input_fn`, `total_training_steps`, " "`steps_per_loop`, `optimizer`, `save_steps` and " "`learning_rate_fn` are required parameters.") if not model_dir: raise TypeError("Model directory must be specified.") train_iterator = data_utils.get_input_iterator(train_input_fn, strategy) if not tf.io.gfile.exists(model_dir): tf.io.gfile.mkdir(model_dir) # Create summary writers summary_dir = os.path.join(model_dir, "summaries") if not tf.io.gfile.exists(summary_dir): tf.io.gfile.mkdir(summary_dir) train_summary_writer = None eval_summary_writer = None if eval_fn: eval_summary_writer = tf.summary.create_file_writer( os.path.join(summary_dir, "eval")) if steps_per_loop >= _MIN_SUMMARY_STEPS: # Only writes summary when the stats are collected sufficiently over # enough steps. train_summary_writer = tf.summary.create_file_writer( os.path.join(summary_dir, "train")) with strategy.scope(): model = model_fn() if init_checkpoint: logging.info("restore from %s", init_checkpoint) if init_from_transformerxl: checkpoint = tf.train.Checkpoint( transformer_xl=model.transformerxl_model) else: checkpoint = tf.train.Checkpoint(model=model) checkpoint.restore(init_checkpoint) model.optimizer = optimizer if not hasattr(model, "optimizer"): raise ValueError("User should set optimizer attribute to model.") train_loss_metric = tf_keras.metrics.Mean("training_loss", dtype=tf.float32) train_metric = None if metric_fn: train_metric = metric_fn() def _replicated_step(inputs, mem=None): """Replicated training step.""" inputs["mems"] = mem with tf.GradientTape() as tape: mem, logits = model(inputs, training=True) loss = model.losses train_loss_metric.update_state(loss) if train_metric: train_metric.update_state(inputs["label_ids"], logits) scaled_loss = loss[0] * 1.0 / float(strategy.num_replicas_in_sync) # Collects training variables. tvars = model.trainable_variables grads = tape.gradient(scaled_loss, tvars) clipped, _ = tf.clip_by_global_norm(grads, clip_norm=1.0) if input_meta_data["lr_layer_decay_rate"] != 1.0: n_layer = 0 for i in range(len(clipped)): m = re.search(r"model/transformer/layer_(\d+?)/", tvars[i].name) if not m: continue n_layer = max(n_layer, int(m.group(1)) + 1) for i in range(len(clipped)): for l in range(n_layer): if "model/transformer/layer_{}/".format(l) in tvars[i].name: abs_rate = input_meta_data["lr_layer_decay_rate"]**( n_layer - 1 - l) clipped[i] *= abs_rate logging.info("Apply mult {:.4f} to layer-{} grad of {}".format( abs_rate, l, tvars[i].name)) break optimizer.apply_gradients(zip(clipped, tvars)) if input_meta_data["mem_len"] > 0: return mem def train_steps(iterator, steps): """Performs distributed training steps in a loop. Args: iterator: the distributed iterator of training datasets. steps: an tf.int32 integer tensor to specify number of steps to run inside host training loop. Raises: ValueError: Any of the arguments or tensor shapes are invalid. Returns: logits: logits computed. """ if not isinstance(steps, tf.Tensor): raise ValueError("steps should be an Tensor. Python object may cause " "retracing.") def cache_fn(): """Initializes memory tensor used in XLNet pretraining.""" mems = [] if input_meta_data["mem_len"] > 0: for _ in range(input_meta_data["n_layer"]): zeros = tf.zeros([ input_meta_data["batch_size_per_core"], input_meta_data["mem_len"], input_meta_data["d_model"] ], dtype=tf.float32) mems.append(zeros) return mems if input_meta_data["mem_len"] > 0: mem = strategy.run(cache_fn) for _ in tf.range(steps): mem = strategy.run( _replicated_step, args=( next(iterator), mem, )) else: for _ in tf.range(steps): strategy.run(_replicated_step, args=(next(iterator),)) if not run_eagerly: train_steps = tf.function(train_steps) logging.info("Start training...") checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) if latest_checkpoint_file: logging.info("Checkpoint file %s found and restoring from checkpoint", latest_checkpoint_file) checkpoint.restore(latest_checkpoint_file) logging.info("Loading from checkpoint file completed") current_step = optimizer.iterations.numpy() checkpoint_name = "xlnet_step_{step}.ckpt" while current_step < total_training_steps: train_loss_metric.reset_states() if train_metric: train_metric.reset_states() steps = model_training_utils.steps_to_run(current_step, save_steps, steps_per_loop) train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32)) current_step += steps train_loss = _float_metric_value(train_loss_metric) log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % ( current_step, total_training_steps, learning_rate_fn(current_step), train_loss) if train_metric: log_stream += " / %s = %f" % (train_metric.name, _float_metric_value(train_metric)) logging.info(log_stream) if train_summary_writer: with train_summary_writer.as_default(): tf.summary.scalar( "learning_rate", learning_rate_fn(current_step), step=current_step) tf.summary.scalar( train_loss_metric.name, train_loss, step=current_step) if train_metric: tf.summary.scalar( train_metric.name, _float_metric_value(train_metric), step=current_step) train_summary_writer.flush() if model_dir and current_step % save_steps == 0: _save_checkpoint(checkpoint, model_dir, checkpoint_name.format(step=current_step)) if eval_fn and current_step % save_steps == 0: logging.info("Running evaluation after step: %s.", current_step) eval_fn(model, current_step, eval_summary_writer) if model_dir: _save_checkpoint(checkpoint, model_dir, checkpoint_name.format(step=current_step)) if eval_fn: logging.info("Running final evaluation after training is complete.") eval_metric = eval_fn(model, current_step, eval_summary_writer) training_summary = { "total_training_steps": total_training_steps, "train_loss": _float_metric_value(train_loss_metric), } if train_metric: training_summary["last_train_metrics"] = _float_metric_value(train_metric) if eval_fn: # eval_metric is supposed to be a float. training_summary["eval_metrics"] = eval_metric model_training_utils.write_txt_summary(training_summary, summary_dir) return model