Spaces:
Sleeping
Sleeping
from torch.optim.lr_scheduler import ReduceLROnPlateau | |
class WarmupThenReduceLROnPlateau(ReduceLROnPlateau): | |
def __init__(self, optimizer, warmup_steps, *args, **kwargs): | |
""" | |
Args: | |
optimizer (Optimizer): Optimizer to wrap | |
warmup_steps: number of steps before reaching base learning rate | |
*args: Arguments for ReduceLROnPlateau | |
**kwargs: Arguments for ReduceLROnPlateau | |
""" | |
super().__init__(optimizer, *args, **kwargs) | |
self.warmup_steps = warmup_steps | |
self.steps_taken = 0 | |
self.base_lrs = list(map(lambda group: group["lr"], optimizer.param_groups)) | |
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | |
param_group["lr"] = lr | |
def get_lr(self): | |
assert self.steps_taken <= self.warmup_steps | |
return [ | |
base_lr * (self.steps_taken / self.warmup_steps) | |
for base_lr in self.base_lrs | |
] | |
def step(self, metrics=None): | |
self.steps_taken += 1 | |
if self.steps_taken <= self.warmup_steps: | |
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | |
param_group["lr"] = lr | |
elif metrics is not None: | |
super().step(metrics) | |