pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training loop for GIVT-style autoregressive and masked models."""
# pylint: disable=consider-using-from-import
import functools
import importlib
import multiprocessing.pool
import os
from absl import app
from absl import flags
from absl import logging
import big_vision.evaluators.common as eval_common
import big_vision.input_pipeline as input_pipeline
from big_vision.models.proj.givt import parallel_decode
import big_vision.models.proj.givt.decode as softar_decode
import big_vision.optax as bv_optax
import big_vision.sharding as bv_sharding
import big_vision.trainers.proj.givt.utils as trainer_utils
from big_vision.trainers.proj.uvim import panoptic_task
import big_vision.utils as u
from clu import parameter_overview
import flax
import jax
from jax.experimental import mesh_utils
from jax.experimental import multihost_utils
from jax.experimental.array_serialization import serialization as array_serial
import jax.numpy as jnp
from ml_collections import config_flags
import numpy as np
import optax
import tensorflow as tf
from tensorflow.io import gfile
# pylint: disable=logging-fstring-interpolation
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=True)
flags.DEFINE_string("workdir", default=None, help="Work unit directory.")
flags.DEFINE_boolean("cleanup", default=False,
help="Delete workdir (only) after successful completion.")
# Adds jax flags to the program.
jax.config.parse_flags_with_absl()
# Transfer guard will fail the program whenever that data between a host and
# a device is transferred implicitly. This often catches subtle bugs that
# cause slowdowns and memory fragmentation. Explicit transfers are done
# with jax.device_put and jax.device_get.
jax.config.update("jax_transfer_guard", "disallow")
# Fixes design flaw in jax.random that may cause unnecessary d2d comms.
jax.config.update("jax_threefry_partitionable", True)
NamedSharding = jax.sharding.NamedSharding
P = jax.sharding.PartitionSpec
def main(argv):
del argv
jax.distributed.initialize()
# Make sure TF does not touch GPUs.
tf.config.set_visible_devices([], "GPU")
config = flags.FLAGS.config
################################################################################
# #
# Set up logging #
# #
################################################################################
# Set up work directory and print welcome message.
workdir = flags.FLAGS.workdir
logging.info(
f"\u001b[33mHello from process {jax.process_index()} holding "
f"{jax.local_device_count()}/{jax.device_count()} devices and "
f"writing to workdir {workdir}.\u001b[0m")
save_ckpt_path = None
if workdir: # Always create if requested, even if we may not write into it.
gfile.makedirs(workdir)
save_ckpt_path = os.path.join(workdir, "checkpoint.bv")
# The pool is used to perform misc operations such as logging in async way.
pool = multiprocessing.pool.ThreadPool()
# Here we register preprocessing ops from modules listed on `pp_modules`.
for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text",
"proj.uvim.pp_ops", "proj.givt.pp_ops"]):
importlib.import_module(f"big_vision.pp.{m}")
# Setup up logging and experiment manager.
xid, wid = -1, -1
def info(s, *a):
logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
def write_note(note):
if jax.process_index() == 0:
info("%s", note)
mw = u.BigVisionMetricWriter(xid, wid, workdir, config)
# Allow for things like timings as early as possible!
u.chrono.inform(measure=mw.measure, write_note=write_note)
################################################################################
# #
# Set up Mesh #
# #
################################################################################
# We rely on jax mesh_utils to organize devices, such that communication
# speed is the fastest for the last dimension, second fastest for the
# penultimate dimension, etc.
config_mesh = config.get("mesh", [("data", jax.device_count())])
# Sharding rules with default
sharding_rules = config.get("sharding_rules", [("act_batch", "data")])
mesh_axes, mesh_size = tuple(zip(*config_mesh))
# Because jax.utils do not support `-1` shape size.
mesh_size = np.array(jax.devices()).reshape(mesh_size).shape
device_mesh = mesh_utils.create_device_mesh(mesh_size)
# Consistent device order is important to ensure correctness of various train
# loop components, such as input pipeline, update step, evaluators. The
# order presribed by the `devices_flat` variable should be used throughout
# the program.
devices_flat = device_mesh.flatten()
################################################################################
# #
# Input Pipeline #
# #
################################################################################
write_note("Initializing train dataset...")
batch_size = config.input.batch_size
if batch_size % jax.device_count() != 0:
raise ValueError(f"Batch size ({batch_size}) must "
f"be divisible by device number ({jax.device_count()})")
info("Global batch size %d on %d hosts results in %d local batch size. With "
"%d dev per host (%d dev total), that's a %d per-device batch size.",
batch_size, jax.process_count(), batch_size // jax.process_count(),
jax.local_device_count(), jax.device_count(),
batch_size // jax.device_count())
train_ds, ntrain_img = input_pipeline.training(config.input)
total_steps = u.steps("total", config, ntrain_img, batch_size)
def get_steps(name, default=ValueError, cfg=config):
return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default)
u.chrono.inform(total_steps=total_steps, global_bs=batch_size,
steps_per_epoch=ntrain_img / batch_size)
info("Running for %d steps, that means %f epochs",
total_steps, total_steps * batch_size / ntrain_img)
# Start input pipeline as early as possible.
n_prefetch = config.get("prefetch_to_device", 1)
train_iter = input_pipeline.start_global(train_ds, devices_flat, n_prefetch)
################################################################################
# #
# Create Model & Optimizer #
# #
################################################################################
write_note(f"Creating {config.vae.model_name} model...")
vae_mod = importlib.import_module(
f"big_vision.models.{config.vae.model_name}")
vae = vae_mod.Model(**config.vae.get("model", {}))
write_note(f"Creating {config.model_name} model...")
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}")
model_config = config.get("model", {})
model = model_mod.Model(**model_config)
if config.get("adaptor_name"):
write_note(f"Creating {config.adaptor_name} model...")
adaptor_mod = importlib.import_module(
f"big_vision.models.{config.adaptor_name}")
adaptor = adaptor_mod.Model(num_channels=model_config.out_dim,
**config.adaptor.model)
else:
adaptor = None
def init(rng):
def _get_dummy_input(input_name, dtype=jnp.int64):
if input_name in train_ds.element_spec:
return jnp.zeros(train_ds.element_spec[input_name].shape, dtype=dtype)
return None
dummy_img = _get_dummy_input("image", dtype=jnp.float32)
dummy_labels = _get_dummy_input("labels")
dummy_cond_img = _get_dummy_input("cond_image", dtype=jnp.float32)
local_batch_size = dummy_img.shape[0] # pytype: disable=attribute-error
code_shape = (
local_batch_size, model_config.seq_len, model_config.out_dim)
dummy_code = jnp.zeros(code_shape, jnp.float32)
input_mask = model.get_input_mask_training(
jax.random.PRNGKey(0), (local_batch_size, model_config.seq_len)
)
params = model.init(rng, dummy_code, dummy_labels, image=dummy_cond_img,
input_mask=input_mask)["params"]
if adaptor is not None:
_, rng_adaptor = jax.random.split(rng)
adaptor_variables = adaptor.init(rng_adaptor, dummy_code)
params_adaptor = flax.core.unfreeze(adaptor_variables["params"])
params["params_adaptor"] = params_adaptor # store in same dict
return params
rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0)))
write_note("Inferring parameter shapes...")
rng, rng_init = jax.random.split(rng)
params_shape = jax.eval_shape(init, rng_init)
write_note("Inferring optimizer state shapes...")
tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict(
total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img))
opt_shape = jax.eval_shape(tx.init, params_shape)
# We jit this, such that the arrays are created on the CPU, not device[0].
sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns]
# Training a stage 2 model requires a pretrained stage 1 model. We treat this
# as a constant and do not shard the parameters.
assert "model_init" in config.vae
params_vae = vae_mod.load(None, config.vae.model_init,
**config.vae.get("model_load", {}))
def vae_encode(images, rng=None, reparametrize=True):
mu, logvar = vae.apply({"params": params_vae}, images, method=vae.encode)
if reparametrize:
assert rng is not None and "dropout" in rng
return vae.apply({"params": params_vae}, mu, logvar,
method=vae.reparametrize, rngs=rng)
return mu
if jax.process_index() == 0:
num_params = sum(np.prod(p.shape) for p in jax.tree_leaves(params_shape))
mw.measure("num_params", num_params)
################################################################################
# #
# Shard & Transfer #
# #
################################################################################
write_note("Creating device mesh...")
mesh = jax.sharding.Mesh(device_mesh, mesh_axes)
repl_sharding = jax.sharding.NamedSharding(mesh, P())
write_note("Inferring shardings...")
train_state_shape = {"params": params_shape, "opt": opt_shape}
strategy = config.get("sharding_strategy", [(".*", "replicate")])
train_state_sharding = bv_sharding.infer_sharding(
train_state_shape, strategy=strategy, mesh=mesh)
write_note("Transferring train_state to devices...")
# RNG is always replicated
rng_init = u.reshard(rng_init, repl_sharding)
# Parameters and the optimizer are now global (distributed) jax arrays.
params = jax.jit(init, out_shardings=train_state_sharding["params"])(rng_init)
opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params)
rng, rng_loop = jax.random.split(rng, 2)
rng_loop = u.reshard(rng_loop, repl_sharding)
del rng # not used anymore, so delete it.
# At this point we have everything we need to form a train state. It contains
# all the parameters that are passed and updated by the main training step.
train_state = {"params": params, "opt": opt}
del params, opt # Delete to avoid memory leak or accidental reuse.
write_note("Logging parameter overview...")
parameter_overview.log_parameter_overview(
train_state["params"], msg="Init params",
include_stats="global", jax_logging_process=0)
################################################################################
# #
# Update Step #
# #
################################################################################
# Define the loss function
def loss_fn(params, images, labels, cond_images, rng):
rng, rng_dropout = jax.random.split(rng, 2)
rng, rng_mask = jax.random.split(rng, 2)
_, rng_droplabels = jax.random.split(rng, 2)
rng_dropout = {"dropout": rng_dropout}
sequence = vae_encode(images, rng_dropout)
if adaptor is not None:
# Use the (invertible) adaptor to map to a new latent sequence
sequence = adaptor.apply({"params": params["params_adaptor"]},
sequence, method=adaptor.forward)
b, s, _ = sequence.shape
# This is None for the non-mask style. Otherwise, shape (b, s).
input_mask = model.get_input_mask_training(rng_mask, (b, s))
drop_labels = model.get_drop_labels(rng_droplabels, batch_size=b)
_, pdf = model.apply(
{"params": params}, sequence, labels,
image=cond_images,
train=True,
input_mask=input_mask,
drop_labels=drop_labels,
rngs=rng_dropout)
# Shape: (B, L, out_dim)
nll = -pdf.log_prob(sequence)
metrics = {"nll": nll}
if input_mask is not None:
metrics["fraction_masked_out"] = input_mask.astype(jnp.float32).mean(
axis=1
)
if nll.ndim == 3:
input_mask = input_mask[:, :, None]
# Note that `input_mask` is True where we mask out the input (ie replace
# with mask token), so we also only gather nlls at the corresponding
# points.
nll = jnp.where(input_mask, nll, 0.0)
# Take mean only of the spots we care about to smooth loss magnitute
# between examples, like in maskgit (ie this is
# sum(loss * input_mask) / sum(input_mask) in their code.
loss = nll.mean(where=input_mask)
else:
loss = nll.mean()
return loss, metrics
@functools.partial(
jax.jit,
donate_argnums=(0,),
out_shardings=(train_state_sharding, repl_sharding))
def update_fn(train_state, rng, batch):
"""Update step."""
images = batch["image"]
labels, cond_images = batch.get("labels"), batch.get("cond_image")
step_count = bv_optax.get_count(train_state["opt"], jittable=True)
rng = jax.random.fold_in(rng, step_count)
measurements = {}
# Get device-specific loss rng.
_, rng_model = jax.random.split(rng, 2)
params, opt = train_state["params"], train_state["opt"]
(loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(
params, images, labels, cond_images, rng_model)
updates, opt = tx.update(grads, opt, params)
params = optax.apply_updates(params, updates)
train_state = {"params": params, "opt": opt}
measurements["training_loss"] = loss
gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.))
measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs]))
ps = jax.tree_leaves(params)
measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps]))
us = jax.tree_leaves(updates)
measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us]))
if adaptor is not None:
ps_a = jax.tree_leaves(params["params_adaptor"])
measurements["l2_params_adaptor"] = jnp.sqrt(sum([jnp.vdot(p, p)
for p in ps_a]))
measurements.update({f"train/{k}": v.mean() for k, v in metrics.items()})
return train_state, measurements
################################################################################
# #
# Set up Evals #
# #
################################################################################
def validation_fn(train_state, batch, seed=0):
params = train_state["params"]
local_rng = trainer_utils.get_local_rng(seed, batch)
_, aux = loss_fn(
params, batch["image"], batch.get("labels"),
batch.get("cond_image"), local_rng)
return {
key: jnp.mean(value, axis=tuple(range(1, value.ndim)))
for key, value in aux.items()
}
def predict_fn_teacher_forcing(train_state, batch, seed=0):
params = train_state["params"]
image, labels = batch["image"], batch.get("labels")
local_rng = trainer_utils.get_local_rng(seed, batch)
rng_dropout = {"dropout": local_rng}
sequence = vae_encode(image, rng_dropout)
if adaptor is not None:
# Use the adaptor to map from VAE latent space to GIVT in/output space.
sequence = adaptor.apply({"params": params["params_adaptor"]},
sequence, method=adaptor.forward)
b, s, _ = sequence.shape
# This is None for the non-mask style. Otherwise, shape (b, s) of zeros
# (nothing masked).
input_mask = model.get_input_mask_teacher_forced((b, s))
_, pdf = model.apply(
{"params": params}, sequence, labels,
train=True, input_mask=input_mask, rngs=rng_dropout)
rng_sample, _ = jax.random.split(local_rng, 2)
sampled_sequence = pdf.sample(seed=rng_sample)
if adaptor is not None:
# Use the adaptor inverse to map back to the VAE latent space
sampled_sequence = adaptor.apply({"params": params["params_adaptor"]},
sampled_sequence, method=adaptor.inverse)
logits = vae.apply(
{"params": params_vae}, sampled_sequence, method=vae.decode)
return {"logits": logits}
def predict_fn_rep(train_state, image, seed=0):
assert model.style == "ar"
assert model.drop_labels_probability == 1.0
params = train_state["params"]
local_rng = trainer_utils.get_local_rng(seed, batch)
rng_dropout = {"dropout": local_rng}
sequence = vae_encode(image, rng_dropout)
placeholder_labels = jnp.zeros((sequence.shape[0],), dtype=jnp.int32)
return model.apply({"params": params}, sequence, labels=placeholder_labels,
return_reps=True, method=model.decode)
def predict_fn_sampling(train_state, batch, seed=0):
params = train_state["params"]
labels = batch.get("labels")
local_rng = trainer_utils.get_local_rng(seed, batch)
code_logprobs = None
if model.style == "ar":
if labels is None:
# Try to infer batch size if labels are not provided
if "image" in batch:
sampling_batch_size = batch["image"].shape[0]
elif "cond_image" in batch:
sampling_batch_size = batch["cond_image"].shape[0]
else:
sampling_batch_size = config.get("sampling_batch_size", 4)
else:
sampling_batch_size = None
sampled_codes, code_logprobs = softar_decode.generate(
params={"params": params},
seed=local_rng,
model=model,
seq_len=config.model.seq_len,
feature_dim=config.model.out_dim,
labels=labels,
cond_image=batch.get("cond_image"),
batch_size=sampling_batch_size,
config=config.get("ar_generation_config"),
)
elif model.style == "masked":
assert "cond_image" not in batch
sampled_codes = parallel_decode.decode_masked( # pytype: disable=wrong-arg-types
rng=local_rng,
labels=labels,
seq_len=config.model.seq_len,
feature_dim=config.model.out_dim,
model=model,
variables={"params": params},
config=parallel_decode.MaskedGenerationConfig(
**config.get("masked_generation_config", {})
),
).current_inputs_q
else:
raise NotImplementedError
if adaptor is not None:
# Use the adaptor inverse to map back to the VAE latent space.
sampled_codes = adaptor.apply({"params": params["params_adaptor"]},
sampled_codes, method=adaptor.inverse)
sampled_images = vae.apply(
{"params": params_vae}, sampled_codes, method=vae.decode)
sampling_results = {"logits": sampled_images}
if code_logprobs is not None:
sampling_results["logprobs"] = code_logprobs
return sampling_results
def predict_fn_sampling_panoptic(
train_state, batch, seed=0, min_fraction=0.0):
logits = predict_fn_sampling(train_state, batch, seed)["logits"]
return panoptic_task.panoptic_predictions_from_logits(
logits["semantics"], logits["instances"], min_fraction=min_fraction)
def predict_fn_sampling_depth(train_state, batch, seed=0):
depth = predict_fn_sampling(train_state, batch, seed)["logits"]["depth"]
depth = trainer_utils.unbin_depth(
depth, min_depth=config.min_depth, max_depth=config.max_depth,
num_bins=config.vae.model.inout_specs["depth"][1])
return {"depth": depth}
# Only initialize evaluators when they are first needed.
@functools.lru_cache(maxsize=None)
def evaluators():
return eval_common.from_config(
config,
{
"validation": validation_fn,
"sample_teacher_forced": predict_fn_teacher_forcing,
"sample": predict_fn_sampling,
"sample_panoptic": predict_fn_sampling_panoptic,
"sample_depth": predict_fn_sampling_depth,
"representation": predict_fn_rep,
},
lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"),
lambda key, cfg: get_steps(key, default=None, cfg=cfg),
devices_flat,
)
# Decide how to initialize training. The order is important.
# 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job.
# 2. Resume from a previous checkpoint, e.g. start a cooldown training job.
# 3. Initialize model from something, e,g, start a fine-tuning job.
# 4. Train from scratch.
resume_ckpt_path = None
if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"):
resume_ckpt_path = save_ckpt_path
elif config.get("resume"):
resume_ckpt_path = fillin(config.resume)
ckpt_mngr = None
if save_ckpt_path or resume_ckpt_path:
ckpt_mngr = array_serial.GlobalAsyncCheckpointManager()
if resume_ckpt_path:
write_note(f"Resuming training from checkpoint {resume_ckpt_path}...")
jax.tree_map(lambda x: x.delete(), train_state)
del train_state
shardings = {
**train_state_sharding,
"chrono": jax.tree_map(lambda _: repl_sharding,
u.chrono.save()),
}
loaded = u.load_checkpoint_ts(
resume_ckpt_path, tree=shardings, shardings=shardings)
train_state = {key: loaded[key] for key in train_state_sharding.keys()}
u.chrono.load(jax.device_get(loaded["chrono"]))
del loaded
elif config.get("model_init"):
write_note(f"Initialize model from {config.model_init}...")
train_state["params"] = model_mod.load(
train_state["params"], config.model_init, config.get("model"),
**config.get("model_load", {}))
# load has the freedom to return params not correctly sharded
train_state["params"] = u.reshard(
train_state["params"], train_state_sharding["params"])
parameter_overview.log_parameter_overview(
train_state["params"], msg="restored params",
include_stats="global", jax_logging_process=0)
# At this point we need to know the current step to see whether to run evals.
write_note("Inferring the first step number...")
first_step_device = bv_optax.get_count(train_state["opt"], jittable=True)
first_step = int(jax.device_get(first_step_device))
u.chrono.inform(first_step=first_step)
# Note that training can be pre-empted during the final evaluation (i.e.
# just after the final checkpoint has been written to disc), in which case we
# want to run the evals.
if first_step in (total_steps, 0):
write_note("Running initial or final evals...")
mw.step_start(first_step)
for (name, evaluator, _, prefix) in evaluators():
if config.evals[name].get("skip_first") and first_step != total_steps:
continue
write_note(f"{name} evaluation...\n{u.chrono.note}")
with u.chrono.log_timing(f"z/secs/eval/{name}"):
with mesh, flax.linen.logical_axis_rules(sharding_rules):
for key, value in evaluator.run(train_state):
mw.measure(f"{prefix}{key}", value)
################################################################################
# #
# Train Loop #
# #
################################################################################
prof = None # Keeps track of start/stop of profiler state.
write_note("Starting training loop, compiling the first step...")
for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter):
# Skip training loop when running an eval-only config
if config.get("eval_only", False):
break
mw.step_start(step)
with jax.profiler.StepTraceAnnotation("train_step", step_num=step):
with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1):
with mesh, flax.linen.logical_axis_rules(sharding_rules):
train_state, measurements = update_fn(train_state, rng_loop, batch)
# On the first host, let's always profile a handful of early steps.
if jax.process_index() == 0:
prof = u.startstop_prof(prof, step, first_step, get_steps("log_training"))
# Report training progress
if (u.itstime(step, get_steps("log_training"), total_steps, host=0)
or u.chrono.warmup and jax.process_index() == 0):
for i, sched_fn_cpu in enumerate(sched_fns_cpu):
mw.measure(f"global_schedule{i if i else ''}",
sched_fn_cpu(u.put_cpu(step - 1)))
measurements = jax.device_get(measurements)
for name, value in measurements.items():
mw.measure(name, value)
u.chrono.tick(step)
if not np.isfinite(measurements["training_loss"]):
raise RuntimeError(f"The loss became nan or inf somewhere within steps "
f"[{step - get_steps('log_training')}, {step}]")
# Checkpoint saving
keep_ckpt_steps = get_steps("keep_ckpt", None) or total_steps
if save_ckpt_path and (
(keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False))
or u.itstime(step, get_steps("ckpt", None), total_steps, first=True)
):
u.chrono.pause(wait_for=train_state)
# Copy because we add extra stuff to the checkpoint.
ckpt = {**train_state}
# To save chrono state correctly and safely in a multihost setup, we
# broadcast the state to all hosts and convert it to a global array.
with jax.transfer_guard("allow"):
chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save())
chrono_shardings = jax.tree_map(lambda _: repl_sharding, chrono_ckpt)
ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)}
u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep)
u.chrono.resume()
for (name, evaluator, log_steps, prefix) in evaluators():
if u.itstime(step, log_steps, total_steps, first=False, last=True):
u.chrono.pause(wait_for=train_state)
u.chrono.tick(step) # Record things like epoch number, core hours etc.
write_note(f"{name} evaluation...\n{u.chrono.note}")
with u.chrono.log_timing(f"z/secs/eval/{name}"):
with mesh, flax.linen.logical_axis_rules(sharding_rules):
for key, value in evaluator.run(train_state):
mw.measure(f"{prefix}{key}", jax.device_get(value))
u.chrono.resume()
mw.step_end()
# Always give a chance to stop the profiler, no matter how things ended.
if jax.process_index() == 0 and prof is not None:
u.startstop_prof(prof)
# Last note needs to happen before the pool's closed =)
write_note(f"Done!\n{u.chrono.note}")
pool.close()
pool.join()
mw.close()
if ckpt_mngr:
ckpt_mngr.wait_until_finished()
# Make sure all hosts stay up until the end of main.
u.sync()
u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info)
if __name__ == "__main__":
app.run(main)