# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for LAMB Optimizer.""" import numpy as np from numpy import linalg import tensorflow as tf, tf_keras from official.modeling.optimization import lamb def lamb_update_numpy(param, g_t, t, m, v, lr=0.001, lamb_wd=0.0, beta1=0.9, beta2=0.999, epsilon=1e-6): m_t = beta1 * m + (1 - beta1) * g_t v_t = beta2 * v + (1 - beta2) * g_t * g_t m_t_hat = m_t / (1 - beta1**(t + 1)) v_t_hat = v_t / (1 - beta2**(t + 1)) update = m_t_hat / (np.sqrt(v_t_hat) + epsilon) update += lamb_wd * param w_norm = linalg.norm(param, ord=2) g_norm = linalg.norm(update, ord=2) ratio = np.where(w_norm > 0, np.where(g_norm > 0, (w_norm / g_norm), 1.0), 1.0) param_t = param - ratio * lr * update return param_t, m_t, v_t def get_beta_accumulators(opt, dtype): local_step = tf.cast(opt.iterations + 1, dtype) beta_1_t = tf.cast(opt._get_hyper("beta_1"), dtype) beta_1_power = tf.math.pow(beta_1_t, local_step) beta_2_t = tf.cast(opt._get_hyper("beta_2"), dtype) beta_2_power = tf.math.pow(beta_2_t, local_step) return (beta_1_power, beta_2_power) class LAMBTest(tf.test.TestCase): def test_sparse(self): dtype = tf.float32 # Initialize tf for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) grads0_np = np.array([0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype) var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype) var0 = tf.Variable(var0_np) var1 = tf.Variable(var1_np) grads0_np_indices = np.array([0, 2], dtype=np.int32) grads0 = tf.IndexedSlices( tf.constant(grads0_np[grads0_np_indices]), tf.constant(grads0_np_indices), tf.constant([3]), ) grads1_np_indices = np.array([0, 2], dtype=np.int32) grads1 = tf.IndexedSlices( tf.constant(grads1_np[grads1_np_indices]), tf.constant(grads1_np_indices), tf.constant([3]), ) opt = lamb.LAMB() # Fetch params to validate initial values np.testing.assert_allclose(np.asanyarray([1.0, 1.0, 2.0]), var0.numpy()) np.testing.assert_allclose(np.asanyarray([3.0, 3.0, 4.0]), var1.numpy()) # Run 3 steps of LAMB for t in range(3): beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype) self.assertAllClose(0.9 ** (t + 1), beta_1_power) self.assertAllClose(0.999 ** (t + 1), beta_2_power) opt.apply_gradients(zip([grads0, grads1], [var0, var1])) var0_np, m0, v0 = lamb_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = lamb_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params self.assertAllClose(var0_np, var0.numpy()) self.assertAllClose(var1_np, var1.numpy()) def test_basic_with_learning_rate_decay(self): dtype = tf.float32 # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) var0 = tf.Variable(var0_np, name="var0") var1 = tf.Variable(var1_np, name="var1") grads0 = tf.constant(grads0_np) grads1 = tf.constant(grads1_np) learning_rate = 0.001 beta_1 = 0.9 beta_2 = 0.999 epsilon = 1e-7 decay = 0.5 lamb_wd = 0.01 opt = lamb.LAMB( learning_rate=learning_rate, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, weight_decay_rate=lamb_wd, decay=decay, ) # Run 3 steps of LAMB for t in range(3): opt.apply_gradients(zip([grads0, grads1], [var0, var1])) lr_np = learning_rate / (1 + decay * t) var0_np, m0, v0 = lamb_update_numpy( var0_np, grads0_np, t, m0, v0, lr=lr_np, lamb_wd=lamb_wd) var1_np, m1, v1 = lamb_update_numpy( var1_np, grads1_np, t, m1, v1, lr=lr_np, lamb_wd=lamb_wd) # Validate updated params self.assertAllClose(var0_np, var0.numpy()) self.assertAllClose(var1_np, var1.numpy()) def test_exclude_weight_decay(self): opt = lamb.LAMB( 0.01, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"] ) assert opt._do_use_weight_decay("var0") assert not opt._do_use_weight_decay("var1") assert not opt._do_use_weight_decay("var1_weight") def test_exclude_layer_adaptation(self): opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"]) assert opt._do_layer_adaptation("var0") assert not opt._do_layer_adaptation("var1") assert not opt._do_layer_adaptation("var1_weight") def test_serialization(self): optimizer = lamb.LAMB(1e-4) config = tf_keras.optimizers.serialize(optimizer, use_legacy_format=True) new_optimizer = tf_keras.optimizers.deserialize( config, use_legacy_format=True ) assert new_optimizer.get_config() == optimizer.get_config() if __name__ == "__main__": tf.test.main()