|
""" |
|
Author: Luigi Piccinelli |
|
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) |
|
""" |
|
|
|
import numpy as np |
|
|
|
|
|
class CosineScheduler(object): |
|
def __init__( |
|
self, |
|
optimizer, |
|
warmup_iters, |
|
total_iters, |
|
key, |
|
overwrite=False, |
|
init_value=None, |
|
base_value=None, |
|
final_value=None, |
|
step_init=-1, |
|
): |
|
super().__init__() |
|
self.iter = step_init |
|
self.overwrite = overwrite |
|
self.optimizer = optimizer |
|
self.base_value = base_value |
|
self.init_value = init_value |
|
self.final_value = final_value |
|
self.total_iters = total_iters |
|
self.warmup_iters = warmup_iters |
|
self.key = key |
|
self.schedulers = [ |
|
self.get_schedulers(group) for group in optimizer.param_groups |
|
] |
|
|
|
def get_schedulers(self, group): |
|
init_value = group.get(self.key + "_init", self.init_value) |
|
base_value = group.get(self.key + "_base", self.base_value) |
|
final_value = group.get(self.key + "_final", self.final_value) |
|
warmup_iters = self.warmup_iters |
|
total_iters = self.total_iters |
|
if self.overwrite: |
|
final_value = self.final_value |
|
|
|
|
|
normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) |
|
normalized_schedule = np.power(normalized_schedule, 2) |
|
warmup_schedule = (base_value - init_value) * normalized_schedule + init_value |
|
|
|
|
|
iters = np.arange(total_iters - warmup_iters) |
|
schedule = final_value + 0.5 * (base_value - final_value) * ( |
|
1 + np.cos(np.pi * iters / len(iters)) |
|
) |
|
return np.concatenate((warmup_schedule, schedule)) |
|
|
|
def step(self): |
|
self.iter = self.iter + 1 |
|
vals = self[self.iter] |
|
for group, val in zip(self.optimizer.param_groups, vals): |
|
if isinstance(group[self.key], (tuple, list)): |
|
val = (val, *group[self.key][1:]) |
|
group[self.key] = val |
|
|
|
def __getitem__(self, it): |
|
it = min(it, self.total_iters - 1) |
|
return [scheduler[it] for scheduler in self.schedulers] |
|
|
|
def get(self): |
|
return [group[self.key] for group in self.optimizer.param_groups] |
|
|