import sys import abc import math import copy import logging from typing import Callable, Iterable, Tuple import torch import torch.nn as nn from torch.nn.utils import clip_grad_norm_ from torch.optim import Optimizer, Adam from torch.optim.lr_scheduler import LambdaLR from torch.optim.optimizer import required logger = logging.getLogger(__name__) def get_optimizer(model_params, total_steps, optimizer_config): optimizer_config = copy.deepcopy(optimizer_config) optimizer_name = optimizer_config.pop('name') optimizer = eval(f'get_{optimizer_name}')( model_params, total_steps=total_steps, **optimizer_config ) return optimizer def get_grouped_parameters(model_params): named_params = [] for m in model_params: named_params += list(m.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] grouped_parameters = [ {'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in named_params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, ] return grouped_parameters def get_BertAdam_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_proportion=0.07, **kwargs): grouped_parameters = get_grouped_parameters(model_params) optimizer = BertAdam(grouped_parameters, lr=lr, warmup=warmup_proportion, t_total=total_steps) return optimizer def get_AdamW_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_proportion=0.07, **kwargs): grouped_parameters = get_grouped_parameters(model_params) optimizer = Lamb(grouped_parameters, lr=lr, warmup=warmup_proportion, t_total=total_steps, adam=True, correct_bias=True, **kwargs) return optimizer def get_Lamb_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_proportion=0.07, **kwargs): grouped_parameters = get_grouped_parameters(model_params) optimizer = Lamb(grouped_parameters, lr=lr, warmup=warmup_proportion, t_total=total_steps, adam=False, correct_bias=False, **kwargs) return optimizer def get_Adam(model_params, lr=2e-4, **kwargs): params = [] for m in model_params: params += list(m.parameters()) return Adam(params, lr=lr, betas=(0.9, 0.999)) def get_AdamW(model_params, lr=2e-4, **kwargs): params = [] for m in model_params: params += list(m.parameters()) optimizer = AdamW(params, lr=lr) return optimizer def get_TorchOptim(model_params, torch_optim_name, **kwargs): params = [] for m in model_params: params += list(m.parameters()) Opt_class = getattr(torch.optim, torch_optim_name) kwargs.pop('total_steps') optim = Opt_class(params, **kwargs) return optim class AdamW(Optimizer): """ Implements Adam algorithm with weight decay fix as introduced in `Decoupled Weight Decay Regularization `__. Parameters: params (:obj:`Iterable[torch.nn.parameter.Parameter]`): Iterable of parameters to optimize or dictionaries defining parameter groups. lr (:obj:`float`, `optional`, defaults to 1e-3): The learning rate to use. betas (:obj:`Tuple[float,float]`, `optional`, defaults to (0.9, 0.999)): Adam's betas parameters (b1, b2). eps (:obj:`float`, `optional`, defaults to 1e-6): Adam's epsilon for numerical stability. weight_decay (:obj:`float`, `optional`, defaults to 0): Decoupled weight decay to apply. correct_bias (:obj:`bool`, `optional`, defaults to `True`): Whether ot not to correct bias in Adam (for instance, in Bert TF repository they use :obj:`False`). """ def __init__( self, params: Iterable[torch.nn.parameter.Parameter], lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-7, weight_decay: float = 0.0, correct_bias: bool = True, ): if lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) super().__init__(params, defaults) def step(self, closure: Callable = None): """ Performs a single optimization step. Arguments: closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 # Decay the first and second moment running average coefficient # In-place operations to update the averages at the same time exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) step_size = group["lr"] if group["correct_bias"]: # No bias correction for Bert bias_correction1 = 1.0 - beta1 ** state["step"] bias_correction2 = 1.0 - beta2 ** state["step"] step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 p.data.addcdiv_(exp_avg, denom, value=-step_size) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. # Add weight decay at the end (fixed version) if group["weight_decay"] > 0.0: p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) return loss def get_lr(self): lr = [] for group in self.param_groups: for p in group['params']: state = self.state[p] if len(state) == 0: pass else: lr.append(group['lr']) return lr # For the following codes: """PyTorch optimization for BERT model.""" # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. if sys.version_info >= (3, 4): ABC = abc.ABC else: ABC = abc.ABCMeta('ABC', (), {}) class _LRSchedule(ABC): """ Parent of all LRSchedules here. """ warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense def __init__(self, warmup=0.002, t_total=-1, **kw): """ :param warmup: what fraction of t_total steps will be used for linear warmup :param t_total: how many training steps (updates) are planned :param kw: """ super(_LRSchedule, self).__init__(**kw) if t_total < 0: logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) if not 0.0 <= warmup < 1.0 and not warmup == -1: raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) warmup = max(warmup, 0.) self.warmup, self.t_total = float(warmup), float(t_total) self.warned_for_t_total_at_progress = -1 def get_lr(self, step, nowarn=False): """ :param step: which of t_total steps we're on :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps :return: learning rate multiplier for current update """ if self.t_total < 0: return 1. progress = float(step) / self.t_total ret = self.get_lr_(progress) # warning for exceeding t_total (only active with warmup_linear if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: logger.warning( "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." .format(ret, self.__class__.__name__)) self.warned_for_t_total_at_progress = progress # end warning return ret @abc.abstractmethod def get_lr_(self, progress): """ :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress :return: learning rate multiplier for current update """ return 1. class ConstantLR(_LRSchedule): def get_lr_(self, progress): return 1. class WarmupCosineSchedule(_LRSchedule): """ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. """ warn_t_total = True def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): """ :param warmup: see LRSchedule :param t_total: see LRSchedule :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. :param kw: """ super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) self.cycles = cycles def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup else: progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): """ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying learning rate (with hard restarts). """ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) assert(cycles >= 1.) def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup else: progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) return ret class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): """ All training progress is divided in `cycles` (default=1.) parts of equal length. Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., followed by a learning rate decreasing from 1. to 0. following a cosine curve. """ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): assert(warmup * cycles < 1.) warmup = warmup * cycles if warmup >= 0 else warmup super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) def get_lr_(self, progress): progress = progress * self.cycles % 1. if progress < self.warmup: return progress / self.warmup else: progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup ret = 0.5 * (1. + math.cos(math.pi * progress)) return ret class WarmupConstantSchedule(_LRSchedule): """ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. Keeps learning rate equal to 1. after warmup. """ def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup return 1. class WarmupLinearSchedule(_LRSchedule): """ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. """ warn_t_total = True def get_lr_(self, progress): if progress < self.warmup: return progress / self.warmup return max((progress - 1.) / (self.warmup - 1.), 0.) SCHEDULES = { None: ConstantLR, "none": ConstantLR, "warmup_cosine": WarmupCosineSchedule, "warmup_constant": WarmupConstantSchedule, "warmup_linear": WarmupLinearSchedule } class BertAdam(Optimizer): """Implements BERT version of Adam algorithm with weight decay fix. Params: lr: learning rate warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 t_total: total number of training steps for the learning rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 schedule: schedule to use for the warmup (see above). Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). If `None` or `'none'`, learning rate is always kept constant. Default : `'warmup_linear'` betas: Adams betas. Default: (0.9, 0.999) e: Adams epsilon. Default: 1e-6 weight_decay: Weight decay. Default: 0.01 max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 """ def __init__(self, params=None, lr='required', warmup=-1, t_total=-1, schedule='warmup_linear', betas=(0.9, 0.999), e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): if lr == 'required' or lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: raise ValueError("Invalid schedule parameter: {}".format(schedule)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {} - should be in [0.0, 1.0[".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {} - should be in [0.0, 1.0[".format(betas[1])) if not e >= 0.0: raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) # initialize schedule object if not isinstance(schedule, _LRSchedule): schedule_type = SCHEDULES[schedule] schedule = schedule_type(warmup=warmup, t_total=t_total) else: if warmup != -1 or t_total != -1: logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " "Please specify custom warmup and t_total in _LRSchedule object.") defaults = dict(lr=lr, schedule=schedule, betas=betas, e=e, weight_decay=weight_decay, max_grad_norm=max_grad_norm) super(BertAdam, self).__init__(params, defaults) def get_lr(self): lr = [] for group in self.param_groups: for p in group['params']: state = self.state[p] if len(state) == 0: pass else: lr_scheduled = group['lr'] lr_scheduled *= group['schedule'].get_lr(state['step']) lr.append(lr_scheduled) return lr def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['next_m'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['next_v'] = torch.zeros_like(p.data) next_m, next_v = state['next_m'], state['next_v'] beta1, beta2 = group['betas'] # Add grad clipping if group['max_grad_norm'] > 0: clip_grad_norm_(p, group['max_grad_norm']) # Decay the first and second moment running average coefficient # In-place operations to update the averages at the same time next_m.mul_(beta1).add_(1 - beta1, grad) next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) update = next_m / (next_v.sqrt() + group['e']) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if group['weight_decay'] > 0.0: update += group['weight_decay'] * p.data lr_scheduled = group['lr'] lr_scheduled *= group['schedule'].get_lr(state['step']) update_with_lr = lr_scheduled * update p.data.add_(-update_with_lr) state['step'] += 1 # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 # No bias correction # bias_correction1 = 1 - beta1 ** state['step'] # bias_correction2 = 1 - beta2 ** state['step'] return loss class Lamb(Optimizer): r"""Implements Lamb algorithm. It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) warmup (float, optional): portion of t_total for the warmup, -1 means no warmup. Default: -1 t_total (int, optional): total number of training steps for the learning rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 schedule (string, optional): schedule to use for the warmup (see above). Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). If `None` or `'none'`, learning rate is always kept constant. Default : `'warmup_linear'` betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) adam (bool, optional): always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes. Set to True for AdamW. correct_bias (bool, optional): adam-correction, no bias correction for Bert. Set to True for AdamW. .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962 """ def __init__(self, params, lr=1e-3, warmup=-1, t_total=-1, schedule='warmup_linear', betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, adam=False, correct_bias=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) # initialize schedule object if not isinstance(schedule, _LRSchedule): schedule_type = SCHEDULES[schedule] schedule = schedule_type(warmup=warmup, t_total=t_total) else: if warmup != -1 or t_total != -1: logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " "Please specify custom warmup and t_total in _LRSchedule object.") defaults = dict(lr=lr, betas=betas, eps=eps, schedule=schedule, weight_decay=weight_decay, correct_bias=correct_bias) self.adam = adam super(Lamb, self).__init__(params, defaults) def get_lr(self): lr = [] for group in self.param_groups: for p in group['params']: state = self.state[p] if len(state) == 0: pass else: lr_scheduled = group['lr'] lr_scheduled *= group['schedule'].get_lr(state['step']) lr.append(lr_scheduled) return lr def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 # Decay the first and second moment running average coefficient # m_t exp_avg.mul_(beta1).add_(1 - beta1, grad) # v_t exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) # Paper v3 does not use debiasing. # bias_correction1 = 1 - beta1 ** state['step'] # bias_correction2 = 1 - beta2 ** state['step'] # Apply bias to lr to avoid broadcast. step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 if group['correct_bias']: # No bias correction for Bert bias_correction1 = 1.0 - beta1 ** state['step'] bias_correction2 = 1.0 - beta2 ** state['step'] step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 lr_scheduled = step_size * group['schedule'].get_lr(state['step']) weight_norm = p.data.pow(2).sum().sqrt() adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) if group['weight_decay'] != 0: adam_step.add_(group['weight_decay'], p.data) adam_norm = adam_step.pow(2).sum().sqrt() if weight_norm == 0 or adam_norm == 0: trust_ratio = 1 else: trust_ratio = weight_norm / adam_norm state['weight_norm'] = weight_norm state['adam_norm'] = adam_norm state['trust_ratio'] = trust_ratio if self.adam: trust_ratio = 1 p.data.add_(-lr_scheduled * trust_ratio, adam_step) return loss