# Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import logging import math from typing import Callable, Dict, Iterable, Optional, Tuple, Union import torch from composer.utils import dist from torch.optim.optimizer import Optimizer log = logging.getLogger(__name__) class DecoupledLionW(Optimizer): metric_functions = { 'l2_norm/moment': lambda param, optim_state, step_tensor: torch.linalg.vector_norm( optim_state['exp_avg']), 'l2_norm/param': lambda param, optim_state, step_tensor: torch.linalg.vector_norm( param.data), 'l2_norm/update': lambda param, optim_state, step_tensor: torch.linalg.vector_norm( step_tensor), 'l2_norm/grad': lambda param, optim_state, step_tensor: torch.linalg.vector_norm( param.grad), } def __init__( self, params: Union[Iterable[torch.Tensor], Iterable[dict]], lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, ): if lr <= 0.: raise Exception(f'Invalid LR: {lr}. LR must be > 0') if not all([0. <= beta <= 1. for beta in betas]): raise Exception( f'Invalid beta values: {betas} All betas must be between 0 and 1.' ) if weight_decay >= 1e-3: log.warning( f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? ' + f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!' ) defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay} super().__init__(params, defaults) for group in self.param_groups: group['initial_lr'] = group['lr'] @staticmethod def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, lr: float, initial_lr: float, wd: float, beta1: float, beta2: float) -> None: # stepweight decay if wd != 0: decay_factor = (lr / initial_lr) if initial_lr else 1.0 p.data.mul_(1 - decay_factor * wd) # update is interpolation between gradient and momentum update = exp_avg.lerp(grad, 1 - beta1).sign_() p.add_(update, alpha=-lr) # momentum is interp b/w gradient and itself exp_avg.lerp_(grad, 1 - beta2) @torch.no_grad() def step(self, closure: Optional[Callable] = None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in filter(lambda p: p.grad is not None and p.requires_grad, group['params']): grad, lr, initial_lr, wd, beta1, beta2, state = p.grad, group[ 'lr'], group['initial_lr'], group[ 'weight_decay'], *group['betas'], self.state[p] # init state - exponential moving average of gradient values if len(state) == 0: state['exp_avg'] = torch.zeros_like(p) exp_avg = state['exp_avg'] self.lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2) return loss def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): local_keys = list(optimizer_metrics.keys()) all_gathered_keys = dist.all_gather_object(local_keys) all_keys = set() for keys in all_gathered_keys: all_keys.update(keys) # Sort keys to ensure every rank has the same keys order # Only L2 norm metric keys are present, can apply regular sort all_keys = sorted(all_keys) for metric in all_keys: if metric.startswith('l2_norm'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: dist.all_reduce(reduced, reduce_operation='SUM') optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced)) else: reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: dist.all_reduce(reduced, reduce_operation='SUM') optimizer_metrics[metric] = reduced / dist.get_world_size() return optimizer_metrics def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): """Preprocess metrics to reduce across ranks correctly.""" # Only L2 norm metric keys are present, can skip sorting at this stage for metric in optimizer_metrics: # L2 norms need to be squared, before they are reduced via summation optimizer_metrics[metric] = optimizer_metrics[metric]**2 return optimizer_metrics def report_per_parameter_metrics(self, param: torch.Tensor, name: str, optimizer_metrics: dict): lr = self.param_groups[0]['lr'] weight_decay = self.param_groups[0]['weight_decay'] initial_lr = self.param_groups[0]['initial_lr'] beta1, _ = self.param_groups[0]['betas'] if param in self.state: param_optim_state = self.state[param] step_tensor = param_optim_state['exp_avg'].clone().lerp_( param.grad, 1 - beta1).sign_().mul_(lr) decay_factor = (lr / initial_lr) if initial_lr else 1.0 step_tensor.add_(param, alpha=-weight_decay * decay_factor) for metric in self.metric_functions: optimizer_metrics[f'{metric}/{name}'] = self.metric_functions[ metric](param, param_optim_state, step_tensor) return optimizer_metrics