|
|
|
|
|
"""Learning rate policy.""" |
|
|
|
import math |
|
|
|
|
|
def get_lr_at_epoch(cfg, cur_epoch): |
|
""" |
|
Retrieve the learning rate of the current epoch with the option to perform |
|
warm up in the beginning of the training stage. |
|
Args: |
|
cfg (CfgNode): configs. Details can be found in |
|
slowfast/config/defaults.py |
|
cur_epoch (float): the number of epoch of the current training stage. |
|
""" |
|
lr = get_lr_func(cfg.SOLVER.LR_POLICY)(cfg, cur_epoch) |
|
|
|
if cur_epoch < cfg.SOLVER.WARMUP_EPOCHS: |
|
lr_start = cfg.SOLVER.WARMUP_START_LR |
|
lr_end = get_lr_func(cfg.SOLVER.LR_POLICY)( |
|
cfg, cfg.SOLVER.WARMUP_EPOCHS |
|
) |
|
alpha = (lr_end - lr_start) / cfg.SOLVER.WARMUP_EPOCHS |
|
lr = cur_epoch * alpha + lr_start |
|
return lr |
|
|
|
|
|
def lr_func_cosine(cfg, cur_epoch): |
|
""" |
|
Retrieve the learning rate to specified values at specified epoch with the |
|
cosine learning rate schedule. Details can be found in: |
|
Ilya Loshchilov, and Frank Hutter |
|
SGDR: Stochastic Gradient Descent With Warm Restarts. |
|
Args: |
|
cfg (CfgNode): configs. Details can be found in |
|
slowfast/config/defaults.py |
|
cur_epoch (float): the number of epoch of the current training stage. |
|
""" |
|
assert cfg.SOLVER.COSINE_END_LR < cfg.SOLVER.BASE_LR |
|
return ( |
|
cfg.SOLVER.COSINE_END_LR |
|
+ (cfg.SOLVER.BASE_LR - cfg.SOLVER.COSINE_END_LR) |
|
* (math.cos(math.pi * cur_epoch / cfg.SOLVER.MAX_EPOCH) + 1.0) |
|
* 0.5 |
|
) |
|
|
|
|
|
def lr_func_steps_with_relative_lrs(cfg, cur_epoch): |
|
""" |
|
Retrieve the learning rate to specified values at specified epoch with the |
|
steps with relative learning rate schedule. |
|
Args: |
|
cfg (CfgNode): configs. Details can be found in |
|
slowfast/config/defaults.py |
|
cur_epoch (float): the number of epoch of the current training stage. |
|
""" |
|
ind = get_step_index(cfg, cur_epoch) |
|
return cfg.SOLVER.LRS[ind] * cfg.SOLVER.BASE_LR |
|
|
|
|
|
def get_step_index(cfg, cur_epoch): |
|
""" |
|
Retrieves the lr step index for the given epoch. |
|
Args: |
|
cfg (CfgNode): configs. Details can be found in |
|
slowfast/config/defaults.py |
|
cur_epoch (float): the number of epoch of the current training stage. |
|
""" |
|
steps = cfg.SOLVER.STEPS + [cfg.SOLVER.MAX_EPOCH] |
|
for ind, step in enumerate(steps): |
|
if cur_epoch < step: |
|
break |
|
return ind - 1 |
|
|
|
|
|
def get_lr_func(lr_policy): |
|
""" |
|
Given the configs, retrieve the specified lr policy function. |
|
Args: |
|
lr_policy (string): the learning rate policy to use for the job. |
|
""" |
|
policy = "lr_func_" + lr_policy |
|
if policy not in globals(): |
|
raise NotImplementedError("Unknown LR policy: {}".format(lr_policy)) |
|
else: |
|
return globals()[policy] |
|
|