File size: 4,227 Bytes
a256709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
""" TanH Scheduler
TanH schedule with warmup, cycle/restarts, noise.
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import math
import numpy as np
import torch
from .scheduler import Scheduler
_logger = logging.getLogger(__name__)
class TanhLRScheduler(Scheduler):
"""
Hyberbolic-Tangent decay with restarts.
This is described in the paper https://arxiv.org/abs/1806.01593
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lb: float = -6.0,
ub: float = 4.0,
t_mul: float = 1.0,
lr_min: float = 0.0,
decay_rate: float = 1.0,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
cycle_limit=0,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True,
) -> None:
super().__init__(
optimizer,
param_group_field="lr",
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
assert t_initial > 0
assert lr_min >= 0
assert lb < ub
assert cycle_limit >= 0
assert warmup_t >= 0
assert warmup_lr_init >= 0
self.lb = lb
self.ub = ub
self.t_initial = t_initial
self.t_mul = t_mul
self.lr_min = lr_min
self.decay_rate = decay_rate
self.cycle_limit = cycle_limit
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
if self.warmup_t:
t_v = (
self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
)
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
if self.warmup_prefix:
t = t - self.warmup_t
if self.t_mul != 1:
i = math.floor(
math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)
)
t_i = self.t_mul ** i * self.t_initial
t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
else:
i = t // self.t_initial
t_i = self.t_initial
t_curr = t - (self.t_initial * i)
if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
gamma = self.decay_rate ** i
lr_min = self.lr_min * gamma
lr_max_values = [v * gamma for v in self.base_values]
tr = t_curr / t_i
lrs = [
lr_min
+ 0.5
* (lr_max - lr_min)
* (1 - math.tanh(self.lb * (1.0 - tr) + self.ub * tr))
for lr_max in lr_max_values
]
else:
lrs = [
self.lr_min * (self.decay_rate ** self.cycle_limit)
for _ in self.base_values
]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
if not cycles:
cycles = self.cycle_limit
cycles = max(1, cycles)
if self.t_mul == 1.0:
return self.t_initial * cycles
else:
return int(
math.floor(
-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)
)
)
|