Spaces:
Runtime error
Runtime error
File size: 11,674 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XLNet training utils."""
import os
import re
from typing import Any, Callable, Dict, Optional, Text
from absl import logging
import tensorflow as tf, tf_keras
from official.legacy.bert import model_training_utils
from official.legacy.xlnet import data_utils
# pytype: disable=attribute-error
# pylint: disable=g-bare-generic,unused-import
_MIN_SUMMARY_STEPS = 10
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to with provided checkpoint prefix."""
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info("Saving model as TF checkpoint: %s", saved_path)
return
def _float_metric_value(metric):
"""Gets the value of a float-value keras metric."""
return metric.result().numpy().astype(float)
def train(
strategy: tf.distribute.Strategy,
model_fn: Callable,
input_meta_data: Dict,
train_input_fn: Callable,
total_training_steps: int,
steps_per_loop: int,
optimizer: tf_keras.optimizers.Optimizer,
learning_rate_fn: tf_keras.optimizers.schedules.LearningRateSchedule,
eval_fn: Optional[Callable[[tf_keras.Model, int, tf.summary.SummaryWriter],
Any]] = None,
metric_fn: Optional[Callable[[], tf_keras.metrics.Metric]] = None,
init_checkpoint: Optional[Text] = None,
init_from_transformerxl: Optional[bool] = False,
model_dir: Optional[Text] = None,
save_steps: Optional[int] = None,
run_eagerly: Optional[bool] = False):
"""Runs customized training.
Args:
strategy: Distribution strategy on which to run low level training loop.
model_fn: The function returns a keras.Model.
input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`,
`n_layer`, `batch_size_per_core` and `d_model`.
train_input_fn: Function returns a tf.data.Dataset used for training.
total_training_steps: Number of steps to train in total.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every
steps_per_loop.
optimizer: The optimizer for model.
learning_rate_fn: the learning rate schedule.
eval_fn: A callback of evaluation function, that takes a keras.Model,
current step and evaluation summary writer.
metric_fn: A metrics function returns a Keras Metric object to record
evaluation result using evaluation dataset or with training dataset
after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
init_from_transformerxl: Whether to load to `transformerxl_model` of
`model_fn`.
model_dir: The directory of model (checkpoints, summaries).
save_steps: The frequency to save checkpoints. Every save_steps, we save a
model checkpoint. Model checkpoint will be saved and evaluation will be
conducted if evaluation dataset is provided.
run_eagerly: Whether to run training eagerly.
Returns:
Last training step logits if training happens, otherwise returns None.
Raises:
TypeError: if model directory is not specified.
"""
required_arguments = [
train_input_fn, total_training_steps, steps_per_loop, optimizer,
learning_rate_fn, save_steps
]
if [arg for arg in required_arguments if arg is None]:
raise ValueError("`train_input_fn`, `total_training_steps`, "
"`steps_per_loop`, `optimizer`, `save_steps` and "
"`learning_rate_fn` are required parameters.")
if not model_dir:
raise TypeError("Model directory must be specified.")
train_iterator = data_utils.get_input_iterator(train_input_fn, strategy)
if not tf.io.gfile.exists(model_dir):
tf.io.gfile.mkdir(model_dir)
# Create summary writers
summary_dir = os.path.join(model_dir, "summaries")
if not tf.io.gfile.exists(summary_dir):
tf.io.gfile.mkdir(summary_dir)
train_summary_writer = None
eval_summary_writer = None
if eval_fn:
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, "eval"))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, "train"))
with strategy.scope():
model = model_fn()
if init_checkpoint:
logging.info("restore from %s", init_checkpoint)
if init_from_transformerxl:
checkpoint = tf.train.Checkpoint(
transformer_xl=model.transformerxl_model)
else:
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint)
model.optimizer = optimizer
if not hasattr(model, "optimizer"):
raise ValueError("User should set optimizer attribute to model.")
train_loss_metric = tf_keras.metrics.Mean("training_loss", dtype=tf.float32)
train_metric = None
if metric_fn:
train_metric = metric_fn()
def _replicated_step(inputs, mem=None):
"""Replicated training step."""
inputs["mems"] = mem
with tf.GradientTape() as tape:
mem, logits = model(inputs, training=True)
loss = model.losses
train_loss_metric.update_state(loss)
if train_metric:
train_metric.update_state(inputs["label_ids"], logits)
scaled_loss = loss[0] * 1.0 / float(strategy.num_replicas_in_sync)
# Collects training variables.
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
clipped, _ = tf.clip_by_global_norm(grads, clip_norm=1.0)
if input_meta_data["lr_layer_decay_rate"] != 1.0:
n_layer = 0
for i in range(len(clipped)):
m = re.search(r"model/transformer/layer_(\d+?)/", tvars[i].name)
if not m:
continue
n_layer = max(n_layer, int(m.group(1)) + 1)
for i in range(len(clipped)):
for l in range(n_layer):
if "model/transformer/layer_{}/".format(l) in tvars[i].name:
abs_rate = input_meta_data["lr_layer_decay_rate"]**(
n_layer - 1 - l)
clipped[i] *= abs_rate
logging.info("Apply mult {:.4f} to layer-{} grad of {}".format(
abs_rate, l, tvars[i].name))
break
optimizer.apply_gradients(zip(clipped, tvars))
if input_meta_data["mem_len"] > 0:
return mem
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
Returns:
logits: logits computed.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError("steps should be an Tensor. Python object may cause "
"retracing.")
def cache_fn():
"""Initializes memory tensor used in XLNet pretraining."""
mems = []
if input_meta_data["mem_len"] > 0:
for _ in range(input_meta_data["n_layer"]):
zeros = tf.zeros([
input_meta_data["batch_size_per_core"],
input_meta_data["mem_len"],
input_meta_data["d_model"]
],
dtype=tf.float32)
mems.append(zeros)
return mems
if input_meta_data["mem_len"] > 0:
mem = strategy.run(cache_fn)
for _ in tf.range(steps):
mem = strategy.run(
_replicated_step, args=(
next(iterator),
mem,
))
else:
for _ in tf.range(steps):
strategy.run(_replicated_step, args=(next(iterator),))
if not run_eagerly:
train_steps = tf.function(train_steps)
logging.info("Start training...")
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info("Checkpoint file %s found and restoring from checkpoint",
latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
logging.info("Loading from checkpoint file completed")
current_step = optimizer.iterations.numpy()
checkpoint_name = "xlnet_step_{step}.ckpt"
while current_step < total_training_steps:
train_loss_metric.reset_states()
if train_metric:
train_metric.reset_states()
steps = model_training_utils.steps_to_run(current_step, save_steps,
steps_per_loop)
train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % (
current_step, total_training_steps, learning_rate_fn(current_step),
train_loss)
if train_metric:
log_stream += " / %s = %f" % (train_metric.name,
_float_metric_value(train_metric))
logging.info(log_stream)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
"learning_rate",
learning_rate_fn(current_step),
step=current_step)
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
if train_metric:
tf.summary.scalar(
train_metric.name,
_float_metric_value(train_metric),
step=current_step)
train_summary_writer.flush()
if model_dir and current_step % save_steps == 0:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_fn and current_step % save_steps == 0:
logging.info("Running evaluation after step: %s.", current_step)
eval_fn(model, current_step, eval_summary_writer)
if model_dir:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_fn:
logging.info("Running final evaluation after training is complete.")
eval_metric = eval_fn(model, current_step, eval_summary_writer)
training_summary = {
"total_training_steps": total_training_steps,
"train_loss": _float_metric_value(train_loss_metric),
}
if train_metric:
training_summary["last_train_metrics"] = _float_metric_value(train_metric)
if eval_fn:
# eval_metric is supposed to be a float.
training_summary["eval_metrics"] = eval_metric
model_training_utils.write_txt_summary(training_summary, summary_dir)
return model
|