# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Legacy functions and classes related to optimization.""" from absl import logging import gin import tensorflow as tf, tf_keras from official.modeling.optimization import lamb from official.modeling.optimization import legacy_adamw AdamWeightDecay = legacy_adamw.AdamWeightDecay LAMB = lamb.LAMB class WarmUp(tf_keras.optimizers.schedules.LearningRateSchedule): """Applies a warmup schedule on a given learning rate decay schedule.""" def __init__(self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None): super(WarmUp, self).__init__() self.initial_learning_rate = initial_learning_rate self.warmup_steps = warmup_steps self.power = power self.decay_schedule_fn = decay_schedule_fn self.name = name def __call__(self, step): with tf.name_scope(self.name or 'WarmUp') as name: # Implements polynomial warmup. i.e., if global_step < warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. global_step_float = tf.cast(step, tf.float32) warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) warmup_percent_done = global_step_float / warmup_steps_float warmup_learning_rate = ( self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)) return tf.cond( global_step_float < warmup_steps_float, lambda: warmup_learning_rate, lambda: self.decay_schedule_fn(step), name=name) def get_config(self): return { 'initial_learning_rate': self.initial_learning_rate, 'decay_schedule_fn': self.decay_schedule_fn, 'warmup_steps': self.warmup_steps, 'power': self.power, 'name': self.name } @gin.configurable def create_optimizer(init_lr, num_train_steps, num_warmup_steps, end_lr=0.0, optimizer_type='adamw', beta_1=0.9, poly_power=1.0): """Creates an optimizer with learning rate schedule.""" # Implements linear decay of the learning rate. lr_schedule = tf_keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=end_lr, power=poly_power) if num_warmup_steps: lr_schedule = WarmUp( initial_learning_rate=init_lr, decay_schedule_fn=lr_schedule, warmup_steps=num_warmup_steps) if optimizer_type == 'adamw': logging.info('using Adamw optimizer') optimizer = AdamWeightDecay( learning_rate=lr_schedule, weight_decay_rate=0.01, beta_1=beta_1, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias']) elif optimizer_type == 'lamb': logging.info('using Lamb optimizer') optimizer = LAMB( learning_rate=lr_schedule, weight_decay_rate=0.01, beta_1=beta_1, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'], ) else: raise ValueError('Unsupported optimizer type: ', optimizer_type) return optimizer