File size: 573 Bytes
59b7eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.optim as optim

def WarmupLR(optimizer, warmup_step=0,  down_step=5e4, max_lr=1e-4, min_lr=1e-5, **kwargs):
    alpha = (max_lr - 1e-5) / warmup_step**2
    def lr_lambda(step):
        init_lr = 1e-5
        s1, s2 = warmup_step, warmup_step + down_step
        if step < s1:
            return init_lr + alpha * step**2
        elif s1 <= step < s2:
            return (max_lr - min_lr) / (s1 - s2) * step + (min_lr*s1 - max_lr*s2) / (s1 - s2)
        else:
            return min_lr
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)