|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Learning rate schedule classes.""" |
|
|
|
from typing import Mapping, Any, Union, Optional |
|
|
|
import tensorflow as tf |
|
|
|
|
|
class LinearWarmup(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
"""Linear warmup schedule.""" |
|
|
|
def __init__(self, after_warmup_lr_sched: Union[ |
|
tf.keras.optimizers.schedules.LearningRateSchedule, float], |
|
warmup_steps: int, warmup_learning_rate: float, |
|
name: Optional[str] = None): |
|
"""Add linear warmup schedule to a learning rate schedule. |
|
|
|
warmup_lr is the initial learning rate, the final learning rate of the |
|
init_warmup period is the initial learning rate of lr_schedule in use. |
|
The learning rate at each step linearly increased according to the following |
|
formula: |
|
learning_rate = warmup_lr + step / warmup_steps |
|
* (final_warmup_lr - warmup_lr). |
|
Using warmup overrides the learning rate schedule by the number of warmup |
|
steps. |
|
|
|
Args: |
|
after_warmup_lr_sched: tf.keras.optimizers.schedules |
|
.LearningRateSchedule or a constant. |
|
warmup_steps: int. number of the warmup steps. |
|
warmup_learning_rate: floating point number. Initial learning rate for the |
|
warmup. |
|
name: Optional, name of warmup schedule. |
|
""" |
|
super(LinearWarmup, self).__init__() |
|
self._name = name |
|
self._after_warmup_lr_sched = after_warmup_lr_sched |
|
self._warmup_steps = warmup_steps |
|
self._init_warmup_lr = warmup_learning_rate |
|
if isinstance(after_warmup_lr_sched, |
|
tf.keras.optimizers.schedules.LearningRateSchedule): |
|
self._final_warmup_lr = after_warmup_lr_sched(warmup_steps) |
|
else: |
|
self._final_warmup_lr = tf.cast( |
|
after_warmup_lr_sched, dtype=tf.float32) |
|
|
|
def __call__(self, step: int): |
|
|
|
global_step = tf.cast(step, dtype=tf.float32) |
|
|
|
linear_warmup_lr = ( |
|
self._init_warmup_lr + global_step / self._warmup_steps * |
|
(self._final_warmup_lr - self._init_warmup_lr)) |
|
|
|
if isinstance(self._after_warmup_lr_sched, |
|
tf.keras.optimizers.schedules.LearningRateSchedule): |
|
after_warmup_lr = self._after_warmup_lr_sched(step) |
|
else: |
|
after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32) |
|
|
|
lr = tf.cond(global_step < self._warmup_steps, |
|
lambda: linear_warmup_lr, |
|
lambda: after_warmup_lr) |
|
return lr |
|
|
|
def get_config(self) -> Mapping[str, Any]: |
|
if isinstance(self._after_warmup_lr_sched, |
|
tf.keras.optimizers.schedules.LearningRateSchedule): |
|
config = { |
|
"after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} |
|
else: |
|
config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} |
|
|
|
config.update({ |
|
"warmup_steps": self._warmup_steps, |
|
"warmup_learning_rate": self._init_warmup_lr, |
|
"name": self._name |
|
}) |
|
return config |
|
|
|
|
|
class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
"""Applies polynomial warmup schedule on a given learning rate decay schedule. |
|
""" |
|
|
|
def __init__(self, |
|
after_warmup_lr_sched: Union[ |
|
tf.keras.optimizers.schedules.LearningRateSchedule, float], |
|
warmup_steps: int, |
|
power: float = 1.0, |
|
name: str = "PolynomialWarmup"): |
|
super(PolynomialWarmUp, self).__init__() |
|
if isinstance(after_warmup_lr_sched, |
|
tf.keras.optimizers.schedules.LearningRateSchedule): |
|
self._initial_learning_rate = after_warmup_lr_sched(warmup_steps) |
|
else: |
|
self._initial_learning_rate = tf.cast( |
|
after_warmup_lr_sched, dtype=tf.float32) |
|
|
|
self._warmup_steps = warmup_steps |
|
self._power = power |
|
self._after_warmup_lr_sched = after_warmup_lr_sched |
|
self._name = name |
|
|
|
def __call__(self, step): |
|
with tf.name_scope(self._name or "PolynomialWarmUp") as name: |
|
|
|
|
|
global_step_float = tf.cast(step, tf.float32) |
|
warmup_steps_float = tf.cast(self._warmup_steps, tf.float32) |
|
warmup_percent_done = global_step_float / warmup_steps_float |
|
warmup_learning_rate = ( |
|
self._initial_learning_rate * |
|
tf.math.pow(warmup_percent_done, self._power)) |
|
|
|
if isinstance(self._after_warmup_lr_sched, |
|
tf.keras.optimizers.schedules.LearningRateSchedule): |
|
after_warmup_lr = self._after_warmup_lr_sched(step) |
|
else: |
|
after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32) |
|
|
|
return tf.cond( |
|
global_step_float < warmup_steps_float, |
|
lambda: warmup_learning_rate, |
|
lambda: after_warmup_lr, |
|
name=name) |
|
|
|
def get_config(self) -> Mapping[str, Any]: |
|
if isinstance(self._after_warmup_lr_sched, |
|
tf.keras.optimizers.schedules.LearningRateSchedule): |
|
config = { |
|
"after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} |
|
else: |
|
config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} |
|
|
|
config.update({ |
|
"warmup_steps": self._warmup_setps, |
|
"power": self._power, |
|
"name": self._name |
|
}) |
|
return config |
|
|