from functools import partial
from typing import Callable


def linear_warm_up(
    step: int, 
    warm_up_steps: int, 
    reduce_lr_steps: int
) -> float:
    r"""Get linear warm up scheduler for LambdaLR.

    Args:
        step (int): global step
        warm_up_steps (int): steps for warm up
        reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step

    .. code-block: python
        >>> lr_lambda = partial(linear_warm_up, warm_up_steps=1000, reduce_lr_steps=10000)
        >>> from torch.optim.lr_scheduler import LambdaLR
        >>> LambdaLR(optimizer, lr_lambda)

    Returns:
        lr_scale (float): learning rate scaler
    """

    if step <= warm_up_steps:
        lr_scale = step / warm_up_steps
    else:
        lr_scale = 0.9 ** (step // reduce_lr_steps)

    return lr_scale


def constant_warm_up(
    step: int, 
    warm_up_steps: int, 
    reduce_lr_steps: int
) -> float:
    r"""Get constant warm up scheduler for LambdaLR.

    Args:
        step (int): global step
        warm_up_steps (int): steps for warm up
        reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step

    .. code-block: python
        >>> lr_lambda = partial(constant_warm_up, warm_up_steps=1000, reduce_lr_steps=10000)
        >>> from torch.optim.lr_scheduler import LambdaLR
        >>> LambdaLR(optimizer, lr_lambda)

    Returns:
        lr_scale (float): learning rate scaler
    """
    
    if 0 <= step < warm_up_steps:
        lr_scale = 0.001

    elif warm_up_steps <= step < 2 * warm_up_steps:
        lr_scale = 0.01

    elif 2 * warm_up_steps <= step < 3 * warm_up_steps:
        lr_scale = 0.1

    else:
        lr_scale = 1

    return lr_scale


def get_lr_lambda(
    lr_lambda_type: str, 
    **kwargs
) -> Callable:
    r"""Get learning scheduler.

    Args:
        lr_lambda_type (str), e.g., "constant_warm_up" | "linear_warm_up"

    Returns:
        lr_lambda_func (Callable)
    """
    if lr_lambda_type == "constant_warm_up":

        lr_lambda_func = partial(
            constant_warm_up, 
            warm_up_steps=kwargs["warm_up_steps"], 
            reduce_lr_steps=kwargs["reduce_lr_steps"],
        )

    elif lr_lambda_type == "linear_warm_up":

        lr_lambda_func = partial(
            linear_warm_up, 
            warm_up_steps=kwargs["warm_up_steps"], 
            reduce_lr_steps=kwargs["reduce_lr_steps"],
        )

    else:
        raise NotImplementedError

    return lr_lambda_func