File size: 3,274 Bytes
a256709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
""" Scheduler Factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from .cosine_lr import CosineLRScheduler
from .tanh_lr import TanhLRScheduler
from .step_lr import StepLRScheduler
from .plateau_lr import PlateauLRScheduler
def create_scheduler(args, optimizer):
num_epochs = args.epochs
if getattr(args, "lr_noise", None) is not None:
lr_noise = getattr(args, "lr_noise")
if isinstance(lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in lr_noise]
if len(noise_range) == 1:
noise_range = noise_range[0]
else:
noise_range = lr_noise * num_epochs
else:
noise_range = None
lr_scheduler = None
if args.sched == "cosine":
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=getattr(args, "lr_cycle_mul", 1.0),
lr_min=args.min_lr,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=getattr(args, "lr_cycle_limit", 1),
t_in_epochs=True,
noise_range_t=noise_range,
noise_pct=getattr(args, "lr_noise_pct", 0.67),
noise_std=getattr(args, "lr_noise_std", 1.0),
noise_seed=getattr(args, "seed", 42),
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == "tanh":
lr_scheduler = TanhLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=getattr(args, "lr_cycle_mul", 1.0),
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=getattr(args, "lr_cycle_limit", 1),
t_in_epochs=True,
noise_range_t=noise_range,
noise_pct=getattr(args, "lr_noise_pct", 0.67),
noise_std=getattr(args, "lr_noise_std", 1.0),
noise_seed=getattr(args, "seed", 42),
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == "step":
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
noise_range_t=noise_range,
noise_pct=getattr(args, "lr_noise_pct", 0.67),
noise_std=getattr(args, "lr_noise_std", 1.0),
noise_seed=getattr(args, "seed", 42),
)
elif args.sched == "plateau":
mode = "min" if "loss" in getattr(args, "eval_metric", "") else "max"
lr_scheduler = PlateauLRScheduler(
optimizer,
decay_rate=args.decay_rate,
patience_t=args.patience_epochs,
lr_min=args.min_lr,
mode=mode,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cooldown_t=0,
noise_range_t=noise_range,
noise_pct=getattr(args, "lr_noise_pct", 0.67),
noise_std=getattr(args, "lr_noise_std", 1.0),
noise_seed=getattr(args, "seed", 42),
)
return lr_scheduler, num_epochs
|