pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gradient transformations and other optax utilities."""
import operator
import big_vision.utils as u
import jax
import jax.numpy as jnp
import optax
def find_states(opt_state, cls):
leaves = jax.tree.leaves(
opt_state, is_leaf=lambda node: isinstance(node, cls))
return [leaf for leaf in leaves if isinstance(leaf, cls)]
def get_count(opt_state, jittable=False):
"""Returns `ScaleByScheduleState.count` from `opt_state` as an integer."""
counts = [
state.count
for state in find_states(opt_state, optax.ScaleByScheduleState)
]
if jittable:
return counts[0]
else:
counts = {int(c) for c in counts}
assert len(counts) == 1, f"Expected exactly 1 ScaleByScheduleState:{counts}"
return next(iter(counts))
def replace_frozen(schedule, pytree, replacement, log=None):
"""Replaces values matching frozen params in `pytree` with `replacement`."""
if not isinstance(schedule, (list, tuple)):
return pytree
masks, scheds = _make_mask_trees(pytree, schedule, log=log)
frozen_mask, _, _ = _split_frozen(masks, scheds)
return jax.tree.map(
lambda v, f: replacement if f else v, pytree, frozen_mask)
def clip_by_per_example_global_norm(
max_norm: float,
) -> optax.GradientTransformation:
"""Clips the norm of per-example gradients."""
def init_fn(params):
del params
return optax.EmptyState()
def update_fn(updates, state, params=None):
del params
grads_flat, grads_treedef = jax.tree_util.tree_flatten(updates)
clipped, _ = optax.per_example_global_norm_clip(grads_flat, max_norm)
return jax.tree_util.tree_unflatten(grads_treedef, clipped), state
return optax.GradientTransformation(init_fn, update_fn)
def make(config, params, *, sched_kw):
"""Returns gradient transform and learning rate functions."""
# Global schedule. No schedule means frozen.
schedule = config.get("schedule", {})
if not isinstance(schedule, (tuple, list)):
schedule = [(".*", schedule)]
masks, scheds = _make_mask_trees(params, schedule, "config.schedule")
frozen_mask, masks, scheds = _split_frozen(masks, scheds)
not_frozen_mask = jax.tree.map(operator.not_, frozen_mask)
def create_schedule(mult=1.0, **kw):
assert "base" not in kw, kw
return u.create_learning_rate_schedule(base=mult, **kw)
schedule_fns = [create_schedule(**sched_kw, **sched) for sched in scheds]
schedule_txs = [
optax.masked(optax.scale_by_schedule(schedule_fn), mask)
for schedule_fn, mask in zip(schedule_fns, masks)
] + [
# Removes weight decay updates. Note that weight decay already has an
# independent mask (which cannot be combined easily with a second mask),
# so instead we multiply updates for frozen params with zero.
optax.masked(optax.set_to_zero(), frozen_mask)
]
# Gradient clipping.
if clip_norm := config.get("grad_clip_norm"):
if config.get("grad_clip_per_example"):
clip_tx = clip_by_per_example_global_norm(clip_norm)
else:
clip_tx = optax.clip_by_global_norm(clip_norm)
grad_clip_norm_tx = optax.masked(clip_tx, not_frozen_mask)
else:
grad_clip_norm_tx = optax.identity()
# Optimizer updates.
tx_func = operator.attrgetter(config.optax_name)(optax)
opt_txs = [optax.masked(tx_func(**config.get("optax", {})), not_frozen_mask)]
assert "optim" not in config, "Deprecated option, use config.optax."
# Learning rate multipliers. Defaults to 1.0.
lr_mult_txs = [optax.scale(config.lr)]
if config.get("lr_mults"):
masks, mults = _make_mask_trees(params, config.lr_mults, "config.lr_mults")
assert all(mult > 0 for mult in mults), (
f"Use schedule=None for parameter freezing instead of lr_mults={mults}")
lr_mult_txs += [
optax.masked(optax.scale(mult), mask)
for mult, mask in zip(mults, masks)
]
# Weight decay. Defaults to 0.0.
# Weight decay is not gradient-based but instead uses "params side-input".
# Hence, weight decay is additive and independent of previous gradient-based
# updates.
assert "weight_decay" not in config, "Deprecated option. Use wd and schedule."
assert config.get("weight_decay_decouple", True), (
"Coupled weight decay not supported anymore.")
if config.get("wd"):
wd_mults = config.get("wd_mults", [(".*/kernel$", 1.0)])
masks, mults = _make_mask_trees(params, wd_mults, "config.wd_mults")
weight_decay_txs = [
optax.add_decayed_weights(config.wd * mult, mask)
for mult, mask in zip(mults, masks)
]
else:
weight_decay_txs = []
# Combine gradient updates and learning rate schedules.
return optax.chain(
grad_clip_norm_tx,
*opt_txs,
*lr_mult_txs,
*weight_decay_txs,
*schedule_txs,
optax.scale(-1.0)), schedule_fns
def _make_mask_trees(params, patterns_values, log):
patterns, values = zip(*patterns_values)
masks = u.make_mask_trees(params, patterns, log=log)
return masks, values
def _split_frozen(masks, scheds):
"""Computes `frozen_mask` and updates `masks` and `scheds`."""
# Specifying `None` as a scheduler freezes params.
all_false = jax.tree.map(lambda *bools: not any(bools), *masks)
not_covered = [k for k, v in u.tree_flatten_with_names(all_false)[0] if v]
assert not not_covered, (
f"All params must be covered (use `None` for freezing): {not_covered}")
frozen_masks = [
mask for mask, sched in zip(masks, scheds) if sched is None]
frozen_mask = jax.tree.map(
lambda *bools: any(bools), *frozen_masks,
all_false) # `all_false` is required when `frozen_masks==[]`.
masks, scheds = zip(*(
(mask, sched) for mask, sched in zip(masks, scheds) if sched is not None))
return frozen_mask, masks, scheds
############ Custom BigVision optimizers #######################################
# Currently there's only one custom optimizer and we don't foresee new ones in
# the near future, we opt not to create a new optimizer folder/module for just
# one isolated case. If there will be more optimizers, we can consider moving
# them into individual files in a subfolder.
# A dummy object to allow for foo.bar access syntax, see
# https://stackoverflow.com/a/19476841/2366315
optax.big_vision = type("", (), {})()
def scale_by_adafactor(min_dim_size_to_factor=32,
decay_rate=0.8, decay_offset=0,
beta2_cap=0.999,
clipping_threshold=None,
momentum=0.9, dtype_momentum=jnp.bfloat16,
eps=1e-30):
"""The BigVision variant of Adafactor optimizer."""
def _decay_rate_pow(i, exponent):
"""Second-order moment decay schedule."""
t = jnp.array(i, jnp.float32) + 1.0
return jnp.minimum(beta2_cap, 1.0 - t**(-exponent))
scale_by_rms = optax.scale_by_factored_rms(
factored=True,
decay_rate=decay_rate,
step_offset=decay_offset,
min_dim_size_to_factor=min_dim_size_to_factor,
epsilon=eps,
decay_rate_fn=_decay_rate_pow)
clip = (optax.clip_by_block_rms(clipping_threshold) if clipping_threshold
else optax.identity())
mom = (optax.ema(momentum, debias=False, accumulator_dtype=dtype_momentum)
if momentum else optax.identity())
return optax.chain(scale_by_rms, clip, mom)
optax.big_vision.scale_by_adafactor = scale_by_adafactor # pytype: disable=module-attr
# A few more aliases we use frequently:
def momentum_hp(momentum=0.9, dtype=jnp.bfloat16, nesterov=False):
"""SGD-Momentum with half-precision accumulator."""
return optax.trace(decay=momentum, accumulator_dtype=dtype, nesterov=nesterov)
optax.big_vision.momentum_hp = momentum_hp # pytype: disable=module-attr
optax.big_vision.sgd = optax.identity # pytype: disable=module-attr