DockFormerPP / dockformerpp /utils /lr_schedulers.py
bshor's picture
add code
0fdcb79
import torch
class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
""" Implements the learning rate schedule defined in the AlphaFold 2
supplement. A linear warmup is followed by a plateau at the maximum
learning rate and then exponential decay.
Note that the initial learning rate of the optimizer in question is
ignored; use this class' base_lr parameter to specify the starting
point of the warmup.
"""
def __init__(self,
optimizer,
last_epoch: int = -1,
verbose: bool = False,
base_lr: float = 0.,
max_lr: float = 0.001,
warmup_no_steps: int = 1000,
start_decay_after_n_steps: int = 50000,
decay_every_n_steps: int = 50000,
decay_factor: float = 0.95,
):
step_counts = {
"warmup_no_steps": warmup_no_steps,
"start_decay_after_n_steps": start_decay_after_n_steps,
}
for k,v in step_counts.items():
if(v < 0):
raise ValueError(f"{k} must be nonnegative")
if(warmup_no_steps > start_decay_after_n_steps):
raise ValueError(
"warmup_no_steps must not exceed start_decay_after_n_steps"
)
self.optimizer = optimizer
self.last_epoch = last_epoch
self.verbose = verbose
self.base_lr = base_lr
self.max_lr = max_lr
self.warmup_no_steps = warmup_no_steps
self.start_decay_after_n_steps = start_decay_after_n_steps
self.decay_every_n_steps = decay_every_n_steps
self.decay_factor = decay_factor
super(AlphaFoldLRScheduler, self).__init__(
optimizer,
last_epoch=last_epoch,
verbose=verbose,
)
def state_dict(self):
state_dict = {
k:v for k,v in self.__dict__.items() if k not in ["optimizer"]
}
return state_dict
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
if(not self._get_lr_called_within_step):
raise RuntimeError(
"To get the last learning rate computed by the scheduler, use "
"get_last_lr()"
)
step_no = self.last_epoch
if(step_no <= self.warmup_no_steps):
lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
elif(step_no > self.start_decay_after_n_steps):
steps_since_decay = step_no - self.start_decay_after_n_steps
exp = (steps_since_decay // self.decay_every_n_steps) + 1
lr = self.max_lr * (self.decay_factor ** exp)
else: # plateau
lr = self.max_lr
return [lr for group in self.optimizer.param_groups]