import weakref import numpy as np class PlainCosineScheduler(object): def __init__( self, klass, key, warmup_iters, total_iters, overwrite=False, init_value=None, base_value=None, final_value=None, step_init=-1, ): super().__init__() self.iter = step_init self.overwrite = overwrite self.base_value = base_value self.init_value = init_value if init_value is not None else base_value self.final_value = final_value self.total_iters = total_iters self.warmup_iters = warmup_iters self.key = key self.klass = klass self.schedulers = [self.get_scheduler()] def get_scheduler(self): init_value = self.init_value base_value = self.base_value final_value = self.final_value warmup_iters = self.warmup_iters total_iters = self.total_iters # normalize in 0,1, then apply function (power) and denormalize normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) normalized_schedule = np.power(normalized_schedule, 1) warmup_schedule = (base_value - init_value) * normalized_schedule + init_value # main scheduling iters = np.arange(total_iters - warmup_iters + 1) schedule = final_value + 0.5 * (base_value - final_value) * ( 1 + np.cos(np.pi * iters / (len(iters) - 1)) ) return np.concatenate((warmup_schedule, schedule)) def step(self): self.iter = self.iter + 1 vals = self[self.iter] for i, val in enumerate(vals): setattr(self.klass, self.key, val) def __getitem__(self, it): it = min(it, self.total_iters) return [scheduler[it] for scheduler in self.schedulers] class CosineScheduler(object): def __init__( self, optimizer, warmup_iters, total_iters, key, overwrite=False, init_value=None, base_value=None, final_value=None, flat_iters=0, step_init=-1, ): super().__init__() self.iter = step_init self.overwrite = overwrite self.optimizer = optimizer self.base_value = base_value self.init_value = init_value self.final_value = final_value self.total_iters = total_iters self.warmup_iters = warmup_iters self.flat_iters = flat_iters self.key = key self.schedulers = [ self.get_schedulers(group) for group in optimizer.param_groups ] def get_schedulers(self, group): init_value = group.get(self.key + "_init", self.init_value) base_value = group.get(self.key + "_base", self.base_value) final_value = group.get(self.key + "_final", self.final_value) warmup_iters = self.warmup_iters total_iters = self.total_iters flat_iters = self.flat_iters if self.overwrite: final_value = self.final_value # normalize in 0,1, then apply function (power) and denormalize normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) normalized_schedule = np.power(normalized_schedule, 1) warmup_schedule = (base_value - init_value) * normalized_schedule + init_value # flat scheduling] flat_schedule = np.ones(flat_iters) * base_value # decay scheduling decay_iters = np.arange(total_iters - warmup_iters - flat_iters + 1) decay_schedule = final_value + 0.5 * (base_value - final_value) * ( 1 + np.cos(np.pi * decay_iters / (len(decay_iters) - 1)) ) return np.concatenate((warmup_schedule, flat_schedule, decay_schedule)) def step(self): self.iter = self.iter + 1 vals = self[self.iter] for group, val in zip(self.optimizer.param_groups, vals): if isinstance(group[self.key], (tuple, list)): val = (val, *group[self.key][1:]) group[self.key] = val def __getitem__(self, it): it = min(it, self.total_iters) return [scheduler[it] for scheduler in self.schedulers] def get(self): return [group[self.key] for group in self.optimizer.param_groups]