# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Learning rate schedule classes.""" import math from typing import Mapping, Any, Union, Optional import tensorflow as tf, tf_keras def _make_offset_wrapper(new_class_name: str, base_lr_class): """Generates a offset wrapper of learning rate schedule. It will returns a subclass of the `base_lr_class`, the subclass takes an `offset` argument in the constructor. When the new class instance is called, the behavior is: new_class_object(step) = base_lr_class_object(step - offset) Example: CosineDecayWithOffset = _make_offset_wrapper( 'CosineDecayWithOffset', tf_keras.optimizers.schedules.CosineDecay) # Use the lr: lr = CosineDecayWithOffset(offset=100, initial_learning_rate=0.1, decay_steps=1000) lr(101) # equals to keras.optimizers.schedules.CosineDecay(...)(101-100) Args: new_class_name: the name of the new class. base_lr_class: the base learning rate schedule class. Should be subclass of tf_keras.optimizers.schedules.LearningRateSchedule Returns: A new class (subclass of the base_lr_class) that can take an offset. """ assert issubclass(base_lr_class, tf_keras.optimizers.schedules.LearningRateSchedule), ( "base_lr_class should be subclass of keras " f"LearningRateSchedule, got {base_lr_class}") # pylint: disable=protected-access,pointless-statement def offset_learning_rate_init(self, offset=0, **kwargs): """Construct learning rate schedule object. When this object is called, its behavior is self.__call__(step) == base_lr_class.__call__(step - offset) Args: self: this object. offset: The offset when computing the learning rate schedule. **kwargs: Pass through to base learning rate class constructor. """ base_lr_class.__init__(self, **kwargs) self._offset = offset def offset_learning_rate_call(self, step): step = tf.cast(step - self._offset, tf.float32) return base_lr_class.__call__(self, step) # pylint: enable=protected-access,pointless-statement return type( new_class_name, (base_lr_class,), { "base_lr_class": base_lr_class, "__init__": offset_learning_rate_init, "__call__": offset_learning_rate_call }) PiecewiseConstantDecayWithOffset = _make_offset_wrapper( "PiecewiseConstantDecayWithOffset", tf_keras.optimizers.schedules.PiecewiseConstantDecay) PolynomialDecayWithOffset = _make_offset_wrapper( "PolynomialDecayWithOffset", tf_keras.optimizers.schedules.PolynomialDecay) ExponentialDecayWithOffset = _make_offset_wrapper( "ExponentialDecayWithOffset", tf_keras.optimizers.schedules.ExponentialDecay) CosineDecayWithOffset = _make_offset_wrapper( "CosineDecayWithOffset", tf_keras.optimizers.schedules.CosineDecay, ) class LinearWarmup(tf_keras.optimizers.schedules.LearningRateSchedule): """Linear warmup schedule.""" def __init__(self, after_warmup_lr_sched: Union[ tf_keras.optimizers.schedules.LearningRateSchedule, float], warmup_steps: int, warmup_learning_rate: float, name: Optional[str] = None): """Add linear warmup schedule to a learning rate schedule. warmup_lr is the initial learning rate, the final learning rate of the init_warmup period is the initial learning rate of lr_schedule in use. The learning rate at each step linearly increased according to the following formula: learning_rate = warmup_lr + step / warmup_steps * (final_warmup_lr - warmup_lr). Using warmup overrides the learning rate schedule by the number of warmup steps. Args: after_warmup_lr_sched: tf_keras.optimizers.schedules .LearningRateSchedule or a constant. warmup_steps: Number of the warmup steps. warmup_learning_rate: Initial learning rate for the warmup. name: Optional, name of warmup schedule. """ super().__init__() self._name = name self._after_warmup_lr_sched = after_warmup_lr_sched self._warmup_steps = warmup_steps self._init_warmup_lr = warmup_learning_rate if isinstance(after_warmup_lr_sched, tf_keras.optimizers.schedules.LearningRateSchedule): self._final_warmup_lr = after_warmup_lr_sched(warmup_steps) else: self._final_warmup_lr = tf.cast(after_warmup_lr_sched, dtype=tf.float32) def __call__(self, step: int): global_step = tf.cast(step, dtype=tf.float32) linear_warmup_lr = ( self._init_warmup_lr + global_step / self._warmup_steps * (self._final_warmup_lr - self._init_warmup_lr)) if isinstance(self._after_warmup_lr_sched, tf_keras.optimizers.schedules.LearningRateSchedule): after_warmup_lr = self._after_warmup_lr_sched(step) else: after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32) lr = tf.cond(global_step < self._warmup_steps, lambda: linear_warmup_lr, lambda: after_warmup_lr) return lr def get_config(self) -> Mapping[str, Any]: if isinstance(self._after_warmup_lr_sched, tf_keras.optimizers.schedules.LearningRateSchedule): config = { "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} # pytype: disable=attribute-error else: config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error config.update({ "warmup_steps": self._warmup_steps, "warmup_learning_rate": self._init_warmup_lr, "name": self._name }) return config class PolynomialWarmUp(tf_keras.optimizers.schedules.LearningRateSchedule): """Applies polynomial warmup schedule on a given learning rate decay schedule.""" def __init__(self, after_warmup_lr_sched: Union[ tf_keras.optimizers.schedules.LearningRateSchedule, float], warmup_steps: int, power: float = 1.0, name: str = "PolynomialWarmup"): super().__init__() if isinstance(after_warmup_lr_sched, tf_keras.optimizers.schedules.LearningRateSchedule): self._initial_learning_rate = after_warmup_lr_sched(warmup_steps) else: self._initial_learning_rate = tf.cast( after_warmup_lr_sched, dtype=tf.float32) self._warmup_steps = warmup_steps self._power = power self._after_warmup_lr_sched = after_warmup_lr_sched self._name = name def __call__(self, step): with tf.name_scope(self._name or "PolynomialWarmUp") as name: # Implements polynomial warmup. i.e., if global_step < warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. global_step_float = tf.cast(step, tf.float32) warmup_steps_float = tf.cast(self._warmup_steps, tf.float32) if self._warmup_steps <= 0: warmup_percent_done = 1.0 else: # A zero `step` may cause Inf. So make `step` positive. step_non_zero = tf.math.maximum(global_step_float, 1.0) warmup_percent_done = step_non_zero / warmup_steps_float warmup_learning_rate = ( self._initial_learning_rate * tf.math.pow(warmup_percent_done, self._power)) if isinstance(self._after_warmup_lr_sched, tf_keras.optimizers.schedules.LearningRateSchedule): after_warmup_lr = self._after_warmup_lr_sched(step) else: after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32) return tf.cond( global_step_float < warmup_steps_float, lambda: warmup_learning_rate, lambda: after_warmup_lr, name=name) def get_config(self) -> Mapping[str, Any]: if isinstance(self._after_warmup_lr_sched, tf_keras.optimizers.schedules.LearningRateSchedule): config = { "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} # pytype: disable=attribute-error else: config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error config.update({ "warmup_steps": self._warmup_steps, "power": self._power, "name": self._name }) return config class DirectPowerDecay(tf_keras.optimizers.schedules.LearningRateSchedule): """Learning rate schedule follows lr * (step)^power.""" def __init__(self, initial_learning_rate: float, power: float = 1.0, name: str = "DirectPowerDecay"): """Initialize configuration of the learning rate schedule. Args: initial_learning_rate: The initial learning rate. power: The order of the polynomial. name: Optional, name of learning rate schedule. """ super().__init__() self._initial_learning_rate = initial_learning_rate self._power = power self._name = name def __call__(self, step): with tf.name_scope(self._name or "DirectPowerDecay"): step = tf.cast(step, tf.float32) learning_rate = self._initial_learning_rate # A zero `step` may cause Inf. So make `step` positive. step_non_zero = tf.math.maximum(step, 1.0) learning_rate *= tf.math.pow(step_non_zero, self._power) return learning_rate def get_config(self): """Get the configuration of the learning rate schedule.""" return { "initial_learning_rate": self._initial_learning_rate, "power": self._power, "name": self._name, } class PowerAndLinearDecay(tf_keras.optimizers.schedules.LearningRateSchedule): """Learning rate schedule with multiplied by linear decay at the end. The schedule has the following behavoir. Let offset_step = step - offset. 1) offset_step < 0, the actual learning rate equals initial_learning_rate. 2) offset_step <= total_decay_steps * (1 - linear_decay_fraction), the actual learning rate equals lr * offset_step^power. 3) total_decay_steps * (1 - linear_decay_fraction) <= offset_step < total_decay_steps, the actual learning rate equals lr * offset_step^power * (total_decay_steps - offset_step) / (total_decay_steps * linear_decay_fraction). 4) offset_step >= total_decay_steps, the actual learning rate equals zero. """ def __init__(self, initial_learning_rate: float, total_decay_steps: int, power: float = 1.0, linear_decay_fraction: float = 0.1, offset: int = 0, name: str = "PowerAndLinearDecay"): """Initialize configuration of the learning rate schedule. Args: initial_learning_rate: The initial learning rate. total_decay_steps: The total number of steps for power + linear decay. power: The order of the polynomial. linear_decay_fraction: In the last `linear_decay_fraction` steps, the learning rate will be multiplied by a linear decay. offset: The offset applied to steps. name: Optional, name of learning rate schedule. """ super().__init__() self._initial_learning_rate = initial_learning_rate self._total_decay_steps = total_decay_steps self._power = power self._linear_decay_fraction = linear_decay_fraction self._offset = offset self._name = name def __call__(self, step): with tf.name_scope(self._name or "PowerAndLinearDecay"): step = tf.cast(step - self._offset, tf.float32) learning_rate = self._initial_learning_rate # A zero `step` may cause Inf. So make `step` positive. step_non_zero = tf.math.maximum(step, 1.0) learning_rate *= tf.math.pow(step_non_zero, self._power) if self._total_decay_steps * self._linear_decay_fraction > 0: learning_rate *= tf.minimum( 1.0, (self._total_decay_steps - step) / (self._total_decay_steps * self._linear_decay_fraction)) learning_rate = tf.maximum(0.0, learning_rate) return learning_rate def get_config(self): """Get the configuration of the learning rate schedule.""" return { "initial_learning_rate": self._initial_learning_rate, "total_decay_steps": self._total_decay_steps, "power": self._power, "linear_decay_fraction": self._linear_decay_fraction, "offset": self._offset, "name": self._name, } class PowerDecayWithOffset(tf_keras.optimizers.schedules.LearningRateSchedule): """Power learning rate decay with offset. Learning rate equals to `pre_offset_learning_rate` if `step` < `offset`. Otherwise, learning rate equals to lr * (step - offset)^power. """ def __init__(self, initial_learning_rate: float, power: float = 1.0, offset: int = 0, pre_offset_learning_rate: float = 1.0e6, name: str = "PowerDecayWithOffset"): """Initialize configuration of the learning rate schedule. Args: initial_learning_rate: The initial learning rate. power: The order of the polynomial. offset: The offset when computing the power decay. pre_offset_learning_rate: The maximum learning rate we'll use. name: Optional, name of learning rate schedule. """ super().__init__() self._initial_learning_rate = initial_learning_rate self._power = power self._offset = offset self._pre_offset_lr = pre_offset_learning_rate self._name = name def __call__(self, step): with tf.name_scope(self._name or "PowerDecayWithOffset"): step = tf.cast(step, tf.float32) lr_after_offset = tf.math.pow( tf.math.maximum(step - self._offset, 1.0), self._power) * ( self._initial_learning_rate) sign = tf.cast(step > self._offset, tf.float32) lr_combined = (1.0 - sign) * self._pre_offset_lr + sign * lr_after_offset # Power may give infinitely large LR. So cap it with pre_offset_lr. return tf.math.minimum(lr_combined, self._pre_offset_lr) def get_config(self): """Get the configuration of the learning rate schedule.""" return { "initial_learning_rate": self._initial_learning_rate, "power": self._power, "offset": self._offset, "pre_offset_learning_rate": self._pre_offset_lr, "name": self._name, } class StepCosineDecayWithOffset( tf_keras.optimizers.schedules.LearningRateSchedule): """Stepwise cosine learning rate decay with offset. Learning rate is equivalent to one or more cosine decay(s) starting and ending at each interval. ExampleL ```python boundaries: [100000, 110000] values: [1.0, 0.5] lr_decayed_fn = ( lr_schedule.StepCosineDecayWithOffset( boundaries, values)) ``` from 0 to 100000 step, it will cosine decay from 1.0 to 0.5 from 100000 to 110000 step, it cosine decay from 0.5 to 0.0 """ def __init__(self, boundaries, values, offset: int = 0, name: str = "StepCosineDecayWithOffset"): """Initialize configuration of the learning rate schedule. Args: boundaries: A list of `Tensor`s or `int`s with strictly increasing entries, and with all elements having the same type as the optimizer step. values: A list of `Tensor`s or `float`s that specifies the values for the intervals defined by `boundaries`. It should have one more element than `boundaries`, and all elements should have the same type. offset: The offset when computing the power decay. name: Optional, name of learning rate schedule. """ super().__init__() self.values = values self.boundaries = boundaries self.offset = offset self.name = name if len(self.values) < 1: raise ValueError(f"Expect non empty {self.values}") if len(self.boundaries) != len(self.values): raise ValueError( "Boundaries length is equal to learning rate levels length" f"{len(self.boundaries)} != {len(self.values)}") self.total_steps = ( [boundaries[i + 1] - boundaries[i] for i in range(len(boundaries) - 1) ] + [0]) def __call__(self, global_step): with tf.name_scope(self.name or "StepCosineDecayWithOffset"): global_step = tf.cast(global_step - self.offset, tf.float32) lr_levels = self.values lr_steps = self.boundaries level_total_steps = self.total_steps num_levels = len(lr_levels) init_lr = lr_levels[0] next_init_lr = lr_levels[1] if num_levels > 1 else 0. init_total_steps = level_total_steps[0] cosine_learning_rate = ((init_lr - next_init_lr) * (tf.cos( tf.constant(math.pi) * (global_step) / (init_total_steps)) + 1.0) / 2.0 + next_init_lr) learning_rate = cosine_learning_rate for i in range(1, num_levels): next_init_lr = lr_levels[i] next_start_step = lr_steps[i] next_total_steps = level_total_steps[i] next_next_init_lr = lr_levels[i + 1] if num_levels > i + 1 else 0. next_cosine_learning_rate = ((next_init_lr - next_next_init_lr) * (tf.cos( tf.constant(math.pi) * (global_step - next_start_step) / (next_total_steps)) + 1.0) / 2.0 + next_next_init_lr) learning_rate = tf.where(global_step >= next_start_step, next_cosine_learning_rate, learning_rate) return learning_rate def get_config(self): return { "boundaries": self.boundaries, "values": self.values, "offset": self.offset, "name": self.name }