Spaces:
Paused
Paused
File size: 8,953 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import math
from bisect import bisect_right
from typing import List
import torch
from fvcore.common.param_scheduler import (
CompositeParamScheduler,
ConstantParamScheduler,
LinearParamScheduler,
ParamScheduler,
)
try:
from torch.optim.lr_scheduler import LRScheduler
except ImportError:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
logger = logging.getLogger(__name__)
class WarmupParamScheduler(CompositeParamScheduler):
"""
Add an initial warmup stage to another scheduler.
"""
def __init__(
self,
scheduler: ParamScheduler,
warmup_factor: float,
warmup_length: float,
warmup_method: str = "linear",
rescale_interval: bool = False,
):
"""
Args:
scheduler: warmup will be added at the beginning of this scheduler
warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001
warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire
training, e.g. 0.01
warmup_method: one of "linear" or "constant"
rescale_interval: whether we will rescale the interval of the scheduler after
warmup
"""
# the value to reach when warmup ends
end_value = scheduler(0.0) if rescale_interval else scheduler(warmup_length)
start_value = warmup_factor * scheduler(0.0)
if warmup_method == "constant":
warmup = ConstantParamScheduler(start_value)
elif warmup_method == "linear":
warmup = LinearParamScheduler(start_value, end_value)
else:
raise ValueError("Unknown warmup method: {}".format(warmup_method))
super().__init__(
[warmup, scheduler],
interval_scaling=["rescaled", "rescaled" if rescale_interval else "fixed"],
lengths=[warmup_length, 1 - warmup_length],
)
class LRMultiplier(LRScheduler):
"""
A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the
learning rate of each param in the optimizer.
Every step, the learning rate of each parameter becomes its initial value
multiplied by the output of the given :class:`ParamScheduler`.
The absolute learning rate value of each parameter can be different.
This scheduler can be used as long as the relative scale among them do
not change during training.
Examples:
::
LRMultiplier(
opt,
WarmupParamScheduler(
MultiStepParamScheduler(
[1, 0.1, 0.01],
milestones=[60000, 80000],
num_updates=90000,
), 0.001, 100 / 90000
),
max_iter=90000
)
"""
# NOTES: in the most general case, every LR can use its own scheduler.
# Supporting this requires interaction with the optimizer when its parameter
# group is initialized. For example, classyvision implements its own optimizer
# that allows different schedulers for every parameter group.
# To avoid this complexity, we use this class to support the most common cases
# where the relative scale among all LRs stay unchanged during training. In this
# case we only need a total of one scheduler that defines the relative LR multiplier.
def __init__(
self,
optimizer: torch.optim.Optimizer,
multiplier: ParamScheduler,
max_iter: int,
last_iter: int = -1,
):
"""
Args:
optimizer, last_iter: See ``torch.optim.lr_scheduler.LRScheduler``.
``last_iter`` is the same as ``last_epoch``.
multiplier: a fvcore ParamScheduler that defines the multiplier on
every LR of the optimizer
max_iter: the total number of training iterations
"""
if not isinstance(multiplier, ParamScheduler):
raise ValueError(
"_LRMultiplier(multiplier=) must be an instance of fvcore "
f"ParamScheduler. Got {multiplier} instead."
)
self._multiplier = multiplier
self._max_iter = max_iter
super().__init__(optimizer, last_epoch=last_iter)
def state_dict(self):
# fvcore schedulers are stateless. Only keep pytorch scheduler states
return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch}
def get_lr(self) -> List[float]:
multiplier = self._multiplier(self.last_epoch / self._max_iter)
return [base_lr * multiplier for base_lr in self.base_lrs]
"""
Content below is no longer needed!
"""
# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes
# only on epoch boundaries. We typically use iteration based schedules instead.
# As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean
# "iteration" instead.
# FIXME: ideally this would be achieved with a CombinedLRScheduler, separating
# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it.
class WarmupMultiStepLR(LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
milestones: List[int],
gamma: float = 0.1,
warmup_factor: float = 0.001,
warmup_iters: int = 1000,
warmup_method: str = "linear",
last_epoch: int = -1,
):
logger.warning(
"WarmupMultiStepLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
)
if not list(milestones) == sorted(milestones):
raise ValueError(
"Milestones should be a list of" " increasing integers. Got {}", milestones
)
self.milestones = milestones
self.gamma = gamma
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
)
return [
base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs
]
def _compute_values(self) -> List[float]:
# The new interface
return self.get_lr()
class WarmupCosineLR(LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
max_iters: int,
warmup_factor: float = 0.001,
warmup_iters: int = 1000,
warmup_method: str = "linear",
last_epoch: int = -1,
):
logger.warning(
"WarmupCosineLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!"
)
self.max_iters = max_iters
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
)
# Different definitions of half-cosine with warmup are possible. For
# simplicity we multiply the standard half-cosine schedule by the warmup
# factor. An alternative is to start the period of the cosine at warmup_iters
# instead of at 0. In the case that warmup_iters << max_iters the two are
# very close to each other.
return [
base_lr
* warmup_factor
* 0.5
* (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters))
for base_lr in self.base_lrs
]
def _compute_values(self) -> List[float]:
# The new interface
return self.get_lr()
def _get_warmup_factor_at_iter(
method: str, iter: int, warmup_iters: int, warmup_factor: float
) -> float:
"""
Return the learning rate warmup factor at a specific iteration.
See :paper:`ImageNet in 1h` for more details.
Args:
method (str): warmup method; either "constant" or "linear".
iter (int): iteration at which to calculate the warmup factor.
warmup_iters (int): the number of warmup iterations.
warmup_factor (float): the base warmup factor (the meaning changes according
to the method used).
Returns:
float: the effective warmup factor at the given iteration.
"""
if iter >= warmup_iters:
return 1.0
if method == "constant":
return warmup_factor
elif method == "linear":
alpha = iter / warmup_iters
return warmup_factor * (1 - alpha) + alpha
else:
raise ValueError("Unknown warmup method: {}".format(method))
|