Spaces:
Sleeping
Sleeping
from torch.optim.lr_scheduler import _LRScheduler | |
class PolynomialLR(_LRScheduler): | |
def __init__( | |
self, | |
optimizer, | |
step_size, | |
iter_warmup, | |
iter_max, | |
power, | |
min_lr=0, | |
last_epoch=-1, | |
): | |
self.step_size = step_size | |
self.iter_warmup = int(iter_warmup) | |
self.iter_max = int(iter_max) | |
self.power = power | |
self.min_lr = min_lr | |
super(PolynomialLR, self).__init__(optimizer, last_epoch) | |
def polynomial_decay(self, lr): | |
iter_cur = float(self.last_epoch) | |
if iter_cur < self.iter_warmup: | |
coef = iter_cur / self.iter_warmup | |
coef *= (1 - self.iter_warmup / self.iter_max) ** self.power | |
else: | |
coef = (1 - iter_cur / self.iter_max) ** self.power | |
return (lr - self.min_lr) * coef + self.min_lr | |
def get_lr(self): | |
if ( | |
(self.last_epoch == 0) | |
or (self.last_epoch % self.step_size != 0) | |
or (self.last_epoch > self.iter_max) | |
): | |
return [group["lr"] for group in self.optimizer.param_groups] | |
return [self.polynomial_decay(lr) for lr in self.base_lrs] | |
def step_update(self, num_updates): | |
self.step() |