File size: 6,787 Bytes
786f6a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
""" Scheduler Factory
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import List, Optional, Union

from torch.optim import Optimizer

from .cosine_lr import CosineLRScheduler
from .multistep_lr import MultiStepLRScheduler
from .plateau_lr import PlateauLRScheduler
from .poly_lr import PolyLRScheduler
from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler


def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None):
    """ cfg/argparse to kwargs helper
    Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
    """
    eval_metric = getattr(cfg, 'eval_metric', 'top1')
    if decreasing_metric is not None:
        plateau_mode = 'min' if decreasing_metric else 'max'
    else:
        plateau_mode = 'min' if 'loss' in eval_metric else 'max'
    kwargs = dict(
        sched=cfg.sched,
        num_epochs=getattr(cfg, 'epochs', 100),
        decay_epochs=getattr(cfg, 'decay_epochs', 30),
        decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]),
        warmup_epochs=getattr(cfg, 'warmup_epochs', 5),
        cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0),
        patience_epochs=getattr(cfg, 'patience_epochs', 10),
        decay_rate=getattr(cfg, 'decay_rate', 0.1),
        min_lr=getattr(cfg, 'min_lr', 0.),
        warmup_lr=getattr(cfg, 'warmup_lr', 1e-5),
        warmup_prefix=getattr(cfg, 'warmup_prefix', False),
        noise=getattr(cfg, 'lr_noise', None),
        noise_pct=getattr(cfg, 'lr_noise_pct', 0.67),
        noise_std=getattr(cfg, 'lr_noise_std', 1.),
        noise_seed=getattr(cfg, 'seed', 42),
        cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.),
        cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1),
        cycle_limit=getattr(cfg, 'lr_cycle_limit', 1),
        k_decay=getattr(cfg, 'lr_k_decay', 1.0),
        plateau_mode=plateau_mode,
        step_on_epochs=not getattr(cfg, 'sched_on_updates', False),
    )
    return kwargs


def create_scheduler(
        args,
        optimizer: Optimizer,
        updates_per_epoch: int = 0,
):
    return create_scheduler_v2(
        optimizer=optimizer,
        **scheduler_kwargs(args),
        updates_per_epoch=updates_per_epoch,
    )


def create_scheduler_v2(
        optimizer: Optimizer,
        sched: str = 'cosine',
        num_epochs: int = 300,
        decay_epochs: int = 90,
        decay_milestones: List[int] = (90, 180, 270),
        cooldown_epochs: int = 0,
        patience_epochs: int = 10,
        decay_rate: float = 0.1,
        min_lr: float = 0,
        warmup_lr: float = 1e-5,
        warmup_epochs: int = 0,
        warmup_prefix: bool = False,
        noise: Union[float, List[float]] = None,
        noise_pct: float = 0.67,
        noise_std: float = 1.,
        noise_seed: int = 42,
        cycle_mul: float = 1.,
        cycle_decay: float = 0.1,
        cycle_limit: int = 1,
        k_decay: float = 1.0,
        plateau_mode: str = 'max',
        step_on_epochs: bool = True,
        updates_per_epoch: int = 0,
):
    t_initial = num_epochs
    warmup_t = warmup_epochs
    decay_t = decay_epochs
    cooldown_t = cooldown_epochs

    if not step_on_epochs:
        assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches'
        t_initial = t_initial * updates_per_epoch
        warmup_t = warmup_t * updates_per_epoch
        decay_t = decay_t * updates_per_epoch
        decay_milestones = [d * updates_per_epoch for d in decay_milestones]
        cooldown_t = cooldown_t * updates_per_epoch

    # warmup args
    warmup_args = dict(
        warmup_lr_init=warmup_lr,
        warmup_t=warmup_t,
        warmup_prefix=warmup_prefix,
    )

    # setup noise args for supporting schedulers
    if noise is not None:
        if isinstance(noise, (list, tuple)):
            noise_range = [n * t_initial for n in noise]
            if len(noise_range) == 1:
                noise_range = noise_range[0]
        else:
            noise_range = noise * t_initial
    else:
        noise_range = None
    noise_args = dict(
        noise_range_t=noise_range,
        noise_pct=noise_pct,
        noise_std=noise_std,
        noise_seed=noise_seed,
    )

    # setup cycle args for supporting schedulers
    cycle_args = dict(
        cycle_mul=cycle_mul,
        cycle_decay=cycle_decay,
        cycle_limit=cycle_limit,
    )

    lr_scheduler = None
    if sched == 'cosine':
        lr_scheduler = CosineLRScheduler(
            optimizer,
            t_initial=t_initial,
            lr_min=min_lr,
            t_in_epochs=step_on_epochs,
            **cycle_args,
            **warmup_args,
            **noise_args,
            k_decay=k_decay,
        )
    elif sched == 'tanh':
        lr_scheduler = TanhLRScheduler(
            optimizer,
            t_initial=t_initial,
            lr_min=min_lr,
            t_in_epochs=step_on_epochs,
            **cycle_args,
            **warmup_args,
            **noise_args,
        )
    elif sched == 'step':
        lr_scheduler = StepLRScheduler(
            optimizer,
            decay_t=decay_t,
            decay_rate=decay_rate,
            t_in_epochs=step_on_epochs,
            **warmup_args,
            **noise_args,
        )
    elif sched == 'multistep':
        lr_scheduler = MultiStepLRScheduler(
            optimizer,
            decay_t=decay_milestones,
            decay_rate=decay_rate,
            t_in_epochs=step_on_epochs,
            **warmup_args,
            **noise_args,
        )
    elif sched == 'plateau':
        assert step_on_epochs, 'Plateau LR only supports step per epoch.'
        warmup_args.pop('warmup_prefix', False)
        lr_scheduler = PlateauLRScheduler(
            optimizer,
            decay_rate=decay_rate,
            patience_t=patience_epochs,
            cooldown_t=0,
            **warmup_args,
            lr_min=min_lr,
            mode=plateau_mode,
            **noise_args,
        )
    elif sched == 'poly':
        lr_scheduler = PolyLRScheduler(
            optimizer,
            power=decay_rate,  # overloading 'decay_rate' as polynomial power
            t_initial=t_initial,
            lr_min=min_lr,
            t_in_epochs=step_on_epochs,
            k_decay=k_decay,
            **cycle_args,
            **warmup_args,
            **noise_args,
        )

    if hasattr(lr_scheduler, 'get_cycle_length'):
        # for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
        t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
        if step_on_epochs:
            num_epochs = t_with_cycles_and_cooldown
        else:
            num_epochs = t_with_cycles_and_cooldown // updates_per_epoch

    return lr_scheduler, num_epochs