|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Training loop for PaliGemma-style VLM.""" |
|
|
|
|
|
|
|
import functools |
|
import importlib |
|
import multiprocessing.pool |
|
import os |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import big_vision.datasets.core as ds_core |
|
import big_vision.evaluators.common as eval_common |
|
import big_vision.input_pipeline as input_pipeline |
|
import big_vision.optax as bv_optax |
|
import big_vision.sharding as bv_sharding |
|
import big_vision.trainers.proj.paligemma.predict_fns as predict_fns |
|
import big_vision.utils as u |
|
from clu import parameter_overview |
|
import flax |
|
import flax.linen as nn |
|
import jax |
|
from jax.experimental import mesh_utils |
|
from jax.experimental import multihost_utils |
|
from jax.experimental.array_serialization import serialization as array_serial |
|
import jax.numpy as jnp |
|
import ml_collections as mlc |
|
from ml_collections import config_flags |
|
import numpy as np |
|
import optax |
|
import tensorflow as tf |
|
|
|
from tensorflow.io import gfile |
|
|
|
|
|
config_flags.DEFINE_config_file( |
|
"config", None, "Training configuration.", lock_config=True) |
|
|
|
flags.DEFINE_string("workdir", default=None, help="Work unit directory.") |
|
flags.DEFINE_boolean("cleanup", default=False, |
|
help="Delete workdir (only) after successful completion.") |
|
|
|
|
|
jax.config.parse_flags_with_absl() |
|
|
|
|
|
|
|
|
|
jax.config.update("jax_transfer_guard", "disallow") |
|
|
|
|
|
NamedSharding = jax.sharding.NamedSharding |
|
P = jax.sharding.PartitionSpec |
|
|
|
|
|
def main(argv): |
|
del argv |
|
|
|
|
|
if os.environ.get("BV_JAX_INIT"): |
|
jax.distributed.initialize() |
|
|
|
|
|
tf.config.set_visible_devices([], "GPU") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = flags.FLAGS.config |
|
workdir = flags.FLAGS.workdir |
|
logging.info( |
|
f"\u001b[33mHello from process {jax.process_index()} holding " |
|
f"{jax.local_device_count()}/{jax.device_count()} devices and " |
|
f"writing to workdir {workdir}.\u001b[0m") |
|
|
|
save_ckpt_path = None |
|
if workdir: |
|
gfile.makedirs(workdir) |
|
save_ckpt_path = os.path.join(workdir, "checkpoint.bv") |
|
|
|
|
|
pool = multiprocessing.pool.ThreadPool() |
|
|
|
|
|
for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): |
|
importlib.import_module(f"big_vision.pp.{m}") |
|
|
|
|
|
xid, wid = -1, -1 |
|
fillin = lambda s: s |
|
def info(s, *a): |
|
logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) |
|
def write_note(note): |
|
if jax.process_index() == 0: |
|
info("%s", note) |
|
|
|
mw = u.BigVisionMetricWriter(xid, wid, workdir, config) |
|
|
|
|
|
u.chrono.inform(measure=mw.measure, write_note=write_note) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config_mesh = config.get("mesh", [("data", jax.device_count())]) |
|
|
|
|
|
sharding_rules = config.get("sharding_rules", [("act_batch", "data")]) |
|
|
|
mesh_axes, mesh_size = tuple(zip(*config_mesh)) |
|
|
|
|
|
mesh_size = np.array(jax.devices()).reshape(mesh_size).shape |
|
|
|
device_mesh = mesh_utils.create_device_mesh( |
|
mesh_size, allow_split_physical_axes=config.get( |
|
"mesh_allow_split_physical_axes", False)) |
|
|
|
|
|
|
|
|
|
|
|
devices_flat = device_mesh.flatten() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
write_note("Initializing train dataset...") |
|
batch_size = config.input.batch_size |
|
if batch_size % jax.device_count() != 0: |
|
raise ValueError(f"Batch size ({batch_size}) must " |
|
f"be divisible by device number ({jax.device_count()})") |
|
info("Global batch size %d on %d hosts results in %d local batch size. With " |
|
"%d dev per host (%d dev total), that's a %d per-device batch size.", |
|
batch_size, jax.process_count(), batch_size // jax.process_count(), |
|
jax.local_device_count(), jax.device_count(), |
|
batch_size // jax.device_count()) |
|
|
|
train_ds, ntrain_img = input_pipeline.training(config.input) |
|
|
|
total_steps = u.steps("total", config, ntrain_img, batch_size) |
|
def get_steps(name, default=ValueError, cfg=config): |
|
return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) |
|
|
|
u.chrono.inform(total_steps=total_steps, global_bs=batch_size, |
|
steps_per_epoch=ntrain_img / batch_size) |
|
|
|
info("Running for %d steps, that means %f epochs", |
|
total_steps, total_steps * batch_size / ntrain_img) |
|
|
|
|
|
|
|
n_prefetch = config.get("prefetch_to_device", 1) |
|
train_iter = input_pipeline.start_global( |
|
train_ds, devices_flat, n_prefetch, warmup=n_prefetch > 0) |
|
|
|
|
|
if isinstance(config.input.data.get("name"), str): |
|
measure_per_dataset_times = lambda step: None |
|
else: |
|
nexamples = { |
|
name: ds_core.get(**config.input[name].data).total_examples |
|
for name in config.input.data |
|
} |
|
def measure_per_dataset_times(step): |
|
total = sum(config.input.data.values()) |
|
for name, w in config.input.data.items(): |
|
w = w / total |
|
mw.measure(f"examples_seen_{name}", u.chrono.accum_examples_seen * w) |
|
mw.measure(f"epoch_{name}", step * batch_size * w / nexamples[name]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
write_note(f"Initializing {config.model_name} model...") |
|
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") |
|
model = model_mod.Model(**mlc.FrozenConfigDict(config.get("model", {}))) |
|
|
|
def init(rng, partial_params=None): |
|
batch = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype), |
|
train_ds.element_spec) |
|
_, variables = model.apply( |
|
{"params": partial_params or {}}, |
|
batch["image"], batch["text"][:, :-1], batch["mask_ar"][:, :-1], |
|
rngs={"params": rng, "dropout": rng}, |
|
mutable=["params"]) |
|
return flax.core.unfreeze(variables["params"]) |
|
|
|
|
|
|
|
|
|
|
|
rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0))) |
|
|
|
write_note("Inferring parameter shapes...") |
|
rng, rng_init = jax.random.split(rng) |
|
params_shape = jax.eval_shape(init, rng_init) |
|
params_shape = nn.unbox(params_shape) |
|
|
|
write_note("Inferring optimizer state shapes...") |
|
tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict( |
|
total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) |
|
opt_shape = jax.eval_shape(tx.init, params_shape) |
|
|
|
sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns] |
|
|
|
if jax.process_index() == 0: |
|
num_params = sum(np.prod(p.shape) for p in jax.tree.leaves(params_shape)) |
|
mw.measure("num_params", num_params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
write_note("Creating device mesh...") |
|
mesh = jax.sharding.Mesh(device_mesh, mesh_axes) |
|
repl_sharding = jax.sharding.NamedSharding(mesh, P()) |
|
|
|
write_note("Inferring shardings...") |
|
train_state_shape = {"params": params_shape, "opt": opt_shape} |
|
|
|
strategy = config.get("sharding_strategy", [(".*", "replicate")]) |
|
train_state_sharding = bv_sharding.infer_sharding( |
|
train_state_shape, strategy=strategy, mesh=mesh) |
|
|
|
|
|
|
|
|
|
|
|
resume_ckpt_path = None |
|
if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"): |
|
resume_ckpt_path = save_ckpt_path |
|
elif config.get("resume"): |
|
resume_ckpt_path = fillin(config.resume) |
|
|
|
if resume_ckpt_path: |
|
write_note(f"Resuming training from checkpoint {resume_ckpt_path}...") |
|
shardings = { |
|
**train_state_sharding, |
|
"chrono": jax.tree.map(lambda _: repl_sharding, u.chrono.save()), |
|
} |
|
loaded = u.load_checkpoint_ts( |
|
resume_ckpt_path, tree=shardings, shardings=shardings) |
|
train_state = {key: loaded[key] for key in train_state_sharding.keys()} |
|
u.chrono.load(jax.device_get(loaded["chrono"])) |
|
del loaded |
|
else: |
|
write_note( |
|
f"Initialize model from {config.get('model_init') or 'scratch'}...") |
|
|
|
|
|
|
|
if config.get("model_init"): |
|
|
|
|
|
params = model_mod.load( |
|
params_shape, config.model_init, config.get("model"), |
|
**config.get("model_load", {})) |
|
|
|
|
|
mask = jax.tree.map( |
|
lambda x: not isinstance(x, jax.ShapeDtypeStruct), params) |
|
params = u.reshard(u.tree_filter(params, mask), |
|
u.tree_filter(train_state_sharding["params"], mask)) |
|
|
|
parameter_overview.log_parameter_overview( |
|
params, msg="Restored params", |
|
include_stats="global", jax_logging_process=0) |
|
else: |
|
params = {} |
|
|
|
|
|
rng_init = u.reshard(rng_init, repl_sharding) |
|
params = jax.jit( |
|
init, donate_argnums=1, out_shardings=train_state_sharding["params"])( |
|
rng_init, params) |
|
params = nn.unbox(params) |
|
|
|
|
|
opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params) |
|
train_state = {"params": params, "opt": opt} |
|
del params, opt |
|
|
|
parameter_overview.log_parameter_overview( |
|
train_state["params"], msg="Parameter overview", |
|
include_stats="global", jax_logging_process=0) |
|
|
|
rng, rng_loop = jax.random.split(rng, 2) |
|
rng_loop = u.reshard(rng_loop, repl_sharding) |
|
del rng, rng_init |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.partial( |
|
jax.jit, |
|
donate_argnums=(0,), |
|
out_shardings=(train_state_sharding, repl_sharding)) |
|
def update_fn(train_state, rng, batch): |
|
"""Update step.""" |
|
|
|
step_count = bv_optax.get_count(train_state["opt"], jittable=True) |
|
rng = jax.random.fold_in(rng, step_count) |
|
assert "mixup" not in config, "Mixup is not supported for SigLIP." |
|
|
|
|
|
_, rng_model = jax.random.split(rng, 2) |
|
|
|
imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"] |
|
|
|
def loss_fn(params): |
|
text_logits, _ = model.apply( |
|
{"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], |
|
train=True, rngs={"dropout": rng_model}) |
|
|
|
logp = jax.nn.log_softmax(text_logits, axis=-1) |
|
targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1]) |
|
off_value = config.get("label_smoothing", 0.0) |
|
if off_value > 0: |
|
denom = text_logits.shape[-1] - 1 |
|
targets = jnp.where( |
|
targets == 1.0, 1.0 - off_value, off_value / denom) |
|
|
|
|
|
token_pplx = jnp.sum(logp * targets, axis=-1) |
|
|
|
|
|
mask_loss = batch["mask_loss"][:, 1:] |
|
token_pplx = token_pplx * mask_loss |
|
pplx = -jnp.sum(token_pplx, axis=-1) |
|
pplx /= jnp.clip(jnp.sum(mask_loss, axis=-1), 1) |
|
|
|
|
|
measurements = dict( |
|
training_loss=jnp.mean(pplx), |
|
avg_sup_seqlen=jnp.mean(jnp.sum(mask_loss, axis=-1)), |
|
max_sup_seqlen=jnp.max(jnp.sum(mask_loss, axis=-1)), |
|
) |
|
|
|
return measurements["training_loss"], measurements |
|
|
|
params, opt = train_state["params"], train_state["opt"] |
|
(_, measurements), grads = jax.value_and_grad(loss_fn, has_aux=True)(params) |
|
updates, opt = tx.update(grads, opt, params) |
|
params = optax.apply_updates(params, updates) |
|
|
|
gs = jax.tree.leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) |
|
measurements["l2_grads"] = jnp.sqrt(sum([jnp.sum(g * g) for g in gs])) |
|
ps = jax.tree.leaves(params) |
|
measurements["l2_params"] = jnp.sqrt(sum([jnp.sum(p * p) for p in ps])) |
|
us = jax.tree.leaves(updates) |
|
measurements["l2_updates"] = jnp.sqrt(sum([jnp.sum(u * u) for u in us])) |
|
|
|
return {"params": params, "opt": opt}, measurements |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=None) |
|
def evaluators(): |
|
return eval_common.from_config( |
|
config, |
|
predict_fns.get_all(model), |
|
lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), |
|
lambda key, cfg: get_steps(key, default=None, cfg=cfg), |
|
devices_flat, |
|
) |
|
|
|
|
|
write_note("Inferring the first step number...") |
|
first_step_device = bv_optax.get_count(train_state["opt"], jittable=True) |
|
first_step = int(jax.device_get(first_step_device)) |
|
u.chrono.inform(first_step=first_step) |
|
|
|
|
|
|
|
|
|
if first_step in (total_steps, 0): |
|
write_note("Running initial or final evals...") |
|
mw.step_start(first_step) |
|
for (name, evaluator, _, prefix) in evaluators(): |
|
if config.evals[name].get("skip_first") and first_step != total_steps: |
|
continue |
|
write_note(f"{name} evaluation...\n{u.chrono.note}") |
|
with u.chrono.log_timing(f"z/secs/eval/{name}"): |
|
with mesh, nn.logical_axis_rules(sharding_rules): |
|
for key, value in evaluator.run(train_state): |
|
mw.measure(f"{prefix}{key}", value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prof = None |
|
ckpt_mngr = None |
|
|
|
write_note("Starting training loop, compiling the first step...") |
|
for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): |
|
mw.step_start(step) |
|
|
|
with jax.profiler.StepTraceAnnotation("train_step", step_num=step): |
|
with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): |
|
with mesh, nn.logical_axis_rules(sharding_rules): |
|
train_state, measurements = update_fn(train_state, rng_loop, batch) |
|
|
|
|
|
if jax.process_index() == 0: |
|
prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) |
|
|
|
|
|
if (u.itstime(step, get_steps("log_training"), total_steps, host=0) |
|
or u.chrono.warmup and jax.process_index() == 0): |
|
for i, sched_fn_cpu in enumerate(sched_fns_cpu): |
|
mw.measure(f"global_schedule{i if i else ''}", |
|
sched_fn_cpu(u.put_cpu(step - 1))) |
|
measurements = jax.device_get(measurements) |
|
for name, value in measurements.items(): |
|
mw.measure(name, value) |
|
u.chrono.tick(step) |
|
measure_per_dataset_times(step) |
|
|
|
for k in ("training_loss", "l2_params", "l2_grads"): |
|
if not np.isfinite(measurements.get(k, 0.0)): |
|
raise RuntimeError(f"{k} became nan or inf somewhere within steps " |
|
f"[{step - get_steps('log_training')}, {step}]") |
|
|
|
|
|
keep_last = total_steps if get_steps("ckpt", None) else None |
|
keep_ckpt_steps = get_steps("keep_ckpt", None) or keep_last |
|
if save_ckpt_path and ( |
|
(keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False)) |
|
or u.itstime(step, get_steps("ckpt", None), total_steps, first=True) |
|
): |
|
u.chrono.pause(wait_for=train_state) |
|
|
|
|
|
ckpt = {**train_state} |
|
|
|
|
|
|
|
with jax.transfer_guard("allow"): |
|
chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save()) |
|
chrono_shardings = jax.tree.map(lambda _: repl_sharding, chrono_ckpt) |
|
ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)} |
|
|
|
ckpt_mngr = ckpt_mngr or array_serial.GlobalAsyncCheckpointManager() |
|
u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep) |
|
u.chrono.resume() |
|
|
|
for (name, evaluator, log_steps, prefix) in evaluators(): |
|
if u.itstime(step, log_steps, total_steps, first=False, last=True): |
|
u.chrono.pause(wait_for=train_state) |
|
u.chrono.tick(step) |
|
write_note(f"{name} evaluation...\n{u.chrono.note}") |
|
with u.chrono.log_timing(f"z/secs/eval/{name}"): |
|
with mesh, nn.logical_axis_rules(sharding_rules): |
|
for key, value in evaluator.run(train_state): |
|
mw.measure(f"{prefix}{key}", jax.device_get(value)) |
|
u.chrono.resume() |
|
mw.step_end() |
|
|
|
|
|
|
|
if jax.process_index() == 0 and prof is not None: |
|
u.startstop_prof(prof) |
|
|
|
|
|
write_note(f"Done!\n{u.chrono.note}") |
|
|
|
pool.close() |
|
pool.join() |
|
mw.close() |
|
if ckpt_mngr: |
|
ckpt_mngr.wait_until_finished() |
|
|
|
|
|
u.sync() |
|
|
|
u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|