File size: 19,846 Bytes
de4ade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
import torch
class DecoupledLionW_8bit(torch.optim.Optimizer):
"""LION optimizer with ~8 bits of state per parameter.
This optimizer is a drop-in replacement for our regular LION optimizer
with decoupled weight decay, but uses less memory, writes smaller
checkpoints, and offers almost-numerically-identical convergence.
Its state saved per parameter is just an int8, though there are auxiliary
scaling factors that bring the total memory per parameter to ~8.5 bits.
The exact quantization scheme is considered an implementation detail
and may change.
When training on CPUs, however, no quantization will actually take place.
See the LION paper (https://arxiv.org/abs/2302.06675) for details about
the algorithm itself.
Args:
params: iterable of parameters to optimize or dicts defining
parameter groups
lr: learning rate
betas: two coefficients between 0 and 1 used to combine the current
gradients and the momentum. The first coefficient is the weight
of the gradient when computing the update. The second is the
weight of the gradient when computing the new momentum.
weight decay: Weights are multiplied by 1 - `weight_decay` after
each optimizer step. Note that we use decoupled weight decay,
meaning that this decay does not contribute to the momentum.
compress_state_dict: if True, this optimizer's `state_dict` will
include quantized optimizer states. Otherwise, the optimizer
states are converted to bfloat16 Tensors matching the shapes of
their corresponding parameters. The former uses ~8.5 bits per
parameter while the latter uses 16 bits per parameter. However,
the former is less thoroughly tested and will not work with
FSDP or other weight sharding approaches.
quantize: If False, optimizer states will not actually be quantized.
This option is available so that one can easily debug whether
the quantization is causing any convergence issues. Because
quantization is only supported for CUDA parameters, attempting to
update a non-CUDA tensor will raise an error.
error_correction: If True, float16 and bfloat16 parameters will be
given an extra state variable, "errors." This tensor will be
of the same shape as the parameter but of dtype uint8. This
auxiliary variable is used to better approximate float32 updates
by retaining information across optimizer steps.
Raises:
NotImplementedError - If any of `quantize`, `compress_state_dict`,
or `error_correction` are `True` and either a) there is no CUDA
device, or b) step() is executed on a non-CUDA parameter.
"""
def __init__(self,
params: Iterable[torch.Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= betas[0] <= 1.0:
raise ValueError('Invalid beta parameter at index 0: {}'.format(
betas[0]))
if not 0.0 <= betas[1] <= 1.0:
raise ValueError('Invalid beta parameter at index 1: {}'.format(
betas[1]))
if not 0.0 <= weight_decay:
raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay))
if not torch.cuda.is_available():
needs_cuda = ' requires a CUDA device.'
if quantize:
raise NotImplementedError('Quantization' + needs_cuda)
if error_correction:
raise NotImplementedError('Error correction' + needs_cuda)
if compress_state_dict:
raise NotImplementedError('Quantized state dict' + needs_cuda)
_fused = _fused and quantize
self._quantize = quantize
self._error_correction = error_correction
self._compress_state_dict = compress_state_dict
defaults = {
'lr': lr,
'initial_lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'fused': _fused
}
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure: Optional[Callable] = None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
self.step_param(p, group)
return loss
def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None:
if not p.requires_grad or p.grad is None:
return
if self._quantize and not p.is_cuda:
raise NotImplementedError(
f"Can't use quantization with param on {p.device} " +
f'({p.shape}, {p.dtype}). If you need ' +
'to use DecoupledLionW_8bit without a CUDA device, try ' +
'creating this optimizer with quantize=False.')
state = self.state[p] # type:ignore using tensor as key
if 'exp_avg' not in state:
mom = torch.zeros_like(p)
state['exp_avg'] = _MaybeQuantizedTensor(
mom, try_quantize=self._quantize)
need_errs = (p.dtype != torch.float32) and self._error_correction
if state.get('errors') is None and need_errs:
numel = p.numel()
numel += numel % 2 # ensure even number of bytes
errors = torch.zeros(numel, dtype=torch.uint8, device=p.device)
# as of torch 2.1, FSDP can't shard ints for no reason
state['errors'] = errors.view(torch.bfloat16)
decay_factor = hparams['weight_decay']
decay_factor *= hparams['lr'] / hparams['initial_lr']
errors: Optional[torch.Tensor] = None
if 'errors' in state:
errors = state['errors']
assert errors is not None # pyright
errors = errors.view(dtype=torch.uint8)
errors = errors[:p.numel()].view(p.shape) # strip padding + reshape
_lion8b_step(momentums=state['exp_avg'],
weights=p,
grads=p.grad,
beta1=hparams['betas'][0],
beta2=hparams['betas'][1],
lr=hparams['lr'],
weight_decay=decay_factor,
fused=hparams['fused'],
errors=errors)
def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
# we override this function to quantize optimizer states when
# loading a state dict
opt_state, _ = state.values() # other val is param_groups
for param_id in opt_state:
param_state = opt_state[param_id]
new_state = {}
if any(k.startswith('exp_avg') for k in param_state):
# the keys can either be just "exp_avg" or
# "exp_avg::quantized" and "exp_avg::scales", depending on
# whether we saved it as quantized or not. The former case
# gives us interop with regular LION.
qtensor = _MaybeQuantizedTensor(None,
try_quantize=self._quantize)
qtensor.load_state_dict(param_state, name='exp_avg')
new_state['exp_avg'] = qtensor
if 'errors' in param_state:
# we need to cast back to the correct dtype since optimizer
# load_state_dict casts to param dtype for fp params; see
# https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa
errs = param_state['errors'].to(dtype=torch.uint8).view(
torch.bfloat16)
new_state['errors'] = errs
opt_state[param_id] = new_state
super().__setstate__(state)
def state_dict(self):
# If the user hasn't opted into storing compressed state dicts
# we have to make sure our states are regular torch.Tensors. This
# is mostly needed to make FSDP happy in the case that we want to
# resume training with a number of devices where
# (param numel / device count) % quantization group size != 0
# for any param.
d = super().state_dict()
opt_state, _ = d.values() # other val is param_groups
for param_id in opt_state:
# make a copy so that we don't mutate our self.state; opt_state
# isn't the same as self.state, but its consituent dicts are
# the same as those in self.state
param_state = {k: v for k, v in opt_state[param_id].items()}
if 'exp_avg' in param_state: # true if we've taken any steps
qtensor = param_state.pop('exp_avg')
assert isinstance(qtensor, _MaybeQuantizedTensor) # pyright
param_state.update(
qtensor.state_dict(
name='exp_avg',
allow_quantized=self._compress_state_dict))
if 'errors' in param_state:
# fsdp apparently needs the states to be the same shape
# as the params
param_state['errors'] = param_state['errors'].view(
torch.uint8).to(dtype=torch.bfloat16)
opt_state[param_id] = param_state
return d
class _MaybeQuantizedTensor:
"""Helper class so 8b LION doesn't have to know quantization details.
Important points about this class:
* It handles CPU tensors not being quantized
* It knows how to save + load state dicts, handling both the quantized
and not quantized cases
* It implements some parts of the torch.Tensor interface that we need,
but is not intended to be a full torch.Tensor replacement
"""
def __init__(self, data: Optional[torch.Tensor], try_quantize: bool = True):
super().__init__()
self.data: Optional[torch.Tensor] = None
self.quantized: Optional[torch.Tensor] = None
self.scales: Optional[torch.Tensor] = None
self._try_quantize = try_quantize and torch.cuda.is_available()
# conditionally import CUDA kernels
self._f_encode = None
self._f_decode = None
if self._try_quantize:
from turbo import dequantize8b, quantize8b
self._f_encode = quantize8b
self._f_decode = dequantize8b
if data is not None:
self.set_data(data)
def state_dict(self,
name: str,
allow_quantized: bool = False) -> Dict[str, torch.Tensor]:
if self.is_quantized() and allow_quantized:
assert self.quantized is not None # pyright
assert self.scales is not None # pyright
return {
f'{name}::quantized': self.quantized,
f'{name}::scales': self.scales
}
return {name: self.materialize().to(dtype=torch.bfloat16)}
def load_state_dict(self, d: Dict[str, torch.Tensor], name: str) -> None:
# we allow other keys in the state dict for convenience, so you can
# just pass this the whole opt state for a parameters
d = {k: v for k, v in d.items() if k.startswith(name)}
if name in d:
if len(d) != 1:
raise ValueError(
f'If state dict specifies {name}, it must not ' +
f'specify other keys. Got {list(d.keys())}')
self.set_data(d[name])
return
self.quantized = d[f'{name}::quantized'].to(dtype=torch.int8)
self.scales = d[f'{name}::scales'].to(dtype=torch.float16)
def set_data(self, data: torch.Tensor) -> None:
if self._try_quantize:
if not data.is_cuda:
raise NotImplementedError(
f'Attempting to quantize a non-CUDA {data.dtype} tensor ' +
f'on device {data.device} with shape {data.shape}.')
self.data = None
assert self._f_encode is not None # pyright
self.quantized, self.scales = self._f_encode(data)
else:
self.data = data.to(dtype=torch.float32)
self.quantized = None
self.scales = None
def is_quantized(self) -> bool:
return self.data is None
def materialize(self) -> torch.Tensor:
if not self.is_quantized():
assert self.data is not None # pyright
return self.data
assert self._f_decode is not None # pyright
assert self.quantized is not None # pyright
assert self.scales is not None # pyright
return self._f_decode(self.quantized, self.scales)
@property # property to mirror Tensor interface
def is_cuda(self) -> bool:
if self.is_quantized():
assert self.quantized is not None # pyright
return self.quantized.is_cuda
assert self.data is not None # pyright
return self.data.is_cuda
@property # property to mirror Tensor interface
def shape(self) -> Tuple[int]:
if self.is_quantized():
assert self.quantized is not None # pyright
return self.quantized.shape
assert self.data is not None # pyright
return self.data.shape
def numel(self) -> int:
if self.is_quantized():
assert self.quantized is not None # pyright
return self.quantized.numel()
assert self.data is not None # pyright
return self.data.numel()
def __repr__(self):
return (f'{self.__class__.__name__} quantized={self.is_quantized()} ' +
f'shape={self.shape}')
def lion_step_unfused(grads: torch.Tensor,
weights: torch.Tensor,
momentums: torch.Tensor,
lr: float,
beta1: float,
beta2: float,
weight_decay: float = 0) -> torch.Tensor:
# f32 cast to match fused impl + for compatibility with f32 grads or weights
momentums = momentums.to(dtype=torch.float32)
grads = grads.to(dtype=torch.float32)
update = momentums.lerp(grads, 1 - beta1).sign_()
if weight_decay > 0:
weights.mul_(1. - weight_decay)
weights.add_(update, alpha=-lr)
momentums.lerp_(grads, 1. - beta2)
return momentums # f32 upcast means not necessarily modified in place
def lion8b_step_fused(grads: torch.Tensor,
weights: torch.Tensor,
momentums: torch.Tensor,
scales: torch.Tensor,
lr: float,
beta1: float,
beta2: float,
weight_decay: float,
errors: Optional[torch.Tensor] = None) -> None:
# just to save space in lists of allowed dtypes
f16, bf16, f32 = torch.float16, torch.bfloat16, torch.float32
use_errors = (errors is not None) and (weights.dtype in (f16, bf16))
orig_shape = weights.shape
# ------------------------------------------------ wall of error checking
quantize_group_size = 32
num_groups = (weights.numel() + quantize_group_size -
1) // quantize_group_size
if (num_groups != scales.numel()):
raise ValueError(f'Expected {num_groups} quantization scales but ' +
f' received {scales.numel()}')
for name, tensor, allowed_dtypes in [('grad', grads, (f16, bf16, f32)),
('param', weights, (f16, bf16, f32)),
('momentum', momentums, [torch.int8]),
('scales', scales, [f16]),
('errors', errors, [torch.uint8])]:
if name == 'errors' and not use_errors:
continue
if not tensor.is_cuda:
raise ValueError(
f'{name} must be on a CUDA device, not {tensor.device}')
if not tensor.is_contiguous():
raise ValueError(f'{name} is not contiguous!')
strides_unequal = tensor.stride() != weights.stride()
if name not in ('scales', 'errors') and strides_unequal:
raise ValueError(f'{name} stride {tensor.stride()} != ' +
f'param stride {weights.stride()}')
if tensor.dtype not in allowed_dtypes:
raise ValueError(f'{name} must have dtype {allowed_dtypes}, not ' +
f'{tensor.dtype}')
if (name != 'scales') and (orig_shape != tensor.shape):
raise ValueError(f'Param shape {orig_shape} != ' +
f'{name} shape {tensor.shape}')
if grads.dtype in (torch.float16, torch.bfloat16):
allowed_dtypes = (grads.dtype, torch.float32)
if weights.dtype not in allowed_dtypes:
raise ValueError(
f'Weights must be f32 or match grad dtype {grads.dtype}')
# ------------------------------------------------ actual function call
from turbo import lion8b_step_cuda
return lion8b_step_cuda(grads=grads,
weights=weights,
momentums=momentums,
scales=scales,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
errors=errors)
def _lion8b_step(grads: torch.Tensor,
weights: torch.Tensor,
momentums: _MaybeQuantizedTensor,
lr: float,
beta1: float,
beta2: float,
weight_decay: float = 0,
errors: Optional[torch.Tensor] = None,
fused: bool = True) -> None:
if fused and not momentums.is_quantized():
raise NotImplementedError(
'Fused LION step only implemented with quantization.')
if momentums.is_quantized() and fused:
assert momentums.quantized is not None # pyright
assert momentums.scales is not None # pyright
return lion8b_step_fused(grads=grads,
weights=weights,
momentums=momentums.quantized,
scales=momentums.scales,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
errors=errors)
momentums_float = momentums.materialize()
new_momentums = lion_step_unfused(grads=grads,
weights=weights,
momentums=momentums_float,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay)
momentums.set_data(new_momentums)
|