Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""XLNet training utils.""" | |
import os | |
import re | |
from typing import Any, Callable, Dict, Optional, Text | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.legacy.bert import model_training_utils | |
from official.legacy.xlnet import data_utils | |
# pytype: disable=attribute-error | |
# pylint: disable=g-bare-generic,unused-import | |
_MIN_SUMMARY_STEPS = 10 | |
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix): | |
"""Saves model to with provided checkpoint prefix.""" | |
checkpoint_path = os.path.join(model_dir, checkpoint_prefix) | |
saved_path = checkpoint.save(checkpoint_path) | |
logging.info("Saving model as TF checkpoint: %s", saved_path) | |
return | |
def _float_metric_value(metric): | |
"""Gets the value of a float-value keras metric.""" | |
return metric.result().numpy().astype(float) | |
def train( | |
strategy: tf.distribute.Strategy, | |
model_fn: Callable, | |
input_meta_data: Dict, | |
train_input_fn: Callable, | |
total_training_steps: int, | |
steps_per_loop: int, | |
optimizer: tf_keras.optimizers.Optimizer, | |
learning_rate_fn: tf_keras.optimizers.schedules.LearningRateSchedule, | |
eval_fn: Optional[Callable[[tf_keras.Model, int, tf.summary.SummaryWriter], | |
Any]] = None, | |
metric_fn: Optional[Callable[[], tf_keras.metrics.Metric]] = None, | |
init_checkpoint: Optional[Text] = None, | |
init_from_transformerxl: Optional[bool] = False, | |
model_dir: Optional[Text] = None, | |
save_steps: Optional[int] = None, | |
run_eagerly: Optional[bool] = False): | |
"""Runs customized training. | |
Args: | |
strategy: Distribution strategy on which to run low level training loop. | |
model_fn: The function returns a keras.Model. | |
input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`, | |
`n_layer`, `batch_size_per_core` and `d_model`. | |
train_input_fn: Function returns a tf.data.Dataset used for training. | |
total_training_steps: Number of steps to train in total. | |
steps_per_loop: Number of steps per graph-mode loop. In order to reduce | |
communication in eager context, training logs are printed every | |
steps_per_loop. | |
optimizer: The optimizer for model. | |
learning_rate_fn: the learning rate schedule. | |
eval_fn: A callback of evaluation function, that takes a keras.Model, | |
current step and evaluation summary writer. | |
metric_fn: A metrics function returns a Keras Metric object to record | |
evaluation result using evaluation dataset or with training dataset | |
after every epoch. | |
init_checkpoint: Optional checkpoint to load to `sub_model` returned by | |
`model_fn`. | |
init_from_transformerxl: Whether to load to `transformerxl_model` of | |
`model_fn`. | |
model_dir: The directory of model (checkpoints, summaries). | |
save_steps: The frequency to save checkpoints. Every save_steps, we save a | |
model checkpoint. Model checkpoint will be saved and evaluation will be | |
conducted if evaluation dataset is provided. | |
run_eagerly: Whether to run training eagerly. | |
Returns: | |
Last training step logits if training happens, otherwise returns None. | |
Raises: | |
TypeError: if model directory is not specified. | |
""" | |
required_arguments = [ | |
train_input_fn, total_training_steps, steps_per_loop, optimizer, | |
learning_rate_fn, save_steps | |
] | |
if [arg for arg in required_arguments if arg is None]: | |
raise ValueError("`train_input_fn`, `total_training_steps`, " | |
"`steps_per_loop`, `optimizer`, `save_steps` and " | |
"`learning_rate_fn` are required parameters.") | |
if not model_dir: | |
raise TypeError("Model directory must be specified.") | |
train_iterator = data_utils.get_input_iterator(train_input_fn, strategy) | |
if not tf.io.gfile.exists(model_dir): | |
tf.io.gfile.mkdir(model_dir) | |
# Create summary writers | |
summary_dir = os.path.join(model_dir, "summaries") | |
if not tf.io.gfile.exists(summary_dir): | |
tf.io.gfile.mkdir(summary_dir) | |
train_summary_writer = None | |
eval_summary_writer = None | |
if eval_fn: | |
eval_summary_writer = tf.summary.create_file_writer( | |
os.path.join(summary_dir, "eval")) | |
if steps_per_loop >= _MIN_SUMMARY_STEPS: | |
# Only writes summary when the stats are collected sufficiently over | |
# enough steps. | |
train_summary_writer = tf.summary.create_file_writer( | |
os.path.join(summary_dir, "train")) | |
with strategy.scope(): | |
model = model_fn() | |
if init_checkpoint: | |
logging.info("restore from %s", init_checkpoint) | |
if init_from_transformerxl: | |
checkpoint = tf.train.Checkpoint( | |
transformer_xl=model.transformerxl_model) | |
else: | |
checkpoint = tf.train.Checkpoint(model=model) | |
checkpoint.restore(init_checkpoint) | |
model.optimizer = optimizer | |
if not hasattr(model, "optimizer"): | |
raise ValueError("User should set optimizer attribute to model.") | |
train_loss_metric = tf_keras.metrics.Mean("training_loss", dtype=tf.float32) | |
train_metric = None | |
if metric_fn: | |
train_metric = metric_fn() | |
def _replicated_step(inputs, mem=None): | |
"""Replicated training step.""" | |
inputs["mems"] = mem | |
with tf.GradientTape() as tape: | |
mem, logits = model(inputs, training=True) | |
loss = model.losses | |
train_loss_metric.update_state(loss) | |
if train_metric: | |
train_metric.update_state(inputs["label_ids"], logits) | |
scaled_loss = loss[0] * 1.0 / float(strategy.num_replicas_in_sync) | |
# Collects training variables. | |
tvars = model.trainable_variables | |
grads = tape.gradient(scaled_loss, tvars) | |
clipped, _ = tf.clip_by_global_norm(grads, clip_norm=1.0) | |
if input_meta_data["lr_layer_decay_rate"] != 1.0: | |
n_layer = 0 | |
for i in range(len(clipped)): | |
m = re.search(r"model/transformer/layer_(\d+?)/", tvars[i].name) | |
if not m: | |
continue | |
n_layer = max(n_layer, int(m.group(1)) + 1) | |
for i in range(len(clipped)): | |
for l in range(n_layer): | |
if "model/transformer/layer_{}/".format(l) in tvars[i].name: | |
abs_rate = input_meta_data["lr_layer_decay_rate"]**( | |
n_layer - 1 - l) | |
clipped[i] *= abs_rate | |
logging.info("Apply mult {:.4f} to layer-{} grad of {}".format( | |
abs_rate, l, tvars[i].name)) | |
break | |
optimizer.apply_gradients(zip(clipped, tvars)) | |
if input_meta_data["mem_len"] > 0: | |
return mem | |
def train_steps(iterator, steps): | |
"""Performs distributed training steps in a loop. | |
Args: | |
iterator: the distributed iterator of training datasets. | |
steps: an tf.int32 integer tensor to specify number of steps to run | |
inside host training loop. | |
Raises: | |
ValueError: Any of the arguments or tensor shapes are invalid. | |
Returns: | |
logits: logits computed. | |
""" | |
if not isinstance(steps, tf.Tensor): | |
raise ValueError("steps should be an Tensor. Python object may cause " | |
"retracing.") | |
def cache_fn(): | |
"""Initializes memory tensor used in XLNet pretraining.""" | |
mems = [] | |
if input_meta_data["mem_len"] > 0: | |
for _ in range(input_meta_data["n_layer"]): | |
zeros = tf.zeros([ | |
input_meta_data["batch_size_per_core"], | |
input_meta_data["mem_len"], | |
input_meta_data["d_model"] | |
], | |
dtype=tf.float32) | |
mems.append(zeros) | |
return mems | |
if input_meta_data["mem_len"] > 0: | |
mem = strategy.run(cache_fn) | |
for _ in tf.range(steps): | |
mem = strategy.run( | |
_replicated_step, args=( | |
next(iterator), | |
mem, | |
)) | |
else: | |
for _ in tf.range(steps): | |
strategy.run(_replicated_step, args=(next(iterator),)) | |
if not run_eagerly: | |
train_steps = tf.function(train_steps) | |
logging.info("Start training...") | |
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) | |
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) | |
if latest_checkpoint_file: | |
logging.info("Checkpoint file %s found and restoring from checkpoint", | |
latest_checkpoint_file) | |
checkpoint.restore(latest_checkpoint_file) | |
logging.info("Loading from checkpoint file completed") | |
current_step = optimizer.iterations.numpy() | |
checkpoint_name = "xlnet_step_{step}.ckpt" | |
while current_step < total_training_steps: | |
train_loss_metric.reset_states() | |
if train_metric: | |
train_metric.reset_states() | |
steps = model_training_utils.steps_to_run(current_step, save_steps, | |
steps_per_loop) | |
train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32)) | |
current_step += steps | |
train_loss = _float_metric_value(train_loss_metric) | |
log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % ( | |
current_step, total_training_steps, learning_rate_fn(current_step), | |
train_loss) | |
if train_metric: | |
log_stream += " / %s = %f" % (train_metric.name, | |
_float_metric_value(train_metric)) | |
logging.info(log_stream) | |
if train_summary_writer: | |
with train_summary_writer.as_default(): | |
tf.summary.scalar( | |
"learning_rate", | |
learning_rate_fn(current_step), | |
step=current_step) | |
tf.summary.scalar( | |
train_loss_metric.name, train_loss, step=current_step) | |
if train_metric: | |
tf.summary.scalar( | |
train_metric.name, | |
_float_metric_value(train_metric), | |
step=current_step) | |
train_summary_writer.flush() | |
if model_dir and current_step % save_steps == 0: | |
_save_checkpoint(checkpoint, model_dir, | |
checkpoint_name.format(step=current_step)) | |
if eval_fn and current_step % save_steps == 0: | |
logging.info("Running evaluation after step: %s.", current_step) | |
eval_fn(model, current_step, eval_summary_writer) | |
if model_dir: | |
_save_checkpoint(checkpoint, model_dir, | |
checkpoint_name.format(step=current_step)) | |
if eval_fn: | |
logging.info("Running final evaluation after training is complete.") | |
eval_metric = eval_fn(model, current_step, eval_summary_writer) | |
training_summary = { | |
"total_training_steps": total_training_steps, | |
"train_loss": _float_metric_value(train_loss_metric), | |
} | |
if train_metric: | |
training_summary["last_train_metrics"] = _float_metric_value(train_metric) | |
if eval_fn: | |
# eval_metric is supposed to be a float. | |
training_summary["eval_metrics"] = eval_metric | |
model_training_utils.write_txt_summary(training_summary, summary_dir) | |
return model | |