File size: 5,009 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import List

import torch.optim.lr_scheduler
from omegaconf import II

from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler


@dataclass
class ReduceLROnPlateauLRScheduleConfig(FairseqDataclass):
    lr_shrink: float = field(
        default=0.1, metadata={"help": "shrink factor for annealing"}
    )
    lr_threshold: float = field(
        default=1e-4,
        metadata={
            "help": (
                "threshold for measuring the new optimum, to only focus on "
                "significant changes"
            )
        },
    )
    lr_patience: int = field(
        default=0,
        metadata={
            "help": (
                "number of epochs with no improvement after which learning rate will "
                "be reduced"
            )
        },
    )
    warmup_updates: int = field(
        default=0,
        metadata={"help": "warmup the learning rate linearly for the first N updates"},
    )
    warmup_init_lr: float = field(
        default=-1,
        metadata={
            "help": "initial learning rate during warmup phase; default is cfg.lr"
        },
    )
    lr: List[float] = II("optimization.lr")
    maximize_best_checkpoint_metric: bool = II(
        "checkpoint.maximize_best_checkpoint_metric"
    )


@register_lr_scheduler(
    "reduce_lr_on_plateau", dataclass=ReduceLROnPlateauLRScheduleConfig
)
class ReduceLROnPlateauLRSchedule(FairseqLRScheduler):
    """
    Decay the LR by a factor every time the validation loss plateaus.
    Also comes with optional warmup phase, where we linearly increase
    the learning rate from some initial learning rate
    (``--warmup-init-lr``) until the configured learning rate
    (``--lr``). Thereafter the lr is adjusted according to original
    reduce_on_plateau scheme.

    During warmup::

      lrs = torch.linspace(
          cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates
      )
      lr = lrs[update_num]
    """

    def __init__(self, cfg: ReduceLROnPlateauLRScheduleConfig, optimizer):
        super().__init__(cfg, optimizer)
        if len(cfg.lr) > 1:
            raise ValueError(
                "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau."
                " Consider --lr-scheduler=fixed instead."
            )
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer.optimizer,
            patience=cfg.lr_patience,
            factor=cfg.lr_shrink,
            mode="max" if cfg.maximize_best_checkpoint_metric else "min",
            threshold=cfg.lr_threshold,
        )
        warmup_end_lr = cfg.lr[0]
        # if no warm up, sets initial lr to be cfg.lr[0]
        if cfg.warmup_init_lr < 0:
            cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr

        # linearly warmup for the first cfg.warmup_updates
        if cfg.warmup_updates > 0:
            self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates

        # this flag is either set from arg when no warm up, or set by
        # step_update() when warmup finishes
        self.warmup_end = True if cfg.warmup_updates <= 0 else False

        # initial learning rate
        # this self.lr is used only during init and/or warm up period
        self.lr = cfg.warmup_init_lr
        self.optimizer.set_lr(self.lr)

    def state_dict(self):
        """Return the LR scheduler state dict."""
        return {
            "best": self.lr_scheduler.best,
            "last_epoch": self.lr_scheduler.last_epoch,
        }

    def load_state_dict(self, state_dict):
        """Load an LR scheduler state dict."""
        self.lr_scheduler.best = state_dict["best"]
        if "last_epoch" in state_dict:
            self.lr_scheduler.last_epoch = state_dict["last_epoch"]

    def step(self, epoch, val_loss=None):
        """
        Update the learning rate at the end of the given epoch if warmup
        finishes otherwise no update of lr on epoch boundaries
        """
        if val_loss is not None and self.warmup_end is True:
            self.lr_scheduler.step(val_loss)
        else:
            self.lr_scheduler.last_epoch = epoch
        return self.optimizer.get_lr()

    def step_update(self, num_updates):
        """
        Update the learning rate after each update."""
        # if there is warmup
        if self.cfg.warmup_updates > 0:
            if num_updates <= self.cfg.warmup_updates:
                self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
                self.optimizer.set_lr(self.lr)
            else:
                if self.warmup_end is False:
                    self.warmup_end = True
        # else do nothing
        return self.optimizer.get_lr()