File size: 19,846 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, Iterable, Optional, Tuple

import torch


class DecoupledLionW_8bit(torch.optim.Optimizer):
    """LION optimizer with ~8 bits of state per parameter.

    This optimizer is a drop-in replacement for our regular LION optimizer
    with decoupled weight decay, but uses less memory, writes smaller
    checkpoints, and offers almost-numerically-identical convergence.

    Its state saved per parameter is just an int8, though there are auxiliary
    scaling factors that bring the total memory per parameter to ~8.5 bits.
    The exact quantization scheme is considered an implementation detail
    and may change.

    When training on CPUs, however, no quantization will actually take place.

    See the LION paper (https://arxiv.org/abs/2302.06675) for details about
    the algorithm itself.

    Args:
        params: iterable of parameters to optimize or dicts defining
            parameter groups
        lr: learning rate
        betas: two coefficients between 0 and 1 used to combine the current
            gradients and the momentum. The first coefficient is the weight
            of the gradient when computing the update. The second is the
            weight of the gradient when computing the new momentum.
        weight decay: Weights are multiplied by 1 - `weight_decay` after
            each optimizer step. Note that we use decoupled weight decay,
            meaning that this decay does not contribute to the momentum.
        compress_state_dict: if True, this optimizer's `state_dict` will
            include quantized optimizer states. Otherwise, the optimizer
            states are converted to bfloat16 Tensors matching the shapes of
            their corresponding parameters. The former uses ~8.5 bits per
            parameter while the latter uses 16 bits per parameter. However,
            the former is less thoroughly tested and will not work with
            FSDP or other weight sharding approaches.
        quantize: If False, optimizer states will not actually be quantized.
            This option is available so that one can easily debug whether
            the quantization is causing any convergence issues. Because
            quantization is only supported for CUDA parameters, attempting to
            update a non-CUDA tensor will raise an error.
        error_correction: If True, float16 and bfloat16 parameters will be
            given an extra state variable, "errors." This tensor will be
            of the same shape as the parameter but of dtype uint8. This
            auxiliary variable is used to better approximate float32 updates
            by retaining information across optimizer steps.

    Raises:
        NotImplementedError - If any of `quantize`, `compress_state_dict`,
            or `error_correction` are `True` and either a) there is no CUDA
            device, or b) step() is executed on a non-CUDA parameter.
    """

    def __init__(self,
                 params: Iterable[torch.Tensor],
                 lr: float = 1e-3,
                 betas: Tuple[float, float] = (0.9, 0.99),
                 weight_decay: float = 0,
                 quantize: bool = True,
                 compress_state_dict: bool = False,
                 error_correction: bool = False,
                 _fused: bool = True):  # XXX this flag is mostly for testing...

        if lr < 0.0:
            raise ValueError('Invalid learning rate: {}'.format(lr))
        if not 0.0 <= betas[0] <= 1.0:
            raise ValueError('Invalid beta parameter at index 0: {}'.format(
                betas[0]))
        if not 0.0 <= betas[1] <= 1.0:
            raise ValueError('Invalid beta parameter at index 1: {}'.format(
                betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError(
                'Invalid weight_decay value: {}'.format(weight_decay))

        if not torch.cuda.is_available():
            needs_cuda = ' requires a CUDA device.'
            if quantize:
                raise NotImplementedError('Quantization' + needs_cuda)
            if error_correction:
                raise NotImplementedError('Error correction' + needs_cuda)
            if compress_state_dict:
                raise NotImplementedError('Quantized state dict' + needs_cuda)

        _fused = _fused and quantize
        self._quantize = quantize
        self._error_correction = error_correction
        self._compress_state_dict = compress_state_dict

        defaults = {
            'lr': lr,
            'initial_lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'fused': _fused
        }
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure: Optional[Callable] = None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                self.step_param(p, group)

        return loss

    def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None:
        if not p.requires_grad or p.grad is None:
            return
        if self._quantize and not p.is_cuda:
            raise NotImplementedError(
                f"Can't use quantization with param on {p.device} " +
                f'({p.shape}, {p.dtype}). If you need ' +
                'to use DecoupledLionW_8bit without a CUDA device, try ' +
                'creating this optimizer with quantize=False.')
        state = self.state[p]  # type:ignore using tensor as key
        if 'exp_avg' not in state:
            mom = torch.zeros_like(p)
            state['exp_avg'] = _MaybeQuantizedTensor(
                mom, try_quantize=self._quantize)
        need_errs = (p.dtype != torch.float32) and self._error_correction
        if state.get('errors') is None and need_errs:
            numel = p.numel()
            numel += numel % 2  # ensure even number of bytes
            errors = torch.zeros(numel, dtype=torch.uint8, device=p.device)
            # as of torch 2.1, FSDP can't shard ints for no reason
            state['errors'] = errors.view(torch.bfloat16)
        decay_factor = hparams['weight_decay']
        decay_factor *= hparams['lr'] / hparams['initial_lr']
        errors: Optional[torch.Tensor] = None
        if 'errors' in state:
            errors = state['errors']
            assert errors is not None  # pyright
            errors = errors.view(dtype=torch.uint8)
            errors = errors[:p.numel()].view(p.shape)  # strip padding + reshape
        _lion8b_step(momentums=state['exp_avg'],
                     weights=p,
                     grads=p.grad,
                     beta1=hparams['betas'][0],
                     beta2=hparams['betas'][1],
                     lr=hparams['lr'],
                     weight_decay=decay_factor,
                     fused=hparams['fused'],
                     errors=errors)

    def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
        # we override this function to quantize optimizer states when
        # loading a state dict
        opt_state, _ = state.values()  # other val is param_groups
        for param_id in opt_state:
            param_state = opt_state[param_id]
            new_state = {}
            if any(k.startswith('exp_avg') for k in param_state):
                # the keys can either be just "exp_avg" or
                # "exp_avg::quantized" and "exp_avg::scales", depending on
                # whether we saved it as quantized or not. The former case
                # gives us interop with regular LION.
                qtensor = _MaybeQuantizedTensor(None,
                                                try_quantize=self._quantize)
                qtensor.load_state_dict(param_state, name='exp_avg')
                new_state['exp_avg'] = qtensor
            if 'errors' in param_state:
                # we need to cast back to the correct dtype since optimizer
                # load_state_dict casts to param dtype for fp params; see
                # https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa
                errs = param_state['errors'].to(dtype=torch.uint8).view(
                    torch.bfloat16)
                new_state['errors'] = errs
            opt_state[param_id] = new_state
        super().__setstate__(state)

    def state_dict(self):
        # If the user hasn't opted into storing compressed state dicts
        # we have to make sure our states are regular torch.Tensors. This
        # is mostly needed to make FSDP happy in the case that we want to
        # resume training with a number of devices where
        #   (param numel / device count) % quantization group size != 0
        # for any param.
        d = super().state_dict()
        opt_state, _ = d.values()  # other val is param_groups
        for param_id in opt_state:
            # make a copy so that we don't mutate our self.state; opt_state
            # isn't the same as self.state, but its consituent dicts are
            # the same as those in self.state
            param_state = {k: v for k, v in opt_state[param_id].items()}
            if 'exp_avg' in param_state:  # true if we've taken any steps
                qtensor = param_state.pop('exp_avg')
                assert isinstance(qtensor, _MaybeQuantizedTensor)  # pyright
                param_state.update(
                    qtensor.state_dict(
                        name='exp_avg',
                        allow_quantized=self._compress_state_dict))
            if 'errors' in param_state:
                # fsdp apparently needs the states to be the same shape
                # as the params
                param_state['errors'] = param_state['errors'].view(
                    torch.uint8).to(dtype=torch.bfloat16)
            opt_state[param_id] = param_state
        return d


class _MaybeQuantizedTensor:
    """Helper class so 8b LION doesn't have to know quantization details.

    Important points about this class:
    * It handles CPU tensors not being quantized
    * It knows how to save + load state dicts, handling both the quantized
        and not quantized cases
    * It implements some parts of the torch.Tensor interface that we need,
        but is not intended to be a full torch.Tensor replacement
    """

    def __init__(self, data: Optional[torch.Tensor], try_quantize: bool = True):
        super().__init__()
        self.data: Optional[torch.Tensor] = None
        self.quantized: Optional[torch.Tensor] = None
        self.scales: Optional[torch.Tensor] = None
        self._try_quantize = try_quantize and torch.cuda.is_available()

        # conditionally import CUDA kernels
        self._f_encode = None
        self._f_decode = None
        if self._try_quantize:
            from turbo import dequantize8b, quantize8b
            self._f_encode = quantize8b
            self._f_decode = dequantize8b

        if data is not None:
            self.set_data(data)

    def state_dict(self,
                   name: str,
                   allow_quantized: bool = False) -> Dict[str, torch.Tensor]:
        if self.is_quantized() and allow_quantized:
            assert self.quantized is not None  # pyright
            assert self.scales is not None  # pyright
            return {
                f'{name}::quantized': self.quantized,
                f'{name}::scales': self.scales
            }
        return {name: self.materialize().to(dtype=torch.bfloat16)}

    def load_state_dict(self, d: Dict[str, torch.Tensor], name: str) -> None:
        # we allow other keys in the state dict for convenience, so you can
        # just pass this the whole opt state for a parameters
        d = {k: v for k, v in d.items() if k.startswith(name)}
        if name in d:
            if len(d) != 1:
                raise ValueError(
                    f'If state dict specifies {name}, it must not ' +
                    f'specify other keys. Got {list(d.keys())}')
            self.set_data(d[name])
            return

        self.quantized = d[f'{name}::quantized'].to(dtype=torch.int8)
        self.scales = d[f'{name}::scales'].to(dtype=torch.float16)

    def set_data(self, data: torch.Tensor) -> None:
        if self._try_quantize:
            if not data.is_cuda:
                raise NotImplementedError(
                    f'Attempting to quantize a non-CUDA {data.dtype} tensor ' +
                    f'on device {data.device} with shape {data.shape}.')
            self.data = None
            assert self._f_encode is not None  # pyright
            self.quantized, self.scales = self._f_encode(data)
        else:
            self.data = data.to(dtype=torch.float32)
            self.quantized = None
            self.scales = None

    def is_quantized(self) -> bool:
        return self.data is None

    def materialize(self) -> torch.Tensor:
        if not self.is_quantized():
            assert self.data is not None  # pyright
            return self.data
        assert self._f_decode is not None  # pyright
        assert self.quantized is not None  # pyright
        assert self.scales is not None  # pyright
        return self._f_decode(self.quantized, self.scales)

    @property  # property to mirror Tensor interface
    def is_cuda(self) -> bool:
        if self.is_quantized():
            assert self.quantized is not None  # pyright
            return self.quantized.is_cuda
        assert self.data is not None  # pyright
        return self.data.is_cuda

    @property  # property to mirror Tensor interface
    def shape(self) -> Tuple[int]:
        if self.is_quantized():
            assert self.quantized is not None  # pyright
            return self.quantized.shape
        assert self.data is not None  # pyright
        return self.data.shape

    def numel(self) -> int:
        if self.is_quantized():
            assert self.quantized is not None  # pyright
            return self.quantized.numel()
        assert self.data is not None  # pyright
        return self.data.numel()

    def __repr__(self):
        return (f'{self.__class__.__name__} quantized={self.is_quantized()} ' +
                f'shape={self.shape}')


def lion_step_unfused(grads: torch.Tensor,
                      weights: torch.Tensor,
                      momentums: torch.Tensor,
                      lr: float,
                      beta1: float,
                      beta2: float,
                      weight_decay: float = 0) -> torch.Tensor:
    # f32 cast to match fused impl + for compatibility with f32 grads or weights
    momentums = momentums.to(dtype=torch.float32)
    grads = grads.to(dtype=torch.float32)

    update = momentums.lerp(grads, 1 - beta1).sign_()
    if weight_decay > 0:
        weights.mul_(1. - weight_decay)

    weights.add_(update, alpha=-lr)
    momentums.lerp_(grads, 1. - beta2)
    return momentums  # f32 upcast means not necessarily modified in place


def lion8b_step_fused(grads: torch.Tensor,
                      weights: torch.Tensor,
                      momentums: torch.Tensor,
                      scales: torch.Tensor,
                      lr: float,
                      beta1: float,
                      beta2: float,
                      weight_decay: float,
                      errors: Optional[torch.Tensor] = None) -> None:
    # just to save space in lists of allowed dtypes
    f16, bf16, f32 = torch.float16, torch.bfloat16, torch.float32

    use_errors = (errors is not None) and (weights.dtype in (f16, bf16))
    orig_shape = weights.shape

    # ------------------------------------------------ wall of error checking
    quantize_group_size = 32
    num_groups = (weights.numel() + quantize_group_size -
                  1) // quantize_group_size
    if (num_groups != scales.numel()):
        raise ValueError(f'Expected {num_groups} quantization scales but ' +
                         f' received {scales.numel()}')

    for name, tensor, allowed_dtypes in [('grad', grads, (f16, bf16, f32)),
                                         ('param', weights, (f16, bf16, f32)),
                                         ('momentum', momentums, [torch.int8]),
                                         ('scales', scales, [f16]),
                                         ('errors', errors, [torch.uint8])]:
        if name == 'errors' and not use_errors:
            continue
        if not tensor.is_cuda:
            raise ValueError(
                f'{name} must be on a CUDA device, not {tensor.device}')
        if not tensor.is_contiguous():
            raise ValueError(f'{name} is not contiguous!')
        strides_unequal = tensor.stride() != weights.stride()
        if name not in ('scales', 'errors') and strides_unequal:
            raise ValueError(f'{name} stride {tensor.stride()} != ' +
                             f'param stride {weights.stride()}')
        if tensor.dtype not in allowed_dtypes:
            raise ValueError(f'{name} must have dtype {allowed_dtypes}, not ' +
                             f'{tensor.dtype}')
        if (name != 'scales') and (orig_shape != tensor.shape):
            raise ValueError(f'Param shape {orig_shape} != ' +
                             f'{name} shape {tensor.shape}')

    if grads.dtype in (torch.float16, torch.bfloat16):
        allowed_dtypes = (grads.dtype, torch.float32)
        if weights.dtype not in allowed_dtypes:
            raise ValueError(
                f'Weights must be f32 or match grad dtype {grads.dtype}')

    # ------------------------------------------------ actual function call
    from turbo import lion8b_step_cuda
    return lion8b_step_cuda(grads=grads,
                            weights=weights,
                            momentums=momentums,
                            scales=scales,
                            lr=lr,
                            beta1=beta1,
                            beta2=beta2,
                            weight_decay=weight_decay,
                            errors=errors)


def _lion8b_step(grads: torch.Tensor,
                 weights: torch.Tensor,
                 momentums: _MaybeQuantizedTensor,
                 lr: float,
                 beta1: float,
                 beta2: float,
                 weight_decay: float = 0,
                 errors: Optional[torch.Tensor] = None,
                 fused: bool = True) -> None:

    if fused and not momentums.is_quantized():
        raise NotImplementedError(
            'Fused LION step only implemented with quantization.')

    if momentums.is_quantized() and fused:
        assert momentums.quantized is not None  # pyright
        assert momentums.scales is not None  # pyright
        return lion8b_step_fused(grads=grads,
                                 weights=weights,
                                 momentums=momentums.quantized,
                                 scales=momentums.scales,
                                 lr=lr,
                                 beta1=beta1,
                                 beta2=beta2,
                                 weight_decay=weight_decay,
                                 errors=errors)

    momentums_float = momentums.materialize()
    new_momentums = lion_step_unfused(grads=grads,
                                      weights=weights,
                                      momentums=momentums_float,
                                      lr=lr,
                                      beta1=beta1,
                                      beta2=beta2,
                                      weight_decay=weight_decay)
    momentums.set_data(new_momentums)