from torch.optim.lr_scheduler import _LRScheduler class PolyLr(_LRScheduler): def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1): self.gamma = gamma self.max_iteration = max_iteration self.minimum_lr = minimum_lr self.warmup_iteration = warmup_iteration self.last_epoch = None self.base_lrs = [] super(PolyLr, self).__init__(optimizer, last_epoch) def poly_lr(self, base_lr, step): return (base_lr - self.minimum_lr) * ((1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr def warmup_lr(self, base_lr, alpha): return base_lr * (1 / 10.0 * (1 - alpha) + alpha) def get_lr(self): if self.last_epoch < self.warmup_iteration: alpha = self.last_epoch / self.warmup_iteration lrs = [min(self.warmup_lr(base_lr, alpha), self.poly_lr(base_lr, self.last_epoch)) for base_lr in self.base_lrs] else: lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] return lrs