English
naveensp commited on
Commit
df47e60
·
verified ·
1 Parent(s): 566a22e

Delete config.py

Browse files
Files changed (1) hide show
  1. config.py +0 -1106
config.py DELETED
@@ -1,1106 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import asdict, dataclass, field
4
- from glob import glob
5
- from pathlib import Path
6
- from typing import (
7
- Any,
8
- Dict,
9
- Iterable,
10
- List,
11
- Optional,
12
- Tuple,
13
- Type,
14
- TypeVar,
15
- Union,
16
- cast,
17
- )
18
-
19
- import torch
20
- from omegaconf import DictConfig, ListConfig
21
- from omegaconf import OmegaConf as om
22
- from omegaconf.errors import OmegaConfBaseException
23
- from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
24
-
25
- from .aliases import PathOrStr
26
- from .beam_search import Sampler
27
- from .exceptions import OLMoConfigurationError
28
- from .util import StrEnum
29
-
30
- __all__ = [
31
- "ActivationType",
32
- "ActivationCheckpointingStrategy",
33
- "BlockType",
34
- "LayerNormType",
35
- "InitFnType",
36
- "ModelConfig",
37
- "OptimizerType",
38
- "OptimizerConfig",
39
- "SchedulerType",
40
- "SchedulerConfig",
41
- "DataConfig",
42
- "EvaluatorConfig",
43
- "TokenizerConfig",
44
- "TrainConfig",
45
- "PaddingDirection",
46
- "TruncationDirection",
47
- "SpeedMonitorConfig",
48
- "WandbConfig",
49
- "CompilerConfig",
50
- "WandbConfig",
51
- "FSDPPrecision",
52
- "FSDPWrapStrategy",
53
- "FSDPConfig",
54
- "CheckpointType",
55
- ]
56
-
57
- C = TypeVar("C", bound="BaseConfig")
58
- D = TypeVar("D", bound="DictConfig|ListConfig")
59
-
60
-
61
- class BaseConfig:
62
- @classmethod
63
- def _register_resolvers(cls, validate_paths: bool = True):
64
- # Expands path globs into a list.
65
- def path_glob(*paths) -> List[str]:
66
- out = []
67
- for path in paths:
68
- matches = sorted(glob(path))
69
- if not matches and validate_paths:
70
- raise FileNotFoundError(f"{path} does not match any files or dirs")
71
- out.extend(matches)
72
- return out
73
-
74
- # Chooses the first path in the arguments that exists.
75
- def path_choose(*paths) -> str:
76
- from .util import is_url
77
-
78
- for path in paths:
79
- if is_url(path) or Path(path).exists():
80
- return path
81
- if validate_paths:
82
- raise FileNotFoundError(", ".join(paths))
83
- else:
84
- return ""
85
-
86
- # Finds the latest checkpoint in a folder.
87
- def path_last_checkpoint(path) -> str:
88
- from .util import find_latest_checkpoint
89
-
90
- latest_checkpoint = find_latest_checkpoint(path)
91
- if latest_checkpoint is None:
92
- if validate_paths:
93
- raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
94
- else:
95
- return ""
96
- else:
97
- return str(latest_checkpoint)
98
-
99
- om.register_new_resolver("path.glob", path_glob, replace=True)
100
- om.register_new_resolver("path.choose", path_choose, replace=True)
101
- om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
102
-
103
- @classmethod
104
- def update_legacy_settings(cls, config: D) -> D:
105
- """
106
- Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
107
- """
108
- return config
109
-
110
- @classmethod
111
- def new(cls: Type[C], **kwargs) -> C:
112
- cls._register_resolvers()
113
- conf = om.structured(cls)
114
- try:
115
- if kwargs:
116
- conf = om.merge(conf, kwargs)
117
- return cast(C, om.to_object(conf))
118
- except OmegaConfBaseException as e:
119
- raise OLMoConfigurationError(str(e))
120
-
121
- @classmethod
122
- def load(
123
- cls: Type[C],
124
- path: PathOrStr,
125
- overrides: Optional[List[str]] = None,
126
- key: Optional[str] = None,
127
- validate_paths: bool = True,
128
- ) -> C:
129
- """Load from a YAML file."""
130
- cls._register_resolvers(validate_paths=validate_paths)
131
- schema = om.structured(cls)
132
- try:
133
- raw = om.load(str(path))
134
- if key is not None:
135
- raw = raw[key] # type: ignore
136
- raw = cls.update_legacy_settings(raw)
137
- conf = om.merge(schema, raw)
138
- if overrides:
139
- conf = om.merge(conf, om.from_dotlist(overrides))
140
- return cast(C, om.to_object(conf))
141
- except OmegaConfBaseException as e:
142
- raise OLMoConfigurationError(str(e))
143
-
144
- def save(self, path: PathOrStr) -> None:
145
- """Save to a YAML file."""
146
- om.save(config=self, f=str(path))
147
-
148
- def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
149
- out = asdict(self) # type: ignore
150
- if exclude is not None:
151
- for name in exclude:
152
- if name in out:
153
- del out[name]
154
- return out
155
-
156
-
157
- class LayerNormType(StrEnum):
158
- default = "default"
159
- """
160
- The default LayerNorm implementation, equivalent to PyTorch's built-in version.
161
- """
162
-
163
- low_precision = "low_precision"
164
- """
165
- A low-precision version of the default LayerNorm.
166
- """
167
-
168
- rms = "rms"
169
- """
170
- An RMSNorm implementation. When using ``torch.compile`` this is
171
- probably the fastest implementation.
172
- """
173
-
174
-
175
- class ActivationType(StrEnum):
176
- gelu = "gelu"
177
- relu = "relu"
178
- swiglu = "swiglu"
179
-
180
-
181
- class BlockType(StrEnum):
182
- sequential = "sequential"
183
-
184
- llama = "llama"
185
- """
186
- A block similar to the sequential block with slightly different
187
- implementations of operations like attention to imitate the behavior of Llama.
188
- """
189
-
190
-
191
- class InitFnType(StrEnum):
192
- mitchell = "mitchell"
193
- """
194
- The strategy suggested to us by Mitchell Wortsman from UW.
195
- This uses a truncated normal distribution with an adaptive standard deviation that depends
196
- on the size of the weights as well as the depth of the layer.
197
- """
198
-
199
- normal = "normal"
200
- """
201
- All weights are initialized from the same normal distribution.
202
- """
203
-
204
- kaiming_normal = "kaiming_normal"
205
- """
206
- All weights are initialized with the Kaiming method from a normal distribution.
207
- Note this currently won't work with FSDP.
208
- """
209
-
210
- fan_in = "fan_in"
211
- """
212
- "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
213
- is the input dimensionality of the kernel.
214
- """
215
-
216
- full_megatron = "full_megatron"
217
- """
218
- This is what metaseq calls "full megatron init". It is the init used for Llama 2.
219
- """
220
-
221
-
222
- @dataclass
223
- class ModelConfig(BaseConfig):
224
- """
225
- OLMo (model) configuration.
226
- """
227
-
228
- # Note that the defaults for these attributes are equivalent to the base GPT2 model.
229
-
230
- d_model: int = 768
231
- """
232
- The hidden size of the model.
233
- """
234
-
235
- n_heads: int = 12
236
- """
237
- The number of self-attention heads.
238
- """
239
-
240
- n_kv_heads: Optional[int] = None
241
- """
242
- The number of heads to use for keys and values. Defaults to `n_heads`.
243
- Set this to ``None`` or ``n_heads`` for normal multi-head attention.
244
- Set this to 1 for multi-query attention.
245
- Set it to some in-between value for Llama2-style grouped query attention.
246
- """
247
-
248
- clip_qkv: Optional[float] = None
249
- """
250
- Clip QKV to this value when set.
251
- """
252
-
253
- n_layers: int = 12
254
- """
255
- The number of layers/blocks.
256
- """
257
-
258
- mlp_ratio: int = 4
259
- """
260
- The ratio of the inner MLP dimensionality to ``d_model``.
261
- This is only used when ``mlp_hidden_size`` is not set.
262
- """
263
-
264
- mlp_hidden_size: Optional[int] = None
265
- """
266
- Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
267
- """
268
-
269
- activation_type: ActivationType = ActivationType.swiglu
270
- """
271
- The activation function to use within the MLP layers.
272
- """
273
-
274
- block_type: BlockType = BlockType.sequential
275
- """
276
- The transformer block implementation.
277
- """
278
-
279
- block_group_size: int = 1
280
- """
281
- The number of blocks to group together into a single parent block.
282
- This has no affect on the number of parameters in the model and is only used to wrap groups
283
- of blocks together with a single FSDP wrapper during training.
284
- """
285
-
286
- alibi: bool = False
287
- """
288
- If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
289
- """
290
-
291
- alibi_bias_max: float = 8.0
292
- """
293
- Maximum absolute value of ALiBi bias.
294
- """
295
-
296
- rope: bool = False
297
- """
298
- Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
299
- """
300
-
301
- rope_full_precision: bool = True
302
- """
303
- If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
304
- apply RoPE at the precision of the input.
305
- """
306
-
307
- flash_attention: bool = False
308
- """
309
- If ``True``, use ``FlashAttention``.
310
- """
311
-
312
- attention_dropout: float = 0.1
313
- """
314
- The dropout probability within the attention modules.
315
- """
316
-
317
- multi_query_attention: Optional[bool] = None
318
- """
319
- Deprecated. Use n_kv_heads instead.
320
- """
321
-
322
- attention_layer_norm: bool = False
323
- """
324
- Apply layer norm to the keys and queries within the attention mechanism.
325
- This can help stabilize training.
326
- """
327
-
328
- residual_dropout: float = 0.1
329
- """
330
- The dropout probability for the MLP and attention output within each block.
331
- """
332
-
333
- embedding_dropout: float = 0.1
334
- """
335
- The dropout probability for embeddings.
336
- """
337
-
338
- layer_norm_type: LayerNormType = LayerNormType.default
339
- """
340
- The layernorm implementation to use.
341
- """
342
-
343
- layer_norm_with_affine: bool = True
344
- """
345
- Whether to include bias and weight parameters for the layer norms.
346
- This only affects layer norms that are immediately followed by a linear layer in the forward pass,
347
- so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
348
- to ``False``.
349
- """
350
-
351
- attention_layer_norm_with_affine: bool = True
352
- """
353
- Toggle affine transform for the QK norms.
354
- """
355
-
356
- max_sequence_length: int = 1024
357
- """
358
- The maximum input sequence length supported by the model.
359
- """
360
-
361
- include_bias: bool = True
362
- """
363
- Whether or not to include bias parameters in linear layers.
364
- In PaLM, they got rid of all bias terms because they found that large
365
- models tend to have near 0 bias terms anyway.
366
- """
367
-
368
- bias_for_layer_norm: Optional[bool] = None
369
- """
370
- Whether or not to include bias parameters in layer norm.
371
- This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
372
- layer norm.
373
- When this is None (the default), it inherits the setting from include_bias.
374
- """
375
-
376
- scale_logits: bool = False
377
- """
378
- If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
379
- """
380
-
381
- vocab_size: int = 50257
382
- """
383
- Vocabulary size of the model.
384
- """
385
-
386
- embedding_size: Optional[int] = 50304
387
- """
388
- The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
389
- to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
390
- next multiple of 128 that's greater than ``vocab_size`` can improve throughput
391
- substantially.
392
- """
393
-
394
- weight_tying: bool = True
395
- """
396
- Whether to tie output linear weights to the input embedding.
397
- """
398
-
399
- eos_token_id: int = 50256
400
- """
401
- The ID of the end-of-sentence special token.
402
- """
403
-
404
- pad_token_id: int = 50256
405
- """
406
- The ID of the token to use for padding. Defaults to the ID of the EOS token.
407
- """
408
-
409
- init_device: Optional[str] = None
410
- """
411
- The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
412
- """
413
-
414
- init_fn: InitFnType = InitFnType.normal
415
- """
416
- The weight initialization strategy.
417
- """
418
-
419
- init_std: float = 0.02
420
- """
421
- The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
422
- as "normal".
423
- """
424
-
425
- init_cutoff_factor: Optional[float] = None
426
- """
427
- A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
428
- as "normal". Setting this to None means values are not cutoff.
429
- """
430
-
431
- precision: Optional[str] = None
432
- """
433
- Precision used to train/evaluate with. You shouldn't set this directly.
434
- See :data:`TrainConfig.precision` instead.
435
- """
436
-
437
- ternary: bool = False
438
- """
439
- Use ternary BitLinear layer from "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" (https://arxiv.org/pdf/2402.17764.pdf)
440
- """
441
-
442
- @property
443
- def effective_n_kv_heads(self) -> int:
444
- if self.n_kv_heads is None:
445
- if self.multi_query_attention is True:
446
- return 1
447
- else:
448
- return self.n_heads
449
- else:
450
- if self.multi_query_attention is None:
451
- return self.n_kv_heads
452
- if self.multi_query_attention:
453
- n_kv_heads_should_be = 1
454
- else:
455
- n_kv_heads_should_be = self.n_heads
456
- if self.n_kv_heads == n_kv_heads_should_be:
457
- return n_kv_heads_should_be
458
- else:
459
- raise OLMoConfigurationError(
460
- "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
461
- )
462
-
463
-
464
- class OptimizerType(StrEnum):
465
- lionw = "lionw"
466
- adamw = "adamw"
467
-
468
-
469
- @dataclass
470
- class OptimizerConfig(BaseConfig):
471
- name: OptimizerType = OptimizerType.lionw
472
- learning_rate: float = 1.0e-4
473
- weight_decay: float = 0.01
474
- betas: Tuple[float, float] = (0.9, 0.95)
475
-
476
- no_decay_norm_and_bias: Optional[bool] = None
477
- """
478
- Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead.
479
- """
480
-
481
- decay_norm_and_bias: bool = False
482
- decay_embeddings: bool = False
483
- metrics_log_interval: Optional[int] = None
484
- """
485
- The interval with which to collect and log detailed parameter-specific metrics.
486
- This only applies when logging to W&B, since these metrics won't be logged to the console.
487
- If not set, defaults to the wandb `log_interval`.
488
- """
489
-
490
- def __post_init__(self):
491
- self.betas = tuple(self.betas) # type: ignore[assignment]
492
-
493
- @classmethod
494
- def update_legacy_settings(cls, config: D) -> D:
495
- new_config = config.copy()
496
- if om.is_dict(new_config):
497
- assert isinstance(new_config, DictConfig)
498
-
499
- if hasattr(new_config, "name") and new_config.name == "decoupled_lionw":
500
- new_config.name = "lionw"
501
- if hasattr(new_config, "eps"):
502
- del new_config.eps
503
-
504
- return new_config
505
-
506
-
507
- class SchedulerType(StrEnum):
508
- cosine_with_warmup = "cosine_with_warmup"
509
- linear_with_warmup = "linear_with_warmup"
510
- inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
511
- max_scheduler = "max_scheduler"
512
- constant = "constant"
513
-
514
-
515
- class SchedulerUnits(StrEnum):
516
- steps = "steps"
517
- tokens = "tokens"
518
-
519
-
520
- @dataclass
521
- class SchedulerConfig(BaseConfig):
522
- name: SchedulerType = SchedulerType.cosine_with_warmup
523
- units: SchedulerUnits = SchedulerUnits.steps
524
- t_warmup: Union[int, float] = 100
525
- t_max: Optional[Union[int, float]] = None
526
- alpha_f: float = 0.1
527
-
528
- grad_clip_warmup_steps: Optional[Union[int, float]] = None
529
- """
530
- The warmup period for which the max grad norm (or norm ratio) will be set to its
531
- warmup value of `max_grad_norm * grad_clip_warmup_factor`.
532
- """
533
-
534
- grad_clip_warmup_factor: Optional[float] = None
535
- """
536
- The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period
537
- vs after the warmup period.
538
- """
539
-
540
-
541
- class PaddingDirection(StrEnum):
542
- right = "right"
543
- left = "left"
544
-
545
-
546
- @dataclass
547
- class DataConfig(BaseConfig):
548
- paths: Optional[List[str]] = None
549
- datasets: Optional[Dict[str, List[str]]] = None
550
- label_mask_paths: Optional[List[str]] = None
551
- pad_direction: PaddingDirection = PaddingDirection.right
552
- generate_attention_mask: bool = False
553
- num_workers: int = 0
554
- drop_last: bool = False
555
- pin_memory: bool = False
556
- prefetch_factor: Optional[int] = None
557
- persistent_workers: bool = False
558
- timeout: int = 0
559
- seed: Optional[int] = None
560
-
561
-
562
- class EvaluatorType(StrEnum):
563
- downstream = "downstream"
564
- lm = "lm"
565
-
566
-
567
- @dataclass
568
- class EvaluatorConfig(BaseConfig):
569
- label: str
570
- type: EvaluatorType = EvaluatorType.lm
571
- data: DataConfig = field(default_factory=DataConfig)
572
- device_eval_batch_size: Optional[int] = None
573
- subset_num_batches: Optional[int] = None
574
-
575
-
576
- class TruncationDirection(StrEnum):
577
- right = "right"
578
- left = "left"
579
-
580
-
581
- @dataclass
582
- class TokenizerConfig(BaseConfig):
583
- identifier: str = "gpt2"
584
- truncate_direction: TruncationDirection = TruncationDirection.right
585
-
586
-
587
- @dataclass
588
- class WandbConfig(BaseConfig):
589
- project: Optional[str] = None
590
- entity: Optional[str] = "ai2-llm"
591
- group: Optional[str] = None
592
- name: Optional[str] = None
593
- tags: Optional[List[str]] = field(default_factory=lambda: ["watching"])
594
- log_artifacts: bool = False
595
- rank_zero_only: bool = True
596
- log_interval: int = 1
597
-
598
-
599
- @dataclass
600
- class SpeedMonitorConfig(BaseConfig):
601
- window_size: int = 100
602
- gpu_flops_available: Optional[Union[float, int]] = None
603
-
604
-
605
- @dataclass
606
- class CompilerConfig(BaseConfig):
607
- mode: Optional[str] = None
608
- """
609
- The mode to compile the model in. At the moment this can be "default",
610
- "reduce-overhead" (useful for smaller models/batches), or "max-autotune"
611
- (the fastest for larger models, but takes a long time to compile).
612
- """
613
-
614
- fullgraph: bool = False
615
- """
616
- Whether it is OK to break model into several subgraphs when compiling.
617
- Note that this is not compatible with FSDP.
618
- """
619
-
620
- backend: str = "inductor"
621
- """
622
- The backend to use.
623
- """
624
-
625
-
626
- class FSDPWrapStrategy(StrEnum):
627
- by_block = "by_block"
628
- """
629
- Wrap each OLMo block with its own FSDP instance.
630
- """
631
-
632
- by_block_and_size = "by_block_and_size"
633
- """
634
- Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well.
635
- """
636
-
637
- by_block_group = "by_block_group"
638
- """
639
- Wrap each block group together into its own FSDP instance.
640
- This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
641
- """
642
-
643
- by_block_group_and_size = "by_block_group_and_size"
644
- """
645
- Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well.
646
- """
647
-
648
- size_based = "size_based"
649
- """
650
- Used PyTorch's default size-based auto wrap policy.
651
- """
652
-
653
- one_in_two = "one_in_two"
654
- one_in_three = "one_in_three"
655
- one_in_four = "one_in_four"
656
- one_in_five = "one_in_five"
657
-
658
-
659
- class FSDPPrecision(StrEnum):
660
- pure = "pure"
661
- """
662
- Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``,
663
- and ``buffer_dtype`` all set to the autocast precision data type.
664
- """
665
-
666
- mixed = "mixed"
667
- """
668
- Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype``
669
- set to the autocast precision data type, while ``reduce_dtype`` is set to fp32.
670
- """
671
-
672
-
673
- @dataclass
674
- class FSDPConfig(BaseConfig):
675
- use_orig_params: bool = True
676
- """
677
- This must be ``True`` if using ``compile`` or you want to track the parameter norm during training.
678
- """
679
-
680
- sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
681
-
682
- wrapping_strategy: Optional[FSDPWrapStrategy] = None
683
- """
684
- The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level
685
- FSDP instance.
686
- """
687
-
688
- precision: FSDPPrecision = FSDPPrecision.pure
689
-
690
-
691
- class CheckpointType(StrEnum):
692
- sharded = "sharded"
693
- unsharded = "unsharded"
694
- sharded_ephemeral = "sharded_ephemeral"
695
-
696
-
697
- class ShardedCheckpointerType(StrEnum):
698
- torch_new = "torch_new"
699
- torch_legacy = "torch_legacy"
700
- local = "local"
701
-
702
-
703
- class ActivationCheckpointingStrategy(StrEnum):
704
- whole_layer = "whole_layer"
705
- """
706
- Checkpoint every transformer layer.
707
- """
708
-
709
- one_in_two = "one_in_two"
710
- """
711
- Checkpoint one in two transformer layers.
712
- """
713
-
714
- one_in_three = "one_in_three"
715
- """
716
- Checkpoint one in three transformer layers.
717
- """
718
-
719
- one_in_four = "one_in_four"
720
- """
721
- Checkpoint one in four transformer layers.
722
- """
723
-
724
- two_in_three = "two_in_three"
725
- """
726
- Checkpoint two out of every three transformer layers.
727
- """
728
-
729
- three_in_four = "three_in_four"
730
- """
731
- Checkpoint three out of four of every transformer layers.
732
- """
733
-
734
- fine_grained = "fine_grained"
735
- """
736
- Focus checkpointing on where it is cheap to recompute and saves most memory.
737
- """
738
-
739
-
740
- @dataclass
741
- class TrainConfig(BaseConfig):
742
- """
743
- OLMo training configuration.
744
- """
745
-
746
- run_name: Optional[str] = None
747
- """
748
- The name of the run.
749
- """
750
-
751
- seed: int = 6198
752
- """
753
- Used to seed all initial RNG states.
754
- """
755
-
756
- epoch: Optional[int] = None
757
- """
758
- Increment this when starting a new epoch.
759
- """
760
-
761
- dry_run: bool = False
762
- """
763
- If ``True``, don't actually train.
764
- """
765
-
766
- model: ModelConfig = field(default_factory=ModelConfig)
767
- """
768
- OLMo Model configuration.
769
- """
770
-
771
- optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
772
- """
773
- Optimizer configuration.
774
- """
775
-
776
- scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
777
- """
778
- Learning rate scheduler configuration.
779
- """
780
-
781
- data: DataConfig = field(default_factory=DataConfig)
782
- """
783
- Training data configuration.
784
- """
785
-
786
- restore_dataloader: bool = True
787
- """
788
- When restarting, restore the data loader to where it left off.
789
- If you restarting in order to train on a different dataset, set this to ``False``.
790
- """
791
-
792
- fast_forward_batches: Optional[int] = None
793
- """
794
- When restarting, use this to fast-forward the dataloader beyond the last checkpoint.
795
- This can be useful when restarting due to a loss spike in order to skip the data that
796
- corresponded to the spike.
797
- """
798
-
799
- evaluators: List[EvaluatorConfig] = field(default_factory=list)
800
- """
801
- Evaluation configurations.
802
- """
803
-
804
- eval_interval: int = 1000
805
- """
806
- How often (in terms of batches) to run evaluations.
807
- """
808
-
809
- tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
810
- """
811
- Tokenizer configuration.
812
- """
813
-
814
- save_folder: str = "./"
815
- """
816
- The directory to save checkpoints to.
817
- """
818
-
819
- remote_save_folder: Optional[str] = None
820
- """
821
- A folder in a cloud bucket to upload saved checkpoints to.
822
- """
823
-
824
- canceled_check_interval: int = 50
825
- """
826
- How often (in batches) to check if the run has been canceled or reached its time limit.
827
- """
828
-
829
- save_interval: int = 1000
830
- """
831
- How often (in terms of steps) to save sharded training state checkpoints.
832
- """
833
-
834
- save_interval_unsharded: Optional[int] = None
835
- """
836
- How often (if at all) to save unsharded training state checkpoint.
837
- For large models it can be costly to save these, so it usually makes sense to save
838
- these less often than regular (sharded) training checkpoints.
839
- """
840
-
841
- save_interval_ephemeral: Optional[int] = None
842
- """
843
- How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
844
- as those saved every `save_interval` except that at most only the most recent one of these is kept.
845
- This is useful when you want to checkpoint often for restarts in case of failures, but don't
846
- want to keep the majority of these checkpoints.
847
-
848
- For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
849
- a temporary checkpoint every 100 steps in case your job fails. In that case you would
850
- set `save_interval=1000` and `save_interval_ephemeral=100`.
851
- """
852
-
853
- save_num_checkpoints_to_keep: int = -1
854
- """
855
- How many sharded checkpoints to keep.
856
- """
857
-
858
- save_num_unsharded_checkpoints_to_keep: int = -1
859
- """
860
- How many unsharded checkpoints to keep.
861
- """
862
-
863
- save_overwrite: bool = False
864
- """
865
- If ``True``, overwrite any conflicting checkpoint files.
866
- """
867
-
868
- force_save_unsharded: bool = False
869
- """
870
- Save an unsharded checkpoint before training (even during a dry run).
871
- Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded
872
- checkpoint into an unsharded checkpoint.
873
- """
874
-
875
- no_pre_train_checkpoint: bool = False
876
- """
877
- Skip saving pre-train checkpoint.
878
- """
879
-
880
- load_path: Optional[str] = None
881
- """
882
- The path to a training checkpoint to restore/resume from.
883
-
884
- Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
885
- a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
886
- For example,
887
-
888
- ```bash
889
- --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
890
- ```
891
- """
892
-
893
- load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
894
- """
895
- The sharded checkpointer type to use to load the initial checkpoint from ``load_path``.
896
- """
897
-
898
- reset_optimizer_state: bool = False
899
- """
900
- When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
901
- We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
902
- curve (according to the current learning rate schedule settings), and continues from there.
903
- """
904
-
905
- reset_trainer_state: bool = False
906
- """
907
- When this is set we don't restore the trainer state from a checkpoint.
908
- """
909
-
910
- sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy
911
- """
912
- The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training.
913
- """
914
-
915
- new_style_checkpoints: Optional[bool] = None
916
- """
917
- Deprecated. Use ``sharded_checkpointer`` instead.
918
- """
919
-
920
- max_duration: Union[int, str] = 10000
921
- """
922
- How long to train for.
923
-
924
- If specified without a unit (the default), the units are assumed to be steps.
925
- You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until
926
- 2 trillion tokens.
927
- """
928
-
929
- global_train_batch_size: int = 512
930
- """
931
- The effective global batch size.
932
- """
933
-
934
- device_train_batch_size: Optional[int] = None # calculated automatically
935
- """
936
- Don't set this manually. This will be set to ``global_train_batch_size // world_size``.
937
- """
938
-
939
- device_train_microbatch_size: int = 16
940
- """
941
- The number of instances passed to the model in a single forward-backward pass. You should set
942
- this as large as you can based on available GPU memory.
943
- """
944
-
945
- device_eval_batch_size: int = 16
946
- """
947
- The number of evaluation instances passed to the model in a single forward pass on each device.
948
- """
949
-
950
- eval_subset_num_batches: int = -1
951
- """
952
- The number of batches to use for downstream evaluation from each dataset.
953
- """
954
-
955
- eval_on_load: bool = False
956
- """
957
- When resuming from a checkpoint, run the evaluation loop right away.
958
- """
959
-
960
- device_train_grad_accum: Optional[int] = None # calculated automatically
961
- """
962
- Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``.
963
- """
964
-
965
- max_grad_norm: Optional[float] = None
966
- """
967
- Clip gradient norms to this value if set.
968
- """
969
-
970
- max_grad_norm_ratio: Optional[float] = None
971
- """
972
- If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`.
973
- This takes priority over `max_grad_norm` when set.
974
- """
975
-
976
- precision: Optional[str] = None
977
- """
978
- Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32").
979
- """
980
-
981
- wandb: Optional[WandbConfig] = None
982
- """
983
- Weights & Biases configuration.
984
- """
985
-
986
- speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
987
- """
988
- Speed monitor configuration.
989
- """
990
-
991
- console_log_interval: int = 1
992
- """
993
- How often to log to the console.
994
- """
995
-
996
- compile: Optional[CompilerConfig] = None
997
- """
998
- Settings for compiling the model with ``torch.compile()``.
999
- """
1000
-
1001
- fsdp: FSDPConfig = field(default_factory=FSDPConfig)
1002
- """
1003
- Fully sharded data parallel settings.
1004
- """
1005
-
1006
- softmax_auxiliary_loss: bool = False
1007
- """
1008
- If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax
1009
- normalizing term to be close to 0.
1010
- """
1011
-
1012
- time_limit: Optional[float] = 60 * 60 * 47.5
1013
- """
1014
- The maximum amount of time to train for before saving a checkpoint and ending early.
1015
- On LUMI we have 48 hours max per job, so we default to just under 48 hours to give us time
1016
- to write out a final checkpoint.
1017
- """
1018
-
1019
- extra_steps_after_cancel: int = 10
1020
- """
1021
- Under certain conditions when a run is canceled we train for a few extra steps after saving
1022
- the final checkpoint so that when the run is restarted from the latest checkpoint we have some
1023
- overlap in metrics.
1024
- """
1025
-
1026
- early_stopping_factor: Optional[float] = None
1027
-
1028
- save_data_indices: bool = True
1029
- """
1030
- Save training data indices from each batch for each worker.
1031
- """
1032
-
1033
- python_profiling: bool = False
1034
- """
1035
- Whether to run the Python profiler on batches 6, 7, and 8.
1036
- """
1037
-
1038
- torch_profiling: bool = False
1039
- """
1040
- Whether to run the PyTorch profiler on batches 6, 7, and 8.
1041
- """
1042
-
1043
- stop_at: Optional[int] = None
1044
- """
1045
- Stop at a specific step.
1046
- """
1047
-
1048
- stop_after: Optional[int] = None
1049
- """
1050
- Stop after a specific number of steps.
1051
- """
1052
-
1053
- activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None
1054
- """
1055
- The activation checkpointing strategy to use.
1056
- """
1057
-
1058
- fused_loss: Optional[bool] = None
1059
- """
1060
- Whether to use the fused CE loss function from `flash-attn`.
1061
- """
1062
-
1063
- @property
1064
- def autocast_precision(self) -> torch.dtype:
1065
- if self.precision == "amp_bf16":
1066
- return torch.bfloat16
1067
- elif self.precision == "amp_fp16":
1068
- return torch.float16
1069
- elif self.precision == "fp32":
1070
- return torch.float32
1071
- else:
1072
- raise ValueError(f"Unexpected precision type '{self.precision}'")
1073
-
1074
- @property
1075
- def fsdp_precision(self) -> MixedPrecision:
1076
- if self.fsdp.precision == FSDPPrecision.pure:
1077
- return MixedPrecision(
1078
- param_dtype=self.autocast_precision,
1079
- reduce_dtype=self.autocast_precision,
1080
- buffer_dtype=self.autocast_precision,
1081
- )
1082
- elif self.fsdp.precision == FSDPPrecision.mixed:
1083
- return MixedPrecision(
1084
- param_dtype=self.autocast_precision,
1085
- reduce_dtype=torch.float32,
1086
- buffer_dtype=self.autocast_precision,
1087
- )
1088
- else:
1089
- raise NotImplementedError(f"{self.fsdp.precision}")
1090
-
1091
- @classmethod
1092
- def update_legacy_settings(cls, config: D) -> D:
1093
- new_config = config.copy()
1094
- if om.is_dict(new_config):
1095
- assert isinstance(new_config, DictConfig)
1096
-
1097
- if hasattr(new_config, "activation_checkpointing"):
1098
- if new_config.activation_checkpointing is False:
1099
- new_config.activation_checkpointing = None
1100
- if new_config.activation_checkpointing is True:
1101
- new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer
1102
-
1103
- if hasattr(new_config, "optimizer"):
1104
- new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)
1105
-
1106
- return new_config