deanna-emery's picture
updates
93528c6
raw
history blame
3.89 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Legacy functions and classes related to optimization."""
from absl import logging
import gin
import tensorflow as tf, tf_keras
from official.modeling.optimization import lamb
from official.modeling.optimization import legacy_adamw
AdamWeightDecay = legacy_adamw.AdamWeightDecay
LAMB = lamb.LAMB
class WarmUp(tf_keras.optimizers.schedules.LearningRateSchedule):
"""Applies a warmup schedule on a given learning rate decay schedule."""
def __init__(self,
initial_learning_rate,
decay_schedule_fn,
warmup_steps,
power=1.0,
name=None):
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.power = power
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step):
with tf.name_scope(self.name or 'WarmUp') as name:
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = (
self.initial_learning_rate *
tf.math.pow(warmup_percent_done, self.power))
return tf.cond(
global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name)
def get_config(self):
return {
'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn,
'warmup_steps': self.warmup_steps,
'power': self.power,
'name': self.name
}
@gin.configurable
def create_optimizer(init_lr,
num_train_steps,
num_warmup_steps,
end_lr=0.0,
optimizer_type='adamw',
beta_1=0.9,
poly_power=1.0):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
lr_schedule = tf_keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr,
decay_steps=num_train_steps,
end_learning_rate=end_lr,
power=poly_power)
if num_warmup_steps:
lr_schedule = WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=lr_schedule,
warmup_steps=num_warmup_steps)
if optimizer_type == 'adamw':
logging.info('using Adamw optimizer')
optimizer = AdamWeightDecay(
learning_rate=lr_schedule,
weight_decay_rate=0.01,
beta_1=beta_1,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
elif optimizer_type == 'lamb':
logging.info('using Lamb optimizer')
optimizer = LAMB(
learning_rate=lr_schedule,
weight_decay_rate=0.01,
beta_1=beta_1,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'],
)
else:
raise ValueError('Unsupported optimizer type: ', optimizer_type)
return optimizer