File size: 18,517 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Learning rate schedule classes."""

import math
from typing import Mapping, Any, Union, Optional

import tensorflow as tf, tf_keras


def _make_offset_wrapper(new_class_name: str, base_lr_class):
  """Generates a offset wrapper of learning rate schedule.

  It will returns a subclass of the `base_lr_class`, the subclass takes an
  `offset` argument in the constructor. When the new class instance is called,
  the behavior is:
    new_class_object(step) = base_lr_class_object(step - offset)

  Example:
    CosineDecayWithOffset = _make_offset_wrapper(
                     'CosineDecayWithOffset', 
                     tf_keras.optimizers.schedules.CosineDecay)
    # Use the lr:
    lr = CosineDecayWithOffset(offset=100, initial_learning_rate=0.1,
                               decay_steps=1000)
    lr(101) # equals to keras.optimizers.schedules.CosineDecay(...)(101-100)

  Args:
    new_class_name: the name of the new class.
    base_lr_class: the base learning rate schedule class. Should be subclass of
      tf_keras.optimizers.schedules.LearningRateSchedule

  Returns:
    A new class (subclass of the base_lr_class) that can take an offset.
  """
  assert issubclass(base_lr_class,
                    tf_keras.optimizers.schedules.LearningRateSchedule), (
                        "base_lr_class should be subclass of keras "
                        f"LearningRateSchedule, got {base_lr_class}")

  # pylint: disable=protected-access,pointless-statement
  def offset_learning_rate_init(self, offset=0, **kwargs):
    """Construct learning rate schedule object.

    When this object is called, its behavior is
       self.__call__(step) == base_lr_class.__call__(step - offset)
    Args:
      self: this object.
      offset: The offset when computing the learning rate schedule.
      **kwargs: Pass through to base learning rate class constructor.
    """
    base_lr_class.__init__(self, **kwargs)
    self._offset = offset

  def offset_learning_rate_call(self, step):
    step = tf.cast(step - self._offset, tf.float32)
    return base_lr_class.__call__(self, step)

  # pylint: enable=protected-access,pointless-statement

  return type(
      new_class_name, (base_lr_class,), {
          "base_lr_class": base_lr_class,
          "__init__": offset_learning_rate_init,
          "__call__": offset_learning_rate_call
      })


PiecewiseConstantDecayWithOffset = _make_offset_wrapper(
    "PiecewiseConstantDecayWithOffset",
    tf_keras.optimizers.schedules.PiecewiseConstantDecay)
PolynomialDecayWithOffset = _make_offset_wrapper(
    "PolynomialDecayWithOffset", tf_keras.optimizers.schedules.PolynomialDecay)
ExponentialDecayWithOffset = _make_offset_wrapper(
    "ExponentialDecayWithOffset",
    tf_keras.optimizers.schedules.ExponentialDecay)
CosineDecayWithOffset = _make_offset_wrapper(
    "CosineDecayWithOffset",
    tf_keras.optimizers.schedules.CosineDecay,
)


class LinearWarmup(tf_keras.optimizers.schedules.LearningRateSchedule):
  """Linear warmup schedule."""

  def __init__(self,
               after_warmup_lr_sched: Union[
                   tf_keras.optimizers.schedules.LearningRateSchedule, float],
               warmup_steps: int,
               warmup_learning_rate: float,
               name: Optional[str] = None):
    """Add linear warmup schedule to a learning rate schedule.

    warmup_lr is the initial learning rate, the final learning rate of the
    init_warmup period is the initial learning rate of lr_schedule in use.
    The learning rate at each step linearly increased according to the following
    formula:
      learning_rate = warmup_lr + step / warmup_steps
                    * (final_warmup_lr - warmup_lr).
    Using warmup overrides the learning rate schedule by the number of warmup
    steps.

    Args:
      after_warmup_lr_sched: tf_keras.optimizers.schedules .LearningRateSchedule
        or a constant.
      warmup_steps: Number of the warmup steps.
      warmup_learning_rate: Initial learning rate for the warmup.
      name: Optional, name of warmup schedule.
    """
    super().__init__()
    self._name = name
    self._after_warmup_lr_sched = after_warmup_lr_sched
    self._warmup_steps = warmup_steps
    self._init_warmup_lr = warmup_learning_rate
    if isinstance(after_warmup_lr_sched,
                  tf_keras.optimizers.schedules.LearningRateSchedule):
      self._final_warmup_lr = after_warmup_lr_sched(warmup_steps)
    else:
      self._final_warmup_lr = tf.cast(after_warmup_lr_sched, dtype=tf.float32)

  def __call__(self, step: int):

    global_step = tf.cast(step, dtype=tf.float32)

    linear_warmup_lr = (
        self._init_warmup_lr + global_step / self._warmup_steps *
        (self._final_warmup_lr - self._init_warmup_lr))

    if isinstance(self._after_warmup_lr_sched,
                  tf_keras.optimizers.schedules.LearningRateSchedule):
      after_warmup_lr = self._after_warmup_lr_sched(step)
    else:
      after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)

    lr = tf.cond(global_step < self._warmup_steps,
                 lambda: linear_warmup_lr,
                 lambda: after_warmup_lr)
    return lr

  def get_config(self) -> Mapping[str, Any]:
    if isinstance(self._after_warmup_lr_sched,
                  tf_keras.optimizers.schedules.LearningRateSchedule):
      config = {
          "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()}  # pytype: disable=attribute-error
    else:
      config = {"after_warmup_lr_sched": self._after_warmup_lr_sched}  # pytype: disable=attribute-error

    config.update({
        "warmup_steps": self._warmup_steps,
        "warmup_learning_rate": self._init_warmup_lr,
        "name": self._name
    })
    return config


class PolynomialWarmUp(tf_keras.optimizers.schedules.LearningRateSchedule):
  """Applies polynomial warmup schedule on a given learning rate decay schedule."""

  def __init__(self,
               after_warmup_lr_sched: Union[
                   tf_keras.optimizers.schedules.LearningRateSchedule, float],
               warmup_steps: int,
               power: float = 1.0,
               name: str = "PolynomialWarmup"):
    super().__init__()
    if isinstance(after_warmup_lr_sched,
                  tf_keras.optimizers.schedules.LearningRateSchedule):
      self._initial_learning_rate = after_warmup_lr_sched(warmup_steps)
    else:
      self._initial_learning_rate = tf.cast(
          after_warmup_lr_sched, dtype=tf.float32)

    self._warmup_steps = warmup_steps
    self._power = power
    self._after_warmup_lr_sched = after_warmup_lr_sched
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "PolynomialWarmUp") as name:
      # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
      # learning rate will be `global_step/num_warmup_steps * init_lr`.
      global_step_float = tf.cast(step, tf.float32)
      warmup_steps_float = tf.cast(self._warmup_steps, tf.float32)

      if self._warmup_steps <= 0:
        warmup_percent_done = 1.0
      else:
        # A zero `step` may cause Inf. So make `step` positive.
        step_non_zero = tf.math.maximum(global_step_float, 1.0)
        warmup_percent_done = step_non_zero / warmup_steps_float

      warmup_learning_rate = (
          self._initial_learning_rate *
          tf.math.pow(warmup_percent_done, self._power))

      if isinstance(self._after_warmup_lr_sched,
                    tf_keras.optimizers.schedules.LearningRateSchedule):
        after_warmup_lr = self._after_warmup_lr_sched(step)
      else:
        after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)

      return tf.cond(
          global_step_float < warmup_steps_float,
          lambda: warmup_learning_rate,
          lambda: after_warmup_lr,
          name=name)

  def get_config(self) -> Mapping[str, Any]:
    if isinstance(self._after_warmup_lr_sched,
                  tf_keras.optimizers.schedules.LearningRateSchedule):
      config = {
          "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()}  # pytype: disable=attribute-error
    else:
      config = {"after_warmup_lr_sched": self._after_warmup_lr_sched}  # pytype: disable=attribute-error

    config.update({
        "warmup_steps": self._warmup_steps,
        "power": self._power,
        "name": self._name
    })
    return config


class DirectPowerDecay(tf_keras.optimizers.schedules.LearningRateSchedule):
  """Learning rate schedule follows lr * (step)^power."""

  def __init__(self,
               initial_learning_rate: float,
               power: float = 1.0,
               name: str = "DirectPowerDecay"):
    """Initialize configuration of the learning rate schedule.

    Args:
      initial_learning_rate: The initial learning rate.
      power: The order of the polynomial.
      name: Optional, name of learning rate schedule.
    """
    super().__init__()
    self._initial_learning_rate = initial_learning_rate
    self._power = power
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "DirectPowerDecay"):
      step = tf.cast(step, tf.float32)
      learning_rate = self._initial_learning_rate
      # A zero `step` may cause Inf. So make `step` positive.
      step_non_zero = tf.math.maximum(step, 1.0)
      learning_rate *= tf.math.pow(step_non_zero, self._power)
      return learning_rate

  def get_config(self):
    """Get the configuration of the learning rate schedule."""
    return {
        "initial_learning_rate": self._initial_learning_rate,
        "power": self._power,
        "name": self._name,
    }


class PowerAndLinearDecay(tf_keras.optimizers.schedules.LearningRateSchedule):
  """Learning rate schedule with multiplied by linear decay at the end.

  The schedule has the following behavoir.
  Let offset_step = step - offset.
  1) offset_step < 0, the actual learning rate equals initial_learning_rate.
  2) offset_step <= total_decay_steps * (1 - linear_decay_fraction), the
  actual learning rate equals lr * offset_step^power.
  3) total_decay_steps * (1 - linear_decay_fraction) <= offset_step <
  total_decay_steps, the actual learning rate equals lr * offset_step^power *
  (total_decay_steps - offset_step) / (total_decay_steps *
  linear_decay_fraction).
  4) offset_step >= total_decay_steps, the actual learning rate equals zero.
  """

  def __init__(self,
               initial_learning_rate: float,
               total_decay_steps: int,
               power: float = 1.0,
               linear_decay_fraction: float = 0.1,
               offset: int = 0,
               name: str = "PowerAndLinearDecay"):
    """Initialize configuration of the learning rate schedule.

    Args:
      initial_learning_rate: The initial learning rate.
      total_decay_steps: The total number of steps for power + linear decay.
      power: The order of the polynomial.
      linear_decay_fraction: In the last `linear_decay_fraction` steps, the
        learning rate will be multiplied by a linear decay.
      offset: The offset applied to steps.
      name: Optional, name of learning rate schedule.
    """
    super().__init__()
    self._initial_learning_rate = initial_learning_rate
    self._total_decay_steps = total_decay_steps
    self._power = power
    self._linear_decay_fraction = linear_decay_fraction
    self._offset = offset
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "PowerAndLinearDecay"):
      step = tf.cast(step - self._offset, tf.float32)
      learning_rate = self._initial_learning_rate
      # A zero `step` may cause Inf. So make `step` positive.
      step_non_zero = tf.math.maximum(step, 1.0)
      learning_rate *= tf.math.pow(step_non_zero, self._power)
      if self._total_decay_steps * self._linear_decay_fraction > 0:
        learning_rate *= tf.minimum(
            1.0, (self._total_decay_steps - step) /
            (self._total_decay_steps * self._linear_decay_fraction))
        learning_rate = tf.maximum(0.0, learning_rate)
      return learning_rate

  def get_config(self):
    """Get the configuration of the learning rate schedule."""
    return {
        "initial_learning_rate": self._initial_learning_rate,
        "total_decay_steps": self._total_decay_steps,
        "power": self._power,
        "linear_decay_fraction": self._linear_decay_fraction,
        "offset": self._offset,
        "name": self._name,
    }


class PowerDecayWithOffset(tf_keras.optimizers.schedules.LearningRateSchedule):
  """Power learning rate decay with offset.

  Learning rate equals to `pre_offset_learning_rate` if `step` < `offset`.
  Otherwise, learning rate equals to lr * (step - offset)^power.
  """

  def __init__(self,
               initial_learning_rate: float,
               power: float = 1.0,
               offset: int = 0,
               pre_offset_learning_rate: float = 1.0e6,
               name: str = "PowerDecayWithOffset"):
    """Initialize configuration of the learning rate schedule.

    Args:
      initial_learning_rate: The initial learning rate.
      power: The order of the polynomial.
      offset: The offset when computing the power decay.
      pre_offset_learning_rate: The maximum learning rate we'll use.
      name: Optional, name of learning rate schedule.
    """
    super().__init__()
    self._initial_learning_rate = initial_learning_rate
    self._power = power
    self._offset = offset
    self._pre_offset_lr = pre_offset_learning_rate
    self._name = name

  def __call__(self, step):
    with tf.name_scope(self._name or "PowerDecayWithOffset"):
      step = tf.cast(step, tf.float32)
      lr_after_offset = tf.math.pow(
          tf.math.maximum(step - self._offset, 1.0), self._power) * (
              self._initial_learning_rate)

      sign = tf.cast(step > self._offset, tf.float32)
      lr_combined = (1.0 - sign) * self._pre_offset_lr + sign * lr_after_offset
      # Power may give infinitely large LR. So cap it with pre_offset_lr.
      return tf.math.minimum(lr_combined, self._pre_offset_lr)

  def get_config(self):
    """Get the configuration of the learning rate schedule."""
    return {
        "initial_learning_rate": self._initial_learning_rate,
        "power": self._power,
        "offset": self._offset,
        "pre_offset_learning_rate": self._pre_offset_lr,
        "name": self._name,
    }


class StepCosineDecayWithOffset(
    tf_keras.optimizers.schedules.LearningRateSchedule):
  """Stepwise cosine learning rate decay with offset.

  Learning rate is equivalent to one or more cosine decay(s) starting and
  ending at each interval.

  ExampleL

    ```python
    boundaries: [100000, 110000]
    values: [1.0, 0.5]
    lr_decayed_fn = (
    lr_schedule.StepCosineDecayWithOffset(
        boundaries,
        values))
    ```

    from 0 to 100000 step, it will cosine decay from 1.0 to 0.5
    from 100000 to 110000 step, it cosine decay from 0.5 to 0.0
  """

  def __init__(self,
               boundaries,
               values,
               offset: int = 0,
               name: str = "StepCosineDecayWithOffset"):
    """Initialize configuration of the learning rate schedule.

    Args:
      boundaries: A list of `Tensor`s or `int`s with strictly
        increasing entries, and with all elements having the same type as the
        optimizer step.
      values: A list of `Tensor`s or `float`s that specifies the
        values for the intervals defined by `boundaries`. It should have one
        more element than `boundaries`, and all elements should have the same
        type.
      offset: The offset when computing the power decay.
      name: Optional, name of learning rate schedule.
    """
    super().__init__()
    self.values = values
    self.boundaries = boundaries
    self.offset = offset
    self.name = name

    if len(self.values) < 1:
      raise ValueError(f"Expect non empty {self.values}")
    if len(self.boundaries) != len(self.values):
      raise ValueError(
          "Boundaries length is equal to learning rate levels length"
          f"{len(self.boundaries)} != {len(self.values)}")

    self.total_steps = (
        [boundaries[i + 1] - boundaries[i] for i in range(len(boundaries) - 1)
        ] + [0])

  def __call__(self, global_step):
    with tf.name_scope(self.name or "StepCosineDecayWithOffset"):
      global_step = tf.cast(global_step - self.offset, tf.float32)
      lr_levels = self.values
      lr_steps = self.boundaries
      level_total_steps = self.total_steps
      num_levels = len(lr_levels)

      init_lr = lr_levels[0]
      next_init_lr = lr_levels[1] if num_levels > 1 else 0.

      init_total_steps = level_total_steps[0]

      cosine_learning_rate = ((init_lr - next_init_lr) * (tf.cos(
          tf.constant(math.pi) * (global_step) /
          (init_total_steps)) + 1.0) / 2.0 + next_init_lr)
      learning_rate = cosine_learning_rate

      for i in range(1, num_levels):
        next_init_lr = lr_levels[i]
        next_start_step = lr_steps[i]
        next_total_steps = level_total_steps[i]
        next_next_init_lr = lr_levels[i + 1] if num_levels > i + 1 else 0.

        next_cosine_learning_rate = ((next_init_lr - next_next_init_lr) *
                                     (tf.cos(
                                         tf.constant(math.pi) *
                                         (global_step - next_start_step) /
                                         (next_total_steps)) + 1.0) / 2.0 +
                                     next_next_init_lr)
        learning_rate = tf.where(global_step >= next_start_step,
                                 next_cosine_learning_rate, learning_rate)

    return learning_rate

  def get_config(self):
    return {
        "boundaries": self.boundaries,
        "values": self.values,
        "offset": self.offset,
        "name": self.name
    }