|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Script that loads a model and only runs evaluators.""" |
|
|
|
from functools import partial |
|
import importlib |
|
|
|
import os |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import big_vision.evaluators.common as eval_common |
|
import big_vision.utils as u |
|
from clu import parameter_overview |
|
import flax |
|
import flax.jax_utils as flax_utils |
|
import jax |
|
import jax.numpy as jnp |
|
from ml_collections import config_flags |
|
from tensorflow.io import gfile |
|
|
|
|
|
config_flags.DEFINE_config_file( |
|
"config", None, "Training configuration.", lock_config=True) |
|
|
|
flags.DEFINE_string("workdir", default=None, help="Work unit directory.") |
|
flags.DEFINE_boolean("cleanup", default=False, |
|
help="Delete workdir (only) after successful completion.") |
|
|
|
|
|
jax.config.parse_flags_with_absl() |
|
|
|
|
|
def main(argv): |
|
del argv |
|
|
|
config = flags.FLAGS.config |
|
workdir = flags.FLAGS.workdir |
|
logging.info("Workdir: %s", workdir) |
|
|
|
|
|
for m in config.get("pp_modules", ["ops_general", "ops_image"]): |
|
importlib.import_module(f"big_vision.pp.{m}") |
|
|
|
|
|
|
|
xid, wid = -1, -1 |
|
def write_note(note): |
|
if jax.process_index() == 0: |
|
logging.info("NOTE: %s", note) |
|
|
|
mw = u.BigVisionMetricWriter(xid, wid, workdir, config) |
|
u.chrono.inform(measure=mw.measure, write_note=write_note) |
|
|
|
write_note(f"Initializing {config.model_name} model...") |
|
assert config.get("model.reinit") is None, ( |
|
"I don't think you want any part of the model to be re-initialized.") |
|
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") |
|
model_kw = dict(config.get("model", {})) |
|
if "num_classes" in config: |
|
model_kw["num_classes"] = config.num_classes |
|
model = model_mod.Model(**model_kw) |
|
|
|
|
|
|
|
|
|
@partial(jax.jit, backend="cpu") |
|
def init(rng): |
|
input_shapes = config.get("init_shapes", [(1, 224, 224, 3)]) |
|
input_types = config.get("init_types", [jnp.float32] * len(input_shapes)) |
|
dummy_inputs = [jnp.zeros(s, t) for s, t in zip(input_shapes, input_types)] |
|
things = flax.core.unfreeze(model.init(rng, *dummy_inputs)) |
|
return things.get("params", {}) |
|
|
|
with u.chrono.log_timing("z/secs/init"): |
|
params_cpu = init(jax.random.PRNGKey(42)) |
|
if jax.process_index() == 0: |
|
parameter_overview.log_parameter_overview(params_cpu, msg="init params") |
|
num_params = sum(p.size for p in jax.tree.leaves(params_cpu)) |
|
mw.measure("num_params", num_params) |
|
|
|
|
|
if config.get("model_init"): |
|
write_note(f"Initialize model from {config.model_init}...") |
|
params_cpu = model_mod.load( |
|
params_cpu, config.model_init, config.get("model"), |
|
**config.get("model_load", {})) |
|
if jax.process_index() == 0: |
|
parameter_overview.log_parameter_overview(params_cpu, msg="loaded params") |
|
|
|
write_note("Replicating...") |
|
params_repl = flax_utils.replicate(params_cpu) |
|
|
|
def predict_fn(params, *a, **kw): |
|
return model.apply({"params": params}, *a, **kw) |
|
|
|
evaluators = eval_common.from_config( |
|
config, {"predict": predict_fn, "model": model}, |
|
lambda s: write_note(f"Initializing evaluator: {s}..."), |
|
lambda key, cfg: 1, |
|
) |
|
|
|
|
|
|
|
|
|
for s in range(config.get("eval_repeats", 1)): |
|
mw.step_start(s) |
|
for (name, evaluator, _, prefix) in evaluators: |
|
write_note(f"{name} evaluation step {s}...") |
|
with u.profile(name, noop=name in config.get("no_profile", [])): |
|
with u.chrono.log_timing(f"z/secs/eval/{name}"): |
|
for key, value in evaluator.run(params_repl): |
|
mw.measure(f"{prefix}{key}", value) |
|
u.sync() |
|
u.chrono.flush_timings() |
|
mw.step_end() |
|
|
|
write_note("Done!") |
|
mw.close() |
|
|
|
|
|
u.sync() |
|
|
|
if workdir and flags.FLAGS.cleanup and jax.process_index() == 0: |
|
gfile.rmtree(workdir) |
|
try: |
|
gfile.remove(os.path.join(workdir, "..")) |
|
except tf.errors.OpError: |
|
pass |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|