from torch.optim.lr_scheduler import LambdaLR def get_inverse_square_root_decay(optimizer, num_warmup_steps=0, last_epoch=-1): def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) else: if num_warmup_steps > 0: return (num_warmup_steps / current_step) ** 0.5 else: return (1 / (current_step + 1)) ** 0.5 return LambdaLR(optimizer, lr_lambda, last_epoch)