|
|
|
|
|
|
|
|
|
|
|
import math |
|
from collections.abc import Collection |
|
from dataclasses import dataclass, field |
|
from typing import List |
|
|
|
from omegaconf import II |
|
|
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler |
|
|
|
|
|
@dataclass |
|
class CosineLRScheduleConfig(FairseqDataclass): |
|
warmup_updates: int = field( |
|
default=0, |
|
metadata={"help": "warmup the learning rate linearly for the first N updates"}, |
|
) |
|
warmup_init_lr: float = field( |
|
default=-1, |
|
metadata={ |
|
"help": "initial learning rate during warmup phase; default is cfg.lr" |
|
}, |
|
) |
|
lr: List[float] = field( |
|
default=II("optimization.lr"), |
|
metadata={"help": "max learning rate, must be more than cfg.min_lr"}, |
|
) |
|
min_lr: float = field(default=0.0, metadata={"help": "min learning rate"}) |
|
t_mult: float = field( |
|
default=1.0, metadata={"help": "factor to grow the length of each period"} |
|
) |
|
lr_period_updates: float = field( |
|
default=-1, metadata={"help": "initial number of updates per period"} |
|
) |
|
lr_shrink: float = field( |
|
default=0.1, metadata={"help": "shrink factor for annealing"} |
|
) |
|
|
|
max_update: int = II("optimization.max_update") |
|
|
|
|
|
@register_lr_scheduler("cosine", dataclass=CosineLRScheduleConfig) |
|
class CosineLRSchedule(FairseqLRScheduler): |
|
"""Assign LR based on a cyclical schedule that follows the cosine function. |
|
|
|
See https://arxiv.org/pdf/1608.03983.pdf for details. |
|
|
|
We also support a warmup phase where we linearly increase the learning rate |
|
from some initial learning rate (``--warmup-init-lr``) until the configured |
|
max learning rate (``--lr``). |
|
|
|
During warmup:: |
|
|
|
lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) |
|
lr = lrs[update_num] |
|
|
|
After warmup:: |
|
|
|
lr = cfg.min_lr + 0.5*(cfg.lr - cfg.min_lr)*(1 + cos(t_curr / t_i)) |
|
|
|
where ``t_curr`` is current percentage of updates within the current period |
|
range and ``t_i`` is the current period range, which is scaled by ``t_mul`` |
|
after every iteration. |
|
""" |
|
|
|
def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer): |
|
super().__init__(cfg, fairseq_optimizer) |
|
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: |
|
raise ValueError( |
|
"Cannot use a fixed learning rate schedule with cosine." |
|
f" Consider --lr-scheduler=fixed instead. ({cfg.lr})" |
|
) |
|
|
|
self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr |
|
assert ( |
|
self.max_lr > cfg.min_lr |
|
), f"max_lr (={cfg.lr}) must be more than min_lr (={cfg.min_lr})" |
|
|
|
warmup_end_lr = self.max_lr |
|
if cfg.warmup_init_lr < 0: |
|
cfg.warmup_init_lr = cfg.min_lr |
|
|
|
self.t_mult = cfg.t_mult |
|
self.period = cfg.lr_period_updates |
|
|
|
if self.period <= 0: |
|
assert ( |
|
cfg.max_update > 0 |
|
), "Either --max_update or --lr-period-updates must be set" |
|
self.period = cfg.max_update - cfg.warmup_updates |
|
|
|
if cfg.warmup_updates > 0: |
|
|
|
self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates |
|
else: |
|
self.lr_step = 1 |
|
|
|
self.warmup_updates = cfg.warmup_updates |
|
self.lr_shrink = cfg.lr_shrink |
|
|
|
|
|
self.lr = cfg.warmup_init_lr |
|
self.optimizer.set_lr(self.lr) |
|
|
|
def step(self, epoch, val_loss=None): |
|
"""Update the learning rate at the end of the given epoch.""" |
|
super().step(epoch, val_loss) |
|
|
|
return self.optimizer.get_lr() |
|
|
|
def step_update(self, num_updates): |
|
"""Update the learning rate after each update.""" |
|
if num_updates < self.cfg.warmup_updates: |
|
self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step |
|
else: |
|
curr_updates = num_updates - self.cfg.warmup_updates |
|
if self.t_mult != 1: |
|
i = math.floor( |
|
math.log( |
|
1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult |
|
) |
|
) |
|
t_i = self.t_mult ** i * self.period |
|
t_curr = ( |
|
curr_updates |
|
- (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period |
|
) |
|
else: |
|
i = math.floor(curr_updates / self.period) |
|
t_i = self.period |
|
t_curr = curr_updates - (self.period * i) |
|
|
|
lr_shrink = self.lr_shrink ** i |
|
min_lr = self.cfg.min_lr * lr_shrink |
|
max_lr = self.max_lr * lr_shrink |
|
|
|
self.lr = min_lr + 0.5 * (max_lr - min_lr) * ( |
|
1 + math.cos(math.pi * t_curr / t_i) |
|
) |
|
|
|
self.optimizer.set_lr(self.lr) |
|
return self.lr |
|
|