import math from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR class ConstantLRScheduler(_LRScheduler): def __init__(self, optimizer, last_epoch: int = -1, verbose: bool = False, init_lr: float = 0., ): """ This is an implementation of constant learning rate scheduler. Args: optimizer: Optimizer last_epoch: The index of last epoch. Default: -1 verbose: If ``True``, prints a message to stdout for each update. Default: ``False`` init_lr: Initial learning rate """ self.init_lr = init_lr super().__init__(optimizer, last_epoch, verbose) def state_dict(self): state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} return state_dict def load_state_dict(self, state_dict): self.__dict__.update(state_dict) def get_lr(self): if not self._get_lr_called_within_step: raise RuntimeError( "To get the last learning rate computed by the scheduler, use " "get_last_lr()" ) return [self.init_lr for group in self.optimizer.param_groups] class CosineAnnealingLRScheduler(_LRScheduler): def __init__(self, optimizer, last_epoch: int = -1, verbose: bool = False, init_lr: float = 0., max_lr: float = 4e-4, final_lr: float = 4e-5, warmup_steps: int = 2000, cosine_steps: int = 10000, ): """ This is an implementation of cosine annealing learning rate scheduler. Args: optimizer: Optimizer last_epoch: The index of last epoch. Default: -1 verbose: If ``True``, prints a message to stdout for each update. Default: ``False`` init_lr: Initial learning rate max_lr: Maximum learning rate after warmup final_lr: Final learning rate after decay warmup_steps: Number of steps for warmup cosine_steps: Number of steps for cosine annealing """ self.init_lr = init_lr self.max_lr = max_lr self.final_lr = final_lr self.warmup_steps = warmup_steps self.cosine_steps = cosine_steps super(CosineAnnealingLRScheduler, self).__init__(optimizer, last_epoch, verbose) def state_dict(self): state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} return state_dict def load_state_dict(self, state_dict): self.__dict__.update(state_dict) def get_lr(self): if not self._get_lr_called_within_step: raise RuntimeError( "To get the last learning rate computed by the scheduler, use " "get_last_lr()" ) step_no = self.last_epoch if step_no <= self.warmup_steps: lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr) else: lr = self.final_lr + 0.5 * (self.max_lr - self.final_lr) \ * (1 + math.cos(math.pi * (step_no - self.warmup_steps) / self.cosine_steps)) return [lr for group in self.optimizer.param_groups] class Esm2LRScheduler(_LRScheduler): def __init__(self, optimizer, last_epoch: int = -1, verbose: bool = False, init_lr: float = 0., max_lr: float = 4e-4, final_lr: float = 4e-5, warmup_steps: int = 2000, start_decay_after_n_steps: int = 500000, end_decay_after_n_steps: int = 5000000, on_use: bool = True, ): """ This is an implementation of ESM2's learning rate scheduler. Args: optimizer: Optimizer last_epoch: The index of last epoch. Default: -1 verbose: If ``True``, prints a message to stdout for each update. Default: ``False`` init_lr: Initial learning rate max_lr: Maximum learning rate after warmup final_lr: Final learning rate after decay warmup_steps: Number of steps for warmup start_decay_after_n_steps: Start decay after this number of steps end_decay_after_n_steps: End decay after this number of steps on_use: Whether to use this scheduler. If ``False``, the scheduler will not change the learning rate and will only use the ``init_lr``. Default: ``True`` """ self.init_lr = init_lr self.max_lr = max_lr self.final_lr = final_lr self.warmup_steps = warmup_steps self.start_decay_after_n_steps = start_decay_after_n_steps self.end_decay_after_n_steps = end_decay_after_n_steps self.on_use = on_use super(Esm2LRScheduler, self).__init__(optimizer, last_epoch, verbose) def state_dict(self): state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]} return state_dict def load_state_dict(self, state_dict): self.__dict__.update(state_dict) def get_lr(self): if not self._get_lr_called_within_step: raise RuntimeError( "To get the last learning rate computed by the scheduler, use " "get_last_lr()" ) step_no = self.last_epoch if not self.on_use: return [base_lr for base_lr in self.base_lrs] if step_no <= self.warmup_steps: lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr) elif step_no <= self.start_decay_after_n_steps: lr = self.max_lr elif step_no <= self.end_decay_after_n_steps: portion = (step_no - self.start_decay_after_n_steps) / (self.end_decay_after_n_steps - self.start_decay_after_n_steps) lr = self.max_lr - portion * (self.max_lr - self.final_lr) else: lr = self.final_lr return [lr for group in self.optimizer.param_groups]