File size: 18,249 Bytes
de4ade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
import torch
from composer.utils import dist
from torch.optim.optimizer import Optimizer
from llmfoundry.optim.outlier_detection import OutlierDetector
log = logging.getLogger(__name__)
class DecoupledAdaLRLion(Optimizer):
"""DecoupledAdaLRLion.
This class implements a variant of Lion which lowers the layerwise
learning rate when the layer's moment becomes an outlier. A moment is an
outlier if it is some multiple `outlier_threshold` times larger than the
simple windowed moving average (MVA) of moment norms taken from steps T-1000
to T-500. If an outlier is detected, the LR is lowered by `lr_penalty` for
`timeout` steps. If N outliers are detected within `timeout` steps, the LR
is scaled down by max(`lr_penalty` ** N, `min_scale`).
Args:
params (Iterable[torch.Parameter]): Model parameters to optimize
lr (float): Learning rate for updates
betas (Tuple[float]): Momentum factors
weight_decay (float): Weight decay
outlier_threshold (float): Multiplicative factor determining what constitutes an "outlier" relative to the MVA of gradient norms.
timeout (int): Number of steps to lower the learning for after seeing an outlier.
lr_penalty (float): Multiplicative scale by which to lower the LR for each outlier.
min_scale (float): Minimum allowed scaling of the LR .
"""
metric_functions = {
'l2_norm/moment':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
optim_state['exp_avg']),
'l2_norm/param':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
param.data),
'l2_norm/update':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
step_tensor),
'l2_norm/grad':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
param.grad),
}
def __init__(self,
params: Union[Iterable[torch.Tensor], Iterable[dict]],
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
outlier_threshold: float = 10.0,
timeout: int = 100,
lr_penalty: float = .707,
min_scale: float = 1e-4):
if lr <= 0.:
raise Exception(f'Invalid LR: {lr}. LR must be > 0')
if not all([0. <= beta <= 1. for beta in betas]):
raise Exception(
f'Invalid beta values: {betas} All betas must be between 0 and 1.'
)
if weight_decay >= 1e-3:
log.warning(
f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? '
+
f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!'
)
defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay}
super().__init__(params, defaults)
for group in self.param_groups:
group['initial_lr'] = group['lr']
self.outlier_threshold = outlier_threshold
self.timeout = timeout
self.lr_penalty = lr_penalty
self.min_scale = min_scale
@staticmethod
def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor,
lr: float, initial_lr: float, wd: float, beta1: float,
beta2: float) -> None:
# stepweight decay
if wd != 0:
decay_factor = (lr / initial_lr) if initial_lr else 1.0
p.data.mul_(1 - decay_factor * wd)
# update is interpolation between gradient and momentum
update = exp_avg.lerp(grad, 1 - beta1).sign_()
p.add_(update, alpha=-lr)
# momentum is interp b/w gradient and itself
exp_avg.lerp_(grad, 1 - beta2)
@staticmethod
def adjust_lr(lr: float, lr_penalty: float, num_times: int,
min_scale: float) -> float:
"""Adjusts LR.
Multiplicatively scales down the LR by lr_penalty for each outlier
that has occurred in the last `timeout` number of steps, capping the
scaling to be no smaller than `min_scale`.
Args:
lr (float): Base learning rate
lr_penalty (float): Scaling factor to multiply by for each outlier
num_times (int): Number of outliers in the last `timeout` steps
min_scale (float): Minimum scaling to apply to our LR.
Returns:
float: Scaled LR
"""
return lr * max(min_scale, lr_penalty**num_times)
@torch.no_grad()
def step(self, closure: Optional[Callable] = None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in filter(lambda p: p.grad is not None and p.requires_grad,
group['params']):
grad, lr, initial_lr, wd, beta1, beta2, state = p.grad, group[
'lr'], group['initial_lr'], group[
'weight_decay'], *group['betas'], self.state[p]
# init state - exponential moving average of gradient values
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['moment_tracker'] = OutlierDetector(
self.outlier_threshold)
state['outlier_timestamp'] = []
state['step'] = 0
exp_avg = state['exp_avg']
# determine if the new moment resulting from this grad would be an outlier
moment_norm = torch.linalg.vector_norm(
exp_avg.lerp(grad, 1 - beta2))**2
if dist.get_world_size() > 1:
dist.all_reduce(moment_norm, reduce_operation='SUM')
moment_norm = math.sqrt(moment_norm)
if state['moment_tracker'].insert_observation(moment_norm):
state['outlier_timestamp'].append(state['step'])
removed = []
for ts in state['outlier_timestamp']:
if state['step'] - ts > self.timeout:
removed.append(ts)
for ts in removed:
state['outlier_timestamp'].remove(ts)
lr = self.adjust_lr(lr, self.lr_penalty,
len(state['outlier_timestamp']),
self.min_scale)
self.lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2)
state['step'] += 1
return loss
def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):
for metric in optimizer_metrics:
if metric.startswith('l2_norm'):
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')
optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced))
elif metric.startswith('cosine'):
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')
_, vectors, layer = tuple(metric.split('/'))
A, B = tuple(vectors.split('_'))
A_reduced_norm = optimizer_metrics[f'l2_norm/{A}/{layer}']
B_reduced_norm = optimizer_metrics[f'l2_norm/{B}/{layer}']
optimizer_metrics[metric] = reduced / (A_reduced_norm *
B_reduced_norm)
elif metric.startswith('layerwise_lr'):
continue
else:
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')
optimizer_metrics[metric] = reduced / dist.get_world_size()
return optimizer_metrics
def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):
"""Preprocess metrics to reduce across ranks correctly."""
# Only L2 norm metric keys are present, can skip sorting at this stage
for metric in optimizer_metrics:
# L2 norms need to be squared, before they are reduced via summation
optimizer_metrics[metric] = optimizer_metrics[metric]**2
return optimizer_metrics
def report_per_parameter_metrics(self, param: torch.Tensor, name: str,
optimizer_metrics: dict):
lr = self.param_groups[0]['lr']
weight_decay = self.param_groups[0]['weight_decay']
initial_lr = self.param_groups[0]['initial_lr']
beta1, _ = self.param_groups[0]['betas']
if param in self.state:
param_optim_state = self.state[param]
layerwise_lr = self.adjust_lr(
lr, self.lr_penalty,
len(param_optim_state['outlier_timestamp']), self.min_scale)
step_tensor = param_optim_state['exp_avg'].clone().lerp_(
param.grad, 1 - beta1).sign_().mul_(lr)
decay_factor = (lr / initial_lr) if initial_lr else 1.0
step_tensor.add_(param, alpha=-weight_decay * decay_factor)
for metric in self.metric_functions:
optimizer_metrics[f'{metric}/{name}'] = self.metric_functions[
metric](param, param_optim_state, step_tensor)
optimizer_metrics[f'layerwise_lr/{name}'] = torch.tensor(
layerwise_lr)
return optimizer_metrics
class DecoupledClipLion(Optimizer):
"""DecoupledClipLION.
This class implements a variant of Lion which clips layerwise gradients
that are "outliers". A gradient is an outlier if it is some multiple k times
larger than the simple windowed moving average (MVA) of gradient norms taken
from steps T-1000 to T-500. If an outlier is detected, it is clipped.
to no longer have norm k * MVA.
Args:
params (Iterable[torch.Parameter]): Model parameters to optimize
lr (float): Learning rate for updates
betas (Tuple[float]): Momentum factors
weight_decay (float): Weight decay
outlier_threshold (float): Multiplicative factor determining what constitutes an "outlier" relative to the MVA of gradient norms.
"""
metric_functions = {
'l2_norm/moment':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
optim_state['exp_avg']),
'l2_norm/param':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
param.data),
'l2_norm/update':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
step_tensor),
'l2_norm/grad':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
param.grad),
}
def __init__(self,
params: Union[Iterable[torch.Tensor], Iterable[dict]],
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
outlier_threshold: float = 5.0):
if lr <= 0.:
raise Exception(f'Invalid LR: {lr}. LR must be > 0')
if not all([0. <= beta <= 1. for beta in betas]):
raise Exception(
f'Invalid beta values: {betas} All betas must be between 0 and 1.'
)
if weight_decay >= 1e-3:
log.warning(
f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? '
+
f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!'
)
defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay}
super().__init__(params, defaults)
for group in self.param_groups:
group['initial_lr'] = group['lr']
self.outlier_threshold = outlier_threshold
@staticmethod
def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor,
lr: float, initial_lr: float, wd: float, beta1: float,
beta2: float) -> None:
# stepweight decay
if wd != 0:
decay_factor = (lr / initial_lr) if initial_lr else 1.0
p.data.mul_(1 - decay_factor * wd)
# update is interpolation between gradient and momentum
update = exp_avg.lerp(grad, 1 - beta1).sign_()
p.add_(update, alpha=-lr)
# momentum is interp b/w gradient and itself
exp_avg.lerp_(grad, 1 - beta2)
@torch.no_grad()
def step(self, closure: Optional[Callable] = None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in filter(lambda p: p.grad is not None and p.requires_grad,
group['params']):
grad, lr, initial_lr, wd, beta1, beta2, state = p.grad, group[
'lr'], group['initial_lr'], group[
'weight_decay'], *group['betas'], self.state[p]
# init state - exponential moving average of gradient values
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['grad_tracker'] = OutlierDetector(
self.outlier_threshold)
state['clipped_batches'] = torch.tensor(0.0)
exp_avg = state['exp_avg']
# determine if the new moment resulting from this grad would be an outlier
grad_norm = torch.linalg.vector_norm(grad)**2
if dist.get_world_size() > 1:
dist.all_reduce(grad_norm, reduce_operation='SUM')
grad_norm = math.sqrt(grad_norm)
if state['grad_tracker'].insert_observation(grad_norm):
state['clipped_batches'] += 1.0
clip_norm = state['grad_tracker'].get_slow_mva(
) * self.outlier_threshold
grad = grad.div(grad_norm).mul_(clip_norm)
self.lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2)
return loss
def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):
local_keys = list(optimizer_metrics.keys())
all_gathered_keys = dist.all_gather_object(local_keys)
all_keys = set()
for keys in all_gathered_keys:
all_keys.update(keys)
# Sort keys to ensure every rank has the same keys order
# Only L2 norm metric keys are present, can apply regular sort
all_keys = sorted(all_keys)
for metric in all_keys:
if metric.startswith('l2_norm'):
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')
optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced))
elif metric.startswith('clipped_batches'):
continue
else:
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')
optimizer_metrics[metric] = reduced / dist.get_world_size()
return optimizer_metrics
def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):
"""Preprocess metrics to reduce across ranks correctly."""
# Sort L2 norms first so they are squared before other metrics, which depend on squared values
metrics = optimizer_metrics.keys()
metrics = sorted(metrics,
key=lambda metric: 0 if 'l2_norm' in metric else 1)
for metric in metrics:
if metric.startswith('l2_norm'):
# L2 norms need to be squared, before they are reduced via summation
optimizer_metrics[metric] = optimizer_metrics[metric]**2
elif metric.startswith('cosine'):
_, vectors, layer = tuple(metric.split('/'))
A, B = tuple(vectors.split('_'))
# L2 norm would've been squared in previous branch
A_rank_subset_norm = math.sqrt(
optimizer_metrics[f'l2_norm/{A}/{layer}'])
B_rank_subset_norm = math.sqrt(
optimizer_metrics[f'l2_norm/{B}/{layer}'])
optimizer_metrics[
metric] *= A_rank_subset_norm * B_rank_subset_norm
return optimizer_metrics
def report_per_parameter_metrics(self, param: torch.Tensor, name: str,
optimizer_metrics: dict):
lr = self.param_groups[0]['lr']
weight_decay = self.param_groups[0]['weight_decay']
initial_lr = self.param_groups[0]['initial_lr']
beta1, _ = self.param_groups[0]['betas']
if param in self.state:
param_optim_state = self.state[param]
step_tensor = param_optim_state['exp_avg'].clone().lerp_(
param.grad, 1 - beta1).sign_().mul_(lr)
decay_factor = (lr / initial_lr) if initial_lr else 1.0
step_tensor.add_(param, alpha=-weight_decay * decay_factor)
for metric in self.metric_functions:
optimizer_metrics[f'{metric}/{name}'] = self.metric_functions[
metric](param, param_optim_state, step_tensor)
optimizer_metrics[f'clipped_batches/{name}'] = param_optim_state[
'clipped_batches']
return optimizer_metrics
|