Spaces:
Runtime error
Runtime error
File size: 10,539 Bytes
5672777 93528c6 5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimizer factory class."""
from typing import Callable, List, Optional, Tuple, Union
import gin
import tensorflow as tf, tf_keras
from official.modeling.optimization import slide_optimizer
from official.modeling.optimization import adafactor_optimizer
from official.modeling.optimization import ema_optimizer
from official.modeling.optimization import lamb
from official.modeling.optimization import lars
from official.modeling.optimization import legacy_adamw
from official.modeling.optimization import lr_schedule
from official.modeling.optimization.configs import optimization_config as opt_cfg
# Optimizer CLS to be used in both legacy and new path.
SHARED_OPTIMIZERS = {
'sgd_experimental': tf_keras.optimizers.experimental.SGD,
'adam_experimental': tf_keras.optimizers.experimental.Adam,
'adamw': legacy_adamw.AdamWeightDecay,
'adamw_experimental': tf_keras.optimizers.experimental.AdamW,
'lamb': lamb.LAMB,
'lars': lars.LARS,
'slide': slide_optimizer.SLIDE,
'adafactor': adafactor_optimizer.Adafactor,
'adafactor_keras': tf_keras.optimizers.Adafactor,
}
LEGACY_OPTIMIZERS_CLS = {
'sgd': tf_keras.optimizers.legacy.SGD,
'adam': tf_keras.optimizers.legacy.Adam,
'rmsprop': tf_keras.optimizers.legacy.RMSprop,
'adagrad': tf_keras.optimizers.legacy.Adagrad,
}
LEGACY_OPTIMIZERS_CLS.update(SHARED_OPTIMIZERS)
NEW_OPTIMIZERS_CLS = {
'sgd': tf_keras.optimizers.experimental.SGD,
'adam': tf_keras.optimizers.experimental.Adam,
'rmsprop': tf_keras.optimizers.experimental.RMSprop,
'adagrad': tf_keras.optimizers.experimental.Adagrad,
}
NEW_OPTIMIZERS_CLS.update(SHARED_OPTIMIZERS)
LR_CLS = {
'stepwise': lr_schedule.PiecewiseConstantDecayWithOffset,
'polynomial': lr_schedule.PolynomialDecayWithOffset,
'exponential': lr_schedule.ExponentialDecayWithOffset,
'cosine': lr_schedule.CosineDecayWithOffset,
'power': lr_schedule.DirectPowerDecay,
'power_linear': lr_schedule.PowerAndLinearDecay,
'power_with_offset': lr_schedule.PowerDecayWithOffset,
'step_cosine_with_offset': lr_schedule.StepCosineDecayWithOffset,
}
WARMUP_CLS = {
'linear': lr_schedule.LinearWarmup,
'polynomial': lr_schedule.PolynomialWarmUp
}
def register_optimizer_cls(key: str,
optimizer_config_cls: Union[
tf_keras.optimizers.Optimizer,
tf_keras.optimizers.legacy.Optimizer,
tf_keras.optimizers.experimental.Optimizer
],
use_legacy_optimizer: bool = True):
"""Register customize optimizer cls.
The user will still need to subclass data classes in
configs.optimization_config to be used with OptimizerFactory.
Args:
key: A string to that the optimizer_config_cls is registered with.
optimizer_config_cls: A class which inherits tf_keras.optimizers.Optimizer.
use_legacy_optimizer: A boolean that indicates if using legacy optimizers.
"""
if use_legacy_optimizer:
if key in LEGACY_OPTIMIZERS_CLS:
raise ValueError('%s already registered in LEGACY_OPTIMIZERS_CLS.' % key)
LEGACY_OPTIMIZERS_CLS[key] = optimizer_config_cls
else:
if key in NEW_OPTIMIZERS_CLS:
raise ValueError('%s already registered in NEW_OPTIMIZERS_CLS.' % key)
NEW_OPTIMIZERS_CLS[key] = optimizer_config_cls
class OptimizerFactory:
"""Optimizer factory class.
This class builds learning rate and optimizer based on an optimization config.
To use this class, you need to do the following:
(1) Define optimization config, this includes optimizer, and learning rate
schedule.
(2) Initialize the class using the optimization config.
(3) Build learning rate.
(4) Build optimizer.
This is a typical example for using this class:
```
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
},
'warmup': {
'type': 'linear',
'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
}
}
opt_config = OptimizationConfig(params)
opt_factory = OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
```
"""
def __init__(self, config: opt_cfg.OptimizationConfig):
"""Initializing OptimizerFactory.
Args:
config: OptimizationConfig instance contain optimization config.
"""
self._config = config
self._optimizer_config = config.optimizer.get()
self._optimizer_type = config.optimizer.type
self._use_ema = config.ema is not None
self._ema_config = config.ema
if self._optimizer_config is None:
raise ValueError('Optimizer type must be specified')
self._lr_config = config.learning_rate.get()
self._lr_type = config.learning_rate.type
if self._lr_type is None:
raise ValueError('Learning rate type must be specified')
self._warmup_config = config.warmup.get()
self._warmup_type = config.warmup.type
def build_learning_rate(self):
"""Build learning rate.
Builds learning rate from config. Learning rate schedule is built according
to the learning rate config. If learning rate type is consant,
lr_config.learning_rate is returned.
Returns:
tf_keras.optimizers.schedules.LearningRateSchedule instance. If
learning rate type is consant, lr_config.learning_rate is returned.
"""
if self._lr_type == 'constant':
lr = self._lr_config.learning_rate
else:
lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
if self._warmup_config:
lr = WARMUP_CLS[self._warmup_type](lr, **self._warmup_config.as_dict())
return lr
@gin.configurable
def build_optimizer(
self,
lr: Union[tf_keras.optimizers.schedules.LearningRateSchedule, float],
gradient_aggregator: Optional[Callable[
[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor,
tf.Tensor]]]] = None,
gradient_transformers: Optional[List[Callable[
[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor,
tf.Tensor]]]]] = None,
postprocessor: Optional[Callable[[tf_keras.optimizers.Optimizer],
tf_keras.optimizers.Optimizer]] = None,
use_legacy_optimizer: bool = True):
"""Build optimizer.
Builds optimizer from config. It takes learning rate as input, and builds
the optimizer according to the optimizer config. Typically, the learning
rate built using self.build_lr() is passed as an argument to this method.
Args:
lr: A floating point value, or a
tf_keras.optimizers.schedules.LearningRateSchedule instance.
gradient_aggregator: Optional function to overwrite gradient aggregation.
gradient_transformers: Optional list of functions to use to transform
gradients before applying updates to Variables. The functions are
applied after gradient_aggregator. The functions should accept and
return a list of (gradient, variable) tuples. clipvalue, clipnorm,
global_clipnorm should not be set when gradient_transformers is passed.
postprocessor: An optional function for postprocessing the optimizer. It
takes an optimizer and returns an optimizer.
use_legacy_optimizer: A boolean that indicates if using legacy optimizers.
Returns:
`tf_keras.optimizers.legacy.Optimizer` or
`tf_keras.optimizers.experimental.Optimizer` instance.
"""
optimizer_dict = self._optimizer_config.as_dict()
## Delete clipnorm, clipvalue, global_clipnorm if None
if optimizer_dict['clipnorm'] is None:
del optimizer_dict['clipnorm']
if optimizer_dict['clipvalue'] is None:
del optimizer_dict['clipvalue']
if optimizer_dict['global_clipnorm'] is None:
del optimizer_dict['global_clipnorm']
optimizer_dict['learning_rate'] = lr
if gradient_aggregator is not None:
optimizer_dict['gradient_aggregator'] = gradient_aggregator
if gradient_transformers is not None:
optimizer_dict['gradient_transformers'] = gradient_transformers
if use_legacy_optimizer:
optimizer = LEGACY_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
else:
if 'decay' in optimizer_dict:
raise ValueError(
'`decay` is deprecated in new Keras optimizer, please reflect the '
'decay logic in `lr` or set `use_legacy_optimizer=True` to use the '
'legacy optimizer.')
optimizer = NEW_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
if self._use_ema:
if not use_legacy_optimizer:
raise ValueError(
'EMA can only work with the legacy optimizer, please set '
'`use_legacy_optimizer=True`.')
optimizer = ema_optimizer.ExponentialMovingAverage(
optimizer, **self._ema_config.as_dict())
if postprocessor:
optimizer = postprocessor(optimizer)
if isinstance(optimizer, tf_keras.optimizers.Optimizer):
return optimizer
# The following check makes sure the function won't break in older TF
# version because of missing the experimental/legacy package.
if hasattr(tf_keras.optimizers, 'experimental'):
if isinstance(optimizer, tf_keras.optimizers.experimental.Optimizer):
return optimizer
if hasattr(tf_keras.optimizers, 'legacy'):
if isinstance(optimizer, tf_keras.optimizers.legacy.Optimizer):
return optimizer
raise TypeError('OptimizerFactory.build_optimizer returning a '
'non-optimizer object: {}'.format(optimizer))
|