File size: 3,274 Bytes
a256709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
""" Scheduler Factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from .cosine_lr import CosineLRScheduler
from .tanh_lr import TanhLRScheduler
from .step_lr import StepLRScheduler
from .plateau_lr import PlateauLRScheduler


def create_scheduler(args, optimizer):
    num_epochs = args.epochs

    if getattr(args, "lr_noise", None) is not None:
        lr_noise = getattr(args, "lr_noise")
        if isinstance(lr_noise, (list, tuple)):
            noise_range = [n * num_epochs for n in lr_noise]
            if len(noise_range) == 1:
                noise_range = noise_range[0]
        else:
            noise_range = lr_noise * num_epochs
    else:
        noise_range = None

    lr_scheduler = None
    if args.sched == "cosine":
        lr_scheduler = CosineLRScheduler(
            optimizer,
            t_initial=num_epochs,
            t_mul=getattr(args, "lr_cycle_mul", 1.0),
            lr_min=args.min_lr,
            decay_rate=args.decay_rate,
            warmup_lr_init=args.warmup_lr,
            warmup_t=args.warmup_epochs,
            cycle_limit=getattr(args, "lr_cycle_limit", 1),
            t_in_epochs=True,
            noise_range_t=noise_range,
            noise_pct=getattr(args, "lr_noise_pct", 0.67),
            noise_std=getattr(args, "lr_noise_std", 1.0),
            noise_seed=getattr(args, "seed", 42),
        )
        num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
    elif args.sched == "tanh":
        lr_scheduler = TanhLRScheduler(
            optimizer,
            t_initial=num_epochs,
            t_mul=getattr(args, "lr_cycle_mul", 1.0),
            lr_min=args.min_lr,
            warmup_lr_init=args.warmup_lr,
            warmup_t=args.warmup_epochs,
            cycle_limit=getattr(args, "lr_cycle_limit", 1),
            t_in_epochs=True,
            noise_range_t=noise_range,
            noise_pct=getattr(args, "lr_noise_pct", 0.67),
            noise_std=getattr(args, "lr_noise_std", 1.0),
            noise_seed=getattr(args, "seed", 42),
        )
        num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
    elif args.sched == "step":
        lr_scheduler = StepLRScheduler(
            optimizer,
            decay_t=args.decay_epochs,
            decay_rate=args.decay_rate,
            warmup_lr_init=args.warmup_lr,
            warmup_t=args.warmup_epochs,
            noise_range_t=noise_range,
            noise_pct=getattr(args, "lr_noise_pct", 0.67),
            noise_std=getattr(args, "lr_noise_std", 1.0),
            noise_seed=getattr(args, "seed", 42),
        )
    elif args.sched == "plateau":
        mode = "min" if "loss" in getattr(args, "eval_metric", "") else "max"
        lr_scheduler = PlateauLRScheduler(
            optimizer,
            decay_rate=args.decay_rate,
            patience_t=args.patience_epochs,
            lr_min=args.min_lr,
            mode=mode,
            warmup_lr_init=args.warmup_lr,
            warmup_t=args.warmup_epochs,
            cooldown_t=0,
            noise_range_t=noise_range,
            noise_pct=getattr(args, "lr_noise_pct", 0.67),
            noise_std=getattr(args, "lr_noise_std", 1.0),
            noise_seed=getattr(args, "seed", 42),
        )

    return lr_scheduler, num_epochs