File size: 21,162 Bytes
74e8f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
# Copyright 2023 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contrastive training loop.

For models Like
- LiT (https://arxiv.org/abs/2111.07991)
- CLIP (https://arxiv.org/abs/2103.00020)
- SigLIP (https://arxiv.org/abs/2303.15343)
"""
# pylint: disable=consider-using-from-import
import functools
import importlib
import multiprocessing.pool
import os

from absl import app
from absl import flags
from absl import logging
import big_vision.evaluators.common as eval_common
import big_vision.input_pipeline as input_pipeline
import big_vision.optax as bv_optax
import big_vision.utils as u
from clu import parameter_overview
import flax
import jax
import jax.numpy as jnp
from ml_collections import config_flags
import numpy as np
import optax
import tensorflow as tf

from tensorflow.io import gfile

# pylint: disable=logging-fstring-interpolation


config_flags.DEFINE_config_file(
    "config", None, "Training configuration.", lock_config=True)

flags.DEFINE_string("workdir", default=None, help="Work unit directory.")
flags.DEFINE_boolean("cleanup", default=False,
                     help="Delete workdir (only) after successful completion.")

# Adds jax flags to the program.
jax.config.parse_flags_with_absl()


def clip(x, *, a_max=None, a_min=None):
  """Like jnp.clip, but allows all-None to mean don't clip."""
  if a_max is None and a_min is None:
    return x
  return jnp.clip(x, a_max=a_max, a_min=a_min)


def all_gather(z, roll=False, only_others=False):
  """All gather and flatten first two dims."""
  def gather_flat(x):
    x = jax.lax.all_gather(x, "batch")
    if roll or only_others:
      # Each device moves "its" chunk to the beginning. Simplies loss/acc calcs.
      x = jnp.roll(x, -jax.lax.axis_index("batch"), axis=0)
      if only_others:
        x = x[1:]
    return jnp.concatenate(x, 0)  # Fold in "device" and "batch" dims.
  return jax.tree_map(gather_flat, z)


def softmax_loss(zimg, ztxt, temperature):
  """Softmax loss following the CLIP paper. Factorized to reduce memory cost."""

  def unidirectional_loss(z1, z2, t):
    z2 = all_gather(z2, roll=True)
    logits = jnp.dot(z1, z2.T) * t
    # This a softmax across the larger gathered axis, taking advantage of the
    # fact that positives are known to be on the diagonal.
    loss = -(jnp.diag(logits) - jax.scipy.special.logsumexp(logits, axis=-1))
    acc = jnp.argmax(logits, axis=1) == jnp.arange(z1.shape[0])
    return loss.mean(), acc.mean()

  extras = {}
  loss = 0
  for name, row, col in [("i2t", zimg, ztxt), ("t2i", ztxt, zimg)]:
    loss_dir, acc_dir = unidirectional_loss(row, col, temperature)
    loss += 0.5 * loss_dir
    extras[f"{name}_acc"] = acc_dir
    extras[f"{name}_loss"] = loss_dir

  loss = jax.lax.pmean(loss, "batch")
  return loss, extras


def _avg_pos_logit(x_me):
  return jnp.mean(jnp.diag(x_me))


def _avg_neg_logit(x_me, x_ot=None):
  nom = jnp.sum(x_me) - jnp.sum(jnp.diag(x_me))
  den = x_me.size - len(x_me)
  if x_ot is not None and x_ot.size:
    nom += jnp.sum(x_ot)
    den += x_ot.size
  return nom / den


def sigmoid_loss(zimg, ztxt, temperature, bias=0.0):
  """Sigmoid loss from SigLIP: https://arxiv.org/abs/2303.15343."""
  # Sigmoid loss. Since it's unidirectional, image embeddings stick to
  # "me", i.e. the device they are computed on, and text embeddings travel.
  ztxt_me = ztxt  # Text embeddings on my devices: (n, D)
  ztxt_ot = all_gather(ztxt, only_others=True)  # Text emb from others: (N, D)

  logits_me = jnp.dot(zimg, ztxt_me.T)  # (n, D) . (D, n) -> (n, n)
  logits_ot = jnp.dot(zimg, ztxt_ot.T)  # (n, D) . (D, N) -> (n, N)
  logits_me = logits_me * temperature + bias
  logits_ot = logits_ot * temperature + bias

  eye = jnp.eye(zimg.shape[0])
  # Standard sigmoid computes everything twice, once assuming positive
  # labels and once assuming negative ones. But here we know exactly where
  # to find positives (on "me" diagonal) and negatives (everywhere else),
  # so compute each one's loss only once:
  m1_diag1 = -jnp.ones_like(logits_me) + 2 * eye
  loglik_me = jax.nn.log_sigmoid(m1_diag1 * logits_me)
  loglik_ot = jax.nn.log_sigmoid(-logits_ot)

  # Normalize by npos per column, but that's one, so just sum.
  nll_me = -loglik_me.sum(axis=-1)
  nll_ot = -loglik_ot.sum(axis=-1)
  l = nll_me.mean() + nll_ot.mean()  # == concat'ing me/ot along axis -1 above.

  return l, {
      # Only local device metrics for now, as last time I tried, there was
      # some funny unimplemented business with jax.lax.pmin/pmax!
      # So what's reported here is average of per-device min/max/avg.
      "pos_min_logit": jnp.min(jnp.diag(logits_me)),
      "pos_max_logit": jnp.max(jnp.diag(logits_me)),
      "pos_avg_logit": _avg_pos_logit(logits_me),
      "local_neg_min_logit": jnp.min(logits_me + 1e9 * eye),
      "local_neg_max_logit": jnp.max(logits_me - 1e9 * eye),
      "local_neg_avg_logit": _avg_neg_logit(logits_me),
      "neg_min_logit": jnp.minimum(
          jnp.min(logits_me + 1e9 * eye),
          jnp.min(logits_ot) if logits_ot.size else jnp.inf),
      "neg_max_logit": jnp.maximum(
          jnp.max(logits_me - 1e9 * eye),
          jnp.max(logits_ot) if logits_ot.size else -jnp.inf),
      "neg_avg_logit": _avg_neg_logit(logits_me, logits_ot),
  }


def _gather_from_device(x, device_id, axis_name="batch"):
  return jax.lax.psum((jax.lax.axis_index(axis_name) == device_id) * x,
                      axis_name)


def chunked_sigmoid_loss(zimg, ztxt, temperature, bias=0.0):
  """Loss computation from section 3.1 of arxiv.org/abs/2303.15343."""

  # Calculate loss for representations on this device, which includes positives.
  logits_me = jnp.dot(zimg, ztxt.T)  # (n, D) . (D, n) -> (n, n)
  logits_me = logits_me * temperature + bias
  m1_diag1 = -jnp.ones_like(logits_me) + 2 * jnp.eye(zimg.shape[0])
  loglik_me = jax.nn.log_sigmoid(m1_diag1 * logits_me)
  nll_me = -loglik_me.sum(axis=-1).mean()

  def negative_loss(ztxt_other_device):
    logits_ot = jnp.dot(zimg, ztxt_other_device.T)  # (n, D) . (D, n) -> (n, n)
    logits_ot = logits_ot * temperature + bias
    loglik_ot = jax.nn.log_sigmoid(-logits_ot)
    return -jnp.sum(loglik_ot, axis=-1).mean()

  me = jax.lax.axis_index("batch")
  # All other devices are negatives. Hot-potato swap ztxt across devices.
  # Interestingly, ppermute based implementation was memory intensive, so using
  # all-reduce to gather representations.
  nll_others = 0
  for device_id in range(jax.device_count()):
    skip = jnp.not_equal(device_id, me)
    nll_others += skip * negative_loss(_gather_from_device(ztxt, device_id))

  eye = jnp.eye(zimg.shape[0])
  return nll_me + nll_others, {
      "pos_min_logit": jnp.min(jnp.diag(logits_me)),
      "pos_max_logit": jnp.max(jnp.diag(logits_me)),
      "pos_avg_logit": _avg_pos_logit(logits_me),
      "local_neg_min_logit": jnp.min(logits_me + 1e9 * eye),
      "local_neg_max_logit": jnp.max(logits_me - 1e9 * eye),
      "local_neg_avg_logit": _avg_neg_logit(logits_me),}


def main(argv):
  del argv
  tf.config.experimental.set_visible_devices([], "GPU")

  config = flags.FLAGS.config
  workdir = flags.FLAGS.workdir
  logging.info(  # pylint: disable=logging-fstring-interpolation
      f"\u001b[33mHello from process {jax.process_index()} holding "
      f"{jax.local_device_count()}/{jax.device_count()} devices and "
      f"writing to workdir {workdir}.\u001b[0m")

  save_ckpt_path = None
  if workdir:  # Always create if requested, even if we may not write into it.
    gfile.makedirs(workdir)
    save_ckpt_path = os.path.join(workdir, "checkpoint.npz")

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  # Here we register preprocessing ops from modules listed on `pp_modules`.
  for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]):
    importlib.import_module(f"big_vision.pp.{m}")

  # This seed makes the Jax part of things (like model init) deterministic.
  # However, full training still won't be deterministic, for example due to the
  # tf.data pipeline not being deterministic even if we would set TF seed.
  # See (internal link) for a fun read on what it takes.
  rng = jax.random.PRNGKey(config.get("seed", 0))

  # These functions do more stuff internally, for OSS release we mock them by
  # trivial alternatives in order to minize disruptions in the code.
  xid, wid = -1, -1
  def info(s, *a):
    logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
  def write_note(note):
    if jax.process_index() == 0:
      info("%s", note)

  write_note("Initializing...")

  batch_size = config.input.batch_size
  if batch_size % jax.device_count() != 0:
    raise ValueError(f"Batch size ({batch_size}) must "
                     f"be divisible by device number ({jax.device_count()})")
  info("Global batch size %d on %d hosts results in %d local batch size. With "
       "%d dev per host (%d dev total), that's a %d per-device batch size.",
       batch_size, jax.process_count(), batch_size // jax.process_count(),
       jax.local_device_count(), jax.device_count(),
       batch_size // jax.device_count())

  # First thing after above sanity checks, so we can log "start" ticks.
  mw = u.BigVisionMetricWriter(xid, wid, workdir, config)

  write_note("Initializing train dataset...")
  train_ds, ntrain_img = input_pipeline.training(config.input)

  # Start prefetching already.
  n_prefetch = config.get("prefetch_to_device", 1)
  train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch)

  total_steps = u.steps("total", config, ntrain_img, batch_size)
  def get_steps(name, default=ValueError, cfg=config):
    return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default)

  u.chrono.inform(total_steps=total_steps, global_bs=batch_size,
                  steps_per_epoch=ntrain_img / batch_size,
                  measure=mw.measure, write_note=write_note)

  info("Running for %d steps, that means %f epochs",
       total_steps, total_steps * batch_size / ntrain_img)

  write_note(f"Initializing {config.model_name} model...")
  model_mod = importlib.import_module(f"big_vision.models.{config.model_name}")
  model = model_mod.Model(**config.get("model", {}))

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @functools.partial(jax.jit, backend="cpu")
  def init(rng):
    bs = batch_size // jax.device_count()
    image_size = tuple(train_ds.element_spec["image"].shape[1:])
    no_image = jnp.zeros((bs,) + image_size, jnp.float32)
    text_size = tuple(train_ds.element_spec["labels"].shape[1:])
    no_text = jnp.zeros((bs,) + text_size, jnp.int32)
    params = flax.core.unfreeze(model.init(rng, no_image, no_text))["params"]
    return params

  rng, rng_init = jax.random.split(rng)
  with u.chrono.log_timing("z/secs/init"):
    params_cpu = init(rng_init)

  if jax.process_index() == 0:
    num_params = sum(p.size for p in jax.tree_leaves(params_cpu))
    parameter_overview.log_parameter_overview(params_cpu, msg="init params")
    mw.measure("num_params", num_params)

  write_note(f"Initializing {config.optax_name} optimizer...")
  tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict(
      total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img))

  # We jit this, such that the arrays are created on the CPU, not device[0].
  opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu)
  sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns]

  @functools.partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1))
  def update_fn(params, opt, rng, batch):
    """Update step."""
    assert "mixup" not in config, "We still have to figure out mixup."

    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch"))

    def loss_fn(params, images, labels):
      zimg, ztxt, extras = model.apply(
          {"params": params}, images, labels,
          train=True, rngs={"dropout": rng_model_local})

      match config.get("loss_fn", "softmax"):
        case "softmax":
          l, l_extras = softmax_loss(zimg, ztxt, extras["t"])
        case "sigmoid":
          l, l_extras = sigmoid_loss(zimg, ztxt, extras["t"], bias=extras["b"])
        case "chunked_sigmoid":
          l, l_extras = chunked_sigmoid_loss(zimg, ztxt, extras["t"],
                                             bias=extras["b"])
        case _:
          raise NotImplementedError(f"Unrecognized loss {config.loss_fn=}")

      return l, {
          "t": extras["t"],
          "t/parameter": extras["t/parameter"],
          "train/nimg": jnp.mean(extras["img/norm"]),
          "train/ntxt": jnp.mean(extras["txt/norm"]),
          **{f"train/{k}": v for k, v in l_extras.items()},
      }

    (l, measurements), grads = jax.value_and_grad(
        loss_fn, has_aux=True)(params, batch["image"], batch["labels"])
    l, measurements, grads = jax.lax.pmean((l, measurements, grads),
                                           axis_name="batch")
    updates, opt = tx.update(grads, opt, params)
    params = optax.apply_updates(params, updates)

    gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.))
    measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs]))
    ps = jax.tree_leaves(params)
    measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps]))
    us = jax.tree_leaves(updates)
    measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us]))

    return params, opt, rng, l, measurements

  # We require hashable function reference for evaluator.
  # We do not jit/pmap this function, because it is passed to evaluator that
  # does it later. We output as many intermediate tensors as possible for
  # maximal flexibility. Later `jit` will prune out things that are not needed.
  def predict_fn(params, image=None, text=None, **unused_kwargs):
    del unused_kwargs  # `unused_kwargs` is to be compatible with few-shot
    zimg, ztxt, out = model.apply({"params": params}, image, text)
    return zimg, ztxt, out

  # Only initialize evaluators when they are first needed.
  @functools.lru_cache(maxsize=None)
  def evaluators():
    return eval_common.from_config(
        config, {"predict": predict_fn},
        lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"),
        lambda key, cfg: get_steps(key, default=None, cfg=cfg),
    )

  # Decide how to initialize training. The order is important.
  # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job.
  # 2. Resume from a previous checkpoint, e.g. start a cooldown training job.
  # 3. Initialize model from something, e,g, start a fine-tuning job.
  # 4. Train from scratch.
  resume_ckpt_path = None
  if save_ckpt_path and gfile.exists(save_ckpt_path):
    resume_ckpt_path = save_ckpt_path
  elif config.get("resume"):
    resume_ckpt_path = config.resume.format(wid=xm_wu.id)
  if resume_ckpt_path:
    write_note("Resume training from checkpoint...")
    checkpoint = {
        "params": params_cpu,
        "opt": opt_cpu,
        "chrono": u.chrono.save(),
    }
    checkpoint_tree = jax.tree_structure(checkpoint)
    loaded = u.load_checkpoint_np(resume_ckpt_path, checkpoint_tree)
    # bfloat16 type gets lost when data is saved to disk, so we recover it.
    checkpoint = jax.tree_map(u.recover_dtype, loaded)
    params_cpu, opt_cpu = checkpoint["params"], checkpoint["opt"]
    u.chrono.load(checkpoint["chrono"])
  elif config.get("model_init"):
    write_note(f"Initialize model from {config.model_init}...")
    params_cpu = model_mod.load(
        params_cpu, config.model_init, config.get("model"),
        **config.get("model_load", {}))
    if jax.process_index() == 0:
      parameter_overview.log_parameter_overview(
          params_cpu, msg="restored params")

  write_note("Kicking off misc stuff...")
  first_step = bv_optax.get_count(opt_cpu)
  u.chrono.inform(first_step=first_step)
  prof = None  # Keeps track of start/stop of profiler state.

  write_note(f"Replicating...\n{u.chrono.note}")
  params_repl = flax.jax_utils.replicate(params_cpu)
  opt_repl = flax.jax_utils.replicate(opt_cpu)

  rng, rng_loop = jax.random.split(rng, 2)
  rngs_loop = flax.jax_utils.replicate(rng_loop)
  ckpt_writer = None

  write_note(f"First step compilations...\n{u.chrono.note}")

  # Note that training can be pre-empted during the final evaluation (i.e.
  # just after the final checkpoint has been written to disc), in which case we
  # want to run the evals.
  if first_step in (total_steps, 0):
    mw.step_start(first_step)
    for (name, evaluator, _, prefix) in evaluators():
      if config.evals[name].get("skip_first") and first_step != total_steps:
        continue
      write_note(f"{name} evaluation...\n{u.chrono.note}")
      with u.chrono.log_timing(f"z/secs/eval/{name}"):
        for key, value in evaluator.run(params_repl):
          mw.measure(f"{prefix}{key}", value)

  # Using a python integer for step here, because opt.state.step is allocated
  # on TPU during replication.
  for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter):
    mw.step_start(step)

    with jax.profiler.StepTraceAnnotation("train_step", step_num=step):
      with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1):
        params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn(
            params_repl, opt_repl, rngs_loop, batch)

    # On the first host, let's always profile a handful of early steps.
    if jax.process_index() == 0:
      prof = u.startstop_prof(prof, step, first_step, get_steps("log_training"))

    # Report training progress
    if (u.itstime(step, get_steps("log_training"), total_steps, host=0)
        or u.chrono.warmup and jax.process_index() == 0):
      for i, sched_fn_cpu in enumerate(sched_fns_cpu):
        mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1))
      l = mw.measure("training_loss", loss_value[0])
      for name, value in measurements.items():
        mw.measure(name, value[0])
      u.chrono.tick(step)
      if not np.isfinite(l):
        raise RuntimeError(f"The loss became nan or inf somewhere within steps "
                           f"[{step - get_steps('log_training')}, {step}]")

    # Checkpoint saving
    if (save_ckpt_path and
        (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or
         u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))):
      u.chrono.pause(wait_for=(params_repl, opt_repl))
      u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1))
      # We need to transfer the weights over now or else we risk keeping them
      # alive while they'll be updated in a future step, creating hard to debug
      # memory errors (see (internal link)). Also, takes device 0's params only.
      params_cpu = jax.tree_map(lambda x: np.array(x[0]), params_repl)
      opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)

      # Check whether we want to keep a copy of the current checkpoint.
      copy_step = None
      if u.itstime(step, get_steps("keep_ckpt", None), total_steps):
        copy_step = step

      ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": u.chrono.save()}
      ckpt_writer = pool.apply_async(
          u.save_checkpoint, (ckpt, save_ckpt_path, copy_step))
      u.chrono.resume()

    for (name, evaluator, log_steps, prefix) in evaluators():
      if u.itstime(step, log_steps, total_steps, first=False, last=True):
        u.chrono.pause(wait_for=params_repl)
        u.chrono.tick(step)  # Record things like epoch number, core hours etc.
        write_note(f"{name} evaluation...\n{u.chrono.note}")
        with u.chrono.log_timing(f"z/secs/eval/{name}"):
          for key, value in evaluator.run(params_repl):
            mw.measure(f"{prefix}{key}", value)
        u.chrono.resume()
    mw.step_end()

  # Always give a chance to stop the profiler, no matter how things ended.
  # TODO: can we also do this when dying of an exception like OOM?
  if jax.process_index() == 0 and prof is not None:
    u.startstop_prof(prof)

  # Last note needs to happen before the pool's closed =)
  write_note(f"Done!\n{u.chrono.note}")

  pool.close()
  pool.join()
  mw.close()

  # Make sure all hosts stay up until the end of main.
  u.sync()

  u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info)


if __name__ == "__main__":
  app.run(main)