|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""SM3 Implementation.""" |
|
|
|
import functools |
|
from typing import Any, NamedTuple |
|
|
|
import chex |
|
import jax |
|
import jax.numpy as jnp |
|
import optax |
|
|
|
from .quantization_utils import QuantizedValue |
|
|
|
|
|
class SM3State(NamedTuple): |
|
count: chex.Array |
|
stats: Any |
|
|
|
|
|
|
|
class ParameterStats(NamedTuple): |
|
"""State associated to each parameter of the model being trained.""" |
|
|
|
diagonal_statistics: chex.Array |
|
diagonal_momentum: QuantizedValue |
|
|
|
|
|
def sm3( |
|
learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False |
|
): |
|
"""SM3 optimizer. |
|
|
|
Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren, |
|
Yoram Singer |
|
|
|
https://arxiv.org/abs/1901.11150 |
|
|
|
Args: |
|
learning_rate: the step size used to update the parameters. |
|
beta1: momentum parameter. |
|
beta2: second moment averaging parameter. |
|
diagonal_epsilon: epsilon for sm3 |
|
normalize_grads: Whether to normalize grads. Author finds it useful when |
|
grads are high variance. |
|
|
|
Returns: |
|
a GradientTransformation. |
|
""" |
|
|
|
def _quantize_momentum(momentum_statistics): |
|
return QuantizedValue.from_float_value(momentum_statistics, jnp.int8) |
|
|
|
def init_fn(params): |
|
"""Initialise the optimiser's state.""" |
|
|
|
def _init(param): |
|
accumulators = [jnp.zeros([s]) for s in param.shape] |
|
momentum = _quantize_momentum(jnp.zeros_like(param)) |
|
return ParameterStats(accumulators, momentum) |
|
|
|
return SM3State( |
|
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params) |
|
) |
|
|
|
def _get_expanded_shape(shape, i): |
|
rank = len(shape) |
|
|
|
|
|
return [1] * i + [shape[i]] + [1] * (rank - i - 1) |
|
|
|
def _moving_averages(grad, accumulators): |
|
w = (1.0 - beta2) if beta2 != 1.0 else 1.0 |
|
if grad.ndim < 2: |
|
return beta2 * accumulators[0] + w * grad**2 |
|
else: |
|
min_accumulator = functools.reduce(jnp.minimum, accumulators) |
|
return beta2 * min_accumulator + w * grad**2 |
|
|
|
def _moving_averages_momentum(grad, momentum): |
|
w = (1.0 - beta1) if beta1 != 1.0 else 1.0 |
|
return beta1 * momentum.to_float() + w * grad |
|
|
|
def _sketch_diagonal_statistics(grad, updated_diagonal_statistics): |
|
all_diagonal_statistics = [] |
|
for i in range(grad.ndim): |
|
axes = list(range(i)) + list(range(i + 1, grad.ndim)) |
|
dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes) |
|
all_diagonal_statistics.append(dim_diagonal_statistics) |
|
if grad.ndim == 1: |
|
all_diagonal_statistics[0] = updated_diagonal_statistics |
|
return all_diagonal_statistics |
|
|
|
def update_fn(updates, state, params=None): |
|
del params |
|
stats = state.stats |
|
if normalize_grads: |
|
updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates) |
|
|
|
|
|
expanded_diagonal_statistics = jax.tree_multimap( |
|
lambda grad, state: [ |
|
jnp.reshape( |
|
state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i) |
|
) |
|
for i in range(grad.ndim) |
|
], |
|
updates, |
|
stats, |
|
) |
|
|
|
|
|
new_diagonal_statistics = jax.tree_multimap( |
|
_moving_averages, updates, expanded_diagonal_statistics |
|
) |
|
|
|
|
|
new_preconditioners = jax.tree_map( |
|
lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics |
|
) |
|
preconditioned_grads = jax.tree_multimap( |
|
lambda g, p: g * p, updates, new_preconditioners |
|
) |
|
|
|
|
|
updated_momentum = jax.tree_multimap( |
|
lambda preconditioned_grad, state: _moving_averages_momentum( |
|
preconditioned_grad, state.diagonal_momentum |
|
), |
|
preconditioned_grads, |
|
stats, |
|
) |
|
|
|
|
|
updated_diagonal_statistics = jax.tree_multimap( |
|
_sketch_diagonal_statistics, updates, new_diagonal_statistics |
|
) |
|
|
|
|
|
new_sm3_stats = jax.tree_multimap( |
|
lambda momentum, diagonal_stats: ParameterStats( |
|
diagonal_stats, _quantize_momentum(momentum) |
|
), |
|
updated_momentum, |
|
updated_diagonal_statistics, |
|
) |
|
|
|
lr = learning_rate |
|
if callable(learning_rate): |
|
lr = learning_rate(state.count) |
|
|
|
new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum) |
|
return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats) |
|
|
|
return optax.GradientTransformation(init_fn, update_fn) |
|
|