File size: 5,307 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from collections.abc import Collection
from dataclasses import dataclass, field
from typing import List

from omegaconf import II

from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler


@dataclass
class CosineLRScheduleConfig(FairseqDataclass):
    warmup_updates: int = field(
        default=0,
        metadata={"help": "warmup the learning rate linearly for the first N updates"},
    )
    warmup_init_lr: float = field(
        default=-1,
        metadata={
            "help": "initial learning rate during warmup phase; default is cfg.lr"
        },
    )
    lr: List[float] = field(
        default=II("optimization.lr"),
        metadata={"help": "max learning rate, must be more than cfg.min_lr"},
    )
    min_lr: float = field(default=0.0, metadata={"help": "min learning rate"})
    t_mult: float = field(
        default=1.0, metadata={"help": "factor to grow the length of each period"}
    )
    lr_period_updates: float = field(
        default=-1, metadata={"help": "initial number of updates per period"}
    )
    lr_shrink: float = field(
        default=0.1, metadata={"help": "shrink factor for annealing"}
    )
    # This is not required, but is for convenience in inferring lr_period_updates
    max_update: int = II("optimization.max_update")


@register_lr_scheduler("cosine", dataclass=CosineLRScheduleConfig)
class CosineLRSchedule(FairseqLRScheduler):
    """Assign LR based on a cyclical schedule that follows the cosine function.

    See https://arxiv.org/pdf/1608.03983.pdf for details.

    We also support a warmup phase where we linearly increase the learning rate
    from some initial learning rate (``--warmup-init-lr``) until the configured
    max learning rate (``--lr``).

    During warmup::

      lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates)
      lr = lrs[update_num]

    After warmup::

      lr = cfg.min_lr + 0.5*(cfg.lr - cfg.min_lr)*(1 + cos(t_curr / t_i))

    where ``t_curr`` is current percentage of updates within the current period
    range and ``t_i`` is the current period range, which is scaled by ``t_mul``
    after every iteration.
    """

    def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer):
        super().__init__(cfg, fairseq_optimizer)
        if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
            raise ValueError(
                "Cannot use a fixed learning rate schedule with cosine."
                f" Consider --lr-scheduler=fixed instead. ({cfg.lr})"
            )

        self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
        assert (
            self.max_lr > cfg.min_lr
        ), f"max_lr (={cfg.lr}) must be more than min_lr (={cfg.min_lr})"

        warmup_end_lr = self.max_lr
        if cfg.warmup_init_lr < 0:
            cfg.warmup_init_lr = cfg.min_lr

        self.t_mult = cfg.t_mult
        self.period = cfg.lr_period_updates

        if self.period <= 0:
            assert (
                cfg.max_update > 0
            ), "Either --max_update or --lr-period-updates must be set"
            self.period = cfg.max_update - cfg.warmup_updates

        if cfg.warmup_updates > 0:
            # linearly warmup for the first cfg.warmup_updates
            self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
        else:
            self.lr_step = 1

        self.warmup_updates = cfg.warmup_updates
        self.lr_shrink = cfg.lr_shrink

        # initial learning rate
        self.lr = cfg.warmup_init_lr
        self.optimizer.set_lr(self.lr)

    def step(self, epoch, val_loss=None):
        """Update the learning rate at the end of the given epoch."""
        super().step(epoch, val_loss)
        # we don't change the learning rate at epoch boundaries
        return self.optimizer.get_lr()

    def step_update(self, num_updates):
        """Update the learning rate after each update."""
        if num_updates < self.cfg.warmup_updates:
            self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
        else:
            curr_updates = num_updates - self.cfg.warmup_updates
            if self.t_mult != 1:
                i = math.floor(
                    math.log(
                        1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult
                    )
                )
                t_i = self.t_mult ** i * self.period
                t_curr = (
                    curr_updates
                    - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period
                )
            else:
                i = math.floor(curr_updates / self.period)
                t_i = self.period
                t_curr = curr_updates - (self.period * i)

            lr_shrink = self.lr_shrink ** i
            min_lr = self.cfg.min_lr * lr_shrink
            max_lr = self.max_lr * lr_shrink

            self.lr = min_lr + 0.5 * (max_lr - min_lr) * (
                1 + math.cos(math.pi * t_curr / t_i)
            )

        self.optimizer.set_lr(self.lr)
        return self.lr