File size: 17,090 Bytes
f14e74e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 |
# Copyright © 2023 Apple Inc.
import math
from typing import List
import mlx.core as mx
from mlx.utils import tree_map
class OptimizerState(dict):
"""The optimizer state implements a recursively defined
:class:`collections.defaultdict`, namely a missing key in an optimizer
state is an :class:`OptimizerState`.
.. note::
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
the key to the ``default`` value if the ``key`` was not present in the
dictionary.
"""
def __getitem__(self, key):
if key not in self:
self[key] = OptimizerState()
return super().__getitem__(key)
def get(self, key, default):
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
if key not in self:
self[key] = default
return super().__getitem__(key)
class Optimizer:
"""The base class for all optimizers. It allows us to implement an
optimizer on a per-parameter basis and apply it to a parameter tree.
Attributes:
state (OptimizerState): It holds the optimizer's state dictionary.
"""
def __init__(self):
self.state = OptimizerState()
def update(self, model: "mlx.nn.Module", gradients: dict):
"""Apply the gradients to the parameters of the model and update the
model with the new parameters.
Args:
model (mlx.nn.Module): An mlx module to be updated.
gradients (dict): A Python tree of gradients, most likely computed
via :func:`mlx.nn.value_and_grad`.
"""
model.update(self.apply_gradients(gradients, model))
def apply_gradients(self, gradients: dict, model: dict):
"""Apply the gradients to the parameters and return the updated parameters.
Can be used to update a model via
``model.update(opt.apply_gradients(grads, model))`` which is precisely
how :meth:`Optimizer.update` is implemented.
Args:
gradients (dict): A Python tree of gradients.
model (dict): A Python tree of parameters. It can be a superset of
the gradients. In that case the returned python tree
will be of the same structure as the gradients.
"""
return tree_map(self.apply_single, gradients, model, self.state)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""To be extended by the children classes to implement each optimizer's
update."""
raise NotImplementedError()
class SGD(Optimizer):
r"""Stochastic gradient descent optimizer.
Updates a parameter :math:`w` with a gradient :math:`g` as follows
.. math::
v_{t+1} &= \mu v_t + (1 - \tau) g_t \\
w_{t+1} &= w_t - \lambda v_{t+1}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
nesterov (bool, optional): Enables Nesterov momentum. Default: ``False``
"""
def __init__(
self,
learning_rate: float,
momentum: float = 0.0,
weight_decay: float = 0.0,
dampening: float = 0.0,
nesterov: bool = False,
):
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening."
)
super().__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.weight_decay = weight_decay
self.dampening = dampening
self.nesterov = nesterov
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the SGD parameter update and stores :math:`v` in the
optimizer state."""
if self.momentum <= 0:
return parameter - self.learning_rate * gradient
v = state.get("v", mx.zeros_like(gradient))
if self.weight_decay != 0:
gradient += self.weight_decay * parameter
v = self.momentum * v
if self.dampening > 0:
v += (1 - self.dampening) * gradient
else:
v += gradient
if self.nesterov:
update = gradient + self.momentum * v
else:
update = v
state["v"] = v
return parameter - self.learning_rate * update
class RMSprop(Optimizer):
r"""Implementation of the RMSprop optimizer [1].
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
.. math::
v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
alpha (float, optional): The smoothing constant :math:`\alpha`.
Default: ``0.99``
eps (float, optional): The term :math:`\epsilon` added to the denominator
to improve numerical stability. Default: ``1e-8``
"""
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
super().__init__()
self.learning_rate = learning_rate
self.alpha = alpha
self.eps = eps
if self.alpha < 0.0:
raise ValueError(
f"RMSprop alpha should be >=0, {self.alpha} was provided instead"
)
if self.eps < 0.0:
raise ValueError(
f"RMSprop epsilon should be >0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
lr = self.learning_rate
alpha = self.alpha
eps = self.eps
v = state.get("v", mx.zeros_like(gradient))
v = alpha * v + (1 - alpha) * mx.square(gradient)
state["v"] = v
return parameter - lr * gradient / (mx.sqrt(v) + eps)
class Adagrad(Optimizer):
r"""Implementation of the Adagrad optimizer [1].
Our Adagrad implementation follows the original paper. In detail,
[1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
for online learning and stochastic optimization. JMLR 2011.
.. math::
v_{t+1} &= v_t + g_t^2 \\
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-8``
"""
def __init__(self, learning_rate: float, eps: float = 1e-8):
super().__init__()
self.learning_rate = learning_rate
self.eps = eps
if self.eps < 0.0:
raise ValueError(
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Adagrad parameter update and stores :math:`v` in the
optimizer state."""
lr = self.learning_rate
eps = self.eps
v = state.get("v", mx.zeros_like(gradient))
v = v + mx.square(gradient)
state["v"] = v
return parameter - lr * gradient / (mx.sqrt(v) + eps)
class AdaDelta(Optimizer):
r"""Implementation of the AdaDelta optimizer with learning rate[1].
Our AdaDelta implementation follows the original paper. In detail,
[1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
.. math::
v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\
\Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\
u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\
w_{t+1} &= w_t - \lambda \Delta w_{t+1}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
rho (float, optional): The coefficient :math:`\rho` used for computing a
running average of squared gradients. Default: ``0.9``
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
numerical stability. Default: `1e-8`
"""
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
super().__init__()
self.learning_rate = learning_rate
self.rho = rho
self.eps = eps
if self.rho < 0.0:
raise ValueError(
f"AdaDelta rho should be >=0, {self.rho} was provided instead"
)
if self.eps < 0.0:
raise ValueError(
f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the AdaDelta parameter update and stores :math:`v` and
:math:`u` in the optimizer state."""
lr = self.learning_rate
rho = self.rho
eps = self.eps
v = state.get("v", mx.zeros_like(gradient))
u = state.get("s", mx.zeros_like(gradient))
v = rho * v + (1 - rho) * mx.square(gradient)
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
u = rho * u + (1 - rho) * mx.square(d)
state["v"] = v
state["u"] = u
return parameter - lr * d
class Adam(Optimizer):
r"""Implementation of the Adam optimizer [1].
Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail,
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
optimization. ICLR 2015.
.. math::
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its square. Default: ``(0.9, 0.999)``
eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-8``
"""
def __init__(
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
):
super().__init__()
self.learning_rate = learning_rate
self.betas = betas
self.eps = eps
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Adam parameter update and stores :math:`v` and
:math:`m` in the optimizer state."""
lr = self.learning_rate
b1, b2 = self.betas
eps = self.eps
m = state.get("m", gradient)
v = state.get("v", mx.square(gradient))
m = b1 * m + (1 - b1) * gradient
v = b2 * v + (1 - b2) * mx.square(gradient)
state["m"] = m
state["v"] = v
return parameter - lr * m / (mx.sqrt(v) + eps)
class AdamW(Adam):
r"""Implementation of the AdamW optimizer [1].
Following the above convention, in contrast with [1], we do not use bias
correction in the first and second moments for AdamW. We update the weights
with a weight_decay (:math:`\lambda`) value:
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
regularization. ICLR 2019.
.. math::
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
Args:
learning_rate (float): The learning rate :math:`\alpha`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its square. Default: ``(0.9, 0.999)``
eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-8``
weight_decay (float, optional): The weight decay :math:`\lambda`.
Default: ``0``.
"""
def __init__(
self,
learning_rate: float,
betas: List[float] = [0.9, 0.999],
eps: float = 1e-8,
weight_decay: float = 0.01,
):
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
self.weight_decay = weight_decay
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the AdamW parameter update by modifying the parameters
passed into Adam.
"""
return super().apply_single(
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
)
class Adamax(Adam):
r"""Implementation of the Adamax optimizer. It is a variant of Adam based
on the infinity norm [1].
Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail,
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
optimization. ICLR 2015.
.. math::
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its square. Default: ``(0.9, 0.999)``
eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-8``
"""
def __init__(
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
):
super().__init__(learning_rate, betas, eps)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Adamax parameter update and stores :math:`v` and
:math:`m` in the optimizer state."""
lr = self.learning_rate
b1, b2 = self.betas
eps = self.eps
m = state.get("m", mx.zeros_like(gradient))
v = state.get("v", mx.zeros_like(gradient))
m = b1 * m + (1 - b1) * gradient
v = mx.maximum(b2 * v, mx.abs(gradient))
state["m"] = m
state["v"] = v
return parameter - lr * m / (v + eps)
class Lion(Optimizer):
r"""Implementation of the Lion optimizer [1].
Since updates are computed through the sign operation, they tend to
have larger norm than for other optimizers such as SGD and Adam.
We recommend a learning rate that is 3-10x smaller than AdamW and a
weight decay 3-10x larger than AdamW to maintain the strength
(lr * wd). Our Lion implementation follows the original paper. In
detail,
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
preprint arXiv:2302.06675.
.. math::
c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t
m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
Args:
learning_rate (float): The learning rate :math:`\eta`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing the gradient
momentum and update direction. Default: ``(0.9, 0.99)``
weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
"""
def __init__(
self,
learning_rate: float,
betas: List[float] = [0.9, 0.99],
weight_decay: float = 0.0,
):
super().__init__()
self.learning_rate = learning_rate
self.betas = betas
self.weight_decay = weight_decay
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Lion parameter update and stores :math:`m`
in the optimizer state."""
lr = self.learning_rate
b1, b2 = self.betas
weight_decay = self.weight_decay
m = state.get("m", gradient)
c = b1 * m + (1 - b1) * gradient
state["m"] = b2 * m + (1 - b2) * gradient
if weight_decay > 0:
parameter = (1 - lr * weight_decay) * parameter
return parameter - lr * mx.sign(c)
|