Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Legacy functions and classes related to optimization.""" | |
from absl import logging | |
import gin | |
import tensorflow as tf, tf_keras | |
from official.modeling.optimization import lamb | |
from official.modeling.optimization import legacy_adamw | |
AdamWeightDecay = legacy_adamw.AdamWeightDecay | |
LAMB = lamb.LAMB | |
class WarmUp(tf_keras.optimizers.schedules.LearningRateSchedule): | |
"""Applies a warmup schedule on a given learning rate decay schedule.""" | |
def __init__(self, | |
initial_learning_rate, | |
decay_schedule_fn, | |
warmup_steps, | |
power=1.0, | |
name=None): | |
super(WarmUp, self).__init__() | |
self.initial_learning_rate = initial_learning_rate | |
self.warmup_steps = warmup_steps | |
self.power = power | |
self.decay_schedule_fn = decay_schedule_fn | |
self.name = name | |
def __call__(self, step): | |
with tf.name_scope(self.name or 'WarmUp') as name: | |
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the | |
# learning rate will be `global_step/num_warmup_steps * init_lr`. | |
global_step_float = tf.cast(step, tf.float32) | |
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) | |
warmup_percent_done = global_step_float / warmup_steps_float | |
warmup_learning_rate = ( | |
self.initial_learning_rate * | |
tf.math.pow(warmup_percent_done, self.power)) | |
return tf.cond( | |
global_step_float < warmup_steps_float, | |
lambda: warmup_learning_rate, | |
lambda: self.decay_schedule_fn(step), | |
name=name) | |
def get_config(self): | |
return { | |
'initial_learning_rate': self.initial_learning_rate, | |
'decay_schedule_fn': self.decay_schedule_fn, | |
'warmup_steps': self.warmup_steps, | |
'power': self.power, | |
'name': self.name | |
} | |
def create_optimizer(init_lr, | |
num_train_steps, | |
num_warmup_steps, | |
end_lr=0.0, | |
optimizer_type='adamw', | |
beta_1=0.9, | |
poly_power=1.0): | |
"""Creates an optimizer with learning rate schedule.""" | |
# Implements linear decay of the learning rate. | |
lr_schedule = tf_keras.optimizers.schedules.PolynomialDecay( | |
initial_learning_rate=init_lr, | |
decay_steps=num_train_steps, | |
end_learning_rate=end_lr, | |
power=poly_power) | |
if num_warmup_steps: | |
lr_schedule = WarmUp( | |
initial_learning_rate=init_lr, | |
decay_schedule_fn=lr_schedule, | |
warmup_steps=num_warmup_steps) | |
if optimizer_type == 'adamw': | |
logging.info('using Adamw optimizer') | |
optimizer = AdamWeightDecay( | |
learning_rate=lr_schedule, | |
weight_decay_rate=0.01, | |
beta_1=beta_1, | |
beta_2=0.999, | |
epsilon=1e-6, | |
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias']) | |
elif optimizer_type == 'lamb': | |
logging.info('using Lamb optimizer') | |
optimizer = LAMB( | |
learning_rate=lr_schedule, | |
weight_decay_rate=0.01, | |
beta_1=beta_1, | |
beta_2=0.999, | |
epsilon=1e-6, | |
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'], | |
) | |
else: | |
raise ValueError('Unsupported optimizer type: ', optimizer_type) | |
return optimizer | |