Spaces:
Sleeping
Sleeping
File size: 522 Bytes
98e2ea5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
from torch.optim.lr_scheduler import LambdaLR
def get_inverse_square_root_decay(optimizer, num_warmup_steps=0, last_epoch=-1):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
else:
if num_warmup_steps > 0:
return (num_warmup_steps / current_step) ** 0.5
else:
return (1 / (current_step + 1)) ** 0.5
return LambdaLR(optimizer, lr_lambda, last_epoch)
|