deanna-emery's picture
updates
93528c6
raw
history blame
4 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flags and common definitions for Ranking Models."""
from absl import flags
import tensorflow as tf, tf_keras
from official.common import flags as tfm_flags
FLAGS = flags.FLAGS
def define_flags() -> None:
"""Defines flags for training the Ranking model."""
tfm_flags.define_flags()
FLAGS.set_default(name='experiment', value='dlrm_criteo')
FLAGS.set_default(name='mode', value='train_and_eval')
flags.DEFINE_integer(
name='seed',
default=None,
help='This value will be used to seed both NumPy and TensorFlow.')
flags.DEFINE_string(
name='profile_steps',
default='20,40',
help='Save profiling data to model dir at given range of global steps. '
'The value must be a comma separated pair of positive integers, '
'specifying the first and last step to profile. For example, '
'"--profile_steps=2,4" triggers the profiler to process 3 steps, starting'
' from the 2nd step. Note that profiler has a non-trivial performance '
'overhead, and the output file can be gigantic if profiling many steps.')
@tf_keras.utils.register_keras_serializable(package='RANKING')
class WarmUpAndPolyDecay(tf_keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate callable for the embeddings.
Linear warmup on [0, warmup_steps] then
Constant on [warmup_steps, decay_start_steps]
And polynomial decay on [decay_start_steps, decay_start_steps + decay_steps].
"""
def __init__(self,
batch_size: int,
decay_exp: float = 2.0,
learning_rate: float = 40.0,
warmup_steps: int = 8000,
decay_steps: int = 12000,
decay_start_steps: int = 10000):
super(WarmUpAndPolyDecay, self).__init__()
self.batch_size = batch_size
self.decay_exp = decay_exp
self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
self.decay_steps = decay_steps
self.decay_start_steps = decay_start_steps
def __call__(self, step):
decay_exp = self.decay_exp
learning_rate = self.learning_rate
warmup_steps = self.warmup_steps
decay_steps = self.decay_steps
decay_start_steps = self.decay_start_steps
scal = self.batch_size / 2048
adj_lr = learning_rate * scal
if warmup_steps == 0:
return adj_lr
warmup_lr = step / warmup_steps * adj_lr
global_step = tf.cast(step, tf.float32)
decay_steps = tf.cast(decay_steps, tf.float32)
decay_start_step = tf.cast(decay_start_steps, tf.float32)
warmup_lr = tf.cast(warmup_lr, tf.float32)
steps_since_decay_start = global_step - decay_start_step
already_decayed_steps = tf.minimum(steps_since_decay_start, decay_steps)
decay_lr = adj_lr * (
(decay_steps - already_decayed_steps) / decay_steps)**decay_exp
decay_lr = tf.maximum(0.0001, decay_lr)
lr = tf.where(
global_step < warmup_steps, warmup_lr,
tf.where(
tf.logical_and(decay_steps > 0, global_step > decay_start_step),
decay_lr, adj_lr))
lr = tf.maximum(0.01, lr)
return lr
def get_config(self):
return {
'batch_size': self.batch_size,
'decay_exp': self.decay_exp,
'learning_rate': self.learning_rate,
'warmup_steps': self.warmup_steps,
'decay_steps': self.decay_steps,
'decay_start_steps': self.decay_start_steps
}