File size: 2,031 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from argparse import Namespace
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.optim import FairseqOptimizer
class FairseqLRScheduler(object):
def __init__(self, cfg, optimizer):
super().__init__()
if optimizer is not None and not isinstance(optimizer, FairseqOptimizer):
raise ValueError("optimizer must be an instance of FairseqOptimizer")
self.cfg = cfg
self.optimizer = optimizer
self.best = None
@classmethod
def add_args(cls, parser):
"""Add arguments to the parser for this LR scheduler."""
dc = getattr(cls, "__dataclass", None)
if dc is not None:
gen_parser_from_dataclass(parser, dc())
def state_dict(self):
"""Return the LR scheduler state dict."""
return {"best": self.best}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
self.best = state_dict["best"]
def step_begin_epoch(self, epoch):
"""Update the learning rate at the beginning of the given epoch."""
pass
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
if val_loss is not None:
if self.best is None:
self.best = val_loss
else:
self.best = min(self.best, val_loss)
def step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.optimizer.get_lr()
class LegacyFairseqLRScheduler(FairseqLRScheduler):
def __init__(self, args: Namespace, optimizer):
if not isinstance(optimizer, FairseqOptimizer):
raise ValueError("optimizer must be an instance of FairseqOptimizer")
self.args = args
self.optimizer = optimizer
self.best = None
|