Spaces:
Paused
Paused
from functools import partial | |
from typing import Callable | |
def linear_warm_up( | |
step: int, | |
warm_up_steps: int, | |
reduce_lr_steps: int | |
) -> float: | |
r"""Get linear warm up scheduler for LambdaLR. | |
Args: | |
step (int): global step | |
warm_up_steps (int): steps for warm up | |
reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step | |
.. code-block: python | |
>>> lr_lambda = partial(linear_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) | |
>>> from torch.optim.lr_scheduler import LambdaLR | |
>>> LambdaLR(optimizer, lr_lambda) | |
Returns: | |
lr_scale (float): learning rate scaler | |
""" | |
if step <= warm_up_steps: | |
lr_scale = step / warm_up_steps | |
else: | |
lr_scale = 0.9 ** (step // reduce_lr_steps) | |
return lr_scale | |
def constant_warm_up( | |
step: int, | |
warm_up_steps: int, | |
reduce_lr_steps: int | |
) -> float: | |
r"""Get constant warm up scheduler for LambdaLR. | |
Args: | |
step (int): global step | |
warm_up_steps (int): steps for warm up | |
reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step | |
.. code-block: python | |
>>> lr_lambda = partial(constant_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) | |
>>> from torch.optim.lr_scheduler import LambdaLR | |
>>> LambdaLR(optimizer, lr_lambda) | |
Returns: | |
lr_scale (float): learning rate scaler | |
""" | |
if 0 <= step < warm_up_steps: | |
lr_scale = 0.001 | |
elif warm_up_steps <= step < 2 * warm_up_steps: | |
lr_scale = 0.01 | |
elif 2 * warm_up_steps <= step < 3 * warm_up_steps: | |
lr_scale = 0.1 | |
else: | |
lr_scale = 1 | |
return lr_scale | |
def get_lr_lambda( | |
lr_lambda_type: str, | |
**kwargs | |
) -> Callable: | |
r"""Get learning scheduler. | |
Args: | |
lr_lambda_type (str), e.g., "constant_warm_up" | "linear_warm_up" | |
Returns: | |
lr_lambda_func (Callable) | |
""" | |
if lr_lambda_type == "constant_warm_up": | |
lr_lambda_func = partial( | |
constant_warm_up, | |
warm_up_steps=kwargs["warm_up_steps"], | |
reduce_lr_steps=kwargs["reduce_lr_steps"], | |
) | |
elif lr_lambda_type == "linear_warm_up": | |
lr_lambda_func = partial( | |
linear_warm_up, | |
warm_up_steps=kwargs["warm_up_steps"], | |
reduce_lr_steps=kwargs["reduce_lr_steps"], | |
) | |
else: | |
raise NotImplementedError | |
return lr_lambda_func | |