File size: 3,945 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import LegacyFairseqLRScheduler, register_lr_scheduler
import logging
import ast
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
@register_lr_scheduler("manual")
class ManualSchedule(LegacyFairseqLRScheduler):
"""Decay the LR on a manual schedule."""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
self.epoch2lr = self.parse_manuallr_args(args.epoch2lr)
self.update2lr = self.parse_manuallr_args(args.update2lr)
logger.info("@@@ ManualSchedule epoch2lr={}".format(self.epoch2lr))
logger.info("@@@ ManualSchedule update2lr={}".format(self.update2lr))
if 1 in self.epoch2lr:
self.lr = self.epoch2lr[1]
elif 1 in self.update2lr:
self.lr = self.update2lr[1]
else:
self.lr = args.lr[0]
self.optimizer.set_lr(self.lr) # Set the beginning of the epoch.
def parse_manuallr_args(self, lr_args_str):
lr_dict = ast.literal_eval(lr_args_str.replace(' ', ''))
if not isinstance(lr_dict, dict):
raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict")
lr_args = {}
logger.info("@@@ after parsing input dictionary lr_dict = {}".format(lr_dict))
for key, val in lr_dict.items():
if "," in key:
for k in key.split(","):
lr_args[int(k)] = float(val)
elif "-" in key:
s = int(key.split("-")[0])
e = int(key.split("-")[1])
for k in range(s, e + 1, 1):
lr_args[k] = float(val)
else:
lr_args[int(key)] = float(val)
return lr_args
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument(
"--epoch2lr",
type=str,
metavar="DICT",
default="{}",
help="a dictionary used to set lr for each epoch manually",
)
parser.add_argument(
"--update2lr",
type=str,
metavar="DICT",
default="{}",
help="a dictionary used to set lr for each update manually",
)
# fmt: on
def state_dict(self):
return {"lr": self.lr}
def load_state_dict(self, state_dict):
if "lr" in state_dict:
self.lr = state_dict["lr"]
def get_next_lr(self, epoch):
manual_keys = [k for k in self.epoch2lr if k <= epoch]
if manual_keys:
manual_lr = self.epoch2lr[max(manual_keys)]
else:
logger.warning("@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format(
epoch, list(self.epoch2lr.items())[:min(10, len(self.epoch2lr.keys())-1)]
))
manual_lr = self.optimizer.get_lr()
return manual_lr
def step_begin_epoch(self, epoch):
"""Update the learning rate at the beginning of the given epoch."""
self.lr = self.get_next_lr(epoch)
self.optimizer.set_lr(self.lr)
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
manual_keys = [k for k in self.update2lr if k <= num_updates]
if manual_keys:
manual_lr = self.update2lr[max(manual_keys)]
else:
logger.warning("epoch={} does not exist in manual lr input update2lr={}...".format(
num_updates, list(self.update2lr.items())[:min(10, len(self.update2lr.keys())-1)]))
manual_lr = self.optimizer.get_lr()
self.optimizer.set_lr(manual_lr)
return self.optimizer.get_lr()
|