Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Layer-wise Adaptive Moments (LAMB) optimizer. | |
See paper [Large Batch Optimization for Deep Learning: Training BERT in | |
76 minutes](https://arxiv.org/abs/1904.00962). | |
""" | |
import re | |
from typing import Optional, Union, Callable, List | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32] | |
class LAMB(tf_keras.optimizers.legacy.Optimizer): | |
"""Optimizer that implements the Layer-wise Adaptive Moments (LAMB). | |
See paper [Large Batch Optimization for Deep Learning: Training BERT | |
in 76 minutes](https://arxiv.org/abs/1904.00962). | |
""" | |
def __init__( | |
self, | |
learning_rate: Union[FloatTensorLike, Callable] = 0.001, | |
beta_1: FloatTensorLike = 0.9, | |
beta_2: FloatTensorLike = 0.999, | |
epsilon: FloatTensorLike = 1e-6, | |
weight_decay_rate: FloatTensorLike = 0.0, | |
exclude_from_weight_decay: Optional[List[str]] = None, | |
exclude_from_layer_adaptation: Optional[List[str]] = None, | |
name: str = "LAMB", | |
**kwargs, | |
): | |
"""Construct a new LAMB optimizer. | |
Args: | |
learning_rate: A `Tensor` or a floating point value. or a schedule that | |
is a `tf_keras.optimizers.schedules.LearningRateSchedule` The learning | |
rate. | |
beta_1: A `float` value or a constant `float` tensor. The exponential | |
decay rate for the 1st moment estimates. | |
beta_2: A `float` value or a constant `float` tensor. The exponential | |
decay rate for the 2nd moment estimates. | |
epsilon: A small constant for numerical stability. | |
weight_decay_rate: weight decay rate. | |
exclude_from_weight_decay: List of regex patterns of variables excluded | |
from weight decay. Variables whose name contain a substring matching | |
the pattern will be excluded. | |
exclude_from_layer_adaptation: List of regex patterns of variables | |
excluded from layer adaptation. Variables whose name contain a | |
substring matching the pattern will be excluded. | |
name: Optional name for the operations created when applying gradients. | |
Defaults to "LAMB". | |
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, | |
`lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is | |
clip gradients by value, `decay` is included for backward | |
compatibility to allow time inverse decay of learning rate. `lr` is | |
included for backward compatibility, recommended to use | |
`learning_rate` instead. | |
""" | |
super().__init__(name, **kwargs) | |
# Just adding the square of the weights to the loss function is *not* | |
# the correct way of using L2 regularization/weight decay with Adam, | |
# since that will interact with the m and v parameters in strange ways. | |
# | |
# Instead we want to decay the weights in a manner that doesn't interact | |
# with the m/v parameters. | |
self._set_hyper("weight_decay_rate", weight_decay_rate) | |
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) | |
# This is learning rate decay for using keras learning rate schedule. | |
self._set_hyper("decay", self._initial_decay) | |
self._set_hyper("beta_1", beta_1) | |
self._set_hyper("beta_2", beta_2) | |
self.epsilon = epsilon or tf.backend_config.epsilon() | |
self.exclude_from_weight_decay = exclude_from_weight_decay | |
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if | |
# the arg is None. | |
if exclude_from_layer_adaptation: | |
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation | |
else: | |
self.exclude_from_layer_adaptation = exclude_from_weight_decay | |
def _create_slots(self, var_list): | |
# Create slots for the first and second moments. | |
# Separate for-loops to respect the ordering of slot variables from v1. | |
for var in var_list: | |
self.add_slot(var, "m") | |
for var in var_list: | |
self.add_slot(var, "v") | |
def _prepare_local(self, var_device, var_dtype, apply_state): | |
super()._prepare_local(var_device, var_dtype, apply_state) | |
local_step = tf.cast(self.iterations + 1, var_dtype) | |
beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype)) | |
beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype)) | |
weight_decay_rate = tf.identity( | |
self._get_hyper("weight_decay_rate", var_dtype) | |
) | |
beta_1_power = tf.pow(beta_1_t, local_step) | |
beta_2_power = tf.pow(beta_2_t, local_step) | |
apply_state[(var_device, var_dtype)].update( | |
dict( | |
weight_decay_rate=weight_decay_rate, | |
epsilon=tf.convert_to_tensor(self.epsilon, var_dtype), | |
beta_1_t=beta_1_t, | |
beta_1_power=beta_1_power, | |
one_minus_beta_1_t=1 - beta_1_t, | |
beta_2_t=beta_2_t, | |
beta_2_power=beta_2_power, | |
one_minus_beta_2_t=1 - beta_2_t, | |
) | |
) | |
def _resource_apply_dense(self, grad, var, apply_state=None): | |
var_device, var_dtype = var.device, var.dtype.base_dtype | |
coefficients = (apply_state or {}).get( | |
(var_device, var_dtype) | |
) or self._fallback_apply_state(var_device, var_dtype) | |
# m_t = beta1 * m + (1 - beta1) * g_t | |
m = self.get_slot(var, "m") | |
m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] | |
m_t = m * coefficients["beta_1_t"] + m_scaled_g_values | |
m_t = m.assign(m_t, use_locking=self._use_locking) | |
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t) | |
v = self.get_slot(var, "v") | |
v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"] | |
v_t = v * coefficients["beta_2_t"] + v_scaled_g_values | |
v_t = v.assign(v_t, use_locking=self._use_locking) | |
m_t_hat = m_t / (1.0 - coefficients["beta_1_power"]) | |
v_t_hat = v_t / (1.0 - coefficients["beta_2_power"]) | |
v_sqrt = tf.sqrt(v_t_hat) | |
update = m_t_hat / (v_sqrt + coefficients["epsilon"]) | |
var_name = self._get_variable_name(var.name) | |
if self._do_use_weight_decay(var_name): | |
update += coefficients["weight_decay_rate"] * var | |
ratio = 1.0 | |
if self._do_layer_adaptation(var_name): | |
w_norm = tf.norm(var, ord=2) | |
g_norm = tf.norm(update, ord=2) | |
ratio = tf.where( | |
tf.greater(w_norm, 0), | |
tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), | |
1.0, | |
) | |
var_update = var - ratio * coefficients["lr_t"] * update | |
return var.assign(var_update, use_locking=self._use_locking) | |
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): | |
var_device, var_dtype = var.device, var.dtype.base_dtype | |
coefficients = (apply_state or {}).get( | |
(var_device, var_dtype) | |
) or self._fallback_apply_state(var_device, var_dtype) | |
# m_t = beta1 * m + (1 - beta1) * g_t | |
m = self.get_slot(var, "m") | |
m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] | |
m_t = m.assign(m * coefficients["beta_1_t"], use_locking=self._use_locking) | |
with tf.control_dependencies([m_t]): | |
m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) | |
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t) | |
v = self.get_slot(var, "v") | |
v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"] | |
v_t = v.assign(v * coefficients["beta_2_t"], use_locking=self._use_locking) | |
with tf.control_dependencies([v_t]): | |
v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) | |
m_t_hat = m_t / (1.0 - coefficients["beta_1_power"]) | |
v_t_hat = v_t / (1.0 - coefficients["beta_2_power"]) | |
v_sqrt = tf.sqrt(v_t_hat) | |
update = m_t_hat / (v_sqrt + coefficients["epsilon"]) | |
var_name = self._get_variable_name(var.name) | |
if self._do_use_weight_decay(var_name): | |
update += coefficients["weight_decay_rate"] * var | |
ratio = 1.0 | |
if self._do_layer_adaptation(var_name): | |
w_norm = tf.norm(var, ord=2) | |
g_norm = tf.norm(update, ord=2) | |
ratio = tf.where( | |
tf.greater(w_norm, 0), | |
tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), | |
1.0, | |
) | |
var_update = var.assign_sub( | |
ratio * coefficients["lr_t"] * update, use_locking=self._use_locking | |
) | |
return tf.group(*[var_update, m_t, v_t]) | |
def get_config(self): | |
config = super().get_config() | |
config.update({ | |
"learning_rate": self._serialize_hyperparameter("learning_rate"), | |
"weight_decay_rate": self._serialize_hyperparameter( | |
"weight_decay_rate" | |
), | |
"decay": self._serialize_hyperparameter("decay"), | |
"beta_1": self._serialize_hyperparameter("beta_1"), | |
"beta_2": self._serialize_hyperparameter("beta_2"), | |
"epsilon": self.epsilon, | |
}) | |
return config | |
def _do_use_weight_decay(self, param_name): | |
"""Whether to use L2 weight decay for `param_name`.""" | |
if self.exclude_from_weight_decay: | |
for r in self.exclude_from_weight_decay: | |
if re.search(r, param_name) is not None: | |
return False | |
return True | |
def _do_layer_adaptation(self, param_name): | |
"""Whether to do layer-wise learning rate adaptation for `param_name`.""" | |
if self.exclude_from_layer_adaptation: | |
for r in self.exclude_from_layer_adaptation: | |
if re.search(r, param_name) is not None: | |
return False | |
return True | |
def _get_variable_name(self, param_name): | |
"""Get the variable name from the tensor name.""" | |
m = re.match("^(.*):\\d+$", param_name) | |
if m is not None: | |
param_name = m.group(1) | |
return param_name | |