|
|
|
|
|
|
|
from typing import Any, Callable, Dict, Iterable, Optional, Tuple |
|
|
|
import torch |
|
|
|
|
|
class DecoupledLionW_8bit(torch.optim.Optimizer): |
|
"""LION optimizer with ~8 bits of state per parameter. |
|
|
|
This optimizer is a drop-in replacement for our regular LION optimizer |
|
with decoupled weight decay, but uses less memory, writes smaller |
|
checkpoints, and offers almost-numerically-identical convergence. |
|
|
|
Its state saved per parameter is just an int8, though there are auxiliary |
|
scaling factors that bring the total memory per parameter to ~8.5 bits. |
|
The exact quantization scheme is considered an implementation detail |
|
and may change. |
|
|
|
When training on CPUs, however, no quantization will actually take place. |
|
|
|
See the LION paper (https://arxiv.org/abs/2302.06675) for details about |
|
the algorithm itself. |
|
|
|
Args: |
|
params: iterable of parameters to optimize or dicts defining |
|
parameter groups |
|
lr: learning rate |
|
betas: two coefficients between 0 and 1 used to combine the current |
|
gradients and the momentum. The first coefficient is the weight |
|
of the gradient when computing the update. The second is the |
|
weight of the gradient when computing the new momentum. |
|
weight decay: Weights are multiplied by 1 - `weight_decay` after |
|
each optimizer step. Note that we use decoupled weight decay, |
|
meaning that this decay does not contribute to the momentum. |
|
compress_state_dict: if True, this optimizer's `state_dict` will |
|
include quantized optimizer states. Otherwise, the optimizer |
|
states are converted to bfloat16 Tensors matching the shapes of |
|
their corresponding parameters. The former uses ~8.5 bits per |
|
parameter while the latter uses 16 bits per parameter. However, |
|
the former is less thoroughly tested and will not work with |
|
FSDP or other weight sharding approaches. |
|
quantize: If False, optimizer states will not actually be quantized. |
|
This option is available so that one can easily debug whether |
|
the quantization is causing any convergence issues. Because |
|
quantization is only supported for CUDA parameters, attempting to |
|
update a non-CUDA tensor will raise an error. |
|
error_correction: If True, float16 and bfloat16 parameters will be |
|
given an extra state variable, "errors." This tensor will be |
|
of the same shape as the parameter but of dtype uint8. This |
|
auxiliary variable is used to better approximate float32 updates |
|
by retaining information across optimizer steps. |
|
|
|
Raises: |
|
NotImplementedError - If any of `quantize`, `compress_state_dict`, |
|
or `error_correction` are `True` and either a) there is no CUDA |
|
device, or b) step() is executed on a non-CUDA parameter. |
|
""" |
|
|
|
def __init__(self, |
|
params: Iterable[torch.Tensor], |
|
lr: float = 1e-3, |
|
betas: Tuple[float, float] = (0.9, 0.99), |
|
weight_decay: float = 0, |
|
quantize: bool = True, |
|
compress_state_dict: bool = False, |
|
error_correction: bool = False, |
|
_fused: bool = True): |
|
|
|
if lr < 0.0: |
|
raise ValueError('Invalid learning rate: {}'.format(lr)) |
|
if not 0.0 <= betas[0] <= 1.0: |
|
raise ValueError('Invalid beta parameter at index 0: {}'.format( |
|
betas[0])) |
|
if not 0.0 <= betas[1] <= 1.0: |
|
raise ValueError('Invalid beta parameter at index 1: {}'.format( |
|
betas[1])) |
|
if not 0.0 <= weight_decay: |
|
raise ValueError( |
|
'Invalid weight_decay value: {}'.format(weight_decay)) |
|
|
|
if not torch.cuda.is_available(): |
|
needs_cuda = ' requires a CUDA device.' |
|
if quantize: |
|
raise NotImplementedError('Quantization' + needs_cuda) |
|
if error_correction: |
|
raise NotImplementedError('Error correction' + needs_cuda) |
|
if compress_state_dict: |
|
raise NotImplementedError('Quantized state dict' + needs_cuda) |
|
|
|
_fused = _fused and quantize |
|
self._quantize = quantize |
|
self._error_correction = error_correction |
|
self._compress_state_dict = compress_state_dict |
|
|
|
defaults = { |
|
'lr': lr, |
|
'initial_lr': lr, |
|
'betas': betas, |
|
'weight_decay': weight_decay, |
|
'fused': _fused |
|
} |
|
super().__init__(params, defaults) |
|
|
|
@torch.no_grad() |
|
def step(self, closure: Optional[Callable] = None): |
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
for p in group['params']: |
|
self.step_param(p, group) |
|
|
|
return loss |
|
|
|
def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None: |
|
if not p.requires_grad or p.grad is None: |
|
return |
|
if self._quantize and not p.is_cuda: |
|
raise NotImplementedError( |
|
f"Can't use quantization with param on {p.device} " + |
|
f'({p.shape}, {p.dtype}). If you need ' + |
|
'to use DecoupledLionW_8bit without a CUDA device, try ' + |
|
'creating this optimizer with quantize=False.') |
|
state = self.state[p] |
|
if 'exp_avg' not in state: |
|
mom = torch.zeros_like(p) |
|
state['exp_avg'] = _MaybeQuantizedTensor( |
|
mom, try_quantize=self._quantize) |
|
need_errs = (p.dtype != torch.float32) and self._error_correction |
|
if state.get('errors') is None and need_errs: |
|
numel = p.numel() |
|
numel += numel % 2 |
|
errors = torch.zeros(numel, dtype=torch.uint8, device=p.device) |
|
|
|
state['errors'] = errors.view(torch.bfloat16) |
|
decay_factor = hparams['weight_decay'] |
|
decay_factor *= hparams['lr'] / hparams['initial_lr'] |
|
errors: Optional[torch.Tensor] = None |
|
if 'errors' in state: |
|
errors = state['errors'] |
|
assert errors is not None |
|
errors = errors.view(dtype=torch.uint8) |
|
errors = errors[:p.numel()].view(p.shape) |
|
_lion8b_step(momentums=state['exp_avg'], |
|
weights=p, |
|
grads=p.grad, |
|
beta1=hparams['betas'][0], |
|
beta2=hparams['betas'][1], |
|
lr=hparams['lr'], |
|
weight_decay=decay_factor, |
|
fused=hparams['fused'], |
|
errors=errors) |
|
|
|
def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: |
|
|
|
|
|
opt_state, _ = state.values() |
|
for param_id in opt_state: |
|
param_state = opt_state[param_id] |
|
new_state = {} |
|
if any(k.startswith('exp_avg') for k in param_state): |
|
|
|
|
|
|
|
|
|
qtensor = _MaybeQuantizedTensor(None, |
|
try_quantize=self._quantize) |
|
qtensor.load_state_dict(param_state, name='exp_avg') |
|
new_state['exp_avg'] = qtensor |
|
if 'errors' in param_state: |
|
|
|
|
|
|
|
errs = param_state['errors'].to(dtype=torch.uint8).view( |
|
torch.bfloat16) |
|
new_state['errors'] = errs |
|
opt_state[param_id] = new_state |
|
super().__setstate__(state) |
|
|
|
def state_dict(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
d = super().state_dict() |
|
opt_state, _ = d.values() |
|
for param_id in opt_state: |
|
|
|
|
|
|
|
param_state = {k: v for k, v in opt_state[param_id].items()} |
|
if 'exp_avg' in param_state: |
|
qtensor = param_state.pop('exp_avg') |
|
assert isinstance(qtensor, _MaybeQuantizedTensor) |
|
param_state.update( |
|
qtensor.state_dict( |
|
name='exp_avg', |
|
allow_quantized=self._compress_state_dict)) |
|
if 'errors' in param_state: |
|
|
|
|
|
param_state['errors'] = param_state['errors'].view( |
|
torch.uint8).to(dtype=torch.bfloat16) |
|
opt_state[param_id] = param_state |
|
return d |
|
|
|
|
|
class _MaybeQuantizedTensor: |
|
"""Helper class so 8b LION doesn't have to know quantization details. |
|
|
|
Important points about this class: |
|
* It handles CPU tensors not being quantized |
|
* It knows how to save + load state dicts, handling both the quantized |
|
and not quantized cases |
|
* It implements some parts of the torch.Tensor interface that we need, |
|
but is not intended to be a full torch.Tensor replacement |
|
""" |
|
|
|
def __init__(self, data: Optional[torch.Tensor], try_quantize: bool = True): |
|
super().__init__() |
|
self.data: Optional[torch.Tensor] = None |
|
self.quantized: Optional[torch.Tensor] = None |
|
self.scales: Optional[torch.Tensor] = None |
|
self._try_quantize = try_quantize and torch.cuda.is_available() |
|
|
|
|
|
self._f_encode = None |
|
self._f_decode = None |
|
if self._try_quantize: |
|
from turbo import dequantize8b, quantize8b |
|
self._f_encode = quantize8b |
|
self._f_decode = dequantize8b |
|
|
|
if data is not None: |
|
self.set_data(data) |
|
|
|
def state_dict(self, |
|
name: str, |
|
allow_quantized: bool = False) -> Dict[str, torch.Tensor]: |
|
if self.is_quantized() and allow_quantized: |
|
assert self.quantized is not None |
|
assert self.scales is not None |
|
return { |
|
f'{name}::quantized': self.quantized, |
|
f'{name}::scales': self.scales |
|
} |
|
return {name: self.materialize().to(dtype=torch.bfloat16)} |
|
|
|
def load_state_dict(self, d: Dict[str, torch.Tensor], name: str) -> None: |
|
|
|
|
|
d = {k: v for k, v in d.items() if k.startswith(name)} |
|
if name in d: |
|
if len(d) != 1: |
|
raise ValueError( |
|
f'If state dict specifies {name}, it must not ' + |
|
f'specify other keys. Got {list(d.keys())}') |
|
self.set_data(d[name]) |
|
return |
|
|
|
self.quantized = d[f'{name}::quantized'].to(dtype=torch.int8) |
|
self.scales = d[f'{name}::scales'].to(dtype=torch.float16) |
|
|
|
def set_data(self, data: torch.Tensor) -> None: |
|
if self._try_quantize: |
|
if not data.is_cuda: |
|
raise NotImplementedError( |
|
f'Attempting to quantize a non-CUDA {data.dtype} tensor ' + |
|
f'on device {data.device} with shape {data.shape}.') |
|
self.data = None |
|
assert self._f_encode is not None |
|
self.quantized, self.scales = self._f_encode(data) |
|
else: |
|
self.data = data.to(dtype=torch.float32) |
|
self.quantized = None |
|
self.scales = None |
|
|
|
def is_quantized(self) -> bool: |
|
return self.data is None |
|
|
|
def materialize(self) -> torch.Tensor: |
|
if not self.is_quantized(): |
|
assert self.data is not None |
|
return self.data |
|
assert self._f_decode is not None |
|
assert self.quantized is not None |
|
assert self.scales is not None |
|
return self._f_decode(self.quantized, self.scales) |
|
|
|
@property |
|
def is_cuda(self) -> bool: |
|
if self.is_quantized(): |
|
assert self.quantized is not None |
|
return self.quantized.is_cuda |
|
assert self.data is not None |
|
return self.data.is_cuda |
|
|
|
@property |
|
def shape(self) -> Tuple[int]: |
|
if self.is_quantized(): |
|
assert self.quantized is not None |
|
return self.quantized.shape |
|
assert self.data is not None |
|
return self.data.shape |
|
|
|
def numel(self) -> int: |
|
if self.is_quantized(): |
|
assert self.quantized is not None |
|
return self.quantized.numel() |
|
assert self.data is not None |
|
return self.data.numel() |
|
|
|
def __repr__(self): |
|
return (f'{self.__class__.__name__} quantized={self.is_quantized()} ' + |
|
f'shape={self.shape}') |
|
|
|
|
|
def lion_step_unfused(grads: torch.Tensor, |
|
weights: torch.Tensor, |
|
momentums: torch.Tensor, |
|
lr: float, |
|
beta1: float, |
|
beta2: float, |
|
weight_decay: float = 0) -> torch.Tensor: |
|
|
|
momentums = momentums.to(dtype=torch.float32) |
|
grads = grads.to(dtype=torch.float32) |
|
|
|
update = momentums.lerp(grads, 1 - beta1).sign_() |
|
if weight_decay > 0: |
|
weights.mul_(1. - weight_decay) |
|
|
|
weights.add_(update, alpha=-lr) |
|
momentums.lerp_(grads, 1. - beta2) |
|
return momentums |
|
|
|
|
|
def lion8b_step_fused(grads: torch.Tensor, |
|
weights: torch.Tensor, |
|
momentums: torch.Tensor, |
|
scales: torch.Tensor, |
|
lr: float, |
|
beta1: float, |
|
beta2: float, |
|
weight_decay: float, |
|
errors: Optional[torch.Tensor] = None) -> None: |
|
|
|
f16, bf16, f32 = torch.float16, torch.bfloat16, torch.float32 |
|
|
|
use_errors = (errors is not None) and (weights.dtype in (f16, bf16)) |
|
orig_shape = weights.shape |
|
|
|
|
|
quantize_group_size = 32 |
|
num_groups = (weights.numel() + quantize_group_size - |
|
1) // quantize_group_size |
|
if (num_groups != scales.numel()): |
|
raise ValueError(f'Expected {num_groups} quantization scales but ' + |
|
f' received {scales.numel()}') |
|
|
|
for name, tensor, allowed_dtypes in [('grad', grads, (f16, bf16, f32)), |
|
('param', weights, (f16, bf16, f32)), |
|
('momentum', momentums, [torch.int8]), |
|
('scales', scales, [f16]), |
|
('errors', errors, [torch.uint8])]: |
|
if name == 'errors' and not use_errors: |
|
continue |
|
if not tensor.is_cuda: |
|
raise ValueError( |
|
f'{name} must be on a CUDA device, not {tensor.device}') |
|
if not tensor.is_contiguous(): |
|
raise ValueError(f'{name} is not contiguous!') |
|
strides_unequal = tensor.stride() != weights.stride() |
|
if name not in ('scales', 'errors') and strides_unequal: |
|
raise ValueError(f'{name} stride {tensor.stride()} != ' + |
|
f'param stride {weights.stride()}') |
|
if tensor.dtype not in allowed_dtypes: |
|
raise ValueError(f'{name} must have dtype {allowed_dtypes}, not ' + |
|
f'{tensor.dtype}') |
|
if (name != 'scales') and (orig_shape != tensor.shape): |
|
raise ValueError(f'Param shape {orig_shape} != ' + |
|
f'{name} shape {tensor.shape}') |
|
|
|
if grads.dtype in (torch.float16, torch.bfloat16): |
|
allowed_dtypes = (grads.dtype, torch.float32) |
|
if weights.dtype not in allowed_dtypes: |
|
raise ValueError( |
|
f'Weights must be f32 or match grad dtype {grads.dtype}') |
|
|
|
|
|
from turbo import lion8b_step_cuda |
|
return lion8b_step_cuda(grads=grads, |
|
weights=weights, |
|
momentums=momentums, |
|
scales=scales, |
|
lr=lr, |
|
beta1=beta1, |
|
beta2=beta2, |
|
weight_decay=weight_decay, |
|
errors=errors) |
|
|
|
|
|
def _lion8b_step(grads: torch.Tensor, |
|
weights: torch.Tensor, |
|
momentums: _MaybeQuantizedTensor, |
|
lr: float, |
|
beta1: float, |
|
beta2: float, |
|
weight_decay: float = 0, |
|
errors: Optional[torch.Tensor] = None, |
|
fused: bool = True) -> None: |
|
|
|
if fused and not momentums.is_quantized(): |
|
raise NotImplementedError( |
|
'Fused LION step only implemented with quantization.') |
|
|
|
if momentums.is_quantized() and fused: |
|
assert momentums.quantized is not None |
|
assert momentums.scales is not None |
|
return lion8b_step_fused(grads=grads, |
|
weights=weights, |
|
momentums=momentums.quantized, |
|
scales=momentums.scales, |
|
lr=lr, |
|
beta1=beta1, |
|
beta2=beta2, |
|
weight_decay=weight_decay, |
|
errors=errors) |
|
|
|
momentums_float = momentums.materialize() |
|
new_momentums = lion_step_unfused(grads=grads, |
|
weights=weights, |
|
momentums=momentums_float, |
|
lr=lr, |
|
beta1=beta1, |
|
beta2=beta2, |
|
weight_decay=weight_decay) |
|
momentums.set_data(new_momentums) |
|
|