File size: 522 Bytes
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch.optim.lr_scheduler import LambdaLR


def get_inverse_square_root_decay(optimizer, num_warmup_steps=0, last_epoch=-1):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        else:
            if num_warmup_steps > 0:
                return (num_warmup_steps / current_step) ** 0.5
            else:
                return (1 / (current_step + 1)) ** 0.5

    return LambdaLR(optimizer, lr_lambda, last_epoch)