#!/usr/bin/python3 # -*- coding: utf-8 -*- import math def anneal_linear(start, end, proportion): return start + proportion * (end - start) def anneal_cosine(start, end, proportion): cos_val = math.cos(math.pi * proportion) + 1 return end + (start - end) / 2 * cos_val class Phase: def __init__(self, start, end, n_iter, cur_iter, anneal_fn): self.start, self.end = start, end self.n_iter = n_iter self.anneal_fn = anneal_fn self.n = cur_iter def step(self): self.n += 1 return self.anneal_fn(self.start, self.end, self.n / self.n_iter) def reset(self): self.n = 0 @property def is_done(self): return self.n >= self.n_iter class LinearWarmupCosineDecay(object): def __init__( self, optimizer, lr_max, n_iter, iteration=0, divider=25, warmup_proportion=0.3, phase=('linear', 'cosine'), ): self.optimizer = optimizer phase1 = int(n_iter * warmup_proportion) phase2 = n_iter - phase1 lr_min = lr_max / divider phase_map = {'linear': anneal_linear, 'cosine': anneal_cosine} cur_iter_phase1 = iteration cur_iter_phase2 = max(0, iteration - phase1) self.lr_phase = [ Phase(lr_min, lr_max, phase1, cur_iter_phase1, phase_map[phase[0]]), Phase(lr_max, lr_min / 1e4, phase2, cur_iter_phase2, phase_map[phase[1]]), ] if iteration < phase1: self.phase = 0 else: self.phase = 1 def step(self): lr = self.lr_phase[self.phase].step() for group in self.optimizer.param_groups: group['lr'] = lr if self.lr_phase[self.phase].is_done: self.phase += 1 if self.phase >= len(self.lr_phase): for phase in self.lr_phase: phase.reset() self.phase = 0 return lr if __name__ == '__main__': pass