UniK3D-demo / unik3d /ops /scheduler.py
Luigi Piccinelli
init demo
1ea89dd
import weakref
import numpy as np
class PlainCosineScheduler(object):
def __init__(
self,
klass,
key,
warmup_iters,
total_iters,
overwrite=False,
init_value=None,
base_value=None,
final_value=None,
step_init=-1,
):
super().__init__()
self.iter = step_init
self.overwrite = overwrite
self.base_value = base_value
self.init_value = init_value if init_value is not None else base_value
self.final_value = final_value
self.total_iters = total_iters
self.warmup_iters = warmup_iters
self.key = key
self.klass = klass
self.schedulers = [self.get_scheduler()]
def get_scheduler(self):
init_value = self.init_value
base_value = self.base_value
final_value = self.final_value
warmup_iters = self.warmup_iters
total_iters = self.total_iters
# normalize in 0,1, then apply function (power) and denormalize
normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
normalized_schedule = np.power(normalized_schedule, 1)
warmup_schedule = (base_value - init_value) * normalized_schedule + init_value
# main scheduling
iters = np.arange(total_iters - warmup_iters + 1)
schedule = final_value + 0.5 * (base_value - final_value) * (
1 + np.cos(np.pi * iters / (len(iters) - 1))
)
return np.concatenate((warmup_schedule, schedule))
def step(self):
self.iter = self.iter + 1
vals = self[self.iter]
for i, val in enumerate(vals):
setattr(self.klass, self.key, val)
def __getitem__(self, it):
it = min(it, self.total_iters)
return [scheduler[it] for scheduler in self.schedulers]
class CosineScheduler(object):
def __init__(
self,
optimizer,
warmup_iters,
total_iters,
key,
overwrite=False,
init_value=None,
base_value=None,
final_value=None,
flat_iters=0,
step_init=-1,
):
super().__init__()
self.iter = step_init
self.overwrite = overwrite
self.optimizer = optimizer
self.base_value = base_value
self.init_value = init_value
self.final_value = final_value
self.total_iters = total_iters
self.warmup_iters = warmup_iters
self.flat_iters = flat_iters
self.key = key
self.schedulers = [
self.get_schedulers(group) for group in optimizer.param_groups
]
def get_schedulers(self, group):
init_value = group.get(self.key + "_init", self.init_value)
base_value = group.get(self.key + "_base", self.base_value)
final_value = group.get(self.key + "_final", self.final_value)
warmup_iters = self.warmup_iters
total_iters = self.total_iters
flat_iters = self.flat_iters
if self.overwrite:
final_value = self.final_value
# normalize in 0,1, then apply function (power) and denormalize
normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
normalized_schedule = np.power(normalized_schedule, 1)
warmup_schedule = (base_value - init_value) * normalized_schedule + init_value
# flat scheduling]
flat_schedule = np.ones(flat_iters) * base_value
# decay scheduling
decay_iters = np.arange(total_iters - warmup_iters - flat_iters + 1)
decay_schedule = final_value + 0.5 * (base_value - final_value) * (
1 + np.cos(np.pi * decay_iters / (len(decay_iters) - 1))
)
return np.concatenate((warmup_schedule, flat_schedule, decay_schedule))
def step(self):
self.iter = self.iter + 1
vals = self[self.iter]
for group, val in zip(self.optimizer.param_groups, vals):
if isinstance(group[self.key], (tuple, list)):
val = (val, *group[self.key][1:])
group[self.key] = val
def __getitem__(self, it):
it = min(it, self.total_iters)
return [scheduler[it] for scheduler in self.schedulers]
def get(self):
return [group[self.key] for group in self.optimizer.param_groups]