English
naveensp commited on
Commit
58e3628
·
verified ·
1 Parent(s): 1412507

Delete optim.py

Browse files
Files changed (1) hide show
  1. optim.py +0 -769
optim.py DELETED
@@ -1,769 +0,0 @@
1
- import logging
2
- from abc import ABCMeta, abstractmethod
3
- from dataclasses import dataclass, replace
4
- from math import cos, pi, sqrt
5
- from typing import Any, Dict, List, Optional, Tuple
6
-
7
- import torch
8
- import torch.distributed as dist
9
- import torch.nn as nn
10
- from torch.distributed.fsdp import FullyShardedDataParallel
11
- from torch.optim.optimizer import Optimizer as OptimizerBase
12
-
13
- from .model import LayerNormBase, BitLinear158
14
- from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig
15
- from .torch_util import get_default_device, is_distributed
16
-
17
- __all__ = [
18
- "Optimizer",
19
- "LionW",
20
- "AdamW",
21
- "Scheduler",
22
- "CosWithWarmup",
23
- "LinearWithWarmup",
24
- "InvSqrtWithWarmup",
25
- "MaxScheduler",
26
- "ConstantScheduler",
27
- "BoltOnWarmupScheduler",
28
- "build_optimizer",
29
- "build_scheduler",
30
- ]
31
-
32
-
33
- log = logging.getLogger(__name__)
34
-
35
-
36
- class Optimizer(OptimizerBase):
37
- def _clean_param_name(self, name: str) -> str:
38
- return name.replace("_fsdp_wrapped_module.", "")
39
-
40
- @torch.no_grad()
41
- def clip_grads_and_collect_metrics(
42
- self, global_step: int, collect_param_metrics: bool = True
43
- ) -> Dict[str, torch.Tensor]:
44
- """
45
- Clips gradients for every group that has the field `max_grad_norm`.
46
- At the same time collect metrics for each parameter and its gradient.
47
- """
48
- device = get_default_device()
49
-
50
- # NOTE (epwalsh): during distributed training we're making an assumption that the order of
51
- # the param groups and the params within each group are the same across all ranks.
52
- # This is justified since we initialize the parameter groups in every rank by iterating over
53
- # `module.parameters()` or `module.named_modules()` / `module.named_parameters()`, each of which
54
- # provides a consistent order.
55
- # For each parameter (with a gradient) we'll collect:
56
- # - min, max, avg, norm of the param itself
57
- # - min, max, avg, norm of the param's gradient
58
- # - min, max, avg, norm of any additional per-parameter optimizer state metrics returned from
59
- # `self.get_state_for_param()`.
60
- # Afterwards we'll reduce these all over all ranks.
61
- per_param_min_metrics: List[torch.Tensor] = []
62
- per_param_max_metrics: List[torch.Tensor] = []
63
- per_param_sum_metrics: List[torch.Tensor] = []
64
- per_param_norm_metrics: List[torch.Tensor] = []
65
- per_param_numel_metrics: List[torch.Tensor] = []
66
-
67
- per_param_min_metric_names: List[str] = []
68
- per_param_max_metric_names: List[str] = []
69
- per_param_avg_metric_names: List[str] = []
70
- per_param_norm_metric_names: List[str] = []
71
-
72
- # Collect metrics locally.
73
- for group in self.param_groups:
74
- if is_distributed():
75
- # TODO (epwalsh): handle non-sharded params. We don't have any right now but we would
76
- # with ReLoRa, for example.
77
- assert group.get("sharded", True) is True
78
-
79
- for name, p in zip(group["param_names"], group["params"]):
80
- name = self._clean_param_name(name)
81
- # Always need to collect the norm of gradients for clipping, even if we're not collecting
82
- # other metrics.
83
- tensors: List[Optional[torch.Tensor]] = [p.grad]
84
- prefixes: List[str] = [f"grad/{name}"]
85
- if collect_param_metrics:
86
- state = self.get_state_for_param(p)
87
- sorted_state_keys = sorted([k for k in state.keys()])
88
- tensors.extend([p] + [state[key] for key in sorted_state_keys])
89
- prefixes.extend([f"param/{name}"] + [f"{key}/{name}" for key in sorted_state_keys])
90
- assert len(tensors) == len(prefixes)
91
-
92
- # Get min, max, avg, and norm for all `tensors` associated with the parameter.
93
- for x, prefix in zip(tensors, prefixes):
94
- # grad or state tensors could be none for params that have their shards completely on
95
- # other ranks.
96
- if x is not None and x.numel() > 0:
97
- if collect_param_metrics:
98
- x_abs = x.abs()
99
- per_param_min_metrics.append(x_abs.min().unsqueeze(0).to(dtype=torch.float32))
100
- per_param_max_metrics.append(x_abs.max().unsqueeze(0).to(dtype=torch.float32))
101
- per_param_sum_metrics.append(x.sum().unsqueeze(0).to(dtype=torch.float32))
102
- per_param_numel_metrics.append(
103
- torch.tensor([x.numel()], device=device, dtype=torch.float32)
104
- )
105
- per_param_norm_metrics.append(
106
- torch.linalg.vector_norm(x, 2.0, dtype=torch.float32).unsqueeze(0)
107
- )
108
- else:
109
- if collect_param_metrics:
110
- per_param_min_metrics.append(
111
- torch.tensor([float("inf")], device=device, dtype=torch.float32)
112
- )
113
- per_param_max_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
114
- per_param_sum_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
115
- per_param_numel_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
116
- per_param_norm_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
117
- if collect_param_metrics:
118
- per_param_min_metric_names.append(f"{prefix}.min")
119
- per_param_max_metric_names.append(f"{prefix}.max")
120
- per_param_avg_metric_names.append(f"{prefix}.avg")
121
- per_param_norm_metric_names.append(f"{prefix}.norm")
122
-
123
- assert (
124
- len(per_param_min_metrics)
125
- == len(per_param_min_metric_names)
126
- == len(per_param_max_metrics)
127
- == len(per_param_max_metric_names)
128
- == len(per_param_sum_metrics)
129
- == len(per_param_numel_metrics)
130
- == len(per_param_avg_metric_names)
131
- )
132
- assert len(per_param_norm_metrics) == len(per_param_norm_metric_names)
133
-
134
- def is_grad_norm_metric(metric_name: str) -> bool:
135
- return metric_name.startswith("grad/") and metric_name.endswith(".norm")
136
-
137
- # Now reduce metrics over all ranks.
138
- total_grad_norm: torch.Tensor
139
- per_param_avg_metrics: List[torch.Tensor] = []
140
- if is_distributed(): # TODO (epwalsh): skip for non-sharded params
141
- # Reduce metrics across all ranks. Note that we can use a `reduce` for most cases
142
- # instead of an `all_reduce`, but we need `all_reduce` for norms so that all ranks
143
- # get the right value for gradient norms so they can clip correctly.
144
- # Reduce mins.
145
- if per_param_min_metrics:
146
- all_mins = torch.cat(per_param_min_metrics).to(device)
147
- dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN)
148
- per_param_min_metrics = all_mins.split(1)
149
- # Reduce maxs.
150
- if per_param_max_metrics:
151
- all_maxs = torch.cat(per_param_max_metrics).to(device)
152
- dist.reduce(all_maxs, 0, op=dist.ReduceOp.MAX)
153
- per_param_max_metrics = all_maxs.split(1)
154
- # Reduce sums or just norms.
155
- all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0
156
- if per_param_sum_metrics and per_param_numel_metrics:
157
- all_sums = torch.cat(per_param_sum_metrics).to(device)
158
- all_numels = torch.cat(per_param_numel_metrics).to(device)
159
- all_sums_norms_numels = torch.cat(
160
- [all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0
161
- )
162
- dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM)
163
- all_sums, all_norms, all_numels = all_sums_norms_numels.split(1)
164
- # Get averages.
165
- # NOTE: could get infs for non-rank0 processes but that's okay.
166
- per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1)
167
- else:
168
- dist.all_reduce(all_norms, op=dist.ReduceOp.SUM)
169
- grad_norm_metric_mask = torch.tensor(
170
- [float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device
171
- )
172
- total_grad_norm = (all_norms * grad_norm_metric_mask).sum() ** 0.5
173
- per_param_norm_metrics = (all_norms ** (0.5)).squeeze(0).split(1)
174
- else:
175
- total_grad_norm = (
176
- torch.cat(
177
- [
178
- m
179
- for m, n in zip(per_param_norm_metrics, per_param_norm_metric_names)
180
- if is_grad_norm_metric(n)
181
- ]
182
- )
183
- ** 2.0
184
- ).sum() ** 0.5
185
- per_param_avg_metrics = [x / n for x, n in zip(per_param_sum_metrics, per_param_numel_metrics)]
186
-
187
- assert len(per_param_avg_metrics) == len(per_param_avg_metric_names)
188
-
189
- # Collect all metrics into a single dict.
190
- all_metrics: Dict[str, torch.Tensor] = {}
191
- for metric_name, metric in zip(per_param_min_metric_names, per_param_min_metrics):
192
- all_metrics[metric_name] = metric.squeeze(0)
193
- for metric_name, metric in zip(per_param_max_metric_names, per_param_max_metrics):
194
- all_metrics[metric_name] = metric.squeeze(0)
195
- for metric_name, metric in zip(per_param_avg_metric_names, per_param_avg_metrics):
196
- all_metrics[metric_name] = metric.squeeze(0)
197
- for metric_name, metric in zip(per_param_norm_metric_names, per_param_norm_metrics):
198
- all_metrics[metric_name] = metric.squeeze(0)
199
- all_metrics["total_grad_norm"] = total_grad_norm
200
-
201
- # Clip gradients.
202
- num_grads_clipped = 0
203
- num_eligible_grads = 0
204
- for group in self.param_groups:
205
- if (max_norm_ratio := group.get("max_grad_norm_ratio")) is not None:
206
- num_clipped = self._do_adaptive_clipping(
207
- group, max_norm_ratio, global_step, all_metrics, collect_param_metrics=collect_param_metrics
208
- )
209
- elif (max_norm := group.get("max_grad_norm")) is not None:
210
- num_clipped = self._do_global_fixed_clipping(
211
- group, max_norm, all_metrics, collect_param_metrics=collect_param_metrics
212
- )
213
- else:
214
- # No clipping needed.
215
- continue
216
- num_eligible_grads += len(group["params"])
217
- if num_clipped is not None:
218
- num_grads_clipped += num_clipped
219
-
220
- if collect_param_metrics:
221
- if num_eligible_grads > 0:
222
- clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu")
223
- else:
224
- clipping_rate = torch.tensor(0.0, device="cpu")
225
- all_metrics["clipping_rate"] = clipping_rate
226
- return all_metrics
227
- else:
228
- return {}
229
-
230
- @torch.no_grad()
231
- def _do_adaptive_clipping(
232
- self,
233
- group: Dict[str, Any],
234
- max_norm_ratio: float,
235
- global_step: int,
236
- all_metrics: Dict[str, torch.Tensor],
237
- collect_param_metrics: bool = True,
238
- ) -> Optional[int]:
239
- """
240
- Do adaptive gradient clipping on a param group.
241
-
242
- If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
243
- """
244
- device = get_default_device()
245
- num_grads_clipped = 0
246
- # We'll use the bigger of beta1 and beta2 to update the exponential average of the norm of
247
- # the gradient (a scalar), not to be confused with the exponential average of the gradient.
248
- # TODO (epwalsh): handle optimizers that don't have betas.
249
- beta1, beta2 = group["betas"]
250
- beta = max(beta1, beta2)
251
- for name, p in zip(group["param_names"], group["params"]):
252
- name = self._clean_param_name(name)
253
- grad_norm = all_metrics.get(f"grad/{name}.norm")
254
- if grad_norm is None:
255
- continue
256
-
257
- # Get or initialize the exponential average of grad norm.
258
- # TODO: The way we have it right now, every rank tracks the `grad_norm_exp_avg` of every parameter,
259
- # even parameters for which the corresponding local shard is empty. This has the potential to
260
- # cause some issues with the optimizer, as we ran into with https://github.com/allenai/LLM/pull/372.
261
- # So we should consider changing how we do this at some point so that we don't add any state
262
- # to parameters for which the local shard is empty. That would probably add extra distributed
263
- # communication, at least on steps where we have to log (i.e. when `collect_param_metrics=True`).
264
- state = self.state[p]
265
- grad_norm_exp_avg = state.get("grad_norm_exp_avg")
266
- if grad_norm_exp_avg is None:
267
- grad_norm_exp_avg = grad_norm.clone().to(device)
268
- # We don't want to add anything to `state` until `state` has been initialized, otherwise
269
- # this will crash some optimizers which rely on checking `len(state)`. The downside here
270
- # is that we won't start tracking `grad_norm_exp_avg` until the 2nd training step.
271
- if global_step > 1:
272
- state["grad_norm_exp_avg"] = grad_norm_exp_avg
273
-
274
- max_allowed_norm = max_norm_ratio * grad_norm_exp_avg
275
- clip_coef = max_allowed_norm / (grad_norm + 1e-6)
276
-
277
- # Clip the gradients and update the exponential average.
278
- # Note that multiplying by the clamped coefficient is meaningless when it is
279
- # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
280
- clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
281
- if p.grad is not None:
282
- # p.grad could be none for some ranks when using FSDP.
283
- p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
284
-
285
- # Update the exponential average of the norm of the gradient with the clipped norm of the gradient.
286
- grad_norm_exp_avg.lerp_((grad_norm * clip_coef_clamped).to(grad_norm_exp_avg.device), 1 - beta)
287
- # Alternative: update with the *unclipped* norm of the gradient.
288
- # grad_norm_exp_avg.lerp_(grad_norm.to(grad_norm_exp_avg.device), 1 - beta)
289
-
290
- if collect_param_metrics:
291
- # Can't avoid host-device sync here.
292
- if clip_coef_clamped < 1.0:
293
- num_grads_clipped += 1
294
- all_metrics[f"grad_norm_exp_avg/{name}"] = grad_norm_exp_avg
295
- return num_grads_clipped if collect_param_metrics else None
296
-
297
- @torch.no_grad()
298
- def _do_global_fixed_clipping(
299
- self,
300
- group: Dict[str, Any],
301
- max_norm: float,
302
- all_metrics: Dict[str, torch.Tensor],
303
- collect_param_metrics: bool = True,
304
- ) -> Optional[int]:
305
- """
306
- Do global fixed gradient clipping on a param group.
307
-
308
- If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
309
- """
310
- device = get_default_device()
311
- total_grad_norm = all_metrics["total_grad_norm"]
312
- clip_coef = max_norm / (total_grad_norm.to(device) + 1e-6)
313
- clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
314
- num_grads_clipped: Optional[int] = None
315
- if collect_param_metrics:
316
- # Can't avoid host-device sync here.
317
- if clip_coef_clamped < 1.0:
318
- num_grads_clipped = len(group["params"])
319
- for p in group["params"]:
320
- # Clip the gradients.
321
- # Note that multiplying by the clamped coefficient is meaningless when it is
322
- # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
323
- if p.grad is not None:
324
- # p.grad could be none for some ranks when using FSDP.
325
- p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
326
- return num_grads_clipped
327
-
328
- def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
329
- del module
330
- return {}
331
-
332
- def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
333
- del param
334
- return {}
335
-
336
-
337
- class LionW(Optimizer):
338
- """
339
- Adapted from https://github.com/google/automl/blob/master/lion/lion_pytorch.py
340
- """
341
-
342
- def __init__(
343
- self,
344
- params,
345
- lr: float = 1e-4,
346
- betas: Tuple[float, float] = (0.9, 0.99),
347
- weight_decay: float = 0.0,
348
- ):
349
- assert lr > 0.0
350
- assert all([0.0 <= beta <= 1.0 for beta in betas])
351
- defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
352
- super().__init__(params, defaults)
353
- for group in self.param_groups:
354
- group["initial_lr"] = group["lr"]
355
- self._update_total_dot_prod: Optional[torch.Tensor] = None
356
- self._update_total_norm: Optional[torch.Tensor] = None
357
- self._signed_update_total_norm: Optional[torch.Tensor] = None
358
-
359
- def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
360
- update_total_dot_prod = self._update_total_dot_prod
361
- update_total_norm = self._update_total_norm
362
- signed_update_total_norm = self._signed_update_total_norm
363
- if update_total_dot_prod is None or update_total_norm is None or signed_update_total_norm is None:
364
- return {}
365
-
366
- if is_distributed() and isinstance(module, FullyShardedDataParallel):
367
- # Reduce total dot prod and norms across all ranks.
368
- update_total_norm = update_total_norm**2.0
369
- signed_update_total_norm = signed_update_total_norm**2.0
370
- # Reduce all together to avoid multiple communication calls.
371
- all_together = torch.stack([update_total_dot_prod, update_total_norm, signed_update_total_norm])
372
- # Only need the final result on rank0, since that's where we log from.
373
- dist.reduce(all_together, 0)
374
- update_total_dot_prod, update_total_norm, signed_update_total_norm = all_together
375
- update_total_norm = update_total_norm**0.5
376
- signed_update_total_norm = signed_update_total_norm**0.5
377
-
378
- update_cos_sim = update_total_dot_prod / torch.max(
379
- update_total_norm * signed_update_total_norm, torch.tensor(1e-8, device=get_default_device())
380
- )
381
- return {"update_cos_sim": update_cos_sim}
382
-
383
- @torch.no_grad()
384
- def step(self, closure=None) -> None:
385
- if closure is not None:
386
- with torch.enable_grad():
387
- closure()
388
-
389
- update_total_dot_prod = torch.tensor(0.0, dtype=torch.float32)
390
- update_norms = []
391
- signed_update_norms = []
392
-
393
- for group in self.param_groups:
394
- for p in group["params"]:
395
- if p.grad is None:
396
- continue
397
-
398
- # Perform step weight decay
399
- p.data.mul_(1 - group["lr"] * group["weight_decay"])
400
-
401
- grad = p.grad
402
- state = self.state[p]
403
-
404
- # State initialization
405
- if len(state) == 0:
406
- # Exponential moving average of gradient values
407
- state["exp_avg"] = torch.zeros_like(p)
408
-
409
- exp_avg = state["exp_avg"]
410
- beta1, beta2 = group["betas"]
411
-
412
- # Weight update
413
- update = exp_avg * beta1 + grad * (1 - beta1)
414
- signed_update = torch.sign(update)
415
- p.add_(signed_update, alpha=-group["lr"])
416
-
417
- # Decay the momentum running average coefficient
418
- exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
419
-
420
- # Track dot product and norms of update vs signed update in order to calculate
421
- # their cosine similarity.
422
- update_total_dot_prod = update_total_dot_prod.to(update.device)
423
- update_total_dot_prod += torch.tensordot(update, signed_update, dims=len(update.shape))
424
- update_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32))
425
- signed_update_norms.append(torch.linalg.vector_norm(signed_update, 2.0, dtype=torch.float32))
426
-
427
- # Compute cosine similarity between update and signed update.
428
- self._update_total_dot_prod = update_total_dot_prod.to(get_default_device())
429
- self._update_total_norm = torch.linalg.vector_norm(
430
- torch.stack(update_norms),
431
- 2.0,
432
- dtype=torch.float32,
433
- ).to(get_default_device())
434
- self._signed_update_total_norm = torch.linalg.vector_norm(
435
- torch.stack(signed_update_norms),
436
- 2.0,
437
- dtype=torch.float32,
438
- ).to(get_default_device())
439
-
440
-
441
- class AdamW(torch.optim.AdamW, Optimizer):
442
- def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
443
- return {key: self.state[param].get(key) for key in ("exp_avg", "exp_avg_sq")} # type: ignore
444
-
445
-
446
- @dataclass
447
- class Scheduler(metaclass=ABCMeta):
448
- # NOTE: these fields are not given default values because otherwise dataclasses complains
449
- # about how the scheduler subclasses are defined.
450
- grad_clip_warmup_steps: Optional[int]
451
- grad_clip_warmup_factor: Optional[float]
452
-
453
- @abstractmethod
454
- def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
455
- raise NotImplementedError
456
-
457
- def _get_max_grad_norm_coeff(
458
- self, initial_value: Optional[float], step: int, max_steps: int
459
- ) -> Optional[float]:
460
- del max_steps # might need this in the future, but for now I just wanted to match the API of `get_lr()`.
461
- if initial_value is None:
462
- return None
463
- elif (
464
- self.grad_clip_warmup_steps is None
465
- or self.grad_clip_warmup_factor is None
466
- or step > self.grad_clip_warmup_steps
467
- ):
468
- return initial_value
469
- else:
470
- return self.grad_clip_warmup_factor * initial_value
471
-
472
- def get_max_grad_norm(
473
- self, initial_max_grad_norm: Optional[float], step: int, max_steps: int
474
- ) -> Optional[float]:
475
- return self._get_max_grad_norm_coeff(initial_max_grad_norm, step, max_steps)
476
-
477
- def get_max_grad_norm_ratio(
478
- self, initial_max_grad_norm_ratio: Optional[float], step: int, max_steps: int
479
- ) -> Optional[float]:
480
- return self._get_max_grad_norm_coeff(initial_max_grad_norm_ratio, step, max_steps)
481
-
482
- def _linear_warmup(self, initial_lr: float, step: int, warmup_steps: int = 2000) -> float:
483
- return initial_lr * (0.1 + 0.9 * min(step, warmup_steps) / warmup_steps)
484
-
485
-
486
- @dataclass
487
- class CosWithWarmup(Scheduler):
488
- warmup_steps: int
489
- alpha_f: float = 0.1
490
- t_max: Optional[int] = None
491
-
492
- def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
493
- max_steps = max_steps if self.t_max is None else self.t_max
494
- eta_min = initial_lr * self.alpha_f
495
- if step < self.warmup_steps:
496
- return self._linear_warmup(initial_lr, step, self.warmup_steps)
497
- elif step >= max_steps:
498
- return eta_min
499
- else:
500
- step = step - self.warmup_steps
501
- max_steps = max_steps - self.warmup_steps
502
- return eta_min + (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2
503
-
504
-
505
- @dataclass
506
- class LinearWithWarmup(Scheduler):
507
- warmup_steps: int
508
- alpha_f: float = 0.1
509
- t_max: Optional[int] = None
510
-
511
- def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
512
- max_steps = max_steps if self.t_max is None else self.t_max
513
- eta_min = initial_lr * self.alpha_f
514
- if step < self.warmup_steps:
515
- return self._linear_warmup(initial_lr, step, self.warmup_steps)
516
- elif step >= max_steps:
517
- return eta_min
518
- else:
519
- step = step - self.warmup_steps
520
- max_steps = max_steps - self.warmup_steps
521
- return initial_lr - (initial_lr - eta_min) * (step / max_steps)
522
-
523
-
524
- @dataclass
525
- class InvSqrtWithWarmup(Scheduler):
526
- warmup_steps: int
527
-
528
- def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
529
- if step < self.warmup_steps:
530
- return self._linear_warmup(initial_lr, step, self.warmup_steps)
531
- del max_steps
532
- return initial_lr * sqrt(self.warmup_steps / max(self.warmup_steps, step))
533
-
534
-
535
- @dataclass
536
- class MaxScheduler(Scheduler):
537
- sched1: Scheduler
538
- sched2: Scheduler
539
-
540
- def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
541
- return max(
542
- self.sched1.get_lr(initial_lr, step, max_steps), self.sched2.get_lr(initial_lr, step, max_steps)
543
- )
544
-
545
-
546
- @dataclass
547
- class BoltOnWarmupScheduler(Scheduler):
548
- inner: Scheduler
549
- warmup_start: int
550
- warmup_end: int
551
-
552
- @classmethod
553
- def wrap(cls, scheduler: Scheduler, warmup_start: int, warmup_end: int) -> "BoltOnWarmupScheduler":
554
- return cls(
555
- grad_clip_warmup_steps=None,
556
- grad_clip_warmup_factor=None,
557
- inner=scheduler,
558
- warmup_start=warmup_start,
559
- warmup_end=warmup_end,
560
- )
561
-
562
- def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
563
- if step < self.warmup_start:
564
- return 0.0
565
- if step < self.warmup_end:
566
- lr_at_intercept = self.inner.get_lr(initial_lr, self.warmup_end, max_steps)
567
- return lr_at_intercept * (step - self.warmup_start) / (self.warmup_end - self.warmup_start)
568
- else:
569
- return self.inner.get_lr(initial_lr, step, max_steps)
570
-
571
- def _get_max_grad_norm_coeff(
572
- self, initial_value: Optional[float], step: int, max_steps: int
573
- ) -> Optional[float]:
574
- return self.inner._get_max_grad_norm_coeff(initial_value, step, max_steps)
575
-
576
-
577
- @dataclass
578
- class ConstantScheduler(Scheduler):
579
- def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
580
- del step, max_steps
581
- return initial_lr
582
-
583
-
584
- PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names")
585
-
586
-
587
- def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]:
588
- """
589
- Separate parameters into weight decay and non weight decay groups.
590
- """
591
- param_groups: List[Dict[str, Any]]
592
- param_group_defaults = {
593
- "sharded": isinstance(model, FullyShardedDataParallel),
594
- "max_grad_norm": cfg.max_grad_norm,
595
- "max_grad_norm_ratio": cfg.max_grad_norm_ratio,
596
- }
597
-
598
- # Separate out parameters that we don't want to apply weight decay to, like norms and biases.
599
- decay = set()
600
- no_decay = set()
601
- all_params = {}
602
- for mn, m in model.named_modules():
603
- for pn, p in m.named_parameters():
604
- # NOTE: because named_modules and named_parameters are recursive
605
- # we will see the same tensors p many many times, but doing it this way
606
- # allows us to know which parent module any tensor p belongs to...
607
- if not p.requires_grad:
608
- continue
609
-
610
- fpn = f"{mn}.{pn}" if mn else pn
611
- all_params[fpn] = p
612
-
613
- if pn.endswith("bias"):
614
- if cfg.optimizer.decay_norm_and_bias:
615
- decay.add(fpn)
616
- else:
617
- no_decay.add(fpn)
618
- elif pn.endswith("weight") and (isinstance(m, nn.Linear) or isinstance(m, BitLinear158)):
619
- decay.add(fpn)
620
- elif pn.endswith("weight") and isinstance(m, (LayerNormBase, nn.LayerNorm)):
621
- if cfg.optimizer.decay_norm_and_bias:
622
- decay.add(fpn)
623
- else:
624
- no_decay.add(fpn)
625
- elif pn.endswith("weight") and isinstance(m, nn.Embedding):
626
- if cfg.optimizer.decay_embeddings:
627
- decay.add(fpn)
628
- else:
629
- no_decay.add(fpn)
630
-
631
- # Validate that we've considered every parameter
632
- inter_params = decay & no_decay
633
- union_params = decay | no_decay
634
- assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!"
635
- assert (
636
- len(all_params.keys() - union_params) == 0
637
- ), f"parameters {all_params.keys() - union_params} were not separated into either decay/no_decay set!"
638
-
639
- # Create the pytorch optimizer groups.
640
- decay_sorted = sorted(list(decay))
641
- no_decay_sorted = sorted(list(no_decay))
642
- param_groups = []
643
- if len(decay_sorted) > 0:
644
- param_groups.append(
645
- {
646
- "params": [all_params[pn] for pn in decay_sorted],
647
- "param_names": decay_sorted,
648
- **param_group_defaults,
649
- }
650
- )
651
- if len(no_decay_sorted) > 0:
652
- param_groups.append(
653
- {
654
- "params": [all_params[pn] for pn in no_decay_sorted],
655
- "param_names": no_decay_sorted,
656
- "weight_decay": 0.0,
657
- **param_group_defaults,
658
- }
659
- )
660
-
661
- # Validate fields.
662
- for group in param_groups:
663
- for key in PARAM_GROUP_FIELDS:
664
- assert key in group
665
-
666
- return param_groups
667
-
668
-
669
- def fix_optim_state_dict(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
670
- """
671
- Make sure old optim state dicts are compatible with new versions.
672
- """
673
- if len(state_dict["param_groups"]) == 1 and len(optimizer.param_groups) == 2:
674
- assert optimizer.param_groups[1]["weight_decay"] == 0.0
675
-
676
- # Decay
677
- decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"}
678
- decay_param_group["params"] = optimizer.state_dict()["param_groups"][0]["params"]
679
-
680
- # No decay.
681
- no_decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"}
682
- no_decay_param_group["weight_decay"] = 0.0
683
- no_decay_param_group["params"] = optimizer.state_dict()["param_groups"][1]["params"]
684
-
685
- state_dict["param_groups"] = [decay_param_group, no_decay_param_group]
686
-
687
- assert len(optimizer.param_groups) == len(state_dict["param_groups"])
688
-
689
- # Make sure:
690
- # - All required fields are included in the state dict,
691
- # - And that the values of those fields doesn't change from what's currently set in the optimizer,
692
- # since we might have changed those fields on purpose after a restart.
693
- for group, sd_group in zip(optimizer.param_groups, state_dict["param_groups"]):
694
- for key in PARAM_GROUP_FIELDS:
695
- sd_group[key] = group[key]
696
-
697
- return state_dict
698
-
699
-
700
- def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
701
- param_groups = get_param_groups(cfg, model)
702
- log.info(f"Constructing optimizer with {len(param_groups)} param groups")
703
- if cfg.optimizer.name == OptimizerType.lionw:
704
- return LionW(
705
- param_groups,
706
- lr=cfg.optimizer.learning_rate,
707
- betas=cfg.optimizer.betas,
708
- weight_decay=cfg.optimizer.weight_decay,
709
- )
710
- elif cfg.optimizer.name == OptimizerType.adamw:
711
- return AdamW(
712
- param_groups,
713
- lr=cfg.optimizer.learning_rate,
714
- betas=cfg.optimizer.betas,
715
- weight_decay=cfg.optimizer.weight_decay,
716
- eps=1e-5,
717
- )
718
- else:
719
- raise NotImplementedError
720
-
721
-
722
- def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = None) -> Scheduler:
723
- sched_cfg = sched_cfg if sched_cfg is not None else cfg.scheduler
724
- if sched_cfg.name == SchedulerType.cosine_with_warmup:
725
- return CosWithWarmup(
726
- grad_clip_warmup_steps=None
727
- if sched_cfg.grad_clip_warmup_steps is None
728
- else int(sched_cfg.grad_clip_warmup_steps),
729
- grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
730
- warmup_steps=int(sched_cfg.t_warmup),
731
- alpha_f=sched_cfg.alpha_f,
732
- t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
733
- )
734
- elif sched_cfg.name == SchedulerType.linear_with_warmup:
735
- return LinearWithWarmup(
736
- grad_clip_warmup_steps=None
737
- if sched_cfg.grad_clip_warmup_steps is None
738
- else int(sched_cfg.grad_clip_warmup_steps),
739
- grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
740
- warmup_steps=int(sched_cfg.t_warmup),
741
- alpha_f=sched_cfg.alpha_f,
742
- t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
743
- )
744
- elif sched_cfg.name == SchedulerType.inverse_sqrt_with_warmup:
745
- return InvSqrtWithWarmup(
746
- grad_clip_warmup_steps=None
747
- if sched_cfg.grad_clip_warmup_steps is None
748
- else int(sched_cfg.grad_clip_warmup_steps),
749
- grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
750
- warmup_steps=int(sched_cfg.t_warmup),
751
- )
752
- elif sched_cfg.name == SchedulerType.max_scheduler:
753
- return MaxScheduler(
754
- grad_clip_warmup_steps=None
755
- if sched_cfg.grad_clip_warmup_steps is None
756
- else int(sched_cfg.grad_clip_warmup_steps),
757
- grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
758
- sched1=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.cosine_with_warmup)),
759
- sched2=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.inverse_sqrt_with_warmup)),
760
- )
761
- elif sched_cfg.name == SchedulerType.constant:
762
- return ConstantScheduler(
763
- grad_clip_warmup_steps=None
764
- if sched_cfg.grad_clip_warmup_steps is None
765
- else int(sched_cfg.grad_clip_warmup_steps),
766
- grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
767
- )
768
- else:
769
- raise NotImplementedError