Spaces:
Sleeping
Sleeping
import math | |
from typing import Any, List | |
from torch.optim import Optimizer | |
from torch.optim.lr_scheduler import LambdaLR | |
from transformers import get_constant_schedule_with_warmup | |
def constant_schedule_with_warmup( | |
optimizer: Optimizer, num_warmup_steps: int, **kwargs | |
) -> LambdaLR: | |
return get_constant_schedule_with_warmup( | |
optimizer=optimizer, num_warmup_steps=num_warmup_steps | |
) | |
# adjusted from transformers | |
def get_cosine_schedule_with_warmup( | |
optimizer: Optimizer, | |
num_warmup_steps: int, | |
num_training_steps: int, | |
min_learning_rate_ratio: float = 0.0, | |
num_cycles: float = 0.5, | |
last_epoch: int = -1, | |
): | |
def lr_lambda(current_step): | |
if current_step < num_warmup_steps: | |
return float(current_step) / float(max(1, num_warmup_steps)) | |
progress = float(current_step - num_warmup_steps) / float( | |
max(1, num_training_steps - num_warmup_steps) | |
) | |
return max( | |
min_learning_rate_ratio, | |
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), | |
) | |
return LambdaLR(optimizer, lr_lambda, last_epoch) | |
# adjusted from transformers | |
def get_linear_schedule_with_warmup( | |
optimizer: Optimizer, | |
num_warmup_steps: int, | |
num_training_steps: int, | |
min_learning_rate_ratio: float = 0.0, | |
last_epoch: int = -1, | |
): | |
def lr_lambda(current_step: int): | |
if current_step < num_warmup_steps: | |
return float(current_step) / float(max(1, num_warmup_steps)) | |
return max( | |
min_learning_rate_ratio, | |
float(num_training_steps - current_step) | |
/ float(max(1, num_training_steps - num_warmup_steps)), | |
) | |
return LambdaLR(optimizer, lr_lambda, last_epoch) | |
class Schedulers: | |
"""Schedulers factory.""" | |
_schedulers = { | |
"Cosine": get_cosine_schedule_with_warmup, | |
"Linear": get_linear_schedule_with_warmup, | |
"Constant": constant_schedule_with_warmup, | |
} | |
def names(cls) -> List[str]: | |
return sorted(cls._schedulers.keys()) | |
def get(cls, name: str) -> Any: | |
"""Access to Schedulers. | |
Args: | |
name: scheduler name | |
Returns: | |
A class to build the Schedulers | |
""" | |
return cls._schedulers.get(name) | |