Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Flags and common definitions for Ranking Models.""" | |
from absl import flags | |
import tensorflow as tf, tf_keras | |
from official.common import flags as tfm_flags | |
FLAGS = flags.FLAGS | |
def define_flags() -> None: | |
"""Defines flags for training the Ranking model.""" | |
tfm_flags.define_flags() | |
FLAGS.set_default(name='experiment', value='dlrm_criteo') | |
FLAGS.set_default(name='mode', value='train_and_eval') | |
flags.DEFINE_integer( | |
name='seed', | |
default=None, | |
help='This value will be used to seed both NumPy and TensorFlow.') | |
flags.DEFINE_string( | |
name='profile_steps', | |
default='20,40', | |
help='Save profiling data to model dir at given range of global steps. ' | |
'The value must be a comma separated pair of positive integers, ' | |
'specifying the first and last step to profile. For example, ' | |
'"--profile_steps=2,4" triggers the profiler to process 3 steps, starting' | |
' from the 2nd step. Note that profiler has a non-trivial performance ' | |
'overhead, and the output file can be gigantic if profiling many steps.') | |
class WarmUpAndPolyDecay(tf_keras.optimizers.schedules.LearningRateSchedule): | |
"""Learning rate callable for the embeddings. | |
Linear warmup on [0, warmup_steps] then | |
Constant on [warmup_steps, decay_start_steps] | |
And polynomial decay on [decay_start_steps, decay_start_steps + decay_steps]. | |
""" | |
def __init__(self, | |
batch_size: int, | |
decay_exp: float = 2.0, | |
learning_rate: float = 40.0, | |
warmup_steps: int = 8000, | |
decay_steps: int = 12000, | |
decay_start_steps: int = 10000): | |
super(WarmUpAndPolyDecay, self).__init__() | |
self.batch_size = batch_size | |
self.decay_exp = decay_exp | |
self.learning_rate = learning_rate | |
self.warmup_steps = warmup_steps | |
self.decay_steps = decay_steps | |
self.decay_start_steps = decay_start_steps | |
def __call__(self, step): | |
decay_exp = self.decay_exp | |
learning_rate = self.learning_rate | |
warmup_steps = self.warmup_steps | |
decay_steps = self.decay_steps | |
decay_start_steps = self.decay_start_steps | |
scal = self.batch_size / 2048 | |
adj_lr = learning_rate * scal | |
if warmup_steps == 0: | |
return adj_lr | |
warmup_lr = step / warmup_steps * adj_lr | |
global_step = tf.cast(step, tf.float32) | |
decay_steps = tf.cast(decay_steps, tf.float32) | |
decay_start_step = tf.cast(decay_start_steps, tf.float32) | |
warmup_lr = tf.cast(warmup_lr, tf.float32) | |
steps_since_decay_start = global_step - decay_start_step | |
already_decayed_steps = tf.minimum(steps_since_decay_start, decay_steps) | |
decay_lr = adj_lr * ( | |
(decay_steps - already_decayed_steps) / decay_steps)**decay_exp | |
decay_lr = tf.maximum(0.0001, decay_lr) | |
lr = tf.where( | |
global_step < warmup_steps, warmup_lr, | |
tf.where( | |
tf.logical_and(decay_steps > 0, global_step > decay_start_step), | |
decay_lr, adj_lr)) | |
lr = tf.maximum(0.01, lr) | |
return lr | |
def get_config(self): | |
return { | |
'batch_size': self.batch_size, | |
'decay_exp': self.decay_exp, | |
'learning_rate': self.learning_rate, | |
'warmup_steps': self.warmup_steps, | |
'decay_steps': self.decay_steps, | |
'decay_start_steps': self.decay_start_steps | |
} | |