|
""" Scheduler Factory |
|
Hacked together by / Copyright 2021 Ross Wightman |
|
""" |
|
from typing import List, Optional, Union |
|
|
|
from torch.optim import Optimizer |
|
|
|
from .cosine_lr import CosineLRScheduler |
|
from .multistep_lr import MultiStepLRScheduler |
|
from .plateau_lr import PlateauLRScheduler |
|
from .poly_lr import PolyLRScheduler |
|
from .step_lr import StepLRScheduler |
|
from .tanh_lr import TanhLRScheduler |
|
|
|
|
|
def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None): |
|
""" cfg/argparse to kwargs helper |
|
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args. |
|
""" |
|
eval_metric = getattr(cfg, 'eval_metric', 'top1') |
|
if decreasing_metric is not None: |
|
plateau_mode = 'min' if decreasing_metric else 'max' |
|
else: |
|
plateau_mode = 'min' if 'loss' in eval_metric else 'max' |
|
kwargs = dict( |
|
sched=cfg.sched, |
|
num_epochs=getattr(cfg, 'epochs', 100), |
|
decay_epochs=getattr(cfg, 'decay_epochs', 30), |
|
decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]), |
|
warmup_epochs=getattr(cfg, 'warmup_epochs', 5), |
|
cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0), |
|
patience_epochs=getattr(cfg, 'patience_epochs', 10), |
|
decay_rate=getattr(cfg, 'decay_rate', 0.1), |
|
min_lr=getattr(cfg, 'min_lr', 0.), |
|
warmup_lr=getattr(cfg, 'warmup_lr', 1e-5), |
|
warmup_prefix=getattr(cfg, 'warmup_prefix', False), |
|
noise=getattr(cfg, 'lr_noise', None), |
|
noise_pct=getattr(cfg, 'lr_noise_pct', 0.67), |
|
noise_std=getattr(cfg, 'lr_noise_std', 1.), |
|
noise_seed=getattr(cfg, 'seed', 42), |
|
cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.), |
|
cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1), |
|
cycle_limit=getattr(cfg, 'lr_cycle_limit', 1), |
|
k_decay=getattr(cfg, 'lr_k_decay', 1.0), |
|
plateau_mode=plateau_mode, |
|
step_on_epochs=not getattr(cfg, 'sched_on_updates', False), |
|
) |
|
return kwargs |
|
|
|
|
|
def create_scheduler( |
|
args, |
|
optimizer: Optimizer, |
|
updates_per_epoch: int = 0, |
|
): |
|
return create_scheduler_v2( |
|
optimizer=optimizer, |
|
**scheduler_kwargs(args), |
|
updates_per_epoch=updates_per_epoch, |
|
) |
|
|
|
|
|
def create_scheduler_v2( |
|
optimizer: Optimizer, |
|
sched: str = 'cosine', |
|
num_epochs: int = 300, |
|
decay_epochs: int = 90, |
|
decay_milestones: List[int] = (90, 180, 270), |
|
cooldown_epochs: int = 0, |
|
patience_epochs: int = 10, |
|
decay_rate: float = 0.1, |
|
min_lr: float = 0, |
|
warmup_lr: float = 1e-5, |
|
warmup_epochs: int = 0, |
|
warmup_prefix: bool = False, |
|
noise: Union[float, List[float]] = None, |
|
noise_pct: float = 0.67, |
|
noise_std: float = 1., |
|
noise_seed: int = 42, |
|
cycle_mul: float = 1., |
|
cycle_decay: float = 0.1, |
|
cycle_limit: int = 1, |
|
k_decay: float = 1.0, |
|
plateau_mode: str = 'max', |
|
step_on_epochs: bool = True, |
|
updates_per_epoch: int = 0, |
|
): |
|
t_initial = num_epochs |
|
warmup_t = warmup_epochs |
|
decay_t = decay_epochs |
|
cooldown_t = cooldown_epochs |
|
|
|
if not step_on_epochs: |
|
assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches' |
|
t_initial = t_initial * updates_per_epoch |
|
warmup_t = warmup_t * updates_per_epoch |
|
decay_t = decay_t * updates_per_epoch |
|
decay_milestones = [d * updates_per_epoch for d in decay_milestones] |
|
cooldown_t = cooldown_t * updates_per_epoch |
|
|
|
|
|
warmup_args = dict( |
|
warmup_lr_init=warmup_lr, |
|
warmup_t=warmup_t, |
|
warmup_prefix=warmup_prefix, |
|
) |
|
|
|
|
|
if noise is not None: |
|
if isinstance(noise, (list, tuple)): |
|
noise_range = [n * t_initial for n in noise] |
|
if len(noise_range) == 1: |
|
noise_range = noise_range[0] |
|
else: |
|
noise_range = noise * t_initial |
|
else: |
|
noise_range = None |
|
noise_args = dict( |
|
noise_range_t=noise_range, |
|
noise_pct=noise_pct, |
|
noise_std=noise_std, |
|
noise_seed=noise_seed, |
|
) |
|
|
|
|
|
cycle_args = dict( |
|
cycle_mul=cycle_mul, |
|
cycle_decay=cycle_decay, |
|
cycle_limit=cycle_limit, |
|
) |
|
|
|
lr_scheduler = None |
|
if sched == 'cosine': |
|
lr_scheduler = CosineLRScheduler( |
|
optimizer, |
|
t_initial=t_initial, |
|
lr_min=min_lr, |
|
t_in_epochs=step_on_epochs, |
|
**cycle_args, |
|
**warmup_args, |
|
**noise_args, |
|
k_decay=k_decay, |
|
) |
|
elif sched == 'tanh': |
|
lr_scheduler = TanhLRScheduler( |
|
optimizer, |
|
t_initial=t_initial, |
|
lr_min=min_lr, |
|
t_in_epochs=step_on_epochs, |
|
**cycle_args, |
|
**warmup_args, |
|
**noise_args, |
|
) |
|
elif sched == 'step': |
|
lr_scheduler = StepLRScheduler( |
|
optimizer, |
|
decay_t=decay_t, |
|
decay_rate=decay_rate, |
|
t_in_epochs=step_on_epochs, |
|
**warmup_args, |
|
**noise_args, |
|
) |
|
elif sched == 'multistep': |
|
lr_scheduler = MultiStepLRScheduler( |
|
optimizer, |
|
decay_t=decay_milestones, |
|
decay_rate=decay_rate, |
|
t_in_epochs=step_on_epochs, |
|
**warmup_args, |
|
**noise_args, |
|
) |
|
elif sched == 'plateau': |
|
assert step_on_epochs, 'Plateau LR only supports step per epoch.' |
|
warmup_args.pop('warmup_prefix', False) |
|
lr_scheduler = PlateauLRScheduler( |
|
optimizer, |
|
decay_rate=decay_rate, |
|
patience_t=patience_epochs, |
|
cooldown_t=0, |
|
**warmup_args, |
|
lr_min=min_lr, |
|
mode=plateau_mode, |
|
**noise_args, |
|
) |
|
elif sched == 'poly': |
|
lr_scheduler = PolyLRScheduler( |
|
optimizer, |
|
power=decay_rate, |
|
t_initial=t_initial, |
|
lr_min=min_lr, |
|
t_in_epochs=step_on_epochs, |
|
k_decay=k_decay, |
|
**cycle_args, |
|
**warmup_args, |
|
**noise_args, |
|
) |
|
|
|
if hasattr(lr_scheduler, 'get_cycle_length'): |
|
|
|
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t |
|
if step_on_epochs: |
|
num_epochs = t_with_cycles_and_cooldown |
|
else: |
|
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch |
|
|
|
return lr_scheduler, num_epochs |
|
|