File size: 5,919 Bytes
de4ade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
import torch
from composer.utils import dist
from torch.optim.optimizer import Optimizer
log = logging.getLogger(__name__)
class DecoupledLionW(Optimizer):
metric_functions = {
'l2_norm/moment':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
optim_state['exp_avg']),
'l2_norm/param':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
param.data),
'l2_norm/update':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
step_tensor),
'l2_norm/grad':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
param.grad),
}
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[dict]],
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
):
if lr <= 0.:
raise Exception(f'Invalid LR: {lr}. LR must be > 0')
if not all([0. <= beta <= 1. for beta in betas]):
raise Exception(
f'Invalid beta values: {betas} All betas must be between 0 and 1.'
)
if weight_decay >= 1e-3:
log.warning(
f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? '
+
f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!'
)
defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay}
super().__init__(params, defaults)
for group in self.param_groups:
group['initial_lr'] = group['lr']
@staticmethod
def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor,
lr: float, initial_lr: float, wd: float, beta1: float,
beta2: float) -> None:
# stepweight decay
if wd != 0:
decay_factor = (lr / initial_lr) if initial_lr else 1.0
p.data.mul_(1 - decay_factor * wd)
# update is interpolation between gradient and momentum
update = exp_avg.lerp(grad, 1 - beta1).sign_()
p.add_(update, alpha=-lr)
# momentum is interp b/w gradient and itself
exp_avg.lerp_(grad, 1 - beta2)
@torch.no_grad()
def step(self, closure: Optional[Callable] = None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in filter(lambda p: p.grad is not None and p.requires_grad,
group['params']):
grad, lr, initial_lr, wd, beta1, beta2, state = p.grad, group[
'lr'], group['initial_lr'], group[
'weight_decay'], *group['betas'], self.state[p]
# init state - exponential moving average of gradient values
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
exp_avg = state['exp_avg']
self.lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2)
return loss
def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):
local_keys = list(optimizer_metrics.keys())
all_gathered_keys = dist.all_gather_object(local_keys)
all_keys = set()
for keys in all_gathered_keys:
all_keys.update(keys)
# Sort keys to ensure every rank has the same keys order
# Only L2 norm metric keys are present, can apply regular sort
all_keys = sorted(all_keys)
for metric in all_keys:
if metric.startswith('l2_norm'):
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')
optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced))
else:
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')
optimizer_metrics[metric] = reduced / dist.get_world_size()
return optimizer_metrics
def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):
"""Preprocess metrics to reduce across ranks correctly."""
# Only L2 norm metric keys are present, can skip sorting at this stage
for metric in optimizer_metrics:
# L2 norms need to be squared, before they are reduced via summation
optimizer_metrics[metric] = optimizer_metrics[metric]**2
return optimizer_metrics
def report_per_parameter_metrics(self, param: torch.Tensor, name: str,
optimizer_metrics: dict):
lr = self.param_groups[0]['lr']
weight_decay = self.param_groups[0]['weight_decay']
initial_lr = self.param_groups[0]['initial_lr']
beta1, _ = self.param_groups[0]['betas']
if param in self.state:
param_optim_state = self.state[param]
step_tensor = param_optim_state['exp_avg'].clone().lerp_(
param.grad, 1 - beta1).sign_().mul_(lr)
decay_factor = (lr / initial_lr) if initial_lr else 1.0
step_tensor.add_(param, alpha=-weight_decay * decay_factor)
for metric in self.metric_functions:
optimizer_metrics[f'{metric}/{name}'] = self.metric_functions[
metric](param, param_optim_state, step_tensor)
return optimizer_metrics
|