|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Contrastive training loop. |
|
|
|
For models Like |
|
- LiT (https://arxiv.org/abs/2111.07991) |
|
- CLIP (https://arxiv.org/abs/2103.00020) |
|
- SigLIP (https://arxiv.org/abs/2303.15343) |
|
""" |
|
|
|
import functools |
|
import importlib |
|
import multiprocessing.pool |
|
import os |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import big_vision.evaluators.common as eval_common |
|
import big_vision.input_pipeline as input_pipeline |
|
import big_vision.optax as bv_optax |
|
import big_vision.utils as u |
|
from clu import parameter_overview |
|
import flax |
|
import jax |
|
import jax.numpy as jnp |
|
from ml_collections import config_flags |
|
import numpy as np |
|
import optax |
|
import tensorflow as tf |
|
|
|
from tensorflow.io import gfile |
|
|
|
|
|
|
|
|
|
config_flags.DEFINE_config_file( |
|
"config", None, "Training configuration.", lock_config=True) |
|
|
|
flags.DEFINE_string("workdir", default=None, help="Work unit directory.") |
|
flags.DEFINE_boolean("cleanup", default=False, |
|
help="Delete workdir (only) after successful completion.") |
|
|
|
|
|
jax.config.parse_flags_with_absl() |
|
|
|
|
|
def clip(x, *, a_max=None, a_min=None): |
|
"""Like jnp.clip, but allows all-None to mean don't clip.""" |
|
if a_max is None and a_min is None: |
|
return x |
|
return jnp.clip(x, a_max=a_max, a_min=a_min) |
|
|
|
|
|
def all_gather(z, roll=False, only_others=False): |
|
"""All gather and flatten first two dims.""" |
|
def gather_flat(x): |
|
x = jax.lax.all_gather(x, "batch") |
|
if roll or only_others: |
|
|
|
x = jnp.roll(x, -jax.lax.axis_index("batch"), axis=0) |
|
if only_others: |
|
x = x[1:] |
|
return jnp.concatenate(x, 0) |
|
return jax.tree_map(gather_flat, z) |
|
|
|
|
|
def softmax_loss(zimg, ztxt, temperature): |
|
"""Softmax loss following the CLIP paper. Factorized to reduce memory cost.""" |
|
|
|
def unidirectional_loss(z1, z2, t): |
|
z2 = all_gather(z2, roll=True) |
|
logits = jnp.dot(z1, z2.T) * t |
|
|
|
|
|
loss = -(jnp.diag(logits) - jax.scipy.special.logsumexp(logits, axis=-1)) |
|
acc = jnp.argmax(logits, axis=1) == jnp.arange(z1.shape[0]) |
|
return loss.mean(), acc.mean() |
|
|
|
extras = {} |
|
loss = 0 |
|
for name, row, col in [("i2t", zimg, ztxt), ("t2i", ztxt, zimg)]: |
|
loss_dir, acc_dir = unidirectional_loss(row, col, temperature) |
|
loss += 0.5 * loss_dir |
|
extras[f"{name}_acc"] = acc_dir |
|
extras[f"{name}_loss"] = loss_dir |
|
|
|
loss = jax.lax.pmean(loss, "batch") |
|
return loss, extras |
|
|
|
|
|
def _avg_pos_logit(x_me): |
|
return jnp.mean(jnp.diag(x_me)) |
|
|
|
|
|
def _avg_neg_logit(x_me, x_ot=None): |
|
nom = jnp.sum(x_me) - jnp.sum(jnp.diag(x_me)) |
|
den = x_me.size - len(x_me) |
|
if x_ot is not None and x_ot.size: |
|
nom += jnp.sum(x_ot) |
|
den += x_ot.size |
|
return nom / den |
|
|
|
|
|
def sigmoid_loss(zimg, ztxt, temperature, bias=0.0): |
|
"""Sigmoid loss from SigLIP: https://arxiv.org/abs/2303.15343.""" |
|
|
|
|
|
ztxt_me = ztxt |
|
ztxt_ot = all_gather(ztxt, only_others=True) |
|
|
|
logits_me = jnp.dot(zimg, ztxt_me.T) |
|
logits_ot = jnp.dot(zimg, ztxt_ot.T) |
|
logits_me = logits_me * temperature + bias |
|
logits_ot = logits_ot * temperature + bias |
|
|
|
eye = jnp.eye(zimg.shape[0]) |
|
|
|
|
|
|
|
|
|
m1_diag1 = -jnp.ones_like(logits_me) + 2 * eye |
|
loglik_me = jax.nn.log_sigmoid(m1_diag1 * logits_me) |
|
loglik_ot = jax.nn.log_sigmoid(-logits_ot) |
|
|
|
|
|
nll_me = -loglik_me.sum(axis=-1) |
|
nll_ot = -loglik_ot.sum(axis=-1) |
|
l = nll_me.mean() + nll_ot.mean() |
|
|
|
return l, { |
|
|
|
|
|
|
|
"pos_min_logit": jnp.min(jnp.diag(logits_me)), |
|
"pos_max_logit": jnp.max(jnp.diag(logits_me)), |
|
"pos_avg_logit": _avg_pos_logit(logits_me), |
|
"local_neg_min_logit": jnp.min(logits_me + 1e9 * eye), |
|
"local_neg_max_logit": jnp.max(logits_me - 1e9 * eye), |
|
"local_neg_avg_logit": _avg_neg_logit(logits_me), |
|
"neg_min_logit": jnp.minimum( |
|
jnp.min(logits_me + 1e9 * eye), |
|
jnp.min(logits_ot) if logits_ot.size else jnp.inf), |
|
"neg_max_logit": jnp.maximum( |
|
jnp.max(logits_me - 1e9 * eye), |
|
jnp.max(logits_ot) if logits_ot.size else -jnp.inf), |
|
"neg_avg_logit": _avg_neg_logit(logits_me, logits_ot), |
|
} |
|
|
|
|
|
def _gather_from_device(x, device_id, axis_name="batch"): |
|
return jax.lax.psum((jax.lax.axis_index(axis_name) == device_id) * x, |
|
axis_name) |
|
|
|
|
|
def chunked_sigmoid_loss(zimg, ztxt, temperature, bias=0.0): |
|
"""Loss computation from section 3.1 of arxiv.org/abs/2303.15343.""" |
|
|
|
|
|
logits_me = jnp.dot(zimg, ztxt.T) |
|
logits_me = logits_me * temperature + bias |
|
m1_diag1 = -jnp.ones_like(logits_me) + 2 * jnp.eye(zimg.shape[0]) |
|
loglik_me = jax.nn.log_sigmoid(m1_diag1 * logits_me) |
|
nll_me = -loglik_me.sum(axis=-1).mean() |
|
|
|
def negative_loss(ztxt_other_device): |
|
logits_ot = jnp.dot(zimg, ztxt_other_device.T) |
|
logits_ot = logits_ot * temperature + bias |
|
loglik_ot = jax.nn.log_sigmoid(-logits_ot) |
|
return -jnp.sum(loglik_ot, axis=-1).mean() |
|
|
|
me = jax.lax.axis_index("batch") |
|
|
|
|
|
|
|
nll_others = 0 |
|
for device_id in range(jax.device_count()): |
|
skip = jnp.not_equal(device_id, me) |
|
nll_others += skip * negative_loss(_gather_from_device(ztxt, device_id)) |
|
|
|
eye = jnp.eye(zimg.shape[0]) |
|
return nll_me + nll_others, { |
|
"pos_min_logit": jnp.min(jnp.diag(logits_me)), |
|
"pos_max_logit": jnp.max(jnp.diag(logits_me)), |
|
"pos_avg_logit": _avg_pos_logit(logits_me), |
|
"local_neg_min_logit": jnp.min(logits_me + 1e9 * eye), |
|
"local_neg_max_logit": jnp.max(logits_me - 1e9 * eye), |
|
"local_neg_avg_logit": _avg_neg_logit(logits_me),} |
|
|
|
|
|
def main(argv): |
|
del argv |
|
tf.config.experimental.set_visible_devices([], "GPU") |
|
|
|
config = flags.FLAGS.config |
|
workdir = flags.FLAGS.workdir |
|
logging.info( |
|
f"\u001b[33mHello from process {jax.process_index()} holding " |
|
f"{jax.local_device_count()}/{jax.device_count()} devices and " |
|
f"writing to workdir {workdir}.\u001b[0m") |
|
|
|
save_ckpt_path = None |
|
if workdir: |
|
gfile.makedirs(workdir) |
|
save_ckpt_path = os.path.join(workdir, "checkpoint.npz") |
|
|
|
|
|
pool = multiprocessing.pool.ThreadPool() |
|
|
|
|
|
for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): |
|
importlib.import_module(f"big_vision.pp.{m}") |
|
|
|
|
|
|
|
|
|
|
|
rng = jax.random.PRNGKey(config.get("seed", 0)) |
|
|
|
|
|
|
|
xid, wid = -1, -1 |
|
def info(s, *a): |
|
logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) |
|
def write_note(note): |
|
if jax.process_index() == 0: |
|
info("%s", note) |
|
|
|
write_note("Initializing...") |
|
|
|
batch_size = config.input.batch_size |
|
if batch_size % jax.device_count() != 0: |
|
raise ValueError(f"Batch size ({batch_size}) must " |
|
f"be divisible by device number ({jax.device_count()})") |
|
info("Global batch size %d on %d hosts results in %d local batch size. With " |
|
"%d dev per host (%d dev total), that's a %d per-device batch size.", |
|
batch_size, jax.process_count(), batch_size // jax.process_count(), |
|
jax.local_device_count(), jax.device_count(), |
|
batch_size // jax.device_count()) |
|
|
|
|
|
mw = u.BigVisionMetricWriter(xid, wid, workdir, config) |
|
|
|
write_note("Initializing train dataset...") |
|
train_ds, ntrain_img = input_pipeline.training(config.input) |
|
|
|
|
|
n_prefetch = config.get("prefetch_to_device", 1) |
|
train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) |
|
|
|
total_steps = u.steps("total", config, ntrain_img, batch_size) |
|
def get_steps(name, default=ValueError, cfg=config): |
|
return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) |
|
|
|
u.chrono.inform(total_steps=total_steps, global_bs=batch_size, |
|
steps_per_epoch=ntrain_img / batch_size, |
|
measure=mw.measure, write_note=write_note) |
|
|
|
info("Running for %d steps, that means %f epochs", |
|
total_steps, total_steps * batch_size / ntrain_img) |
|
|
|
write_note(f"Initializing {config.model_name} model...") |
|
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") |
|
model = model_mod.Model(**config.get("model", {})) |
|
|
|
|
|
|
|
|
|
@functools.partial(jax.jit, backend="cpu") |
|
def init(rng): |
|
bs = batch_size // jax.device_count() |
|
image_size = tuple(train_ds.element_spec["image"].shape[1:]) |
|
no_image = jnp.zeros((bs,) + image_size, jnp.float32) |
|
text_size = tuple(train_ds.element_spec["labels"].shape[1:]) |
|
no_text = jnp.zeros((bs,) + text_size, jnp.int32) |
|
params = flax.core.unfreeze(model.init(rng, no_image, no_text))["params"] |
|
return params |
|
|
|
rng, rng_init = jax.random.split(rng) |
|
with u.chrono.log_timing("z/secs/init"): |
|
params_cpu = init(rng_init) |
|
|
|
if jax.process_index() == 0: |
|
num_params = sum(p.size for p in jax.tree_leaves(params_cpu)) |
|
parameter_overview.log_parameter_overview(params_cpu, msg="init params") |
|
mw.measure("num_params", num_params) |
|
|
|
write_note(f"Initializing {config.optax_name} optimizer...") |
|
tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict( |
|
total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) |
|
|
|
|
|
opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) |
|
sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] |
|
|
|
@functools.partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1)) |
|
def update_fn(params, opt, rng, batch): |
|
"""Update step.""" |
|
assert "mixup" not in config, "We still have to figure out mixup." |
|
|
|
|
|
rng, rng_model = jax.random.split(rng, 2) |
|
rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) |
|
|
|
def loss_fn(params, images, labels): |
|
zimg, ztxt, extras = model.apply( |
|
{"params": params}, images, labels, |
|
train=True, rngs={"dropout": rng_model_local}) |
|
|
|
match config.get("loss_fn", "softmax"): |
|
case "softmax": |
|
l, l_extras = softmax_loss(zimg, ztxt, extras["t"]) |
|
case "sigmoid": |
|
l, l_extras = sigmoid_loss(zimg, ztxt, extras["t"], bias=extras["b"]) |
|
case "chunked_sigmoid": |
|
l, l_extras = chunked_sigmoid_loss(zimg, ztxt, extras["t"], |
|
bias=extras["b"]) |
|
case _: |
|
raise NotImplementedError(f"Unrecognized loss {config.loss_fn=}") |
|
|
|
return l, { |
|
"t": extras["t"], |
|
"t/parameter": extras["t/parameter"], |
|
"train/nimg": jnp.mean(extras["img/norm"]), |
|
"train/ntxt": jnp.mean(extras["txt/norm"]), |
|
**{f"train/{k}": v for k, v in l_extras.items()}, |
|
} |
|
|
|
(l, measurements), grads = jax.value_and_grad( |
|
loss_fn, has_aux=True)(params, batch["image"], batch["labels"]) |
|
l, measurements, grads = jax.lax.pmean((l, measurements, grads), |
|
axis_name="batch") |
|
updates, opt = tx.update(grads, opt, params) |
|
params = optax.apply_updates(params, updates) |
|
|
|
gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) |
|
measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs])) |
|
ps = jax.tree_leaves(params) |
|
measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) |
|
us = jax.tree_leaves(updates) |
|
measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) |
|
|
|
return params, opt, rng, l, measurements |
|
|
|
|
|
|
|
|
|
|
|
def predict_fn(params, image=None, text=None, **unused_kwargs): |
|
del unused_kwargs |
|
zimg, ztxt, out = model.apply({"params": params}, image, text) |
|
return zimg, ztxt, out |
|
|
|
|
|
@functools.lru_cache(maxsize=None) |
|
def evaluators(): |
|
return eval_common.from_config( |
|
config, {"predict": predict_fn}, |
|
lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), |
|
lambda key, cfg: get_steps(key, default=None, cfg=cfg), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
resume_ckpt_path = None |
|
if save_ckpt_path and gfile.exists(save_ckpt_path): |
|
resume_ckpt_path = save_ckpt_path |
|
elif config.get("resume"): |
|
resume_ckpt_path = config.resume.format(wid=xm_wu.id) |
|
if resume_ckpt_path: |
|
write_note("Resume training from checkpoint...") |
|
checkpoint = { |
|
"params": params_cpu, |
|
"opt": opt_cpu, |
|
"chrono": u.chrono.save(), |
|
} |
|
checkpoint_tree = jax.tree_structure(checkpoint) |
|
loaded = u.load_checkpoint_np(resume_ckpt_path, checkpoint_tree) |
|
|
|
checkpoint = jax.tree_map(u.recover_dtype, loaded) |
|
params_cpu, opt_cpu = checkpoint["params"], checkpoint["opt"] |
|
u.chrono.load(checkpoint["chrono"]) |
|
elif config.get("model_init"): |
|
write_note(f"Initialize model from {config.model_init}...") |
|
params_cpu = model_mod.load( |
|
params_cpu, config.model_init, config.get("model"), |
|
**config.get("model_load", {})) |
|
if jax.process_index() == 0: |
|
parameter_overview.log_parameter_overview( |
|
params_cpu, msg="restored params") |
|
|
|
write_note("Kicking off misc stuff...") |
|
first_step = bv_optax.get_count(opt_cpu) |
|
u.chrono.inform(first_step=first_step) |
|
prof = None |
|
|
|
write_note(f"Replicating...\n{u.chrono.note}") |
|
params_repl = flax.jax_utils.replicate(params_cpu) |
|
opt_repl = flax.jax_utils.replicate(opt_cpu) |
|
|
|
rng, rng_loop = jax.random.split(rng, 2) |
|
rngs_loop = flax.jax_utils.replicate(rng_loop) |
|
ckpt_writer = None |
|
|
|
write_note(f"First step compilations...\n{u.chrono.note}") |
|
|
|
|
|
|
|
|
|
if first_step in (total_steps, 0): |
|
mw.step_start(first_step) |
|
for (name, evaluator, _, prefix) in evaluators(): |
|
if config.evals[name].get("skip_first") and first_step != total_steps: |
|
continue |
|
write_note(f"{name} evaluation...\n{u.chrono.note}") |
|
with u.chrono.log_timing(f"z/secs/eval/{name}"): |
|
for key, value in evaluator.run(params_repl): |
|
mw.measure(f"{prefix}{key}", value) |
|
|
|
|
|
|
|
for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): |
|
mw.step_start(step) |
|
|
|
with jax.profiler.StepTraceAnnotation("train_step", step_num=step): |
|
with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): |
|
params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( |
|
params_repl, opt_repl, rngs_loop, batch) |
|
|
|
|
|
if jax.process_index() == 0: |
|
prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) |
|
|
|
|
|
if (u.itstime(step, get_steps("log_training"), total_steps, host=0) |
|
or u.chrono.warmup and jax.process_index() == 0): |
|
for i, sched_fn_cpu in enumerate(sched_fns_cpu): |
|
mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) |
|
l = mw.measure("training_loss", loss_value[0]) |
|
for name, value in measurements.items(): |
|
mw.measure(name, value[0]) |
|
u.chrono.tick(step) |
|
if not np.isfinite(l): |
|
raise RuntimeError(f"The loss became nan or inf somewhere within steps " |
|
f"[{step - get_steps('log_training')}, {step}]") |
|
|
|
|
|
if (save_ckpt_path and |
|
(u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or |
|
u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): |
|
u.chrono.pause(wait_for=(params_repl, opt_repl)) |
|
u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) |
|
|
|
|
|
|
|
params_cpu = jax.tree_map(lambda x: np.array(x[0]), params_repl) |
|
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) |
|
|
|
|
|
copy_step = None |
|
if u.itstime(step, get_steps("keep_ckpt", None), total_steps): |
|
copy_step = step |
|
|
|
ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": u.chrono.save()} |
|
ckpt_writer = pool.apply_async( |
|
u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) |
|
u.chrono.resume() |
|
|
|
for (name, evaluator, log_steps, prefix) in evaluators(): |
|
if u.itstime(step, log_steps, total_steps, first=False, last=True): |
|
u.chrono.pause(wait_for=params_repl) |
|
u.chrono.tick(step) |
|
write_note(f"{name} evaluation...\n{u.chrono.note}") |
|
with u.chrono.log_timing(f"z/secs/eval/{name}"): |
|
for key, value in evaluator.run(params_repl): |
|
mw.measure(f"{prefix}{key}", value) |
|
u.chrono.resume() |
|
mw.step_end() |
|
|
|
|
|
|
|
if jax.process_index() == 0 and prof is not None: |
|
u.startstop_prof(prof) |
|
|
|
|
|
write_note(f"Done!\n{u.chrono.note}") |
|
|
|
pool.close() |
|
pool.join() |
|
mw.close() |
|
|
|
|
|
u.sync() |
|
|
|
u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|