English
naveensp commited on
Commit
2e5456c
·
verified ·
1 Parent(s): 10b98ea

Upload config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. config.py +1106 -0
config.py ADDED
@@ -0,0 +1,1106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from glob import glob
5
+ from pathlib import Path
6
+ from typing import (
7
+ Any,
8
+ Dict,
9
+ Iterable,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Type,
14
+ TypeVar,
15
+ Union,
16
+ cast,
17
+ )
18
+
19
+ import torch
20
+ from omegaconf import DictConfig, ListConfig
21
+ from omegaconf import OmegaConf as om
22
+ from omegaconf.errors import OmegaConfBaseException
23
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
24
+
25
+ from .aliases import PathOrStr
26
+ from .beam_search import Sampler
27
+ from .exceptions import OLMoConfigurationError
28
+ from .util import StrEnum
29
+
30
+ __all__ = [
31
+ "ActivationType",
32
+ "ActivationCheckpointingStrategy",
33
+ "BlockType",
34
+ "LayerNormType",
35
+ "InitFnType",
36
+ "ModelConfig",
37
+ "OptimizerType",
38
+ "OptimizerConfig",
39
+ "SchedulerType",
40
+ "SchedulerConfig",
41
+ "DataConfig",
42
+ "EvaluatorConfig",
43
+ "TokenizerConfig",
44
+ "TrainConfig",
45
+ "PaddingDirection",
46
+ "TruncationDirection",
47
+ "SpeedMonitorConfig",
48
+ "WandbConfig",
49
+ "CompilerConfig",
50
+ "WandbConfig",
51
+ "FSDPPrecision",
52
+ "FSDPWrapStrategy",
53
+ "FSDPConfig",
54
+ "CheckpointType",
55
+ ]
56
+
57
+ C = TypeVar("C", bound="BaseConfig")
58
+ D = TypeVar("D", bound="DictConfig|ListConfig")
59
+
60
+
61
+ class BaseConfig:
62
+ @classmethod
63
+ def _register_resolvers(cls, validate_paths: bool = True):
64
+ # Expands path globs into a list.
65
+ def path_glob(*paths) -> List[str]:
66
+ out = []
67
+ for path in paths:
68
+ matches = sorted(glob(path))
69
+ if not matches and validate_paths:
70
+ raise FileNotFoundError(f"{path} does not match any files or dirs")
71
+ out.extend(matches)
72
+ return out
73
+
74
+ # Chooses the first path in the arguments that exists.
75
+ def path_choose(*paths) -> str:
76
+ from .util import is_url
77
+
78
+ for path in paths:
79
+ if is_url(path) or Path(path).exists():
80
+ return path
81
+ if validate_paths:
82
+ raise FileNotFoundError(", ".join(paths))
83
+ else:
84
+ return ""
85
+
86
+ # Finds the latest checkpoint in a folder.
87
+ def path_last_checkpoint(path) -> str:
88
+ from .util import find_latest_checkpoint
89
+
90
+ latest_checkpoint = find_latest_checkpoint(path)
91
+ if latest_checkpoint is None:
92
+ if validate_paths:
93
+ raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
94
+ else:
95
+ return ""
96
+ else:
97
+ return str(latest_checkpoint)
98
+
99
+ om.register_new_resolver("path.glob", path_glob, replace=True)
100
+ om.register_new_resolver("path.choose", path_choose, replace=True)
101
+ om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
102
+
103
+ @classmethod
104
+ def update_legacy_settings(cls, config: D) -> D:
105
+ """
106
+ Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
107
+ """
108
+ return config
109
+
110
+ @classmethod
111
+ def new(cls: Type[C], **kwargs) -> C:
112
+ cls._register_resolvers()
113
+ conf = om.structured(cls)
114
+ try:
115
+ if kwargs:
116
+ conf = om.merge(conf, kwargs)
117
+ return cast(C, om.to_object(conf))
118
+ except OmegaConfBaseException as e:
119
+ raise OLMoConfigurationError(str(e))
120
+
121
+ @classmethod
122
+ def load(
123
+ cls: Type[C],
124
+ path: PathOrStr,
125
+ overrides: Optional[List[str]] = None,
126
+ key: Optional[str] = None,
127
+ validate_paths: bool = True,
128
+ ) -> C:
129
+ """Load from a YAML file."""
130
+ cls._register_resolvers(validate_paths=validate_paths)
131
+ schema = om.structured(cls)
132
+ try:
133
+ raw = om.load(str(path))
134
+ if key is not None:
135
+ raw = raw[key] # type: ignore
136
+ raw = cls.update_legacy_settings(raw)
137
+ conf = om.merge(schema, raw)
138
+ if overrides:
139
+ conf = om.merge(conf, om.from_dotlist(overrides))
140
+ return cast(C, om.to_object(conf))
141
+ except OmegaConfBaseException as e:
142
+ raise OLMoConfigurationError(str(e))
143
+
144
+ def save(self, path: PathOrStr) -> None:
145
+ """Save to a YAML file."""
146
+ om.save(config=self, f=str(path))
147
+
148
+ def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
149
+ out = asdict(self) # type: ignore
150
+ if exclude is not None:
151
+ for name in exclude:
152
+ if name in out:
153
+ del out[name]
154
+ return out
155
+
156
+
157
+ class LayerNormType(StrEnum):
158
+ default = "default"
159
+ """
160
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
161
+ """
162
+
163
+ low_precision = "low_precision"
164
+ """
165
+ A low-precision version of the default LayerNorm.
166
+ """
167
+
168
+ rms = "rms"
169
+ """
170
+ An RMSNorm implementation. When using ``torch.compile`` this is
171
+ probably the fastest implementation.
172
+ """
173
+
174
+
175
+ class ActivationType(StrEnum):
176
+ gelu = "gelu"
177
+ relu = "relu"
178
+ swiglu = "swiglu"
179
+
180
+
181
+ class BlockType(StrEnum):
182
+ sequential = "sequential"
183
+
184
+ llama = "llama"
185
+ """
186
+ A block similar to the sequential block with slightly different
187
+ implementations of operations like attention to imitate the behavior of Llama.
188
+ """
189
+
190
+
191
+ class InitFnType(StrEnum):
192
+ mitchell = "mitchell"
193
+ """
194
+ The strategy suggested to us by Mitchell Wortsman from UW.
195
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
196
+ on the size of the weights as well as the depth of the layer.
197
+ """
198
+
199
+ normal = "normal"
200
+ """
201
+ All weights are initialized from the same normal distribution.
202
+ """
203
+
204
+ kaiming_normal = "kaiming_normal"
205
+ """
206
+ All weights are initialized with the Kaiming method from a normal distribution.
207
+ Note this currently won't work with FSDP.
208
+ """
209
+
210
+ fan_in = "fan_in"
211
+ """
212
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
213
+ is the input dimensionality of the kernel.
214
+ """
215
+
216
+ full_megatron = "full_megatron"
217
+ """
218
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
219
+ """
220
+
221
+
222
+ @dataclass
223
+ class ModelConfig(BaseConfig):
224
+ """
225
+ OLMo (model) configuration.
226
+ """
227
+
228
+ # Note that the defaults for these attributes are equivalent to the base GPT2 model.
229
+
230
+ d_model: int = 768
231
+ """
232
+ The hidden size of the model.
233
+ """
234
+
235
+ n_heads: int = 12
236
+ """
237
+ The number of self-attention heads.
238
+ """
239
+
240
+ n_kv_heads: Optional[int] = None
241
+ """
242
+ The number of heads to use for keys and values. Defaults to `n_heads`.
243
+ Set this to ``None`` or ``n_heads`` for normal multi-head attention.
244
+ Set this to 1 for multi-query attention.
245
+ Set it to some in-between value for Llama2-style grouped query attention.
246
+ """
247
+
248
+ clip_qkv: Optional[float] = None
249
+ """
250
+ Clip QKV to this value when set.
251
+ """
252
+
253
+ n_layers: int = 12
254
+ """
255
+ The number of layers/blocks.
256
+ """
257
+
258
+ mlp_ratio: int = 4
259
+ """
260
+ The ratio of the inner MLP dimensionality to ``d_model``.
261
+ This is only used when ``mlp_hidden_size`` is not set.
262
+ """
263
+
264
+ mlp_hidden_size: Optional[int] = None
265
+ """
266
+ Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
267
+ """
268
+
269
+ activation_type: ActivationType = ActivationType.swiglu
270
+ """
271
+ The activation function to use within the MLP layers.
272
+ """
273
+
274
+ block_type: BlockType = BlockType.sequential
275
+ """
276
+ The transformer block implementation.
277
+ """
278
+
279
+ block_group_size: int = 1
280
+ """
281
+ The number of blocks to group together into a single parent block.
282
+ This has no affect on the number of parameters in the model and is only used to wrap groups
283
+ of blocks together with a single FSDP wrapper during training.
284
+ """
285
+
286
+ alibi: bool = False
287
+ """
288
+ If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
289
+ """
290
+
291
+ alibi_bias_max: float = 8.0
292
+ """
293
+ Maximum absolute value of ALiBi bias.
294
+ """
295
+
296
+ rope: bool = False
297
+ """
298
+ Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
299
+ """
300
+
301
+ rope_full_precision: bool = True
302
+ """
303
+ If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
304
+ apply RoPE at the precision of the input.
305
+ """
306
+
307
+ flash_attention: bool = False
308
+ """
309
+ If ``True``, use ``FlashAttention``.
310
+ """
311
+
312
+ attention_dropout: float = 0.1
313
+ """
314
+ The dropout probability within the attention modules.
315
+ """
316
+
317
+ multi_query_attention: Optional[bool] = None
318
+ """
319
+ Deprecated. Use n_kv_heads instead.
320
+ """
321
+
322
+ attention_layer_norm: bool = False
323
+ """
324
+ Apply layer norm to the keys and queries within the attention mechanism.
325
+ This can help stabilize training.
326
+ """
327
+
328
+ residual_dropout: float = 0.1
329
+ """
330
+ The dropout probability for the MLP and attention output within each block.
331
+ """
332
+
333
+ embedding_dropout: float = 0.1
334
+ """
335
+ The dropout probability for embeddings.
336
+ """
337
+
338
+ layer_norm_type: LayerNormType = LayerNormType.default
339
+ """
340
+ The layernorm implementation to use.
341
+ """
342
+
343
+ layer_norm_with_affine: bool = True
344
+ """
345
+ Whether to include bias and weight parameters for the layer norms.
346
+ This only affects layer norms that are immediately followed by a linear layer in the forward pass,
347
+ so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
348
+ to ``False``.
349
+ """
350
+
351
+ attention_layer_norm_with_affine: bool = True
352
+ """
353
+ Toggle affine transform for the QK norms.
354
+ """
355
+
356
+ max_sequence_length: int = 1024
357
+ """
358
+ The maximum input sequence length supported by the model.
359
+ """
360
+
361
+ include_bias: bool = True
362
+ """
363
+ Whether or not to include bias parameters in linear layers.
364
+ In PaLM, they got rid of all bias terms because they found that large
365
+ models tend to have near 0 bias terms anyway.
366
+ """
367
+
368
+ bias_for_layer_norm: Optional[bool] = None
369
+ """
370
+ Whether or not to include bias parameters in layer norm.
371
+ This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
372
+ layer norm.
373
+ When this is None (the default), it inherits the setting from include_bias.
374
+ """
375
+
376
+ scale_logits: bool = False
377
+ """
378
+ If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
379
+ """
380
+
381
+ vocab_size: int = 50257
382
+ """
383
+ Vocabulary size of the model.
384
+ """
385
+
386
+ embedding_size: Optional[int] = 50304
387
+ """
388
+ The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
389
+ to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
390
+ next multiple of 128 that's greater than ``vocab_size`` can improve throughput
391
+ substantially.
392
+ """
393
+
394
+ weight_tying: bool = True
395
+ """
396
+ Whether to tie output linear weights to the input embedding.
397
+ """
398
+
399
+ eos_token_id: int = 50256
400
+ """
401
+ The ID of the end-of-sentence special token.
402
+ """
403
+
404
+ pad_token_id: int = 50256
405
+ """
406
+ The ID of the token to use for padding. Defaults to the ID of the EOS token.
407
+ """
408
+
409
+ init_device: Optional[str] = None
410
+ """
411
+ The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
412
+ """
413
+
414
+ init_fn: InitFnType = InitFnType.normal
415
+ """
416
+ The weight initialization strategy.
417
+ """
418
+
419
+ init_std: float = 0.02
420
+ """
421
+ The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
422
+ as "normal".
423
+ """
424
+
425
+ init_cutoff_factor: Optional[float] = None
426
+ """
427
+ A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
428
+ as "normal". Setting this to None means values are not cutoff.
429
+ """
430
+
431
+ precision: Optional[str] = None
432
+ """
433
+ Precision used to train/evaluate with. You shouldn't set this directly.
434
+ See :data:`TrainConfig.precision` instead.
435
+ """
436
+
437
+ ternary: bool = False
438
+ """
439
+ Use ternary BitLinear layer from "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" (https://arxiv.org/pdf/2402.17764.pdf)
440
+ """
441
+
442
+ @property
443
+ def effective_n_kv_heads(self) -> int:
444
+ if self.n_kv_heads is None:
445
+ if self.multi_query_attention is True:
446
+ return 1
447
+ else:
448
+ return self.n_heads
449
+ else:
450
+ if self.multi_query_attention is None:
451
+ return self.n_kv_heads
452
+ if self.multi_query_attention:
453
+ n_kv_heads_should_be = 1
454
+ else:
455
+ n_kv_heads_should_be = self.n_heads
456
+ if self.n_kv_heads == n_kv_heads_should_be:
457
+ return n_kv_heads_should_be
458
+ else:
459
+ raise OLMoConfigurationError(
460
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
461
+ )
462
+
463
+
464
+ class OptimizerType(StrEnum):
465
+ lionw = "lionw"
466
+ adamw = "adamw"
467
+
468
+
469
+ @dataclass
470
+ class OptimizerConfig(BaseConfig):
471
+ name: OptimizerType = OptimizerType.lionw
472
+ learning_rate: float = 1.0e-4
473
+ weight_decay: float = 0.01
474
+ betas: Tuple[float, float] = (0.9, 0.95)
475
+
476
+ no_decay_norm_and_bias: Optional[bool] = None
477
+ """
478
+ Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead.
479
+ """
480
+
481
+ decay_norm_and_bias: bool = False
482
+ decay_embeddings: bool = False
483
+ metrics_log_interval: Optional[int] = None
484
+ """
485
+ The interval with which to collect and log detailed parameter-specific metrics.
486
+ This only applies when logging to W&B, since these metrics won't be logged to the console.
487
+ If not set, defaults to the wandb `log_interval`.
488
+ """
489
+
490
+ def __post_init__(self):
491
+ self.betas = tuple(self.betas) # type: ignore[assignment]
492
+
493
+ @classmethod
494
+ def update_legacy_settings(cls, config: D) -> D:
495
+ new_config = config.copy()
496
+ if om.is_dict(new_config):
497
+ assert isinstance(new_config, DictConfig)
498
+
499
+ if hasattr(new_config, "name") and new_config.name == "decoupled_lionw":
500
+ new_config.name = "lionw"
501
+ if hasattr(new_config, "eps"):
502
+ del new_config.eps
503
+
504
+ return new_config
505
+
506
+
507
+ class SchedulerType(StrEnum):
508
+ cosine_with_warmup = "cosine_with_warmup"
509
+ linear_with_warmup = "linear_with_warmup"
510
+ inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
511
+ max_scheduler = "max_scheduler"
512
+ constant = "constant"
513
+
514
+
515
+ class SchedulerUnits(StrEnum):
516
+ steps = "steps"
517
+ tokens = "tokens"
518
+
519
+
520
+ @dataclass
521
+ class SchedulerConfig(BaseConfig):
522
+ name: SchedulerType = SchedulerType.cosine_with_warmup
523
+ units: SchedulerUnits = SchedulerUnits.steps
524
+ t_warmup: Union[int, float] = 100
525
+ t_max: Optional[Union[int, float]] = None
526
+ alpha_f: float = 0.1
527
+
528
+ grad_clip_warmup_steps: Optional[Union[int, float]] = None
529
+ """
530
+ The warmup period for which the max grad norm (or norm ratio) will be set to its
531
+ warmup value of `max_grad_norm * grad_clip_warmup_factor`.
532
+ """
533
+
534
+ grad_clip_warmup_factor: Optional[float] = None
535
+ """
536
+ The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period
537
+ vs after the warmup period.
538
+ """
539
+
540
+
541
+ class PaddingDirection(StrEnum):
542
+ right = "right"
543
+ left = "left"
544
+
545
+
546
+ @dataclass
547
+ class DataConfig(BaseConfig):
548
+ paths: Optional[List[str]] = None
549
+ datasets: Optional[Dict[str, List[str]]] = None
550
+ label_mask_paths: Optional[List[str]] = None
551
+ pad_direction: PaddingDirection = PaddingDirection.right
552
+ generate_attention_mask: bool = False
553
+ num_workers: int = 0
554
+ drop_last: bool = False
555
+ pin_memory: bool = False
556
+ prefetch_factor: Optional[int] = None
557
+ persistent_workers: bool = False
558
+ timeout: int = 0
559
+ seed: Optional[int] = None
560
+
561
+
562
+ class EvaluatorType(StrEnum):
563
+ downstream = "downstream"
564
+ lm = "lm"
565
+
566
+
567
+ @dataclass
568
+ class EvaluatorConfig(BaseConfig):
569
+ label: str
570
+ type: EvaluatorType = EvaluatorType.lm
571
+ data: DataConfig = field(default_factory=DataConfig)
572
+ device_eval_batch_size: Optional[int] = None
573
+ subset_num_batches: Optional[int] = None
574
+
575
+
576
+ class TruncationDirection(StrEnum):
577
+ right = "right"
578
+ left = "left"
579
+
580
+
581
+ @dataclass
582
+ class TokenizerConfig(BaseConfig):
583
+ identifier: str = "gpt2"
584
+ truncate_direction: TruncationDirection = TruncationDirection.right
585
+
586
+
587
+ @dataclass
588
+ class WandbConfig(BaseConfig):
589
+ project: Optional[str] = None
590
+ entity: Optional[str] = "ai2-llm"
591
+ group: Optional[str] = None
592
+ name: Optional[str] = None
593
+ tags: Optional[List[str]] = field(default_factory=lambda: ["watching"])
594
+ log_artifacts: bool = False
595
+ rank_zero_only: bool = True
596
+ log_interval: int = 1
597
+
598
+
599
+ @dataclass
600
+ class SpeedMonitorConfig(BaseConfig):
601
+ window_size: int = 100
602
+ gpu_flops_available: Optional[Union[float, int]] = None
603
+
604
+
605
+ @dataclass
606
+ class CompilerConfig(BaseConfig):
607
+ mode: Optional[str] = None
608
+ """
609
+ The mode to compile the model in. At the moment this can be "default",
610
+ "reduce-overhead" (useful for smaller models/batches), or "max-autotune"
611
+ (the fastest for larger models, but takes a long time to compile).
612
+ """
613
+
614
+ fullgraph: bool = False
615
+ """
616
+ Whether it is OK to break model into several subgraphs when compiling.
617
+ Note that this is not compatible with FSDP.
618
+ """
619
+
620
+ backend: str = "inductor"
621
+ """
622
+ The backend to use.
623
+ """
624
+
625
+
626
+ class FSDPWrapStrategy(StrEnum):
627
+ by_block = "by_block"
628
+ """
629
+ Wrap each OLMo block with its own FSDP instance.
630
+ """
631
+
632
+ by_block_and_size = "by_block_and_size"
633
+ """
634
+ Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well.
635
+ """
636
+
637
+ by_block_group = "by_block_group"
638
+ """
639
+ Wrap each block group together into its own FSDP instance.
640
+ This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
641
+ """
642
+
643
+ by_block_group_and_size = "by_block_group_and_size"
644
+ """
645
+ Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well.
646
+ """
647
+
648
+ size_based = "size_based"
649
+ """
650
+ Used PyTorch's default size-based auto wrap policy.
651
+ """
652
+
653
+ one_in_two = "one_in_two"
654
+ one_in_three = "one_in_three"
655
+ one_in_four = "one_in_four"
656
+ one_in_five = "one_in_five"
657
+
658
+
659
+ class FSDPPrecision(StrEnum):
660
+ pure = "pure"
661
+ """
662
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``,
663
+ and ``buffer_dtype`` all set to the autocast precision data type.
664
+ """
665
+
666
+ mixed = "mixed"
667
+ """
668
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype``
669
+ set to the autocast precision data type, while ``reduce_dtype`` is set to fp32.
670
+ """
671
+
672
+
673
+ @dataclass
674
+ class FSDPConfig(BaseConfig):
675
+ use_orig_params: bool = True
676
+ """
677
+ This must be ``True`` if using ``compile`` or you want to track the parameter norm during training.
678
+ """
679
+
680
+ sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
681
+
682
+ wrapping_strategy: Optional[FSDPWrapStrategy] = None
683
+ """
684
+ The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level
685
+ FSDP instance.
686
+ """
687
+
688
+ precision: FSDPPrecision = FSDPPrecision.pure
689
+
690
+
691
+ class CheckpointType(StrEnum):
692
+ sharded = "sharded"
693
+ unsharded = "unsharded"
694
+ sharded_ephemeral = "sharded_ephemeral"
695
+
696
+
697
+ class ShardedCheckpointerType(StrEnum):
698
+ torch_new = "torch_new"
699
+ torch_legacy = "torch_legacy"
700
+ local = "local"
701
+
702
+
703
+ class ActivationCheckpointingStrategy(StrEnum):
704
+ whole_layer = "whole_layer"
705
+ """
706
+ Checkpoint every transformer layer.
707
+ """
708
+
709
+ one_in_two = "one_in_two"
710
+ """
711
+ Checkpoint one in two transformer layers.
712
+ """
713
+
714
+ one_in_three = "one_in_three"
715
+ """
716
+ Checkpoint one in three transformer layers.
717
+ """
718
+
719
+ one_in_four = "one_in_four"
720
+ """
721
+ Checkpoint one in four transformer layers.
722
+ """
723
+
724
+ two_in_three = "two_in_three"
725
+ """
726
+ Checkpoint two out of every three transformer layers.
727
+ """
728
+
729
+ three_in_four = "three_in_four"
730
+ """
731
+ Checkpoint three out of four of every transformer layers.
732
+ """
733
+
734
+ fine_grained = "fine_grained"
735
+ """
736
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
737
+ """
738
+
739
+
740
+ @dataclass
741
+ class TrainConfig(BaseConfig):
742
+ """
743
+ OLMo training configuration.
744
+ """
745
+
746
+ run_name: Optional[str] = None
747
+ """
748
+ The name of the run.
749
+ """
750
+
751
+ seed: int = 6198
752
+ """
753
+ Used to seed all initial RNG states.
754
+ """
755
+
756
+ epoch: Optional[int] = None
757
+ """
758
+ Increment this when starting a new epoch.
759
+ """
760
+
761
+ dry_run: bool = False
762
+ """
763
+ If ``True``, don't actually train.
764
+ """
765
+
766
+ model: ModelConfig = field(default_factory=ModelConfig)
767
+ """
768
+ OLMo Model configuration.
769
+ """
770
+
771
+ optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
772
+ """
773
+ Optimizer configuration.
774
+ """
775
+
776
+ scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
777
+ """
778
+ Learning rate scheduler configuration.
779
+ """
780
+
781
+ data: DataConfig = field(default_factory=DataConfig)
782
+ """
783
+ Training data configuration.
784
+ """
785
+
786
+ restore_dataloader: bool = True
787
+ """
788
+ When restarting, restore the data loader to where it left off.
789
+ If you restarting in order to train on a different dataset, set this to ``False``.
790
+ """
791
+
792
+ fast_forward_batches: Optional[int] = None
793
+ """
794
+ When restarting, use this to fast-forward the dataloader beyond the last checkpoint.
795
+ This can be useful when restarting due to a loss spike in order to skip the data that
796
+ corresponded to the spike.
797
+ """
798
+
799
+ evaluators: List[EvaluatorConfig] = field(default_factory=list)
800
+ """
801
+ Evaluation configurations.
802
+ """
803
+
804
+ eval_interval: int = 1000
805
+ """
806
+ How often (in terms of batches) to run evaluations.
807
+ """
808
+
809
+ tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
810
+ """
811
+ Tokenizer configuration.
812
+ """
813
+
814
+ save_folder: str = "./"
815
+ """
816
+ The directory to save checkpoints to.
817
+ """
818
+
819
+ remote_save_folder: Optional[str] = None
820
+ """
821
+ A folder in a cloud bucket to upload saved checkpoints to.
822
+ """
823
+
824
+ canceled_check_interval: int = 50
825
+ """
826
+ How often (in batches) to check if the run has been canceled or reached its time limit.
827
+ """
828
+
829
+ save_interval: int = 1000
830
+ """
831
+ How often (in terms of steps) to save sharded training state checkpoints.
832
+ """
833
+
834
+ save_interval_unsharded: Optional[int] = None
835
+ """
836
+ How often (if at all) to save unsharded training state checkpoint.
837
+ For large models it can be costly to save these, so it usually makes sense to save
838
+ these less often than regular (sharded) training checkpoints.
839
+ """
840
+
841
+ save_interval_ephemeral: Optional[int] = None
842
+ """
843
+ How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
844
+ as those saved every `save_interval` except that at most only the most recent one of these is kept.
845
+ This is useful when you want to checkpoint often for restarts in case of failures, but don't
846
+ want to keep the majority of these checkpoints.
847
+
848
+ For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
849
+ a temporary checkpoint every 100 steps in case your job fails. In that case you would
850
+ set `save_interval=1000` and `save_interval_ephemeral=100`.
851
+ """
852
+
853
+ save_num_checkpoints_to_keep: int = -1
854
+ """
855
+ How many sharded checkpoints to keep.
856
+ """
857
+
858
+ save_num_unsharded_checkpoints_to_keep: int = -1
859
+ """
860
+ How many unsharded checkpoints to keep.
861
+ """
862
+
863
+ save_overwrite: bool = False
864
+ """
865
+ If ``True``, overwrite any conflicting checkpoint files.
866
+ """
867
+
868
+ force_save_unsharded: bool = False
869
+ """
870
+ Save an unsharded checkpoint before training (even during a dry run).
871
+ Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded
872
+ checkpoint into an unsharded checkpoint.
873
+ """
874
+
875
+ no_pre_train_checkpoint: bool = False
876
+ """
877
+ Skip saving pre-train checkpoint.
878
+ """
879
+
880
+ load_path: Optional[str] = None
881
+ """
882
+ The path to a training checkpoint to restore/resume from.
883
+
884
+ Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
885
+ a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
886
+ For example,
887
+
888
+ ```bash
889
+ --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
890
+ ```
891
+ """
892
+
893
+ load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
894
+ """
895
+ The sharded checkpointer type to use to load the initial checkpoint from ``load_path``.
896
+ """
897
+
898
+ reset_optimizer_state: bool = False
899
+ """
900
+ When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
901
+ We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
902
+ curve (according to the current learning rate schedule settings), and continues from there.
903
+ """
904
+
905
+ reset_trainer_state: bool = False
906
+ """
907
+ When this is set we don't restore the trainer state from a checkpoint.
908
+ """
909
+
910
+ sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy
911
+ """
912
+ The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training.
913
+ """
914
+
915
+ new_style_checkpoints: Optional[bool] = None
916
+ """
917
+ Deprecated. Use ``sharded_checkpointer`` instead.
918
+ """
919
+
920
+ max_duration: Union[int, str] = 10000
921
+ """
922
+ How long to train for.
923
+
924
+ If specified without a unit (the default), the units are assumed to be steps.
925
+ You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until
926
+ 2 trillion tokens.
927
+ """
928
+
929
+ global_train_batch_size: int = 512
930
+ """
931
+ The effective global batch size.
932
+ """
933
+
934
+ device_train_batch_size: Optional[int] = None # calculated automatically
935
+ """
936
+ Don't set this manually. This will be set to ``global_train_batch_size // world_size``.
937
+ """
938
+
939
+ device_train_microbatch_size: int = 16
940
+ """
941
+ The number of instances passed to the model in a single forward-backward pass. You should set
942
+ this as large as you can based on available GPU memory.
943
+ """
944
+
945
+ device_eval_batch_size: int = 16
946
+ """
947
+ The number of evaluation instances passed to the model in a single forward pass on each device.
948
+ """
949
+
950
+ eval_subset_num_batches: int = -1
951
+ """
952
+ The number of batches to use for downstream evaluation from each dataset.
953
+ """
954
+
955
+ eval_on_load: bool = False
956
+ """
957
+ When resuming from a checkpoint, run the evaluation loop right away.
958
+ """
959
+
960
+ device_train_grad_accum: Optional[int] = None # calculated automatically
961
+ """
962
+ Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``.
963
+ """
964
+
965
+ max_grad_norm: Optional[float] = None
966
+ """
967
+ Clip gradient norms to this value if set.
968
+ """
969
+
970
+ max_grad_norm_ratio: Optional[float] = None
971
+ """
972
+ If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`.
973
+ This takes priority over `max_grad_norm` when set.
974
+ """
975
+
976
+ precision: Optional[str] = None
977
+ """
978
+ Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32").
979
+ """
980
+
981
+ wandb: Optional[WandbConfig] = None
982
+ """
983
+ Weights & Biases configuration.
984
+ """
985
+
986
+ speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
987
+ """
988
+ Speed monitor configuration.
989
+ """
990
+
991
+ console_log_interval: int = 1
992
+ """
993
+ How often to log to the console.
994
+ """
995
+
996
+ compile: Optional[CompilerConfig] = None
997
+ """
998
+ Settings for compiling the model with ``torch.compile()``.
999
+ """
1000
+
1001
+ fsdp: FSDPConfig = field(default_factory=FSDPConfig)
1002
+ """
1003
+ Fully sharded data parallel settings.
1004
+ """
1005
+
1006
+ softmax_auxiliary_loss: bool = False
1007
+ """
1008
+ If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax
1009
+ normalizing term to be close to 0.
1010
+ """
1011
+
1012
+ time_limit: Optional[float] = 60 * 60 * 47.5
1013
+ """
1014
+ The maximum amount of time to train for before saving a checkpoint and ending early.
1015
+ On LUMI we have 48 hours max per job, so we default to just under 48 hours to give us time
1016
+ to write out a final checkpoint.
1017
+ """
1018
+
1019
+ extra_steps_after_cancel: int = 10
1020
+ """
1021
+ Under certain conditions when a run is canceled we train for a few extra steps after saving
1022
+ the final checkpoint so that when the run is restarted from the latest checkpoint we have some
1023
+ overlap in metrics.
1024
+ """
1025
+
1026
+ early_stopping_factor: Optional[float] = None
1027
+
1028
+ save_data_indices: bool = True
1029
+ """
1030
+ Save training data indices from each batch for each worker.
1031
+ """
1032
+
1033
+ python_profiling: bool = False
1034
+ """
1035
+ Whether to run the Python profiler on batches 6, 7, and 8.
1036
+ """
1037
+
1038
+ torch_profiling: bool = False
1039
+ """
1040
+ Whether to run the PyTorch profiler on batches 6, 7, and 8.
1041
+ """
1042
+
1043
+ stop_at: Optional[int] = None
1044
+ """
1045
+ Stop at a specific step.
1046
+ """
1047
+
1048
+ stop_after: Optional[int] = None
1049
+ """
1050
+ Stop after a specific number of steps.
1051
+ """
1052
+
1053
+ activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None
1054
+ """
1055
+ The activation checkpointing strategy to use.
1056
+ """
1057
+
1058
+ fused_loss: Optional[bool] = None
1059
+ """
1060
+ Whether to use the fused CE loss function from `flash-attn`.
1061
+ """
1062
+
1063
+ @property
1064
+ def autocast_precision(self) -> torch.dtype:
1065
+ if self.precision == "amp_bf16":
1066
+ return torch.bfloat16
1067
+ elif self.precision == "amp_fp16":
1068
+ return torch.float16
1069
+ elif self.precision == "fp32":
1070
+ return torch.float32
1071
+ else:
1072
+ raise ValueError(f"Unexpected precision type '{self.precision}'")
1073
+
1074
+ @property
1075
+ def fsdp_precision(self) -> MixedPrecision:
1076
+ if self.fsdp.precision == FSDPPrecision.pure:
1077
+ return MixedPrecision(
1078
+ param_dtype=self.autocast_precision,
1079
+ reduce_dtype=self.autocast_precision,
1080
+ buffer_dtype=self.autocast_precision,
1081
+ )
1082
+ elif self.fsdp.precision == FSDPPrecision.mixed:
1083
+ return MixedPrecision(
1084
+ param_dtype=self.autocast_precision,
1085
+ reduce_dtype=torch.float32,
1086
+ buffer_dtype=self.autocast_precision,
1087
+ )
1088
+ else:
1089
+ raise NotImplementedError(f"{self.fsdp.precision}")
1090
+
1091
+ @classmethod
1092
+ def update_legacy_settings(cls, config: D) -> D:
1093
+ new_config = config.copy()
1094
+ if om.is_dict(new_config):
1095
+ assert isinstance(new_config, DictConfig)
1096
+
1097
+ if hasattr(new_config, "activation_checkpointing"):
1098
+ if new_config.activation_checkpointing is False:
1099
+ new_config.activation_checkpointing = None
1100
+ if new_config.activation_checkpointing is True:
1101
+ new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer
1102
+
1103
+ if hasattr(new_config, "optimizer"):
1104
+ new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)
1105
+
1106
+ return new_config