|
|
|
|
|
|
|
import logging |
|
import math |
|
from typing import Callable, Dict, Iterable, Optional, Tuple, Union |
|
|
|
import torch |
|
from composer.utils import dist |
|
from torch.optim.optimizer import Optimizer |
|
|
|
from llmfoundry.optim.outlier_detection import OutlierDetector |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
class DecoupledAdaLRLion(Optimizer): |
|
"""DecoupledAdaLRLion. |
|
|
|
This class implements a variant of Lion which lowers the layerwise |
|
learning rate when the layer's moment becomes an outlier. A moment is an |
|
outlier if it is some multiple `outlier_threshold` times larger than the |
|
simple windowed moving average (MVA) of moment norms taken from steps T-1000 |
|
to T-500. If an outlier is detected, the LR is lowered by `lr_penalty` for |
|
`timeout` steps. If N outliers are detected within `timeout` steps, the LR |
|
is scaled down by max(`lr_penalty` ** N, `min_scale`). |
|
|
|
Args: |
|
params (Iterable[torch.Parameter]): Model parameters to optimize |
|
lr (float): Learning rate for updates |
|
betas (Tuple[float]): Momentum factors |
|
weight_decay (float): Weight decay |
|
outlier_threshold (float): Multiplicative factor determining what constitutes an "outlier" relative to the MVA of gradient norms. |
|
timeout (int): Number of steps to lower the learning for after seeing an outlier. |
|
lr_penalty (float): Multiplicative scale by which to lower the LR for each outlier. |
|
min_scale (float): Minimum allowed scaling of the LR . |
|
""" |
|
metric_functions = { |
|
'l2_norm/moment': |
|
lambda param, optim_state, step_tensor: torch.linalg.vector_norm( |
|
optim_state['exp_avg']), |
|
'l2_norm/param': |
|
lambda param, optim_state, step_tensor: torch.linalg.vector_norm( |
|
param.data), |
|
'l2_norm/update': |
|
lambda param, optim_state, step_tensor: torch.linalg.vector_norm( |
|
step_tensor), |
|
'l2_norm/grad': |
|
lambda param, optim_state, step_tensor: torch.linalg.vector_norm( |
|
param.grad), |
|
} |
|
|
|
def __init__(self, |
|
params: Union[Iterable[torch.Tensor], Iterable[dict]], |
|
lr: float = 1e-4, |
|
betas: Tuple[float, float] = (0.9, 0.99), |
|
weight_decay: float = 0.0, |
|
outlier_threshold: float = 10.0, |
|
timeout: int = 100, |
|
lr_penalty: float = .707, |
|
min_scale: float = 1e-4): |
|
if lr <= 0.: |
|
raise Exception(f'Invalid LR: {lr}. LR must be > 0') |
|
if not all([0. <= beta <= 1. for beta in betas]): |
|
raise Exception( |
|
f'Invalid beta values: {betas} All betas must be between 0 and 1.' |
|
) |
|
if weight_decay >= 1e-3: |
|
log.warning( |
|
f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? ' |
|
+ |
|
f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!' |
|
) |
|
|
|
defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay} |
|
|
|
super().__init__(params, defaults) |
|
|
|
for group in self.param_groups: |
|
group['initial_lr'] = group['lr'] |
|
self.outlier_threshold = outlier_threshold |
|
self.timeout = timeout |
|
self.lr_penalty = lr_penalty |
|
self.min_scale = min_scale |
|
|
|
@staticmethod |
|
def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, |
|
lr: float, initial_lr: float, wd: float, beta1: float, |
|
beta2: float) -> None: |
|
|
|
if wd != 0: |
|
decay_factor = (lr / initial_lr) if initial_lr else 1.0 |
|
p.data.mul_(1 - decay_factor * wd) |
|
|
|
|
|
update = exp_avg.lerp(grad, 1 - beta1).sign_() |
|
p.add_(update, alpha=-lr) |
|
|
|
|
|
exp_avg.lerp_(grad, 1 - beta2) |
|
|
|
@staticmethod |
|
def adjust_lr(lr: float, lr_penalty: float, num_times: int, |
|
min_scale: float) -> float: |
|
"""Adjusts LR. |
|
|
|
Multiplicatively scales down the LR by lr_penalty for each outlier |
|
that has occurred in the last `timeout` number of steps, capping the |
|
scaling to be no smaller than `min_scale`. |
|
|
|
Args: |
|
lr (float): Base learning rate |
|
lr_penalty (float): Scaling factor to multiply by for each outlier |
|
num_times (int): Number of outliers in the last `timeout` steps |
|
min_scale (float): Minimum scaling to apply to our LR. |
|
|
|
Returns: |
|
float: Scaled LR |
|
""" |
|
return lr * max(min_scale, lr_penalty**num_times) |
|
|
|
@torch.no_grad() |
|
def step(self, closure: Optional[Callable] = None): |
|
|
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
for p in filter(lambda p: p.grad is not None and p.requires_grad, |
|
group['params']): |
|
|
|
grad, lr, initial_lr, wd, beta1, beta2, state = p.grad, group[ |
|
'lr'], group['initial_lr'], group[ |
|
'weight_decay'], *group['betas'], self.state[p] |
|
|
|
|
|
|
|
if len(state) == 0: |
|
state['exp_avg'] = torch.zeros_like(p) |
|
state['moment_tracker'] = OutlierDetector( |
|
self.outlier_threshold) |
|
state['outlier_timestamp'] = [] |
|
state['step'] = 0 |
|
|
|
exp_avg = state['exp_avg'] |
|
|
|
|
|
moment_norm = torch.linalg.vector_norm( |
|
exp_avg.lerp(grad, 1 - beta2))**2 |
|
|
|
if dist.get_world_size() > 1: |
|
dist.all_reduce(moment_norm, reduce_operation='SUM') |
|
moment_norm = math.sqrt(moment_norm) |
|
|
|
if state['moment_tracker'].insert_observation(moment_norm): |
|
state['outlier_timestamp'].append(state['step']) |
|
|
|
removed = [] |
|
for ts in state['outlier_timestamp']: |
|
if state['step'] - ts > self.timeout: |
|
removed.append(ts) |
|
|
|
for ts in removed: |
|
state['outlier_timestamp'].remove(ts) |
|
|
|
lr = self.adjust_lr(lr, self.lr_penalty, |
|
len(state['outlier_timestamp']), |
|
self.min_scale) |
|
self.lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2) |
|
state['step'] += 1 |
|
|
|
return loss |
|
|
|
def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): |
|
for metric in optimizer_metrics: |
|
if metric.startswith('l2_norm'): |
|
reduced = optimizer_metrics[metric] |
|
if dist.get_world_size() > 1: |
|
dist.all_reduce(reduced, reduce_operation='SUM') |
|
|
|
optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced)) |
|
elif metric.startswith('cosine'): |
|
reduced = optimizer_metrics[metric] |
|
if dist.get_world_size() > 1: |
|
dist.all_reduce(reduced, reduce_operation='SUM') |
|
|
|
_, vectors, layer = tuple(metric.split('/')) |
|
|
|
A, B = tuple(vectors.split('_')) |
|
|
|
A_reduced_norm = optimizer_metrics[f'l2_norm/{A}/{layer}'] |
|
B_reduced_norm = optimizer_metrics[f'l2_norm/{B}/{layer}'] |
|
optimizer_metrics[metric] = reduced / (A_reduced_norm * |
|
B_reduced_norm) |
|
elif metric.startswith('layerwise_lr'): |
|
continue |
|
else: |
|
reduced = optimizer_metrics[metric] |
|
if dist.get_world_size() > 1: |
|
dist.all_reduce(reduced, reduce_operation='SUM') |
|
optimizer_metrics[metric] = reduced / dist.get_world_size() |
|
|
|
return optimizer_metrics |
|
|
|
def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): |
|
"""Preprocess metrics to reduce across ranks correctly.""" |
|
|
|
for metric in optimizer_metrics: |
|
|
|
optimizer_metrics[metric] = optimizer_metrics[metric]**2 |
|
return optimizer_metrics |
|
|
|
def report_per_parameter_metrics(self, param: torch.Tensor, name: str, |
|
optimizer_metrics: dict): |
|
lr = self.param_groups[0]['lr'] |
|
weight_decay = self.param_groups[0]['weight_decay'] |
|
initial_lr = self.param_groups[0]['initial_lr'] |
|
|
|
beta1, _ = self.param_groups[0]['betas'] |
|
if param in self.state: |
|
param_optim_state = self.state[param] |
|
layerwise_lr = self.adjust_lr( |
|
lr, self.lr_penalty, |
|
len(param_optim_state['outlier_timestamp']), self.min_scale) |
|
|
|
step_tensor = param_optim_state['exp_avg'].clone().lerp_( |
|
param.grad, 1 - beta1).sign_().mul_(lr) |
|
decay_factor = (lr / initial_lr) if initial_lr else 1.0 |
|
step_tensor.add_(param, alpha=-weight_decay * decay_factor) |
|
for metric in self.metric_functions: |
|
optimizer_metrics[f'{metric}/{name}'] = self.metric_functions[ |
|
metric](param, param_optim_state, step_tensor) |
|
|
|
optimizer_metrics[f'layerwise_lr/{name}'] = torch.tensor( |
|
layerwise_lr) |
|
|
|
return optimizer_metrics |
|
|
|
|
|
class DecoupledClipLion(Optimizer): |
|
"""DecoupledClipLION. |
|
|
|
This class implements a variant of Lion which clips layerwise gradients |
|
that are "outliers". A gradient is an outlier if it is some multiple k times |
|
larger than the simple windowed moving average (MVA) of gradient norms taken |
|
from steps T-1000 to T-500. If an outlier is detected, it is clipped. |
|
|
|
to no longer have norm k * MVA. |
|
|
|
Args: |
|
params (Iterable[torch.Parameter]): Model parameters to optimize |
|
lr (float): Learning rate for updates |
|
betas (Tuple[float]): Momentum factors |
|
weight_decay (float): Weight decay |
|
outlier_threshold (float): Multiplicative factor determining what constitutes an "outlier" relative to the MVA of gradient norms. |
|
""" |
|
metric_functions = { |
|
'l2_norm/moment': |
|
lambda param, optim_state, step_tensor: torch.linalg.vector_norm( |
|
optim_state['exp_avg']), |
|
'l2_norm/param': |
|
lambda param, optim_state, step_tensor: torch.linalg.vector_norm( |
|
param.data), |
|
'l2_norm/update': |
|
lambda param, optim_state, step_tensor: torch.linalg.vector_norm( |
|
step_tensor), |
|
'l2_norm/grad': |
|
lambda param, optim_state, step_tensor: torch.linalg.vector_norm( |
|
param.grad), |
|
} |
|
|
|
def __init__(self, |
|
params: Union[Iterable[torch.Tensor], Iterable[dict]], |
|
lr: float = 1e-4, |
|
betas: Tuple[float, float] = (0.9, 0.99), |
|
weight_decay: float = 0.0, |
|
outlier_threshold: float = 5.0): |
|
if lr <= 0.: |
|
raise Exception(f'Invalid LR: {lr}. LR must be > 0') |
|
if not all([0. <= beta <= 1. for beta in betas]): |
|
raise Exception( |
|
f'Invalid beta values: {betas} All betas must be between 0 and 1.' |
|
) |
|
if weight_decay >= 1e-3: |
|
log.warning( |
|
f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? ' |
|
+ |
|
f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!' |
|
) |
|
|
|
defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay} |
|
|
|
super().__init__(params, defaults) |
|
|
|
for group in self.param_groups: |
|
group['initial_lr'] = group['lr'] |
|
self.outlier_threshold = outlier_threshold |
|
|
|
@staticmethod |
|
def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, |
|
lr: float, initial_lr: float, wd: float, beta1: float, |
|
beta2: float) -> None: |
|
|
|
if wd != 0: |
|
decay_factor = (lr / initial_lr) if initial_lr else 1.0 |
|
p.data.mul_(1 - decay_factor * wd) |
|
|
|
|
|
update = exp_avg.lerp(grad, 1 - beta1).sign_() |
|
p.add_(update, alpha=-lr) |
|
|
|
|
|
exp_avg.lerp_(grad, 1 - beta2) |
|
|
|
@torch.no_grad() |
|
def step(self, closure: Optional[Callable] = None): |
|
|
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
for p in filter(lambda p: p.grad is not None and p.requires_grad, |
|
group['params']): |
|
|
|
grad, lr, initial_lr, wd, beta1, beta2, state = p.grad, group[ |
|
'lr'], group['initial_lr'], group[ |
|
'weight_decay'], *group['betas'], self.state[p] |
|
|
|
|
|
|
|
if len(state) == 0: |
|
state['exp_avg'] = torch.zeros_like(p) |
|
state['grad_tracker'] = OutlierDetector( |
|
self.outlier_threshold) |
|
state['clipped_batches'] = torch.tensor(0.0) |
|
|
|
exp_avg = state['exp_avg'] |
|
|
|
|
|
grad_norm = torch.linalg.vector_norm(grad)**2 |
|
|
|
if dist.get_world_size() > 1: |
|
dist.all_reduce(grad_norm, reduce_operation='SUM') |
|
grad_norm = math.sqrt(grad_norm) |
|
|
|
if state['grad_tracker'].insert_observation(grad_norm): |
|
state['clipped_batches'] += 1.0 |
|
clip_norm = state['grad_tracker'].get_slow_mva( |
|
) * self.outlier_threshold |
|
grad = grad.div(grad_norm).mul_(clip_norm) |
|
|
|
self.lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2) |
|
|
|
return loss |
|
|
|
def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): |
|
local_keys = list(optimizer_metrics.keys()) |
|
all_gathered_keys = dist.all_gather_object(local_keys) |
|
all_keys = set() |
|
for keys in all_gathered_keys: |
|
all_keys.update(keys) |
|
|
|
|
|
|
|
all_keys = sorted(all_keys) |
|
for metric in all_keys: |
|
if metric.startswith('l2_norm'): |
|
reduced = optimizer_metrics[metric] |
|
if dist.get_world_size() > 1: |
|
dist.all_reduce(reduced, reduce_operation='SUM') |
|
|
|
optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced)) |
|
elif metric.startswith('clipped_batches'): |
|
continue |
|
else: |
|
reduced = optimizer_metrics[metric] |
|
if dist.get_world_size() > 1: |
|
dist.all_reduce(reduced, reduce_operation='SUM') |
|
optimizer_metrics[metric] = reduced / dist.get_world_size() |
|
|
|
return optimizer_metrics |
|
|
|
def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): |
|
"""Preprocess metrics to reduce across ranks correctly.""" |
|
|
|
metrics = optimizer_metrics.keys() |
|
metrics = sorted(metrics, |
|
key=lambda metric: 0 if 'l2_norm' in metric else 1) |
|
for metric in metrics: |
|
if metric.startswith('l2_norm'): |
|
|
|
optimizer_metrics[metric] = optimizer_metrics[metric]**2 |
|
elif metric.startswith('cosine'): |
|
_, vectors, layer = tuple(metric.split('/')) |
|
|
|
A, B = tuple(vectors.split('_')) |
|
|
|
|
|
A_rank_subset_norm = math.sqrt( |
|
optimizer_metrics[f'l2_norm/{A}/{layer}']) |
|
B_rank_subset_norm = math.sqrt( |
|
optimizer_metrics[f'l2_norm/{B}/{layer}']) |
|
|
|
optimizer_metrics[ |
|
metric] *= A_rank_subset_norm * B_rank_subset_norm |
|
|
|
return optimizer_metrics |
|
|
|
def report_per_parameter_metrics(self, param: torch.Tensor, name: str, |
|
optimizer_metrics: dict): |
|
lr = self.param_groups[0]['lr'] |
|
weight_decay = self.param_groups[0]['weight_decay'] |
|
initial_lr = self.param_groups[0]['initial_lr'] |
|
|
|
beta1, _ = self.param_groups[0]['betas'] |
|
if param in self.state: |
|
param_optim_state = self.state[param] |
|
step_tensor = param_optim_state['exp_avg'].clone().lerp_( |
|
param.grad, 1 - beta1).sign_().mul_(lr) |
|
decay_factor = (lr / initial_lr) if initial_lr else 1.0 |
|
step_tensor.add_(param, alpha=-weight_decay * decay_factor) |
|
for metric in self.metric_functions: |
|
optimizer_metrics[f'{metric}/{name}'] = self.metric_functions[ |
|
metric](param, param_optim_state, step_tensor) |
|
|
|
optimizer_metrics[f'clipped_batches/{name}'] = param_optim_state[ |
|
'clipped_batches'] |
|
|
|
return optimizer_metrics |
|
|