Spaces:
Sleeping
Sleeping
File size: 663 Bytes
801501a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
class PolyDecayScheduler(torch.optim.lr_scheduler.LambdaLR):
def __init__(self, optimizer, init_lr, power=0.99, lr_end=1e-7, last_epoch=-1):
def lr_lambda(step):
lr = max(power**step, lr_end / init_lr)
return lr
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
def get_dropout_mask(shape, dropout: float, device):
if dropout == 1:
return torch.zeros(shape, device=device, dtype=torch.bool)
elif dropout == 0:
return torch.ones_like(shape, device=device, dtype=torch.bool)
else:
return torch.zeros(shape, device=device).float().uniform_(0, 1) > dropout
|