import math class CosineSchedule: def __init__(self, min_lr, init_lr, decay_rate, max_epochs) -> None: self.min_lr = min_lr self.init_lr = init_lr self.decay_rate = decay_rate self.max_epochs = max_epochs def __call__(self, optimizer, epoch): """Decay the learning rate""" lr = (self.init_lr - self.min_lr) * 0.5 * ( 1.0 + math.cos(math.pi * epoch / self.max_epochs) ) + self.min_lr for param_group in optimizer.param_groups: param_group["lr"] = lr class StepSchedule: def __init__(self, min_lr, init_lr, decay_rate) -> None: self.min_lr = min_lr self.init_lr = init_lr self.decay_rate = decay_rate def __call__(self, optimizer, epoch): lr = max(self.min_lr, self.init_lr * (self.decay_rate**epoch)) for param_group in optimizer.param_groups: param_group["lr"] = lr