File size: 2,762 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import List
from omegaconf import II
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@dataclass
class TriangularLRScheduleConfig(FairseqDataclass):
max_lr: float = field(
default="???", metadata={"help": "max learning rate, must be more than cfg.lr"}
)
lr_period_updates: float = field(
default=5000,
metadata={"help": "initial number of updates per period (cycle length)"},
)
lr_shrink: float = field(
default=0.1, metadata={"help": "shrink factor for annealing"}
)
shrink_min: bool = field(
default=False, metadata={"help": "if set, also shrinks min lr"}
)
lr: List[float] = II("optimization.lr")
@register_lr_scheduler("triangular", dataclass=TriangularLRScheduleConfig)
class TriangularLRSchedule(FairseqLRScheduler):
"""Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details.
"""
def __init__(self, cfg: TriangularLRScheduleConfig, optimizer):
super().__init__(cfg, optimizer)
if len(cfg.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with triangular."
" Consider --lr-scheduler=fixed instead."
)
lr = cfg.lr[0]
assert cfg.max_lr > lr, "max_lr must be more than lr"
self.min_lr = lr
self.max_lr = cfg.max_lr
self.stepsize = cfg.lr_period_updates // 2
self.lr_shrink = cfg.lr_shrink
self.shrink_min = cfg.shrink_min
# initial learning rate
self.lr = self.min_lr
self.optimizer.set_lr(self.lr)
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
# we don't change the learning rate at epoch boundaries
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
cycle = math.floor(num_updates / (2 * self.stepsize))
lr_shrink = self.lr_shrink ** cycle
max_lr = self.max_lr * lr_shrink
if self.shrink_min:
min_lr = self.min_lr * lr_shrink
else:
min_lr = self.min_lr
x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1)
self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x))
self.optimizer.set_lr(self.lr)
return self.lr
|