File size: 2,871 Bytes
503ec99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# Copyright 2023 by zhongying
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.keras.optimizers.schedules import ExponentialDecay
class TransformerLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
""" Transformer learning rate schedule """
def __init__(self, d_model, init_steps=0, warmup_steps=4000, max_lr=None):
super(TransformerLRSchedule, self).__init__()
self.d_model = d_model
self.d_model = tf.cast(self.d_model, tf.float32)
self.max_lr = max_lr
self.warmup_steps = warmup_steps
self.init_steps = init_steps
def __call__(self, step):
# lr = (d_model^-0.5) * min(step^-0.5, step*(warm_up^-1.5))
step += self.init_steps
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
lr = tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
if self.max_lr is not None:
return tf.math.minimum(self.max_lr, lr)
return lr
def get_config(self):
return {
"d_model": self.d_model,
"warmup_steps": self.warmup_steps,
"max_lr": self.max_lr
}
class SANSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, lamb, d_model, warmup_steps=4000):
super(SANSchedule, self).__init__()
self.lamb = tf.cast(lamb, tf.float32)
self.d_model = tf.cast(d_model, tf.float32)
self.warmup_steps = tf.cast(warmup_steps, tf.float32)
def __call__(self, step):
arg1 = step / (self.warmup_steps ** 1.5)
arg2 = 1 / tf.math.sqrt(step)
return (self.lamb / tf.math.sqrt(self.d_model)) * tf.math.minimum(arg1, arg2)
def get_config(self):
return {
"lamb": self.lamb,
"d_model": self.d_model,
"warmup_steps": self.warmup_steps
}
class BoundExponentialDecay(ExponentialDecay):
def __init__(self, min_lr=0.0, **kwargs):
super().__init__(**kwargs)
self.min_lr = min_lr
def __call__(self, step):
with ops.name_scope_v2(self.name or "ExponentialDecay") as name:
initial_learning_rate = ops.convert_to_tensor(
self.initial_learning_rate, name="initial_learning_rate")
dtype = initial_learning_rate.dtype
decay_steps = math_ops.cast(self.decay_steps, dtype)
decay_rate = math_ops.cast(self.decay_rate, dtype)
global_step_recomp = math_ops.cast(step, dtype)
p = global_step_recomp / decay_steps
if self.staircase:
p = math_ops.floor(p)
new_lr = math_ops.multiply(
initial_learning_rate, math_ops.pow(decay_rate, p), name=name)
return math_ops.maximum(self.min_lr, new_lr)
|