|
|
|
|
|
import tensorflow as tf |
|
from tensorflow.python.framework import ops |
|
from tensorflow.python.ops import math_ops |
|
from tensorflow.keras.optimizers.schedules import ExponentialDecay |
|
|
|
|
|
class TransformerLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
""" Transformer learning rate schedule """ |
|
|
|
def __init__(self, d_model, init_steps=0, warmup_steps=4000, max_lr=None): |
|
super(TransformerLRSchedule, self).__init__() |
|
|
|
self.d_model = d_model |
|
self.d_model = tf.cast(self.d_model, tf.float32) |
|
self.max_lr = max_lr |
|
self.warmup_steps = warmup_steps |
|
self.init_steps = init_steps |
|
|
|
def __call__(self, step): |
|
|
|
step += self.init_steps |
|
arg1 = tf.math.rsqrt(step) |
|
arg2 = step * (self.warmup_steps ** -1.5) |
|
lr = tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) |
|
if self.max_lr is not None: |
|
return tf.math.minimum(self.max_lr, lr) |
|
return lr |
|
|
|
def get_config(self): |
|
return { |
|
"d_model": self.d_model, |
|
"warmup_steps": self.warmup_steps, |
|
"max_lr": self.max_lr |
|
} |
|
|
|
|
|
class SANSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
def __init__(self, lamb, d_model, warmup_steps=4000): |
|
super(SANSchedule, self).__init__() |
|
|
|
self.lamb = tf.cast(lamb, tf.float32) |
|
self.d_model = tf.cast(d_model, tf.float32) |
|
|
|
self.warmup_steps = tf.cast(warmup_steps, tf.float32) |
|
|
|
def __call__(self, step): |
|
arg1 = step / (self.warmup_steps ** 1.5) |
|
arg2 = 1 / tf.math.sqrt(step) |
|
|
|
return (self.lamb / tf.math.sqrt(self.d_model)) * tf.math.minimum(arg1, arg2) |
|
|
|
def get_config(self): |
|
return { |
|
"lamb": self.lamb, |
|
"d_model": self.d_model, |
|
"warmup_steps": self.warmup_steps |
|
} |
|
|
|
|
|
class BoundExponentialDecay(ExponentialDecay): |
|
def __init__(self, min_lr=0.0, **kwargs): |
|
super().__init__(**kwargs) |
|
self.min_lr = min_lr |
|
|
|
def __call__(self, step): |
|
with ops.name_scope_v2(self.name or "ExponentialDecay") as name: |
|
initial_learning_rate = ops.convert_to_tensor( |
|
self.initial_learning_rate, name="initial_learning_rate") |
|
dtype = initial_learning_rate.dtype |
|
decay_steps = math_ops.cast(self.decay_steps, dtype) |
|
decay_rate = math_ops.cast(self.decay_rate, dtype) |
|
|
|
global_step_recomp = math_ops.cast(step, dtype) |
|
p = global_step_recomp / decay_steps |
|
if self.staircase: |
|
p = math_ops.floor(p) |
|
new_lr = math_ops.multiply( |
|
initial_learning_rate, math_ops.pow(decay_rate, p), name=name) |
|
return math_ops.maximum(self.min_lr, new_lr) |
|
|