HoneyTian's picture
update
e27a095
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import math
def anneal_linear(start, end, proportion):
return start + proportion * (end - start)
def anneal_cosine(start, end, proportion):
cos_val = math.cos(math.pi * proportion) + 1
return end + (start - end) / 2 * cos_val
class Phase:
def __init__(self, start, end, n_iter, cur_iter, anneal_fn):
self.start, self.end = start, end
self.n_iter = n_iter
self.anneal_fn = anneal_fn
self.n = cur_iter
def step(self):
self.n += 1
return self.anneal_fn(self.start, self.end, self.n / self.n_iter)
def reset(self):
self.n = 0
@property
def is_done(self):
return self.n >= self.n_iter
class LinearWarmupCosineDecay(object):
def __init__(
self,
optimizer,
lr_max,
n_iter,
iteration=0,
divider=25,
warmup_proportion=0.3,
phase=('linear', 'cosine'),
):
self.optimizer = optimizer
phase1 = int(n_iter * warmup_proportion)
phase2 = n_iter - phase1
lr_min = lr_max / divider
phase_map = {'linear': anneal_linear, 'cosine': anneal_cosine}
cur_iter_phase1 = iteration
cur_iter_phase2 = max(0, iteration - phase1)
self.lr_phase = [
Phase(lr_min, lr_max, phase1, cur_iter_phase1, phase_map[phase[0]]),
Phase(lr_max, lr_min / 1e4, phase2, cur_iter_phase2, phase_map[phase[1]]),
]
if iteration < phase1:
self.phase = 0
else:
self.phase = 1
def step(self):
lr = self.lr_phase[self.phase].step()
for group in self.optimizer.param_groups:
group['lr'] = lr
if self.lr_phase[self.phase].is_done:
self.phase += 1
if self.phase >= len(self.lr_phase):
for phase in self.lr_phase:
phase.reset()
self.phase = 0
return lr
if __name__ == '__main__':
pass