File size: 13,127 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93528c6
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dataclasses for optimizer configs."""
from typing import List, Optional

import dataclasses
from official.modeling.hyperparams import base_config


@dataclasses.dataclass
class BaseOptimizerConfig(base_config.Config):
  """Base optimizer config.

  Attributes:
    clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
      their L2 norm exceeds this value.
    clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
      their absolute value exceeds this value.
    global_clipnorm: float >= 0 or None. If not None, gradient of all weights is
      clipped so that their global norm is no higher than this value
  """
  clipnorm: Optional[float] = None
  clipvalue: Optional[float] = None
  global_clipnorm: Optional[float] = None


@dataclasses.dataclass
class SGDConfig(BaseOptimizerConfig):
  """Configuration for SGD optimizer.

  The attributes for this class matches the arguments of tf_keras.optimizer.SGD.

  Attributes:
    name: name of the optimizer.
    decay: decay rate for SGD optimizer.
    nesterov: nesterov for SGD optimizer.
    momentum: momentum for SGD optimizer.
  """
  name: str = "SGD"
  decay: float = 0.0
  nesterov: bool = False
  momentum: float = 0.0


# TODO(b/216129465): Merge this config with SGDConfig after the experimental
# optimizer graduates.
@dataclasses.dataclass
class SGDExperimentalConfig(BaseOptimizerConfig):
  """Configuration for SGD optimizer.

  The attributes for this class matches the arguments of
  `tf_keras.optimizer.experimental.SGD`.

  Attributes:
    name: name of the optimizer.
    nesterov: nesterov for SGD optimizer.
    momentum: momentum for SGD optimizer.
    jit_compile: if True, jit compile will be used.
  """
  name: str = "SGD"
  nesterov: bool = False
  momentum: float = 0.0
  jit_compile: bool = False


@dataclasses.dataclass
class RMSPropConfig(BaseOptimizerConfig):
  """Configuration for RMSProp optimizer.

  The attributes for this class matches the arguments of
  tf_keras.optimizers.RMSprop.

  Attributes:
    name: name of the optimizer.
    rho: discounting factor for RMSprop optimizer.
    momentum: momentum for RMSprop optimizer.
    epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
    centered: Whether to normalize gradients or not.
  """
  name: str = "RMSprop"
  rho: float = 0.9
  momentum: float = 0.0
  epsilon: float = 1e-7
  centered: bool = False


@dataclasses.dataclass
class AdagradConfig(BaseOptimizerConfig):
  """Configuration for Adagrad optimizer.

  The attributes of this class match the arguments of
  tf_keras.optimizer.Adagrad.

  Attributes:
    name: name of the optimizer.
    initial_accumulator_value: A floating point value. Starting value for the
      accumulators, must be non-negative.
    epsilon: A small floating point value to avoid zero denominator.
  """
  name: str = "Adagrad"
  initial_accumulator_value: float = 0.1
  epsilon: float = 1e-07


@dataclasses.dataclass
class AdamConfig(BaseOptimizerConfig):
  """Configuration for Adam optimizer.

  The attributes for this class matches the arguments of
  tf_keras.optimizer.Adam.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in Adam optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
      the paper "On the Convergence of Adam and beyond".
  """
  name: str = "Adam"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False


@dataclasses.dataclass
class AdamExperimentalConfig(BaseOptimizerConfig):
  """Configuration for experimental Adam optimizer.

  The attributes for this class matches the arguments of
  `tf_keras.optimizer.experimental.Adam`.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in Adam optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
      the paper "On the Convergence of Adam and beyond".
    jit_compile: if True, jit compile will be used.
  """
  name: str = "Adam"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False
  jit_compile: bool = False


@dataclasses.dataclass
class AdamWeightDecayConfig(BaseOptimizerConfig):
  """Configuration for Adam optimizer with weight decay.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in the optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
      the paper "On the Convergence of Adam and beyond".
    weight_decay_rate: float. Weight decay rate. Default to 0.
    include_in_weight_decay: list[str], or None. List of weight names to include
      in weight decay.
    exclude_from_weight_decay: list[str], or None. List of weight names to not
      include in weight decay.
    gradient_clip_norm: A positive float. Clips the gradients to this maximum
      L2-norm. Default to 1.0.
  """
  name: str = "AdamWeightDecay"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False
  weight_decay_rate: float = 0.0
  include_in_weight_decay: Optional[List[str]] = None
  exclude_from_weight_decay: Optional[List[str]] = None
  gradient_clip_norm: float = 1.0


@dataclasses.dataclass
class AdamWeightDecayExperimentalConfig(BaseOptimizerConfig):
  """Configuration for Adam optimizer with weight decay.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in the optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
      the paper "On the Convergence of Adam and beyond".
    weight_decay: float. Weight decay rate. Default to 0.
    global_clipnorm: A positive float. Clips the gradients to this maximum
      L2-norm. Default to 1.0.
    jit_compile: if True, jit compile will be used.
  """
  name: str = "AdamWeightDecayExperimental"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False
  weight_decay: float = 0.0
  global_clipnorm: float = 1.0
  jit_compile: bool = False


@dataclasses.dataclass
class LAMBConfig(BaseOptimizerConfig):
  """Configuration for LAMB optimizer.

  The attributes for this class matches the arguments of LAMB optimizer.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in LAMB optimizer.
    weight_decay_rate: float. Weight decay rate. Default to 0.
    exclude_from_weight_decay: List of regex patterns of variables excluded from
      weight decay. Variables whose name contain a substring matching the
      pattern will be excluded.
    exclude_from_layer_adaptation: List of regex patterns of variables excluded
      from layer adaptation. Variables whose name contain a substring matching
      the pattern will be excluded.
  """
  name: str = "LAMB"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-6
  weight_decay_rate: float = 0.0
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None


@dataclasses.dataclass
class EMAConfig(BaseOptimizerConfig):
  """Exponential moving average optimizer config.

  Attributes:
    name: 'str', name of the optimizer.
    trainable_weights_only: 'bool', if True, only model trainable weights will
      be updated. Otherwise, all model weights will be updated. This mainly
      affects batch normalization parameters.
    average_decay: 'float', average decay value.
    start_step: 'int', start step to apply moving average.
    dynamic_decay: 'bool', whether to apply dynamic decay or not.
  """
  name: str = "ExponentialMovingAverage"
  trainable_weights_only: bool = True
  average_decay: float = 0.99
  start_step: int = 0
  dynamic_decay: bool = True


@dataclasses.dataclass
class LARSConfig(BaseOptimizerConfig):
  """Layer-wise adaptive rate scaling config.

  Attributes:
    name: 'str', name of the optimizer.
    momentum: `float` hyperparameter >= 0 that accelerates gradient descent in
      the relevant direction and dampens oscillations. Defaults to 0.9.
    eeta: `float` LARS coefficient as used in the paper. Default set to LARS
      coefficient from the paper. (eeta / weight_decay) determines the highest
      scaling factor in LARS..
    weight_decay_rate: `float` for weight decay.
    nesterov: 'boolean' for whether to use nesterov momentum.
    classic_momentum: `boolean` for whether to use classic (or popular)
      momentum. The learning rate is applied during momentum update in classic
      momentum, but after momentum for popular momentum.
    exclude_from_weight_decay: A list of `string` for variable screening, if any
      of the string appears in a variable's name, the variable will be excluded
      for computing weight decay. For example, one could specify the list like
      ['batch_normalization', 'bias'] to exclude BN and bias from weight decay.
    exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but for
      layer adaptation. If it is None, it will be defaulted the same as
      exclude_from_weight_decay.
  """
  name: str = "LARS"
  momentum: float = 0.9
  eeta: float = 0.001
  weight_decay_rate: float = 0.0
  nesterov: bool = False
  classic_momentum: bool = True
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None


@dataclasses.dataclass
class SLIDEConfig(BaseOptimizerConfig):
  """Configuration for SLIDE optimizer.

  Details coming soon.
  """
  name: str = "SLIDE"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-6
  weight_decay_rate: float = 0.0
  weight_decay_type: str = "inner"
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None
  include_in_sparse_layer_adaptation: Optional[List[str]] = None
  sparse_layer_learning_rate: float = 0.1
  do_gradient_rescaling: bool = True
  norm_type: str = "layer"
  ratio_clip_norm: float = 1e5


@dataclasses.dataclass
class AdafactorConfig(BaseOptimizerConfig):
  """Configuration for Adafactor optimizer.

  The attributes for this class matches the arguments of the Adafactor
  implementation.
  """
  name: str = "Adafactor"
  factored: bool = True
  multiply_by_parameter_scale: bool = True
  beta1: Optional[float] = None
  decay_rate: float = 0.8
  step_offset: int = 0
  clipping_threshold: float = 1.0
  min_dim_size_to_factor: int = 128
  epsilon1: float = 1e-30
  epsilon2: float = 1e-3
  weight_decay: Optional[float] = None
  include_in_weight_decay: Optional[str] = None


@dataclasses.dataclass
class AdafactorKerasConfig(BaseOptimizerConfig):
  """Configuration for AdafactorKeras optimizer.

  The attributes for this class matches the arguments of the Adafactor
  implementation provided by keras.

  Attributes:
          learning_rate: Initial value for the learning rate: either a floating
            point value, or a
            `tf_keras.optimizers.schedules.LearningRateSchedule` instance.
            Defaults to 0.001.
        beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`.
        epsilon_1: float, defaults to 1e-30. A small offset to keep denominator
          away from 0.
        epsilon_2: float, defaults to 1e-3. A small offset to avoid learning
          rate becoming too small by time.
        clip_threshold: float, defaults to 1.0. Clipping threshold. This is a
          part of Adafactor algorithm, independent from `clipnorm`, `clipvalue`
          and `global_clipnorm`.
        relative_step: bool, defaults to True. If `learning_rate` is a constant
          and `relative_step=True`, learning rate will be adjusted based on
          current iterations. This is a default learning rate decay in
          Adafactor.
  """
  name: str = "Adafactor"
  learning_rate: float = 0.001
  beta_2_decay: float = -0.8
  epsilon_1: float = 1e-30
  epsilon_2: float = 1e-3
  clip_threshold: float = 1.0
  relative_step: bool = True