import copy import math import random from typing import Callable, Iterable, Tuple import torch import numpy as np from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR def get_scheduler(optimizer, total_steps, scheduler_config): scheduler_config = copy.deepcopy(scheduler_config) scheduler_name = scheduler_config.pop('name') scheduler = eval(f'get_{scheduler_name}')( optimizer, num_training_steps=total_steps, **scheduler_config ) return scheduler def get_cosine_with_hard_restarts_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer (:class:`~torch.optim.Optimizer`): The optimizer for which to schedule the learning rate. num_warmup_steps (:obj:`int`): The number of steps for the warmup phase. num_training_steps (:obj:`int`): The total number of training steps. num_cycles (:obj:`int`, `optional`, defaults to 1): The number of hard restarts to use. last_epoch (:obj:`int`, `optional`, defaults to -1): The index of the last epoch when resuming training. Return: :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / \ float(max(1, num_training_steps - num_warmup_steps)) if progress >= 1.0: return 0.0 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_cosine_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer (:class:`~torch.optim.Optimizer`): The optimizer for which to schedule the learning rate. num_warmup_steps (:obj:`int`): The number of steps for the warmup phase. num_training_steps (:obj:`int`): The total number of training steps. num_cycles (:obj:`float`, `optional`, defaults to 0.5): The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine). last_epoch (:obj:`int`, `optional`, defaults to -1): The index of the last epoch when resuming training. Return: :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / \ float(max(1, num_training_steps - num_warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer (:class:`~torch.optim.Optimizer`): The optimizer for which to schedule the learning rate. num_warmup_steps (:obj:`int`): The number of steps for the warmup phase. num_training_steps (:obj:`int`): The total number of training steps. last_epoch (:obj:`int`, `optional`, defaults to -1): The index of the last epoch when resuming training. Return: :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) ) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_sqrt_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer (:class:`~torch.optim.Optimizer`): The optimizer for which to schedule the learning rate. num_warmup_steps (:obj:`int`): The number of steps for the warmup phase. num_training_steps (:obj:`int`): The total number of training steps. last_epoch (:obj:`int`, `optional`, defaults to -1): The index of the last epoch when resuming training. Return: :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return 1.0 / math.sqrt(max(current_step, num_warmup_steps)) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_constant_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer (:class:`~torch.optim.Optimizer`): The optimizer for which to schedule the learning rate. num_warmup_steps (:obj:`int`): The number of steps for the warmup phase. num_training_steps (:obj:`int`): The total number of training steps. last_epoch (:obj:`int`, `optional`, defaults to -1): The index of the last epoch when resuming training. Return: :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch) def get_noam_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer (:class:`~torch.optim.Optimizer`): The optimizer for which to schedule the learning rate. num_warmup_steps (:obj:`int`): The number of steps for the warmup phase. num_training_steps (:obj:`int`): The total number of training steps. last_epoch (:obj:`int`, `optional`, defaults to -1): The index of the last epoch when resuming training. Return: :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return ( 768 ** (-0.5) * min(current_step ** (-0.5), current_step * num_warmup_steps**(-1.5))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_polynomial_decay_schedule_with_warmup( optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 ): """ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by `lr_end`, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer (:class:`~torch.optim.Optimizer`): The optimizer for which to schedule the learning rate. num_warmup_steps (:obj:`int`): The number of steps for the warmup phase. num_training_steps (:obj:`int`): The total number of training steps. lr_end (:obj:`float`, `optional`, defaults to 1e-7): The end LR. power (:obj:`float`, `optional`, defaults to 1.0): Power factor. last_epoch (:obj:`int`, `optional`, defaults to -1): The index of the last epoch when resuming training. Note: `power` defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 Return: :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_init = optimizer.defaults["lr"] assert lr_init > lr_end, f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})" def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step > num_training_steps: return lr_end / lr_init # as LambdaLR multiplies by lr_init else: lr_range = lr_init - lr_end decay_steps = num_training_steps - num_warmup_steps pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps decay = lr_range * pct_remaining ** power + lr_end return decay / lr_init # as LambdaLR multiplies by lr_init return LambdaLR(optimizer, lr_lambda, last_epoch)