Spaces:
Sleeping
Sleeping
File size: 2,342 Bytes
5caedb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import math
from typing import Any, List
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from transformers import get_constant_schedule_with_warmup
def constant_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, **kwargs
) -> LambdaLR:
return get_constant_schedule_with_warmup(
optimizer=optimizer, num_warmup_steps=num_warmup_steps
)
# adjusted from transformers
def get_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_learning_rate_ratio: float = 0.0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
min_learning_rate_ratio,
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
# adjusted from transformers
def get_linear_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_learning_rate_ratio: float = 0.0,
last_epoch: int = -1,
):
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
min_learning_rate_ratio,
float(num_training_steps - current_step)
/ float(max(1, num_training_steps - num_warmup_steps)),
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
class Schedulers:
"""Schedulers factory."""
_schedulers = {
"Cosine": get_cosine_schedule_with_warmup,
"Linear": get_linear_schedule_with_warmup,
"Constant": constant_schedule_with_warmup,
}
@classmethod
def names(cls) -> List[str]:
return sorted(cls._schedulers.keys())
@classmethod
def get(cls, name: str) -> Any:
"""Access to Schedulers.
Args:
name: scheduler name
Returns:
A class to build the Schedulers
"""
return cls._schedulers.get(name)
|