naveensp commited on
Commit
886ea5f
1 Parent(s): a7c44ca

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. OLMo_Bitnet_1B/.gitattributes +35 -0
  2. OLMo_Bitnet_1B/README.md +38 -0
  3. OLMo_Bitnet_1B/__init__.py +0 -0
  4. OLMo_Bitnet_1B/__pycache__/__init__.cpython-310.pyc +0 -0
  5. OLMo_Bitnet_1B/__pycache__/__init__.cpython-311.pyc +0 -0
  6. OLMo_Bitnet_1B/__pycache__/__init__.cpython-312.pyc +0 -0
  7. OLMo_Bitnet_1B/__pycache__/aliases.cpython-310.pyc +0 -0
  8. OLMo_Bitnet_1B/__pycache__/aliases.cpython-311.pyc +0 -0
  9. OLMo_Bitnet_1B/__pycache__/aliases.cpython-312.pyc +0 -0
  10. OLMo_Bitnet_1B/__pycache__/beam_search.cpython-310.pyc +0 -0
  11. OLMo_Bitnet_1B/__pycache__/beam_search.cpython-311.pyc +0 -0
  12. OLMo_Bitnet_1B/__pycache__/beam_search.cpython-312.pyc +0 -0
  13. OLMo_Bitnet_1B/__pycache__/config.cpython-310.pyc +0 -0
  14. OLMo_Bitnet_1B/__pycache__/config.cpython-311.pyc +0 -0
  15. OLMo_Bitnet_1B/__pycache__/config.cpython-312.pyc +0 -0
  16. OLMo_Bitnet_1B/__pycache__/configuration_olmo.cpython-310.pyc +0 -0
  17. OLMo_Bitnet_1B/__pycache__/configuration_olmo.cpython-311.pyc +0 -0
  18. OLMo_Bitnet_1B/__pycache__/configuration_olmo.cpython-312.pyc +0 -0
  19. OLMo_Bitnet_1B/__pycache__/exceptions.cpython-310.pyc +0 -0
  20. OLMo_Bitnet_1B/__pycache__/exceptions.cpython-311.pyc +0 -0
  21. OLMo_Bitnet_1B/__pycache__/exceptions.cpython-312.pyc +0 -0
  22. OLMo_Bitnet_1B/__pycache__/initialization.cpython-310.pyc +0 -0
  23. OLMo_Bitnet_1B/__pycache__/initialization.cpython-311.pyc +0 -0
  24. OLMo_Bitnet_1B/__pycache__/initialization.cpython-312.pyc +0 -0
  25. OLMo_Bitnet_1B/__pycache__/model.cpython-310.pyc +0 -0
  26. OLMo_Bitnet_1B/__pycache__/model.cpython-311.pyc +0 -0
  27. OLMo_Bitnet_1B/__pycache__/model.cpython-312.pyc +0 -0
  28. OLMo_Bitnet_1B/__pycache__/modeling_olmo.cpython-310.pyc +0 -0
  29. OLMo_Bitnet_1B/__pycache__/modeling_olmo.cpython-311.pyc +0 -0
  30. OLMo_Bitnet_1B/__pycache__/modeling_olmo.cpython-312.pyc +0 -0
  31. OLMo_Bitnet_1B/__pycache__/optim.cpython-310.pyc +0 -0
  32. OLMo_Bitnet_1B/__pycache__/optim.cpython-311.pyc +0 -0
  33. OLMo_Bitnet_1B/__pycache__/optim.cpython-312.pyc +0 -0
  34. OLMo_Bitnet_1B/__pycache__/safetensors_util.cpython-310.pyc +0 -0
  35. OLMo_Bitnet_1B/__pycache__/safetensors_util.cpython-311.pyc +0 -0
  36. OLMo_Bitnet_1B/__pycache__/safetensors_util.cpython-312.pyc +0 -0
  37. OLMo_Bitnet_1B/__pycache__/torch_util.cpython-310.pyc +0 -0
  38. OLMo_Bitnet_1B/__pycache__/torch_util.cpython-311.pyc +0 -0
  39. OLMo_Bitnet_1B/__pycache__/torch_util.cpython-312.pyc +0 -0
  40. OLMo_Bitnet_1B/__pycache__/util.cpython-310.pyc +0 -0
  41. OLMo_Bitnet_1B/__pycache__/util.cpython-311.pyc +0 -0
  42. OLMo_Bitnet_1B/__pycache__/util.cpython-312.pyc +0 -0
  43. OLMo_Bitnet_1B/aliases.py +7 -0
  44. OLMo_Bitnet_1B/beam_search.py +1078 -0
  45. OLMo_Bitnet_1B/checkpoint.py +1671 -0
  46. OLMo_Bitnet_1B/config.json +50 -0
  47. OLMo_Bitnet_1B/config.py +1106 -0
  48. OLMo_Bitnet_1B/configuration_olmo.py +52 -0
  49. OLMo_Bitnet_1B/exceptions.py +50 -0
  50. OLMo_Bitnet_1B/initialization.py +95 -0
OLMo_Bitnet_1B/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
OLMo_Bitnet_1B/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - allenai/dolma
5
+ ---
6
+ # OLMo-Bitnet-1B
7
+
8
+ OLMo-Bitnet-1B is a 1B parameter model trained using the method described in [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764).
9
+
10
+ It was trained on the first 60B tokens of the [Dolma](https://huggingface.co/datasets/allenai/dolma) dataset, so it is merely a research proof-of-concept to test out the methodolgy.
11
+
12
+ A separate training run was run with the exact same hyperparameters, but using standard fp16 weights.
13
+ The comparison can be found in [this wandb report](https://api.wandb.ai/links/emozilla/evltqiv7).
14
+
15
+
16
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6317aade83d8d2fd903192d9/NAw-hyWJl5ihVsAPqz3Xe.png)
17
+
18
+ Sample inference code
19
+
20
+ ```sh
21
+ pip install ai2-olmo
22
+ ```
23
+
24
+ ```python
25
+ import torch
26
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/OLMo-Bitnet-1B")
29
+ model = AutoModelForCausalLM.from_pretrained("NousResearch/OLMo-Bitnet-1B",
30
+ torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
31
+
32
+ streamer = TextStreamer(tokenizer)
33
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id,
34
+ temperature=0.8, repetition_penalty=1.1, do_sample=True,streamer=streamer)
35
+ pipe("The capitol of Paris is", max_new_tokens=256)
36
+ ```
37
+
38
+ Training was performed using [OLMo](https://github.com/allenai/OLMo).
OLMo_Bitnet_1B/__init__.py ADDED
File without changes
OLMo_Bitnet_1B/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
OLMo_Bitnet_1B/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (165 Bytes). View file
 
OLMo_Bitnet_1B/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (153 Bytes). View file
 
OLMo_Bitnet_1B/__pycache__/aliases.cpython-310.pyc ADDED
Binary file (268 Bytes). View file
 
OLMo_Bitnet_1B/__pycache__/aliases.cpython-311.pyc ADDED
Binary file (342 Bytes). View file
 
OLMo_Bitnet_1B/__pycache__/aliases.cpython-312.pyc ADDED
Binary file (300 Bytes). View file
 
OLMo_Bitnet_1B/__pycache__/beam_search.cpython-310.pyc ADDED
Binary file (31.6 kB). View file
 
OLMo_Bitnet_1B/__pycache__/beam_search.cpython-311.pyc ADDED
Binary file (48 kB). View file
 
OLMo_Bitnet_1B/__pycache__/beam_search.cpython-312.pyc ADDED
Binary file (45.6 kB). View file
 
OLMo_Bitnet_1B/__pycache__/config.cpython-310.pyc ADDED
Binary file (18.1 kB). View file
 
OLMo_Bitnet_1B/__pycache__/config.cpython-311.pyc ADDED
Binary file (28.4 kB). View file
 
OLMo_Bitnet_1B/__pycache__/config.cpython-312.pyc ADDED
Binary file (25 kB). View file
 
OLMo_Bitnet_1B/__pycache__/configuration_olmo.cpython-310.pyc ADDED
Binary file (1.83 kB). View file
 
OLMo_Bitnet_1B/__pycache__/configuration_olmo.cpython-311.pyc ADDED
Binary file (2.74 kB). View file
 
OLMo_Bitnet_1B/__pycache__/configuration_olmo.cpython-312.pyc ADDED
Binary file (2.36 kB). View file
 
OLMo_Bitnet_1B/__pycache__/exceptions.cpython-310.pyc ADDED
Binary file (1.45 kB). View file
 
OLMo_Bitnet_1B/__pycache__/exceptions.cpython-311.pyc ADDED
Binary file (1.99 kB). View file
 
OLMo_Bitnet_1B/__pycache__/exceptions.cpython-312.pyc ADDED
Binary file (1.68 kB). View file
 
OLMo_Bitnet_1B/__pycache__/initialization.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
OLMo_Bitnet_1B/__pycache__/initialization.cpython-311.pyc ADDED
Binary file (5.12 kB). View file
 
OLMo_Bitnet_1B/__pycache__/initialization.cpython-312.pyc ADDED
Binary file (5.09 kB). View file
 
OLMo_Bitnet_1B/__pycache__/model.cpython-310.pyc ADDED
Binary file (47.9 kB). View file
 
OLMo_Bitnet_1B/__pycache__/model.cpython-311.pyc ADDED
Binary file (90.1 kB). View file
 
OLMo_Bitnet_1B/__pycache__/model.cpython-312.pyc ADDED
Binary file (86.4 kB). View file
 
OLMo_Bitnet_1B/__pycache__/modeling_olmo.cpython-310.pyc ADDED
Binary file (6.53 kB). View file
 
OLMo_Bitnet_1B/__pycache__/modeling_olmo.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
OLMo_Bitnet_1B/__pycache__/modeling_olmo.cpython-312.pyc ADDED
Binary file (9.86 kB). View file
 
OLMo_Bitnet_1B/__pycache__/optim.cpython-310.pyc ADDED
Binary file (19.6 kB). View file
 
OLMo_Bitnet_1B/__pycache__/optim.cpython-311.pyc ADDED
Binary file (41 kB). View file
 
OLMo_Bitnet_1B/__pycache__/optim.cpython-312.pyc ADDED
Binary file (36 kB). View file
 
OLMo_Bitnet_1B/__pycache__/safetensors_util.cpython-310.pyc ADDED
Binary file (2.8 kB). View file
 
OLMo_Bitnet_1B/__pycache__/safetensors_util.cpython-311.pyc ADDED
Binary file (5.01 kB). View file
 
OLMo_Bitnet_1B/__pycache__/safetensors_util.cpython-312.pyc ADDED
Binary file (4.4 kB). View file
 
OLMo_Bitnet_1B/__pycache__/torch_util.cpython-310.pyc ADDED
Binary file (5.04 kB). View file
 
OLMo_Bitnet_1B/__pycache__/torch_util.cpython-311.pyc ADDED
Binary file (9.11 kB). View file
 
OLMo_Bitnet_1B/__pycache__/torch_util.cpython-312.pyc ADDED
Binary file (8.08 kB). View file
 
OLMo_Bitnet_1B/__pycache__/util.cpython-310.pyc ADDED
Binary file (19.7 kB). View file
 
OLMo_Bitnet_1B/__pycache__/util.cpython-311.pyc ADDED
Binary file (37.3 kB). View file
 
OLMo_Bitnet_1B/__pycache__/util.cpython-312.pyc ADDED
Binary file (33.2 kB). View file
 
OLMo_Bitnet_1B/aliases.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from os import PathLike
2
+ from typing import Union
3
+
4
+ __all__ = ["PathOrStr"]
5
+
6
+
7
+ PathOrStr = Union[str, PathLike]
OLMo_Bitnet_1B/beam_search.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a self-contained and flexible beam search implementation adapted from
3
+ AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
4
+ """
5
+
6
+ import copy
7
+ import warnings
8
+ from abc import abstractmethod
9
+ from inspect import signature
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
11
+
12
+ import torch
13
+
14
+ __all__ = [
15
+ "Sampler",
16
+ "DeterministicSampler",
17
+ "MultinomialSampler",
18
+ "TopKSampler",
19
+ "TopPSampler",
20
+ "GumbelSampler",
21
+ "FinalSequenceScorer",
22
+ "SequenceLogProbabilityScorer",
23
+ "LengthNormalizedSequenceLogProbabilityScorer",
24
+ "Constraint",
25
+ "RepeatedNGramBlockingConstraint",
26
+ "BeamSearch",
27
+ ]
28
+
29
+ StateType = Dict[str, torch.Tensor]
30
+ StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
31
+ StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
32
+
33
+ StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
34
+ """
35
+ The type of step function that can be passed to [`BeamSearch.search`](#search).
36
+
37
+ This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
38
+ or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
39
+ """
40
+
41
+ ConstraintStateType = List[List[Dict[str, Any]]]
42
+
43
+
44
+ class Sampler:
45
+ """
46
+ An abstract class that can be used to sample candidates (either nodes or beams)
47
+ within `BeamSearch`.
48
+
49
+ A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
50
+
51
+ `init_state()` takes three arguments:
52
+
53
+ - a tensor of starting log probs with shape `(batch_size,, num_classes)`,
54
+ - the batch size, an int,
55
+ - and the number of classes, also an int.
56
+
57
+ It returns a state dictionary with any state tensors needed for subsequent
58
+ calls to `sample_nodes()` and `sample_beams()`.
59
+
60
+ By default this method just returns an empty dictionary.
61
+
62
+ Both `sample_nodes()` and `sample_beams()` should take three arguments:
63
+
64
+ - tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
65
+ - an integer representing the number of samples to take for each example in the batch,
66
+ - and a state dictionary which could contain any tensors needed for the `Sampler` to keep
67
+ track of state.
68
+
69
+ For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
70
+ `num_examples = beam_size * per_node_beam_size`.
71
+
72
+ The return value should be a tuple containing:
73
+
74
+ - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
75
+ - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
76
+ - and the updated state dictionary.
77
+
78
+ A default implementation of `sample_beams` is provided, which just deterministically
79
+ picks the `k` examples with highest log probability.
80
+ """
81
+
82
+ def init_state(
83
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
84
+ ) -> StateType:
85
+ del start_class_log_probabilities, batch_size, num_classes
86
+ return {}
87
+
88
+ @abstractmethod
89
+ def sample_nodes(
90
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
91
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
92
+ raise NotImplementedError
93
+
94
+ def sample_beams(
95
+ self, log_probs: torch.Tensor, beam_size: int, state: StateType
96
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
97
+ del state
98
+ selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
99
+ return selected_log_probs, selected_indices, {}
100
+
101
+
102
+ class DeterministicSampler(Sampler):
103
+ """
104
+ A `Sampler` that just deterministically returns the `k` nodes or beams with highest
105
+ log probability.
106
+ """
107
+
108
+ def sample_nodes(
109
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
111
+ del state
112
+ selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
113
+ return selected_log_probs, selected_indices, {}
114
+
115
+
116
+ class MultinomialSampler(Sampler):
117
+ """
118
+ A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
119
+ in the default, non-deterministic way.
120
+
121
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
122
+ above 1.0 produces a flatter probability distribution.
123
+ :param with_replacement: Whether to sample with replacement.
124
+
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ temperature: float = 1.0,
130
+ with_replacement: bool = False,
131
+ ) -> None:
132
+ self.temperature = temperature
133
+ self.with_replacement = with_replacement
134
+
135
+ def sample_nodes(
136
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
137
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
138
+ if self.temperature != 1.0:
139
+ _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
140
+ else:
141
+ _probabilities = log_probs.exp()
142
+
143
+ selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
144
+
145
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
146
+
147
+
148
+ class TopKSampler(Sampler):
149
+ """
150
+ A `Sampler` which redistributes the probability mass function for nodes among the
151
+ top `k` choices, then samples from that subset after re-normalizing the probabilities.
152
+
153
+ Beams are sampled in the default, deterministic way.
154
+
155
+ :param k: The number of top choices to be selected from.
156
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
157
+ above 1.0 produces a flatter probability distribution.
158
+ :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ k: int = 1,
164
+ temperature: float = 1.0,
165
+ with_replacement: bool = False,
166
+ ):
167
+ self.k = k
168
+ self.temperature = temperature or 1.0
169
+ self.with_replacement = with_replacement
170
+
171
+ def sample_nodes(
172
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
173
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
174
+ if not per_node_beam_size <= self.k <= log_probs.size()[1]:
175
+ raise ValueError(
176
+ "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
177
+ )
178
+
179
+ # shape (both): (batch_size, k)
180
+ top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
181
+
182
+ # Apply temperature if necessary.
183
+ # shape: (batch_size, k)
184
+ if self.temperature != 1.0:
185
+ top_k_log_probs = top_k_log_probs / self.temperature
186
+
187
+ # Re-normalize the subset.
188
+ # shape: (batch_size, k)
189
+ normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
190
+
191
+ # Sample from the re-normalized subset.
192
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
193
+ # shape: (batch_size, per_node_beam_size)
194
+ sampled_indices = torch.multinomial(
195
+ normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
196
+ )
197
+
198
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
199
+ # shape: (batch_size, per_node_beam_size)
200
+ indices = top_k_indices.gather(-1, sampled_indices)
201
+
202
+ return log_probs.gather(1, indices), indices, state
203
+
204
+
205
+ class TopPSampler(Sampler):
206
+ """
207
+ A `Sampler` which redistributes the probability mass function for nodes among
208
+ the top choices with a cumulative probability of at least `p`, then samples from that subset
209
+ after re-normalizing the probabilities.
210
+
211
+ Beams are sampled in the default, deterministic way.
212
+
213
+ :param p:
214
+ The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
215
+ examples to sample from. If `with_replacement` is `False` and the number of possible samples is
216
+ insufficient to sample without replacement from when calling `sample_nodes`, then the top
217
+ `per_node_beam_size` examples will be chosen.
218
+ :param temperature:
219
+ A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
220
+ above 1.0 produces a flatter probability distribution.
221
+ :param with_replacement:
222
+ If set to `True`, samples will be selected with replacement from the top choices.
223
+
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ p: float = 0.9,
229
+ temperature: float = 1.0,
230
+ with_replacement: bool = False,
231
+ ):
232
+ if p < 0.0 or p > 1.0:
233
+ raise ValueError("p must be a positive float no greater than 1.0")
234
+ self.p = p
235
+ self.temperature = temperature or 1.0
236
+ self.with_replacement = with_replacement
237
+
238
+ def sample_nodes(
239
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
240
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
241
+ if not per_node_beam_size <= log_probs.size()[1]:
242
+ raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
243
+
244
+ # First apply temperature coefficient:
245
+ if self.temperature != 1.0:
246
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
247
+ else:
248
+ _log_probs = log_probs
249
+
250
+ # Sort the probabilities in descending order to then find cumulative sum
251
+ log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
252
+
253
+ # shape: (batch_size, num_classes)
254
+ probabilities_descending = log_probs_descending.exp()
255
+ probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
256
+
257
+ # Create a mask for filtering out probabilities that don't make the top `p`.
258
+ # shape: (batch_size, num_classes)
259
+ exclusion_mask = probabilities_summed >= self.p
260
+
261
+ # We want to include the first index where probabilities_summed >= p, so we shift over one.
262
+ exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
263
+ exclusion_mask[..., 0] = False
264
+
265
+ # Make sure there's at least `per_node_beam_size` options to be selected.
266
+ if not self.with_replacement:
267
+ exclusion_mask[..., :per_node_beam_size] = False
268
+
269
+ log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
270
+
271
+ # Now re-normalized the included log probs.
272
+ # shape: (batch_size, num_classes)
273
+ filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
274
+
275
+ # Sample from the re-normalized subset.
276
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
277
+ # shape: (batch_size, per_node_beam_size)
278
+ sampled_indices = torch.multinomial(
279
+ filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
280
+ )
281
+
282
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
283
+ # shape: (batch_size, per_node_beam_size)
284
+ selected_indices = sorting_indices.gather(-1, sampled_indices)
285
+
286
+ # Return (selected log probabilities, selected classes)
287
+ # shape: (len(log_probs),1) , (len(log_probs), 1)
288
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
289
+
290
+
291
+ class GumbelSampler(Sampler):
292
+ """
293
+ A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
294
+ [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
295
+ Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
296
+ (https://api.semanticscholar.org/CorpusID:76662039).
297
+
298
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
299
+ above 1.0 produces a flatter probability distribution.
300
+ """
301
+
302
+ def __init__(self, temperature: float = 1.0):
303
+ self.temperature = temperature
304
+
305
+ def init_state(
306
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
307
+ ) -> StateType:
308
+ # shape: (batch_size, num_classes)
309
+ zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
310
+
311
+ # shape: (batch_size, num_classes)
312
+ G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
313
+
314
+ return {"G_phi_S": G_phi_S}
315
+
316
+ def sample_nodes(
317
+ self,
318
+ log_probs: torch.Tensor,
319
+ per_node_beam_size: int,
320
+ state: StateType,
321
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
322
+ # First apply temperature coefficient:
323
+ # shape: (batch_size * beam_size, num_classes)
324
+ if self.temperature != 1.0:
325
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
326
+ else:
327
+ _log_probs = log_probs
328
+
329
+ # shape: (group_size,)
330
+ phi_S = state["phi_S"]
331
+
332
+ # shape: (group_size, num_classes)
333
+ phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
334
+
335
+ # shape: (group_size, num_classes)
336
+ phi_S_new = phi_S + _log_probs
337
+
338
+ # shape: (group_size, 1)
339
+ G_phi_S = state["G_phi_S"].unsqueeze(-1)
340
+
341
+ # shape: (group_size, num_classes)
342
+ G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
343
+
344
+ # Replace NaNs with very negative number.
345
+ # shape: (group_size, num_classes)
346
+ # G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
347
+
348
+ # shape (both): (group_size, per_node_beam_size)
349
+ top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
350
+
351
+ # shape: (group_size, per_node_beam_size)
352
+ top_log_probs = log_probs.gather(1, top_indices)
353
+
354
+ return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
355
+
356
+ def sample_beams(
357
+ self,
358
+ log_probs: torch.Tensor,
359
+ beam_size: int,
360
+ state: StateType,
361
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
362
+ """
363
+ Returns the beams with the highest perturbed log probabilities.
364
+ """
365
+ # shape (log_probs): (batch_size, beam_size * per_node_beam_size)
366
+
367
+ batch_size = log_probs.size()[0]
368
+
369
+ # shape: (batch_size * beam_size, per_node_beam_size)
370
+ G_phi_S = state["G_phi_S"]
371
+
372
+ # shape: (batch_size, beam_size * per_node_beam_size)
373
+ G_phi_S = G_phi_S.reshape_as(log_probs)
374
+
375
+ # shape (both): (batch_size, beam_size)
376
+ G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
377
+
378
+ # shape: (batch_size, beam_size)
379
+ selected_log_probs = log_probs.gather(1, selected_indices)
380
+
381
+ # Now sort the selected beams by their true log prob.
382
+ # shape (all): (batch_size, beam_size)
383
+ selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
384
+ selected_indices = selected_indices.gather(1, sort_indices)
385
+ G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
386
+
387
+ # shape: (batch_size * beam_size,)
388
+ G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
389
+
390
+ # shape: (batch_size * beam_size,)
391
+ phi_S = selected_log_probs.reshape(batch_size * beam_size)
392
+
393
+ return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
394
+
395
+ def gumbel(self, phi) -> torch.Tensor:
396
+ """
397
+ Sample `Gumbel(phi)`.
398
+
399
+ `phi` should have shape `(batch_size, num_classes)`.
400
+ """
401
+ return -torch.log(-torch.log(torch.rand_like(phi))) + phi
402
+
403
+ def gumbel_with_max(self, phi, T) -> torch.Tensor:
404
+ """
405
+ Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
406
+
407
+ `phi` should have shape `(batch_size, num_classes)` and `T` should have
408
+ shape `(batch_size, 1)`.
409
+ """
410
+ # Shape: (batch_size, num_classes)
411
+ G_phi = self.gumbel(phi)
412
+
413
+ # Now we find the maximum from these samples.
414
+ # Shape: (batch_size, )
415
+ Z, _ = G_phi.max(dim=-1)
416
+
417
+ # Shape: (batch_size, num_classes)
418
+ v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
419
+
420
+ # Shape: (batch_size, num_classes)
421
+ return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
422
+
423
+
424
+ class FinalSequenceScorer:
425
+ """
426
+ An abstract class that can be used to score the final generated sequences found
427
+ by beam search. Given the predicted sequences and the corresponding log probabilities of
428
+ those sequences, the class calculates and returns the final score of the sequences.
429
+
430
+ The default implementation scores the sequences using the sum of the log probabilities of
431
+ the sequence, which is passed as input.
432
+ """
433
+
434
+ @abstractmethod
435
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
436
+ """
437
+ Score the final predictions found by beam search.
438
+ Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
439
+
440
+ :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
441
+ :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
442
+ of the log probabilities per token, with shape `(batch_size, beam_size)`.
443
+ :param end_index: The index of the end symbol.
444
+
445
+ """
446
+ raise NotImplementedError
447
+
448
+
449
+ class SequenceLogProbabilityScorer(FinalSequenceScorer):
450
+ """
451
+ A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
452
+ across the sequence's tokens.
453
+ """
454
+
455
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
456
+ del predictions, end_index
457
+ # The sum of the sequence log probabilities is the input parameter, so just
458
+ # return it.
459
+ return log_probabilities
460
+
461
+
462
+ class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
463
+ """
464
+ A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
465
+ tokens in the sequence. It optionally includes a length penalty which promotes
466
+ or demotes sequences based on their lengths. The final score for a sequence will
467
+ be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
468
+ here includes the end token.
469
+
470
+ :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
471
+ A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
472
+ """
473
+
474
+ def __init__(self, length_penalty: float = 1.0):
475
+ super().__init__()
476
+ self.length_penalty = length_penalty
477
+
478
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
479
+ # shape: (batch_size, beam_size)
480
+ lengths = (predictions != end_index).long().sum(dim=2)
481
+
482
+ # If the sequence ended during beam search, the `log_probabilities` will include
483
+ # the transition to the end token. Therefore, in such situations, `lengths` is
484
+ # actually off by 1. This corrects for that.
485
+ # shape: (batch_size, beam_size)
486
+ is_end_token = predictions[:, :, -1] == end_index
487
+ lengths += is_end_token.long()
488
+
489
+ # shape: (batch_size, beam_size)
490
+ average_log_probs = log_probabilities / (lengths**self.length_penalty)
491
+ return average_log_probs
492
+
493
+
494
+ class Constraint:
495
+ """
496
+ An abstract class that can be used to enforce constraints on the output predictions
497
+ by manipulating the class log probabilities during beam search.
498
+
499
+ A `Constraint` just has three methods that need to be implemented by subclasses:
500
+ `init_state()`, `apply()` and `_update_state()`.
501
+
502
+ `init_state()` takes one argument:
503
+
504
+ - the batch size, an int
505
+
506
+ It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
507
+ calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
508
+ Each inner list should be of length 1.
509
+
510
+ `apply()` takes two arguments:
511
+
512
+ - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
513
+ and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
514
+ - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
515
+ log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
516
+
517
+ The `apply()` method should return new `class_log_probabilities` that enforce the constraint
518
+ for this step of beam search. For instance, it may prevent a specific class from being selected by setting
519
+ the corresponding log probability to a negligible value such as `float("-inf")` or
520
+ `torch.finfo(class_log_probabilities.dtype).min`.
521
+
522
+ `_update_state()` takes two arguments:
523
+
524
+ - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
525
+ copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
526
+ directly edited in-place without affecting the others.
527
+ - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
528
+ step of beam search.
529
+
530
+ The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
531
+ length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
532
+
533
+ """
534
+
535
+ @abstractmethod
536
+ def init_state(
537
+ self,
538
+ batch_size: int,
539
+ ) -> ConstraintStateType:
540
+ raise NotImplementedError
541
+
542
+ @abstractmethod
543
+ def apply(
544
+ self,
545
+ state: ConstraintStateType,
546
+ class_log_probabilities: torch.Tensor,
547
+ ) -> torch.Tensor:
548
+ raise NotImplementedError
549
+
550
+ @staticmethod
551
+ def _copy_state(
552
+ state: ConstraintStateType,
553
+ batch_size: int,
554
+ beam_size: int,
555
+ last_backpointer: Optional[torch.Tensor] = None,
556
+ ) -> ConstraintStateType:
557
+ """
558
+ Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
559
+ is not appropriate for your constraint, you will need to implement the copying yourself.
560
+ """
561
+ new_state = []
562
+ for i in range(batch_size):
563
+ batch_state = []
564
+ for j in range(beam_size):
565
+ if last_backpointer is None:
566
+ # This is the first prediction, so the backpointer is 0
567
+ backpointer = 0
568
+ else:
569
+ backpointer = last_backpointer[i, j].item()
570
+ batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
571
+ new_state.append(batch_state)
572
+ return new_state
573
+
574
+ def update_state(
575
+ self,
576
+ state: ConstraintStateType,
577
+ last_prediction: torch.Tensor,
578
+ last_backpointer: Optional[torch.Tensor] = None,
579
+ ) -> ConstraintStateType:
580
+ batch_size, beam_size = last_prediction.size()
581
+ new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
582
+ return self._update_state(new_state, last_prediction)
583
+
584
+ @abstractmethod
585
+ def _update_state(
586
+ self,
587
+ state: ConstraintStateType,
588
+ last_prediction: torch.Tensor,
589
+ ) -> ConstraintStateType:
590
+ raise NotImplementedError
591
+
592
+
593
+ class RepeatedNGramBlockingConstraint(Constraint):
594
+ def __init__(self, ngram_size: int, **kwargs) -> None:
595
+ super().__init__(**kwargs)
596
+ self.ngram_size = ngram_size
597
+
598
+ def init_state(
599
+ self,
600
+ batch_size: int,
601
+ ) -> ConstraintStateType:
602
+ return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
603
+
604
+ def apply(
605
+ self,
606
+ state: ConstraintStateType,
607
+ class_log_probabilities: torch.Tensor,
608
+ ) -> torch.Tensor:
609
+ for i, batch in enumerate(state):
610
+ for j, beam in enumerate(batch):
611
+ current_prefix = tuple(beam["current_prefix"])
612
+ seen_ngrams = beam["seen_ngrams"]
613
+ try:
614
+ disallowed_indices = seen_ngrams[current_prefix]
615
+ class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
616
+ class_log_probabilities.dtype
617
+ ).min
618
+ except KeyError:
619
+ # We have not seen this prefix before, so there is no index
620
+ # that needs to be blocked
621
+ pass
622
+ return class_log_probabilities
623
+
624
+ def _update_state(
625
+ self,
626
+ state: ConstraintStateType,
627
+ last_prediction: torch.Tensor,
628
+ ) -> ConstraintStateType:
629
+ for i, batch in enumerate(state):
630
+ for j, beam in enumerate(batch):
631
+ prediction = last_prediction[i, j].item()
632
+ prefix = beam["current_prefix"]
633
+ seen_ngrams = beam["seen_ngrams"]
634
+
635
+ if len(prefix) == self.ngram_size - 1:
636
+ # This is a new ngram that we have to remember
637
+ if tuple(prefix) not in seen_ngrams:
638
+ seen_ngrams[tuple(prefix)] = []
639
+ seen_ngrams[tuple(prefix)].append(prediction)
640
+
641
+ # Create the new prefix, removing the oldest index if the prefix
642
+ # is too long
643
+ prefix.append(prediction)
644
+ if len(prefix) == self.ngram_size:
645
+ prefix.pop(0)
646
+ return state
647
+
648
+
649
+ class BeamSearch:
650
+ """
651
+ Implements the beam search algorithm for decoding the most likely sequences.
652
+
653
+ :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
654
+
655
+ :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
656
+ of the predicted sequences.
657
+
658
+ :param beam_size: The width of the beam used.
659
+
660
+ :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
661
+ If not given, this just defaults to `beam_size`. Setting this parameter
662
+ to a number smaller than `beam_size` may give better results, as it can introduce
663
+ more diversity into the search. See
664
+ [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
665
+ (https://api.semanticscholar.org/CorpusID:2229477).
666
+
667
+ :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
668
+ If not specified, `DeterministicSampler` will be used, which just takes the
669
+ `per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
670
+
671
+ Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
672
+ [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
673
+
674
+ :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
675
+ the predicted sequences. This does not include the start or end tokens. If `None`,
676
+ no minimum is enforced.
677
+
678
+ :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
679
+ The output from this module is what is returned by the `search` method. If not
680
+ specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
681
+ by the sum of the token log probabilities.
682
+
683
+ :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
684
+ provided, no constraints will be enforced.
685
+
686
+ """
687
+
688
+ def __init__(
689
+ self,
690
+ end_index: int,
691
+ *,
692
+ max_steps: int = 50,
693
+ beam_size: int = 10,
694
+ per_node_beam_size: Optional[int] = None,
695
+ sampler: Optional[Sampler] = None,
696
+ min_steps: Optional[int] = None,
697
+ final_sequence_scorer: Optional[FinalSequenceScorer] = None,
698
+ constraints: Optional[List[Constraint]] = None,
699
+ ) -> None:
700
+ if not max_steps > 0:
701
+ raise ValueError("max_steps must be positive")
702
+ if not beam_size > 0:
703
+ raise ValueError("beam_size must be positive")
704
+ if per_node_beam_size is not None and not per_node_beam_size > 0:
705
+ raise ValueError("per_node_beam_size must be positive")
706
+ if min_steps is not None:
707
+ if not min_steps >= 0:
708
+ raise ValueError("min_steps must be non-negative")
709
+ if not min_steps <= max_steps:
710
+ raise ValueError("min_steps must be less than or equal to max_steps")
711
+
712
+ self._end_index = end_index
713
+ self.max_steps = max_steps
714
+ self.beam_size = beam_size
715
+ self.per_node_beam_size = per_node_beam_size or beam_size
716
+ self.sampler = sampler or DeterministicSampler()
717
+ self.min_steps = min_steps or 0
718
+ self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
719
+ self.constraints = constraints or []
720
+
721
+ @staticmethod
722
+ def _reconstruct_sequences(predictions, backpointers):
723
+ # Reconstruct the sequences.
724
+ # shape: [(batch_size, beam_size, 1)]
725
+ reconstructed_predictions = [predictions[-1].unsqueeze(2)]
726
+
727
+ if not backpointers:
728
+ return reconstructed_predictions
729
+
730
+ # shape: (batch_size, beam_size)
731
+ cur_backpointers = backpointers[-1]
732
+
733
+ for timestep in range(len(predictions) - 2, 0, -1):
734
+ # shape: (batch_size, beam_size, 1)
735
+ cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
736
+
737
+ reconstructed_predictions.append(cur_preds)
738
+
739
+ # shape: (batch_size, beam_size)
740
+ cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
741
+
742
+ # shape: (batch_size, beam_size, 1)
743
+ final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
744
+
745
+ reconstructed_predictions.append(final_preds)
746
+
747
+ return reconstructed_predictions
748
+
749
+ def search(
750
+ self,
751
+ start_predictions: torch.Tensor,
752
+ start_state: StateType,
753
+ step: StepFunctionType,
754
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
755
+ """
756
+ Given a starting state and a step function, apply beam search to find the
757
+ most likely target sequences.
758
+
759
+ Returns a tuple of `(predictions, final_scores)`, where `predictions`
760
+ has shape `(batch_size, beam_size, max_steps)` and `final_scores`
761
+ has shape `(batch_size, beam_size)`.
762
+
763
+ .. note::
764
+ If your step function returns `-inf` for some log probabilities
765
+ (like if you're using a masked log-softmax) then some of the "best"
766
+ sequences returned may also have `-inf` log probability. Specifically
767
+ this happens when the beam size is smaller than the number of actions
768
+ with finite log probability (non-zero probability) returned by the step function.
769
+ Therefore if you're using a mask you may want to check the results from `search`
770
+ and potentially discard sequences with non-finite log probability.
771
+
772
+ :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
773
+ Usually the initial predictions are just the index of the "start" token
774
+ in the target vocabulary.
775
+
776
+ :param start_state: The initial state passed to the `step` function. Each value of the state dict
777
+ should be a tensor of shape `(batch_size, *)`, where `*` means any other
778
+ number of dimensions.
779
+
780
+ :param step: A function that is responsible for computing the next most likely tokens,
781
+ given the current state and the predictions from the last time step.
782
+ The function should accept two or three arguments:
783
+
784
+ - a tensor of shape `(group_size,)` or representing the index of the predicted
785
+ tokens from the last time step,
786
+ - the current state, a `StateType`, and
787
+ - optionally, the timestep, an `int`.
788
+
789
+ The `group_size` will be `batch_size * beam_size`, except in the initial
790
+ step, for which it will just be `batch_size`.
791
+
792
+ The function is expected to return a tuple, where the first element
793
+ is a tensor of shape `(group_size, vocab_size)` containing
794
+ the log probabilities of the tokens for the next step, and the second
795
+ element is the updated state. The tensor in the state should have shape
796
+ `(group_size, *)`, where `*` means any other number of dimensions.
797
+
798
+ """
799
+ step_signature = signature(step)
800
+ if len(step_signature.parameters) < 3:
801
+ # If the step function we're given does not take the time step argument, wrap it
802
+ # in one that does.
803
+ old_step = cast(StepFunctionTypeNoTimestep, step)
804
+
805
+ def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
806
+ del time_step
807
+ return old_step(last_predictions, state)
808
+
809
+ return self._search(start_predictions, start_state, new_step)
810
+ else:
811
+ return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
812
+
813
+ def _search(
814
+ self,
815
+ start_predictions: torch.Tensor,
816
+ start_state: StateType,
817
+ step: StepFunctionTypeWithTimestep,
818
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
819
+ batch_size = start_predictions.size()[0]
820
+
821
+ # List of (batch_size, beam_size) tensors. One for each time step. Does not
822
+ # include the start symbols, which are implicit.
823
+ predictions: List[torch.Tensor] = []
824
+
825
+ # List of (batch_size, beam_size) tensors. One for each time step. None for
826
+ # the first. Stores the index n for the parent prediction, i.e.
827
+ # predictions[t-1][i][n], that it came from.
828
+ backpointers: List[torch.Tensor] = []
829
+
830
+ constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
831
+
832
+ # Calculate the first timestep. This is done outside the main loop
833
+ # because we are going from a single decoder input (the output from the
834
+ # encoder) to the top `beam_size` decoder outputs. On the other hand,
835
+ # within the main loop we are going from the `beam_size` elements of the
836
+ # beam to `beam_size`^2 candidates from which we will select the top
837
+ # `beam_size` elements for the next iteration.
838
+ # shape: (batch_size, num_classes)
839
+ start_class_log_probabilities, state = step(start_predictions, start_state, 0)
840
+
841
+ num_classes = start_class_log_probabilities.size()[1]
842
+
843
+ # Make sure `per_node_beam_size` is not larger than `num_classes`.
844
+ if self.per_node_beam_size > num_classes:
845
+ raise ValueError(
846
+ f"Vocab size ({num_classes:d}) too small "
847
+ f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
848
+ f"Please decrease beam_size or per_node_beam_size."
849
+ )
850
+
851
+ sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
852
+
853
+ # Apply all constraints.
854
+ if self.constraints:
855
+ # shape: (batch_size, 1, num_classes)
856
+ expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
857
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
858
+ expanded_start_class_log_probabilities = constraint.apply(
859
+ constraint_state, expanded_start_class_log_probabilities
860
+ )
861
+ start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
862
+
863
+ # Prevent selecting the end symbol if there is any min_steps constraint
864
+ if self.min_steps >= 1:
865
+ start_class_log_probabilities[:, self._end_index] = torch.finfo(
866
+ start_class_log_probabilities.dtype
867
+ ).min
868
+
869
+ # Get the initial predicted classed and their log probabilities.
870
+ # shape: (batch_size, beam_size), (batch_size, beam_size)
871
+ (
872
+ start_top_log_probabilities,
873
+ start_predicted_classes,
874
+ sampler_state,
875
+ ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
876
+
877
+ if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
878
+ warnings.warn(
879
+ "Empty sequences predicted. You may want to increase the beam size or ensure "
880
+ "your step function is working properly.",
881
+ RuntimeWarning,
882
+ )
883
+ return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
884
+
885
+ # The log probabilities for the last time step.
886
+ # shape: (batch_size, beam_size)
887
+ last_log_probabilities = start_top_log_probabilities
888
+
889
+ # shape: [(batch_size, beam_size)]
890
+ predictions.append(start_predicted_classes)
891
+
892
+ # Log probability tensor that mandates that the end token is selected.
893
+ # shape: (batch_size * beam_size, num_classes)
894
+ log_probs_after_end = start_class_log_probabilities.new_full(
895
+ (batch_size * self.beam_size, num_classes),
896
+ torch.finfo(start_class_log_probabilities.dtype).min,
897
+ )
898
+ log_probs_after_end[:, self._end_index] = 0.0
899
+
900
+ # Set the same state for each element in the beam.
901
+ self._update_initial_state(state, batch_size)
902
+
903
+ for i, constraint in enumerate(self.constraints):
904
+ constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
905
+
906
+ for timestep in range(self.max_steps - 1):
907
+ # shape: (batch_size * beam_size,)
908
+ last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
909
+
910
+ # If every predicted token from the last step is `self._end_index`,
911
+ # then we can stop early.
912
+ if (last_predictions == self._end_index).all():
913
+ break
914
+ # Take a step. This get the predicted log probs of the next classes
915
+ # and updates the state.
916
+ # shape: (batch_size * beam_size, num_classes)
917
+ class_log_probabilities, state = step(last_predictions, state, timestep + 1)
918
+
919
+ # Apply all constraints.
920
+ if self.constraints:
921
+ # shape: (batch_size, beam_size, num_classes)
922
+ reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
923
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
924
+ reshaped_class_log_probabilities = constraint.apply(
925
+ constraint_state, reshaped_class_log_probabilities
926
+ )
927
+ # shape: (batch_size * beam_size, num_classes)
928
+ class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
929
+
930
+ # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
931
+ # of the sequence (because `timestep` is 0-indexed and we generated the first token
932
+ # before the for loop). Here we block the end index if the search is not allowed to
933
+ # terminate on this iteration.
934
+ if timestep + 2 <= self.min_steps:
935
+ class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
936
+
937
+ # shape: (batch_size * beam_size, num_classes)
938
+ last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
939
+ batch_size * self.beam_size, num_classes
940
+ )
941
+
942
+ # Here we are finding any beams where we predicted the end token in
943
+ # the previous timestep and replacing the distribution with a
944
+ # one-hot distribution, forcing the beam to predict the end token
945
+ # this timestep as well.
946
+ # shape: (batch_size * beam_size, num_classes)
947
+ cleaned_log_probabilities = torch.where(
948
+ last_predictions_expanded == self._end_index,
949
+ log_probs_after_end,
950
+ class_log_probabilities,
951
+ )
952
+
953
+ # shape (both): (batch_size * beam_size, per_node_beam_size)
954
+ top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
955
+ cleaned_log_probabilities, self.per_node_beam_size, sampler_state
956
+ )
957
+
958
+ # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
959
+ # so that we can add them to the current log probs for this timestep.
960
+ # This lets us maintain the log probability of each element on the beam.
961
+ # shape: (batch_size * beam_size, per_node_beam_size)
962
+ expanded_last_log_probabilities = (
963
+ last_log_probabilities.unsqueeze(2)
964
+ .expand(batch_size, self.beam_size, self.per_node_beam_size)
965
+ .reshape(batch_size * self.beam_size, self.per_node_beam_size)
966
+ )
967
+
968
+ # shape: (batch_size * beam_size, per_node_beam_size)
969
+ summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
970
+
971
+ # shape: (batch_size, beam_size * per_node_beam_size)
972
+ reshaped_summed = summed_top_log_probabilities.reshape(
973
+ batch_size, self.beam_size * self.per_node_beam_size
974
+ )
975
+
976
+ # shape: (batch_size, beam_size * per_node_beam_size)
977
+ reshaped_predicted_classes = predicted_classes.reshape(
978
+ batch_size, self.beam_size * self.per_node_beam_size
979
+ )
980
+
981
+ # Keep only the top `beam_size` beam indices.
982
+ # shape (both): (batch_size, beam_size)
983
+ (
984
+ restricted_beam_log_probs,
985
+ restricted_beam_indices,
986
+ sampler_state,
987
+ ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
988
+
989
+ # Use the beam indices to extract the corresponding classes.
990
+ # shape: (batch_size, beam_size)
991
+ restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
992
+
993
+ predictions.append(restricted_predicted_classes)
994
+
995
+ # shape: (batch_size, beam_size)
996
+ last_log_probabilities = restricted_beam_log_probs
997
+
998
+ # The beam indices come from a `beam_size * per_node_beam_size` dimension where the
999
+ # indices with a common ancestor are grouped together. Hence
1000
+ # dividing by per_node_beam_size gives the ancestor. (Note that this is integer
1001
+ # division as the tensor is a LongTensor.)
1002
+ # shape: (batch_size, beam_size)
1003
+ backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
1004
+ backpointers.append(backpointer)
1005
+
1006
+ # Keep only the pieces of the state tensors corresponding to the
1007
+ # ancestors created this iteration.
1008
+ self._update_state(state, backpointer)
1009
+
1010
+ for i, constraint in enumerate(self.constraints):
1011
+ constraint_states[i] = constraint.update_state(
1012
+ constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
1013
+ )
1014
+
1015
+ # Warn about "-inf" log probabilities if not using any constraints (negligible
1016
+ # log probabilities are expected when using constraints).
1017
+ if not self.constraints and (
1018
+ not torch.isfinite(last_log_probabilities).all()
1019
+ or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
1020
+ ):
1021
+ warnings.warn(
1022
+ "Negligible log probabilities encountered ('-inf' or equivalent). "
1023
+ "Some final sequences may not make sense. "
1024
+ "This can happen when the beam size is larger than the number of valid (non-zero "
1025
+ "probability) transitions that the step function produces.",
1026
+ RuntimeWarning,
1027
+ )
1028
+
1029
+ reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
1030
+
1031
+ # shape: (batch_size, beam_size, max_steps)
1032
+ all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
1033
+
1034
+ # Calculate the final sequence scores
1035
+ # shape: (batch_size, beam_size)
1036
+ final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
1037
+
1038
+ # Sort the sequences based on the final scores so the best scoring
1039
+ # sequence is at index 0
1040
+ sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
1041
+ sorted_all_predictions = torch.gather(
1042
+ all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
1043
+ )
1044
+
1045
+ return sorted_all_predictions, sorted_final_scores
1046
+
1047
+ def _update_initial_state(self, state: StateType, batch_size: int):
1048
+ """
1049
+ Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
1050
+ """
1051
+ for key, state_tensor in state.items():
1052
+ if state_tensor is None:
1053
+ continue
1054
+ # shape: (batch_size * beam_size, *)
1055
+ _, *last_dims = state_tensor.size()
1056
+ state[key] = (
1057
+ state_tensor.unsqueeze(1)
1058
+ .expand(batch_size, self.beam_size, *last_dims)
1059
+ .reshape(batch_size * self.beam_size, *last_dims)
1060
+ )
1061
+
1062
+ def _update_state(self, state: StateType, backpointer: torch.Tensor):
1063
+ batch_size = backpointer.size()[0]
1064
+
1065
+ for key, state_tensor in state.items():
1066
+ if state_tensor is None:
1067
+ continue
1068
+ _, *last_dims = state_tensor.size()
1069
+ # shape: (batch_size, beam_size, *)
1070
+ expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
1071
+ batch_size, self.beam_size, *last_dims
1072
+ )
1073
+ # shape: (batch_size * beam_size, *)
1074
+ state[key] = (
1075
+ state_tensor.reshape(batch_size, self.beam_size, *last_dims)
1076
+ .gather(1, expanded_backpointer)
1077
+ .reshape(batch_size * self.beam_size, *last_dims)
1078
+ )
OLMo_Bitnet_1B/checkpoint.py ADDED
@@ -0,0 +1,1671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import io
3
+ import logging
4
+ import pickle
5
+ import shutil
6
+ import traceback
7
+ from abc import ABCMeta, abstractmethod
8
+ from collections import defaultdict
9
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
10
+ from contextlib import contextmanager
11
+ from copy import deepcopy
12
+ from dataclasses import dataclass, field, replace
13
+ from functools import reduce
14
+ from multiprocessing import shared_memory
15
+ from pathlib import Path
16
+ from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.distributed.checkpoint as dist_cp
21
+ import torch.multiprocessing as mp
22
+ from packaging import version
23
+ from torch.distributed import _remote_device
24
+ from torch.distributed._shard._utils import narrow_tensor_by_index
25
+ from torch.distributed._shard.metadata import ShardMetadata
26
+ from torch.distributed._shard.sharded_tensor import ShardedTensor
27
+ from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
28
+ from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
29
+ from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
30
+ from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
31
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
32
+ from torch.distributed.fsdp import StateDictType
33
+ from torch.distributed.fsdp.api import (
34
+ FullOptimStateDictConfig,
35
+ FullStateDictConfig,
36
+ ShardedOptimStateDictConfig,
37
+ ShardedStateDictConfig,
38
+ )
39
+ from torch.futures import Future
40
+
41
+ try:
42
+ from torch.distributed.fsdp.flat_param import FlatParamHandle # type: ignore
43
+ except ModuleNotFoundError:
44
+ from torch.distributed.fsdp._flat_param import FlatParamHandle # type: ignore
45
+
46
+ from . import util
47
+
48
+ from .aliases import PathOrStr
49
+ from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
50
+ from .exceptions import OLMoCheckpointError
51
+ from .optim import Optimizer, fix_optim_state_dict
52
+ from .safetensors_util import safetensors_file_to_state_dict
53
+ from .torch_util import (
54
+ barrier,
55
+ gc_cuda,
56
+ get_fs_local_rank,
57
+ get_global_rank,
58
+ get_world_size,
59
+ )
60
+ from .util import (
61
+ _get_s3_client,
62
+ default_thread_count,
63
+ dir_is_empty,
64
+ get_bytes_range,
65
+ get_progress_bar,
66
+ resource_path,
67
+ upload,
68
+ wait_for,
69
+ )
70
+
71
+ __all__ = [
72
+ "save_fsdp_model_and_optim_state",
73
+ "load_fsdp_model_and_optim_state",
74
+ "load_fsdp_optim_state",
75
+ "save_state_dict",
76
+ "load_state_dict",
77
+ "load_model_state",
78
+ "RemoteFileSystemWriter",
79
+ "RemoteFileSystemReader",
80
+ "Checkpointer",
81
+ "FullCheckpointer",
82
+ "TorchNewStyleShardedCheckpointer",
83
+ "TorchLegacyShardedCheckpointer",
84
+ "LocalShardedCheckpointer",
85
+ "build_sharded_checkpointer",
86
+ ]
87
+
88
+
89
+ log = logging.getLogger(__name__)
90
+
91
+ MODEL_AND_OPTIM_FOLDER = "model_and_optim"
92
+
93
+
94
+ def save_fsdp_model_and_optim_state(
95
+ checkpoint_dir: PathOrStr,
96
+ fsdp_model: FSDP,
97
+ optim: Optimizer,
98
+ *,
99
+ upload_to: Optional[str] = None,
100
+ save_overwrite: bool = False,
101
+ ):
102
+ """
103
+ Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
104
+ functions. This should be used during distributed training and should be called by all ranks.
105
+
106
+ :param checkpoint_dir: The directory to save to.
107
+ :param fsdp_model: The FSDP model.
108
+ :param optim: The FSDP model's optimizer.
109
+ :param upload_to: Optional, a remote "directory" to upload the checkpoint files to.
110
+ :param save_overwrite: Overwrite existing files.
111
+
112
+ :raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``.
113
+ """
114
+ checkpoint_dir = Path(checkpoint_dir)
115
+ target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
116
+ if save_overwrite:
117
+ if get_fs_local_rank() == 0:
118
+ shutil.rmtree(target_dir, ignore_errors=True)
119
+ elif not dir_is_empty(target_dir):
120
+ raise FileExistsError(target_dir)
121
+ barrier()
122
+ if get_fs_local_rank() == 0:
123
+ target_dir.mkdir(exist_ok=True, parents=True)
124
+ barrier()
125
+ with FSDP.state_dict_type(
126
+ fsdp_model,
127
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
128
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
129
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
130
+ ):
131
+ model_and_optim_state = {
132
+ "model": fsdp_model.state_dict(),
133
+ "optim": FSDP.optim_state_dict(fsdp_model, optim),
134
+ }
135
+ dist_cp.save_state_dict(
136
+ model_and_optim_state,
137
+ RemoteFileSystemWriter(
138
+ target_dir,
139
+ upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
140
+ save_overwrite=save_overwrite,
141
+ ),
142
+ )
143
+
144
+
145
+ def load_fsdp_model_and_optim_state(
146
+ checkpoint_dir: PathOrStr,
147
+ fsdp_model: FSDP,
148
+ optim: Optimizer,
149
+ *,
150
+ local_cache: Optional[PathOrStr] = None,
151
+ load_optimizer_state: bool = True,
152
+ ):
153
+ """
154
+ Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
155
+ functions. This should be used during distributed training and should be called by all ranks.
156
+
157
+ :param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory.
158
+ :param fsdp_model: The FSDP model.
159
+ :param optim: The FSDP model's optimizer.
160
+ :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
161
+ remote "directory" but there might be a cached version of the same artifacts.
162
+ :param load_optimizer_state: Set to ``False`` to skip loading the optimizer state.
163
+
164
+ :raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint.
165
+ """
166
+ load_path = str(checkpoint_dir).rstrip("/")
167
+ local_cache = None if local_cache is None else Path(local_cache)
168
+ with FSDP.state_dict_type(
169
+ fsdp_model,
170
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
171
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
172
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
173
+ ):
174
+ # Load the model state dict in place.
175
+ log.info("Loading model state...")
176
+ model_state = {"model": fsdp_model.state_dict()}
177
+ dist_cp.load_state_dict(
178
+ model_state,
179
+ RemoteFileSystemReader(
180
+ f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
181
+ local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
182
+ ),
183
+ )
184
+ fsdp_model.load_state_dict(model_state["model"])
185
+
186
+ if not load_optimizer_state:
187
+ return
188
+
189
+ # Load optim state dict in place.
190
+ log.info("Loading sharded optimizer state...")
191
+ optim_state = load_sharded_optimizer_state_dict(
192
+ model_state_dict=model_state["model"],
193
+ optimizer_key="optim",
194
+ storage_reader=RemoteFileSystemReader(
195
+ f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
196
+ local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
197
+ ),
198
+ )
199
+ del model_state
200
+ gc_cuda()
201
+ load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])
202
+
203
+
204
+ def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
205
+ log.info("Flattening sharded optimizer state...")
206
+ # NOTE: Careful! The order of the these arguments has changed from 2.0 to 2.1... ¯\_(ツ)_/¯
207
+ if version.parse(torch.__version__) < version.parse("2.1.0"):
208
+ flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim) # type: ignore
209
+ else:
210
+ flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state) # type: ignore
211
+ del optim_state
212
+ gc.collect()
213
+ log.info("Loading flattened optimizer state...")
214
+ # Put optim state on CPU since `Optimizer.load_state_dict()` will create a deepcopy of the whole state dict,
215
+ # which takes up unnecessary GPU memory.
216
+ for state in flattened_osd["state"].values():
217
+ for k in state.keys():
218
+ v = state[k]
219
+ if isinstance(v, torch.Tensor):
220
+ state[k] = v.to(device="cpu")
221
+ gc_cuda()
222
+ optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))
223
+
224
+
225
+ def save_state_dict(
226
+ checkpoint_dir: PathOrStr,
227
+ fname: str,
228
+ state_dict: Dict[str, Any],
229
+ *,
230
+ upload_to: Optional[str] = None,
231
+ save_overwrite: bool = False,
232
+ synchronize: bool = True,
233
+ ):
234
+ """
235
+ Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`.
236
+ This can be used during distributed training or not. If during distributed training the ``fname`` should be unique
237
+ for each rank.
238
+
239
+ :param checkpoint_dir: The directory to save to.
240
+ :param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``.
241
+ :param state_dict: The state dict to save.
242
+ :param upload_to: Optional, a remote "directory" to upload the file to.
243
+ :param save_overwrite: Overwrite existing files.
244
+ :param synchronize: If ``False``, don't do any distributed synchronization. Use this when only calling
245
+ this function from a single rank.
246
+
247
+ :raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``.
248
+ """
249
+ checkpoint_dir = Path(checkpoint_dir)
250
+ target_path = checkpoint_dir / fname
251
+ if save_overwrite:
252
+ target_path.unlink(missing_ok=True)
253
+ elif target_path.is_file():
254
+ raise FileExistsError(target_path)
255
+ if synchronize:
256
+ barrier()
257
+ target_path.parent.mkdir(exist_ok=True, parents=True)
258
+ if synchronize:
259
+ barrier()
260
+ torch.save(state_dict, target_path)
261
+ if upload_to is not None:
262
+ upload_target = f"{upload_to.rstrip('/')}/{fname}"
263
+ log.info(f"Uploading {target_path} to {upload_target}...")
264
+ upload(target_path, upload_target, save_overwrite=save_overwrite)
265
+
266
+
267
+ def load_state_dict(
268
+ checkpoint_dir: PathOrStr,
269
+ fname: str,
270
+ *,
271
+ local_cache: Optional[PathOrStr] = None,
272
+ map_location: Optional[str] = None,
273
+ ):
274
+ """
275
+ Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`.
276
+ This can be used during distributed training or not.
277
+
278
+ :param checkpoint_dir: A local or remote checkpoint directory.
279
+ :param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``.
280
+ :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
281
+ remote "directory" but there might be a cached version of the same artifacts.
282
+
283
+ :raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
284
+ """
285
+ if fname.endswith(".pt"):
286
+ # Try safetensors version first.
287
+ try:
288
+ path = resource_path(
289
+ str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
290
+ )
291
+ return safetensors_file_to_state_dict(path, map_location=map_location)
292
+ except FileNotFoundError:
293
+ pass
294
+
295
+ path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
296
+ return torch.load(path, map_location=map_location)
297
+
298
+
299
+ def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
300
+ """
301
+ Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`.
302
+ Note that ``model`` should not be wrapped with FSDP.
303
+ """
304
+ state_dict = {"model": model.state_dict()}
305
+ dist_cp.load_state_dict(
306
+ state_dict,
307
+ RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
308
+ no_dist=True,
309
+ )
310
+ model.load_state_dict(state_dict["model"])
311
+
312
+
313
+ class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
314
+ """
315
+ A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files
316
+ directly to a cloud bucket when ``upload_to`` is specified.
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ path: PathOrStr,
322
+ single_file_per_rank: bool = True,
323
+ sync_files: bool = True,
324
+ thread_count: Optional[int] = None,
325
+ per_thread_copy_ahead: int = 10_000_000,
326
+ upload_to: Optional[str] = None,
327
+ save_overwrite: bool = False,
328
+ ) -> None:
329
+ if thread_count is not None and thread_count <= 0:
330
+ raise ValueError("thread count must be at least 1")
331
+ super().__init__(
332
+ path,
333
+ single_file_per_rank=single_file_per_rank,
334
+ sync_files=sync_files,
335
+ # NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
336
+ # returns because uploading big checkpoint files with multiple threads causes
337
+ # boto3 to fail in weird ways.
338
+ thread_count=thread_count or 1,
339
+ per_thread_copy_ahead=per_thread_copy_ahead,
340
+ )
341
+ self.upload_to = None if upload_to is None else upload_to.rstrip("/")
342
+ self.save_overwrite = save_overwrite
343
+
344
+ def write_data(
345
+ self,
346
+ plan: dist_cp.SavePlan,
347
+ planner: dist_cp.SavePlanner,
348
+ ) -> Future[List[WriteResult]]:
349
+ fut = super().write_data(plan, planner)
350
+ if self.upload_to is not None:
351
+ files_to_upload = set()
352
+ for write_result in fut.wait():
353
+ files_to_upload.add(write_result.storage_data.relative_path)
354
+
355
+ # Create the global S3 client up front to work around a threading issue in boto.
356
+ if self.upload_to.startswith("s3://"):
357
+ _get_s3_client("s3")
358
+ elif self.upload_to.startswith("r2://"):
359
+ _get_s3_client("r2")
360
+
361
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
362
+ futures = []
363
+ for fname in files_to_upload:
364
+ source = self.path / fname
365
+ target = f"{self.upload_to}/{fname}"
366
+ log.info(f"Uploading {source} to {target}...")
367
+ futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
368
+ for f in as_completed(futures):
369
+ try:
370
+ f.result()
371
+ except BaseException:
372
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
373
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
374
+ # sure we're raising a simple error type that can be pickled.
375
+ raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
376
+ return fut
377
+
378
+ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
379
+ super().finish(metadata, results)
380
+ if self.upload_to is not None:
381
+ source = self.path / ".metadata"
382
+ target = f"{self.upload_to}/.metadata"
383
+ log.info(f"Uploading {source} to {target}...")
384
+ upload(source, target, save_overwrite=self.save_overwrite)
385
+
386
+
387
+ class RemoteFileSystemReader(dist_cp.StorageReader):
388
+ """
389
+ A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader`
390
+ that can read data directly from cloud storage as well as a local directory.
391
+ """
392
+
393
+ def __init__(
394
+ self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
395
+ ):
396
+ super().__init__()
397
+ if thread_count is not None and thread_count <= 0:
398
+ raise ValueError("thread count must be at least 1")
399
+ self.path = str(path).rstrip("/")
400
+ self.cache = None if local_cache is None else Path(local_cache)
401
+ self.thread_count = thread_count or default_thread_count()
402
+ self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
403
+ self._metadata: Optional[Metadata] = None
404
+
405
+ def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
406
+ if self.cache is not None and (path := self.cache / relative_path).is_file():
407
+ return get_bytes_range(path, offset, length)
408
+ else:
409
+ return get_bytes_range(f"{self.path}/{relative_path}", offset, length)
410
+
411
+ def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
412
+ sinfo = self.storage_data[read_item.storage_index]
413
+ content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
414
+ return (read_item, content)
415
+
416
+ def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
417
+ # Create the global S3 client up front to work around a threading issue in boto.
418
+ if isinstance(self.path, str):
419
+ if self.path.startswith("s3://"):
420
+ _get_s3_client("s3")
421
+ elif self.path.startswith("r2://"):
422
+ _get_s3_client("r2")
423
+
424
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
425
+ read_item_content_futures = []
426
+ for read_item in plan.items:
427
+ read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
428
+ read_item_content_results = []
429
+ for f in as_completed(read_item_content_futures):
430
+ try:
431
+ read_item_content_results.append(f.result())
432
+ except BaseException:
433
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
434
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
435
+ # sure we're raising a simple error type that can be pickled.
436
+ raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
437
+
438
+ # Modified from `FileSystemReader.read_data()`
439
+ for read_item, content in read_item_content_results:
440
+ bytes = io.BytesIO(content)
441
+ bytes.seek(0)
442
+ if read_item.type == LoadItemType.BYTE_IO:
443
+ planner.load_bytes(read_item, bytes)
444
+ else:
445
+ tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
446
+ tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
447
+ target_tensor = planner.resolve_tensor(read_item).detach()
448
+
449
+ assert (
450
+ target_tensor.size() == tensor.size()
451
+ ), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
452
+ target_tensor.copy_(tensor)
453
+ planner.commit_tensor(read_item, target_tensor)
454
+
455
+ fut: Future = Future()
456
+ fut.set_result(None)
457
+ return fut
458
+
459
+ def read_metadata(self) -> Metadata:
460
+ if self._metadata is None:
461
+ with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
462
+ self._metadata = pickle.load(metadata_file)
463
+ return self._metadata
464
+
465
+ def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
466
+ del is_coordinator
467
+ self.storage_data = metadata.storage_data
468
+ assert self.storage_data is not None
469
+
470
+ def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
471
+ return plan
472
+
473
+ def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
474
+ return global_plan
475
+
476
+
477
+ class Checkpointer(metaclass=ABCMeta):
478
+ def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None):
479
+ self.cfg = cfg
480
+ self.thread_count = thread_count or default_thread_count()
481
+
482
+ @abstractmethod
483
+ def save_checkpoint(
484
+ self,
485
+ dir: PathOrStr,
486
+ fsdp_model: FSDP,
487
+ optim: Optimizer,
488
+ train_state: Dict[str, Any],
489
+ *,
490
+ upload_to: Optional[str] = None,
491
+ ) -> None:
492
+ raise NotImplementedError
493
+
494
+ @abstractmethod
495
+ def restore_checkpoint(
496
+ self,
497
+ load_path: PathOrStr,
498
+ fsdp_model: FSDP,
499
+ optim: Optimizer,
500
+ *,
501
+ local_cache: Optional[PathOrStr] = None,
502
+ load_optimizer_state: bool = True,
503
+ ) -> Dict[str, Any]:
504
+ """
505
+ Restores a checkpoint to the model and optimizer. Returns the remaining trainer state.
506
+ """
507
+ raise NotImplementedError
508
+
509
+ def unshard_checkpoint(
510
+ self,
511
+ load_path: PathOrStr,
512
+ *,
513
+ local_cache: Optional[PathOrStr] = None,
514
+ load_optimizer_state: bool = True,
515
+ load_trainer_state: bool = True,
516
+ device: Optional[torch.device] = None,
517
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
518
+ """
519
+ Unshard a checkpoint.
520
+
521
+ Note this is not marked abstract because child classes are not required to implemented this.
522
+ """
523
+ del load_path, local_cache, load_optimizer_state, load_trainer_state, device
524
+ raise NotImplementedError
525
+
526
+ @contextmanager
527
+ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
528
+ # Make sure checkpoint directory doesn't exist unless it's okay to overwrite it.
529
+ checkpoint_dir = Path(dir)
530
+ if not dir_is_empty(checkpoint_dir):
531
+ if self.cfg.save_overwrite:
532
+ if get_fs_local_rank() == 0:
533
+ shutil.rmtree(checkpoint_dir, ignore_errors=True)
534
+ else:
535
+ raise FileExistsError(checkpoint_dir)
536
+ # No need to mkdir here since we'll directly replace the temporary directory with
537
+ # this directory below.
538
+ barrier()
539
+
540
+ # Prepare temporary directory. We don't have to be as careful here, we can
541
+ # just remove it if it already exists.
542
+ checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
543
+ if get_fs_local_rank() == 0:
544
+ shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
545
+ checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)
546
+
547
+ barrier()
548
+
549
+ # Yield temporary directory for `.save_checkpoint()` to use.
550
+ yield checkpoint_dir_tmp
551
+
552
+ barrier()
553
+
554
+ # Finally if all went well replace the temporary directory with the actual
555
+ # checkpoint directory.
556
+ if get_fs_local_rank() == 0:
557
+ # Replace temp directory with target checkpoint directory.
558
+ try:
559
+ checkpoint_dir_tmp.replace(checkpoint_dir)
560
+ except FileNotFoundError:
561
+ # Caught when another (file-system) local rank 0 has already replaced the tmp directory.
562
+ # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
563
+ # file-systems.
564
+ if not checkpoint_dir.exists():
565
+ raise
566
+
567
+ # In the cases where we're using a shared NFS drive between ranks to save checkpoints,
568
+ # replacing the temp directory with the final directory from rank 0 might not be immediately
569
+ # realized in the file systems of the other ranks.
570
+ # So we wait here across all ranks until that final checkpoint directory is visible.
571
+ wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)
572
+
573
+ barrier()
574
+
575
+ def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
576
+ if get_global_rank() == 0:
577
+ log.info("Saving config...")
578
+ self.cfg.save(config_path := Path(dir) / "config.yaml")
579
+ if upload_to is not None:
580
+ upload_target = f"{upload_to}/config.yaml"
581
+ log.info(f"Uploading {config_path} to {upload_target}")
582
+ upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite)
583
+
584
+
585
+ class FullCheckpointer(Checkpointer):
586
+ """
587
+ A :class:`Checkpointer` that saves a single full model and optimizer state dictionary.
588
+ """
589
+
590
+ def save_checkpoint(
591
+ self,
592
+ dir: PathOrStr,
593
+ fsdp_model: FSDP,
594
+ optim: Optimizer,
595
+ trainer_state: Dict[str, Any],
596
+ *,
597
+ upload_to: Optional[str] = None,
598
+ ) -> None:
599
+ with self._temporary_wd(dir) as checkpoint_dir:
600
+ with FSDP.state_dict_type(
601
+ fsdp_model,
602
+ state_dict_type=StateDictType.FULL_STATE_DICT,
603
+ state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
604
+ optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
605
+ ):
606
+ # We'll write the model and optimizer state dicts individually to reduce (CPU) memory consumption.
607
+ # First the model state.
608
+ model_state_dict = fsdp_model.state_dict()
609
+ if get_global_rank() == 0:
610
+ log.info("Saving model state...")
611
+ save_state_dict(
612
+ checkpoint_dir,
613
+ "model.pt",
614
+ model_state_dict,
615
+ upload_to=upload_to,
616
+ save_overwrite=self.cfg.save_overwrite,
617
+ synchronize=False,
618
+ )
619
+ del model_state_dict
620
+ barrier()
621
+
622
+ # Then the optimizer state.
623
+ optim_state_dict = FSDP.optim_state_dict(fsdp_model, optim)
624
+ if get_global_rank() == 0:
625
+ log.info("Saving optim state...")
626
+ save_state_dict(
627
+ checkpoint_dir,
628
+ "optim.pt",
629
+ optim_state_dict,
630
+ upload_to=upload_to,
631
+ save_overwrite=self.cfg.save_overwrite,
632
+ synchronize=False,
633
+ )
634
+ del optim_state_dict
635
+ barrier()
636
+
637
+ # Save trainer state.
638
+ if get_global_rank() == 0:
639
+ log.info("Saving trainer state...")
640
+ save_state_dict(
641
+ checkpoint_dir,
642
+ "train.pt",
643
+ trainer_state,
644
+ upload_to=upload_to,
645
+ save_overwrite=self.cfg.save_overwrite,
646
+ synchronize=False,
647
+ )
648
+ # Save config.
649
+ self._save_config(checkpoint_dir, upload_to=upload_to)
650
+
651
+ def restore_checkpoint(
652
+ self,
653
+ load_path: PathOrStr,
654
+ fsdp_model: FSDP,
655
+ optim: Optimizer,
656
+ *,
657
+ local_cache: Optional[PathOrStr] = None,
658
+ load_optimizer_state: bool = True,
659
+ ) -> Dict[str, Any]:
660
+ with FSDP.state_dict_type(
661
+ fsdp_model,
662
+ state_dict_type=StateDictType.FULL_STATE_DICT,
663
+ state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
664
+ optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
665
+ ):
666
+ with torch.no_grad():
667
+ # fill everything with NaN, so we can check afterwards that every parameter has been restored
668
+ for module_name, module in fsdp_model.named_modules():
669
+ if not isinstance(module, FSDP):
670
+ continue
671
+ for param in module.params:
672
+ param.fill_(torch.nan)
673
+
674
+ # restore params from checkpoint
675
+ state_dict_to_load = load_state_dict(
676
+ load_path, "model.pt", local_cache=local_cache, map_location="cpu"
677
+ )
678
+ (
679
+ state_dict_to_load,
680
+ og_keys_to_new,
681
+ ) = fsdp_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load)
682
+
683
+ for module_name, module in fsdp_model.named_modules():
684
+ if not isinstance(module, FSDP):
685
+ continue
686
+ for param in module.params:
687
+ assert param._is_flat_param
688
+ for fqn, spi in zip(param._fqns, param._shard_param_infos):
689
+ if not spi.in_shard:
690
+ continue
691
+ key = f"{module_name}.{fqn}"
692
+ key = key.replace("_fsdp_wrapped_module.", "")
693
+ key = key.lstrip(".")
694
+ t = state_dict_to_load[key]
695
+ t = t.flatten()
696
+ param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_(
697
+ t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1]
698
+ )
699
+
700
+ # make sure that every parameter has been restored
701
+ for module_name, module in fsdp_model.named_modules():
702
+ if not isinstance(module, FSDP):
703
+ continue
704
+ for param in module.params:
705
+ if torch.isnan(param).any():
706
+ raise ValueError(
707
+ f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints"
708
+ )
709
+
710
+ # Load optimizer state.
711
+ if load_optimizer_state:
712
+ optim_state_dict_to_load = load_state_dict(
713
+ load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
714
+ )
715
+ optim_state_dict_to_load = self._make_optim_state_dict_compatible(
716
+ optim_state_dict_to_load,
717
+ og_keys_to_new,
718
+ )
719
+ load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load)
720
+ del optim_state_dict_to_load
721
+
722
+ # Load other state.
723
+ try:
724
+ trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
725
+ except FileNotFoundError:
726
+ # for backwards compatibility
727
+ trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
728
+ barrier()
729
+ return trainer_state
730
+
731
+ def _make_optim_state_dict_compatible(
732
+ self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
733
+ ) -> Dict[str, Any]:
734
+ # This state dict comes in two forms: one where the state keys are integers and one where the
735
+ # keys are fully qualified parameter names. The latter case is easier to deal with here so we
736
+ # first transform the integer key form into the FQN key form.
737
+ if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
738
+ id_to_fqn: Dict[int, str] = {}
739
+ for group in optim_state_dict["param_groups"]:
740
+ new_param_names = []
741
+ for fqn, id in zip(group["param_names"], group["params"]):
742
+ fqn = fqn.replace("_fsdp_wrapped_module.", "")
743
+ id_to_fqn[id] = fqn
744
+ new_param_names.append(fqn)
745
+ group["param_names"] = new_param_names
746
+ group["params"] = new_param_names
747
+ for id in list(optim_state_dict["state"].keys()):
748
+ optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
749
+ else:
750
+ # Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix.
751
+ for group in optim_state_dict["param_groups"]:
752
+ group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
753
+ group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
754
+ assert group["param_names"] == group["params"]
755
+ for key in list(optim_state_dict["state"].keys()):
756
+ optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
757
+ "state"
758
+ ].pop(key)
759
+
760
+ # Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
761
+ # First fix param names in the state.
762
+ for og_key, new_keys in og_keys_to_new.items():
763
+ og_state = optim_state_dict["state"].pop(og_key, None)
764
+ if og_state is None:
765
+ continue
766
+ for i, new_key in enumerate(new_keys):
767
+ if i == len(new_keys) - 1:
768
+ optim_state_dict["state"][new_key] = og_state
769
+ else:
770
+ optim_state_dict["state"][new_key] = deepcopy(og_state)
771
+ # Now fix param names in the param groups.
772
+ for group in optim_state_dict["param_groups"]:
773
+ og_names = group["params"]
774
+ new_names = []
775
+ for og_key in og_names:
776
+ for new_key in og_keys_to_new[og_key]:
777
+ new_names.append(new_key)
778
+ group["params"] = new_names
779
+ group["param_names"] = new_names
780
+
781
+ return optim_state_dict
782
+
783
+ def load_checkpoint(
784
+ self,
785
+ load_path: PathOrStr,
786
+ *,
787
+ local_cache: Optional[PathOrStr] = None,
788
+ load_optimizer_state: bool = True,
789
+ device: Optional[torch.device] = None,
790
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
791
+ device = device if device is not None else torch.device("cpu")
792
+ model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device) # type: ignore
793
+ optim_state = None
794
+ if load_optimizer_state:
795
+ optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device) # type: ignore
796
+ return model_state, optim_state
797
+
798
+
799
+ class TorchNewStyleShardedCheckpointer(Checkpointer):
800
+ """
801
+ A sharded :class:`Checkpointer` that uses PyTorch's new distributed checkpointing functionality.
802
+ """
803
+
804
+ def save_checkpoint(
805
+ self,
806
+ dir: PathOrStr,
807
+ fsdp_model: FSDP,
808
+ optim: Optimizer,
809
+ trainer_state: Dict[str, Any],
810
+ *,
811
+ upload_to: Optional[str] = None,
812
+ ) -> None:
813
+ with self._temporary_wd(dir) as checkpoint_dir:
814
+ # Save model and optim state.
815
+ save_fsdp_model_and_optim_state(
816
+ checkpoint_dir,
817
+ fsdp_model,
818
+ optim,
819
+ upload_to=upload_to,
820
+ save_overwrite=self.cfg.save_overwrite,
821
+ )
822
+
823
+ # Save trainer state.
824
+ log.info("Saving trainer state...")
825
+ save_state_dict(
826
+ checkpoint_dir,
827
+ f"train/rank{get_global_rank()}.pt",
828
+ trainer_state,
829
+ upload_to=upload_to,
830
+ save_overwrite=self.cfg.save_overwrite,
831
+ )
832
+
833
+ # Save config.
834
+ self._save_config(checkpoint_dir, upload_to=upload_to)
835
+
836
+ def restore_checkpoint(
837
+ self,
838
+ load_path: PathOrStr,
839
+ fsdp_model: FSDP,
840
+ optim: Optimizer,
841
+ *,
842
+ local_cache: Optional[PathOrStr] = None,
843
+ load_optimizer_state: bool = True,
844
+ ) -> Dict[str, Any]:
845
+ # Load model and optimizer state in place.
846
+ log.info("Loading model and optimizer state...")
847
+ load_fsdp_model_and_optim_state(
848
+ load_path,
849
+ fsdp_model,
850
+ optim,
851
+ local_cache=local_cache,
852
+ load_optimizer_state=load_optimizer_state,
853
+ )
854
+
855
+ # Load trainer state dict.
856
+ log.info("Loading trainer state...")
857
+ try:
858
+ trainer_state = load_state_dict(
859
+ load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
860
+ )
861
+ except FileNotFoundError:
862
+ # Fall back to rank 0 train state.
863
+ # This can happen when we're restoring a checkpoint with a different world size.
864
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
865
+ barrier()
866
+ return trainer_state
867
+
868
+
869
+ class TorchLegacyShardedCheckpointer(Checkpointer):
870
+ """
871
+ A sharded :class:`Checkpointer` that just uses `torch.save()` with extra logic for handling FSDP model
872
+ and optim state.
873
+
874
+ The world size must be kept consistent when using this checkpointer.
875
+ """
876
+
877
+ def save_checkpoint(
878
+ self,
879
+ dir: PathOrStr,
880
+ fsdp_model: FSDP,
881
+ optim: Optimizer,
882
+ trainer_state: Dict[str, Any],
883
+ *,
884
+ upload_to: Optional[str] = None,
885
+ ) -> None:
886
+ with self._temporary_wd(dir) as checkpoint_dir:
887
+ with FSDP.state_dict_type(
888
+ fsdp_model,
889
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
890
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
891
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
892
+ ):
893
+ state_dict = {
894
+ "model": fsdp_model.state_dict(),
895
+ "optim": FSDP.optim_state_dict(fsdp_model, optim),
896
+ **trainer_state,
897
+ }
898
+ save_state_dict(
899
+ checkpoint_dir,
900
+ f"rank{get_global_rank()}.pt",
901
+ state_dict,
902
+ upload_to=upload_to,
903
+ save_overwrite=self.cfg.save_overwrite,
904
+ )
905
+
906
+ # Save config.
907
+ self._save_config(checkpoint_dir, upload_to=upload_to)
908
+
909
+ def restore_checkpoint(
910
+ self,
911
+ load_path: PathOrStr,
912
+ fsdp_model: FSDP,
913
+ optim: Optimizer,
914
+ *,
915
+ local_cache: Optional[PathOrStr] = None,
916
+ load_optimizer_state: bool = True,
917
+ ) -> Dict[str, Any]:
918
+ with FSDP.state_dict_type(
919
+ fsdp_model,
920
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
921
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
922
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
923
+ ):
924
+ # Deserialize state dict.
925
+ state_dict = load_state_dict(
926
+ load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
927
+ )
928
+
929
+ # Load model and optimizer state.
930
+ log.info("Loading model state...")
931
+ fsdp_model.load_state_dict(state_dict["model"])
932
+ del state_dict["model"]
933
+ if load_optimizer_state:
934
+ log.info("Loading optimizer state...")
935
+ load_fsdp_optim_state(fsdp_model, optim, state_dict["optim"])
936
+ del state_dict["optim"]
937
+
938
+ barrier()
939
+ return state_dict
940
+
941
+ def unshard_checkpoint(
942
+ self,
943
+ load_path: PathOrStr,
944
+ *,
945
+ local_cache: Optional[PathOrStr] = None,
946
+ load_optimizer_state: bool = True,
947
+ load_trainer_state: bool = True,
948
+ device: Optional[torch.device] = None,
949
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
950
+ assert local_cache is None, "this method currently only supports local files"
951
+ full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"})
952
+ model_state = full_state_dict.pop("model")
953
+ optim_state = full_state_dict.pop("optim")
954
+ return (
955
+ model_state,
956
+ optim_state if load_optimizer_state else None,
957
+ full_state_dict if load_trainer_state else None,
958
+ )
959
+
960
+ def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
961
+ key = tuple() if key is None else key
962
+ if isinstance(state, (list, tuple, set)):
963
+ for i, sub_state in enumerate(state):
964
+ self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
965
+ elif isinstance(state, dict):
966
+ for name in state.keys():
967
+ self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
968
+ elif isinstance(state, ShardedTensor):
969
+ self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
970
+ return
971
+ else:
972
+ return
973
+
974
+ def _get_shard_placement_and_rank_sizes(
975
+ self, shards_metadata: List[ShardMetadata], world_size: int
976
+ ) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
977
+ def shard_size(shard_md):
978
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
979
+
980
+ rank_sizes = [0 for _ in range(world_size)]
981
+ shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
982
+ for shard_md in shards_metadata:
983
+ shard_rank = cast(_remote_device, shard_md.placement).rank()
984
+ assert shard_rank is not None
985
+ if shard_rank >= world_size:
986
+ raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")
987
+
988
+ shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
989
+ rank_sizes[shard_rank] += shard_size(shard_md)
990
+
991
+ return shard_placement, rank_sizes
992
+
993
+ def _copy_sharded_tensor_to_shared_mem(
994
+ self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
995
+ ) -> Any:
996
+ shard0_md = sharded_tensor.metadata()
997
+ shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
998
+ shard0_md.shards_metadata, world_size
999
+ )
1000
+
1001
+ rank_size = rank_sizes[rank]
1002
+ assert rank_size >= 0
1003
+ if rank_size == 0:
1004
+ return
1005
+
1006
+ assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1007
+ numpy_type = np.float32
1008
+
1009
+ sharded_memory_name = "-".join(key + (str(rank),))
1010
+
1011
+ shm = shared_memory.SharedMemory(
1012
+ create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
1013
+ )
1014
+ np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1015
+
1016
+ for local_shard in sharded_tensor.local_shards():
1017
+ shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
1018
+ assert shard_rank == rank
1019
+
1020
+ src = local_shard.tensor.flatten()
1021
+ shard_offset = shard_placement[local_shard.metadata][1]
1022
+
1023
+ np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()
1024
+
1025
+ shm.close()
1026
+
1027
+ def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
1028
+ shard_number = int(shard_filepath.name[4:-3])
1029
+ log.info("Starting unsharding shard number %d to shared memory", shard_number)
1030
+
1031
+ with self._patch_sharded_tensor_load():
1032
+ shard = torch.load(shard_filepath, map_location="cpu")
1033
+ log.debug("Done loading shard number %d", shard_number)
1034
+
1035
+ self._copy_sharded_tensors_to_shared_mem(
1036
+ shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
1037
+ )
1038
+ log.info("Done unsharding shard number %d to shared memory", shard_number)
1039
+
1040
+ def _unshard_using_sharded_mem(
1041
+ self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
1042
+ ) -> Any:
1043
+ return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))
1044
+
1045
+ def _unshard_state_using_shared_mem(
1046
+ self, state: Any, world_size: int, device: torch.device, key: Tuple
1047
+ ) -> Any:
1048
+ if isinstance(state, (list, tuple, set)):
1049
+ return state.__class__(
1050
+ self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
1051
+ for i, sub_state in enumerate(state)
1052
+ )
1053
+ elif isinstance(state, dict):
1054
+ return {
1055
+ name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
1056
+ for name in state.keys()
1057
+ }
1058
+ elif isinstance(state, ShardedTensor):
1059
+ return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
1060
+ elif isinstance(state, torch.Tensor):
1061
+ return state.to(device=device)
1062
+ else:
1063
+ return state
1064
+
1065
+ def _unshard_tensor_using_shared_mem(
1066
+ self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
1067
+ ) -> torch.Tensor:
1068
+ shard0_md = sharded_tensor.metadata()
1069
+
1070
+ def shard_size(shard_md):
1071
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
1072
+
1073
+ shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
1074
+ shard0_md.shards_metadata, world_size
1075
+ )
1076
+
1077
+ assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1078
+ numpy_type = np.float32
1079
+
1080
+ out = torch.empty(
1081
+ *sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
1082
+ )
1083
+ dims = len(sharded_tensor.metadata().size)
1084
+ for shard_md, (rank, rank_offset) in shard_placement.items():
1085
+ if rank >= world_size:
1086
+ raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")
1087
+
1088
+ sharded_memory_name = "-".join(key + (str(rank),))
1089
+ shm = shared_memory.SharedMemory(name=sharded_memory_name)
1090
+
1091
+ rank_size = rank_sizes[rank]
1092
+ assert rank_size >= 0
1093
+ if rank_size == 0:
1094
+ continue
1095
+
1096
+ np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1097
+
1098
+ tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
1099
+ tensor = tensor.view(shard_md.shard_sizes)
1100
+
1101
+ out_narrow_view = out
1102
+ for dim in range(dims):
1103
+ out_narrow_view = out_narrow_view.narrow(
1104
+ dim,
1105
+ shard_md.shard_offsets[dim],
1106
+ shard_md.shard_sizes[dim],
1107
+ )
1108
+
1109
+ out_narrow_view.copy_(tensor)
1110
+
1111
+ shm.close()
1112
+ shm.unlink()
1113
+
1114
+ return out
1115
+
1116
+ @contextmanager
1117
+ def _patch_sharded_tensor_load(self):
1118
+ """
1119
+ Monkeypatch for torch's ShardedTensor, so we can unpickle without having torch.distributed set up.
1120
+ """
1121
+
1122
+ def _rebuild_from_type_v2_monkey(func, new_type, args, state):
1123
+ ret = func(*args)
1124
+ if type(ret) is not new_type:
1125
+ ret = ret.as_subclass(new_type)
1126
+
1127
+ # Shortcut the construction of ShardedTensor
1128
+ # This is in the top 5 of my worst hacks.
1129
+ if isinstance(ret, ShardedTensor):
1130
+ ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state
1131
+ return ret
1132
+
1133
+ # The rest of this function ought to be in the top 5 of somebody else's worst hacks.
1134
+ # Tensor does define __setstate__ even though it doesn't define
1135
+ # __getstate__. So only use __setstate__ if it is NOT the one defined
1136
+ # on Tensor
1137
+ if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__:
1138
+ ret.__setstate__(state)
1139
+ else:
1140
+ ret = torch._utils._set_obj_state(ret, state)
1141
+ return ret
1142
+
1143
+ original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2
1144
+ try:
1145
+ torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey
1146
+ yield
1147
+ finally:
1148
+ torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2
1149
+
1150
+ def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
1151
+ """
1152
+ The current unsharding implementation consists of:
1153
+
1154
+ 1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
1155
+ 2. Loading 1 shard on the main process as a base unsharded object.
1156
+ 3. Using the sharded tensors in shared memory to populate the base unsharded object.
1157
+
1158
+ This implementation replaced a prior implementation that instead loaded
1159
+ all shards using threads, because that implementation turned out to
1160
+ be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
1161
+ The current implementation is slower than the old one in many scenarios,
1162
+ but is significantly faster in the above mentioned case (e.g. 30 minutes)
1163
+ if there are enough CPUs.
1164
+ """
1165
+
1166
+ input_dir = Path(input_dir)
1167
+ skip_keys = skip_keys or set()
1168
+
1169
+ shard_filepaths = list(input_dir.glob("rank*.pt"))
1170
+ world_size = len(shard_filepaths)
1171
+ if world_size == 0:
1172
+ raise RuntimeError("No shards found for unsharding")
1173
+
1174
+ log.info("Number of shards: %d", world_size)
1175
+ shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
1176
+ min_ram_required_estimate_gb = shard_size_gb * world_size
1177
+ log.info(
1178
+ "Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
1179
+ )
1180
+
1181
+ log.info("Copying sharded tensors to shared memory using multiple processes")
1182
+ # Copy sharded data to shared memory using multiple processes, so this process can load
1183
+ # from memory rather than disk. We spawn a new process instead of forking since shared memory
1184
+ # appears to get deleted when forked processes end for some reason.
1185
+ executor = ProcessPoolExecutor(
1186
+ mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
1187
+ )
1188
+ futures = []
1189
+ for shard_filepath in shard_filepaths:
1190
+ shard_rank = int(shard_filepath.name[4:-3])
1191
+
1192
+ if shard_rank >= world_size:
1193
+ raise RuntimeError(
1194
+ f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
1195
+ )
1196
+
1197
+ futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))
1198
+
1199
+ for f in as_completed(futures):
1200
+ f.result()
1201
+ executor.shutdown()
1202
+
1203
+ log.info("Loading a shard on the main process to be unsharded state")
1204
+ with self._patch_sharded_tensor_load():
1205
+ state = torch.load(shard_filepaths[0], map_location="cpu")
1206
+
1207
+ for key in skip_keys:
1208
+ if key in state:
1209
+ del state[key]
1210
+
1211
+ log.info("Unsharding from %d shards ...", world_size)
1212
+ return self._unshard_using_sharded_mem(state, world_size, device, input_dir)
1213
+
1214
+
1215
+ @dataclass
1216
+ class _LocalShardedCheckpointerMetadata(BaseConfig):
1217
+ world_size: int = field(default_factory=get_world_size)
1218
+
1219
+
1220
+ @dataclass
1221
+ class _FlatParamShard:
1222
+ full_shape: torch.Size
1223
+ shard_offsets: Tuple[int, int]
1224
+ shard_data: Optional[torch.Tensor]
1225
+
1226
+ def copy_into(self, full_tensor: torch.Tensor) -> None:
1227
+ assert self.shard_data is not None
1228
+ full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1]
1229
+ assert self.shard_data.shape == full_tensor_shard_view.shape
1230
+ full_tensor_shard_view.copy_(self.shard_data)
1231
+
1232
+
1233
+ class LocalShardedCheckpointer(Checkpointer):
1234
+ """
1235
+ A sharded :class:`Checkpointer` that directly saves the local FSDP flat params data.
1236
+ The optimizer state is saved directly with `torch.save()` without reformatting via FSDP methods.
1237
+
1238
+ The world size must be kept consistent when using this checkpointer. However, you can easily
1239
+ reconstruct a full unsharded model and/or optimizer state dictionary from a single Python process
1240
+ using :meth:`unshard_checkpoint()` (no distributed initialization required).
1241
+ """
1242
+
1243
+ # These correspond to metadata attributes on `torch.distributed.fsdp.flat_param.FlatParameter`.
1244
+ _FLAT_PARAM_METADATA_TO_SAVE = (
1245
+ "_fqns",
1246
+ "_shard_param_offsets",
1247
+ "_shard_indices",
1248
+ "_numels",
1249
+ "_numels_with_padding",
1250
+ "_shapes",
1251
+ "_shard_numel_padded",
1252
+ "_shard_param_infos",
1253
+ )
1254
+
1255
+ def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]:
1256
+ """
1257
+ Returns a list of FSDP modules with their FQN.
1258
+ """
1259
+ modules = []
1260
+ for name, module in fsdp_model.named_modules():
1261
+ if isinstance(module, FSDP):
1262
+ modules.append((name, module))
1263
+ return modules
1264
+
1265
+ def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None:
1266
+ from torch.distributed.fsdp._runtime_utils import _lazy_init
1267
+
1268
+ # TODO (epwalsh): I'm not sure if this is necessary, but this is what PyTorch does before saving/loading
1269
+ # an FSDP state dict through the built-in methods.
1270
+ if torch.cuda.is_available():
1271
+ torch.cuda.synchronize()
1272
+ _lazy_init(fsdp_model, fsdp_model)
1273
+
1274
+ def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
1275
+ if version.parse(torch.__version__) < version.parse("2.1.0"):
1276
+ return fsdp_model._handles # type: ignore
1277
+ elif version.parse(torch.__version__) < version.parse("2.3.0"):
1278
+ # Handle could be None if the FSDP wrapper doesn't manage any parameters.
1279
+ if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
1280
+ return [fsdp_model._handle] # type: ignore
1281
+ else:
1282
+ return []
1283
+ else:
1284
+ # Need to verify FSDP internals with newer versions.
1285
+ raise NotImplementedError
1286
+
1287
+ @torch.no_grad()
1288
+ def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]:
1289
+ self._prepare_fsdp_model(fsdp_model)
1290
+ module_data = []
1291
+ for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model):
1292
+ handle_data = []
1293
+ for handle in self._fsdp_handles(fsdp_module):
1294
+ data: Dict[str, Any] = {}
1295
+ # This is a `FlatParameter` instance.
1296
+ # See `torch.distributed.fsdp.flat_param` for the API.
1297
+ flat_param = handle.flat_param
1298
+ data["flat_param.data"] = flat_param.detach()
1299
+ for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1300
+ if hasattr(flat_param, key):
1301
+ data[f"flat_param.{key}"] = getattr(flat_param, key)
1302
+ handle_data.append(data)
1303
+ module_data.append({"handles": handle_data, "name": module_fqn})
1304
+ return {"modules": module_data}
1305
+
1306
+ @torch.no_grad()
1307
+ def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]):
1308
+ """Load the state produced from `self._get_flat_param_state_to_save()`."""
1309
+ self._prepare_fsdp_model(fsdp_model)
1310
+ fsdp_modules = self._fsdp_modules(fsdp_model)
1311
+ assert len(model_state["modules"]) == len(fsdp_modules)
1312
+ for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]):
1313
+ handles = self._fsdp_handles(fsdp_module)
1314
+ assert len(handles) == len(module_data["handles"])
1315
+ for handle, data in zip(handles, module_data["handles"]):
1316
+ flat_param = handle.flat_param
1317
+ # Make sure metadata matches.
1318
+ for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1319
+ if hasattr(flat_param, key):
1320
+ assert getattr(flat_param, key) == data[f"flat_param.{key}"]
1321
+ # Load the flat sharded data.
1322
+ flat_param.copy_(data["flat_param.data"])
1323
+
1324
+ def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
1325
+ if get_fs_local_rank() == 0:
1326
+ log.info("Saving metadata...")
1327
+ metadata = _LocalShardedCheckpointerMetadata()
1328
+ metadata.save(metadata_path := Path(dir) / "metadata.yaml")
1329
+ if upload_to is not None and get_global_rank() == 0:
1330
+ upload_target = f"{upload_to}/metadata.yaml"
1331
+ log.info(f"Uploading {metadata_path} to {upload_target}")
1332
+ upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite)
1333
+
1334
+ def _load_metadata(
1335
+ self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None
1336
+ ) -> _LocalShardedCheckpointerMetadata:
1337
+ metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache)
1338
+ return _LocalShardedCheckpointerMetadata.load(metadata_path)
1339
+
1340
+ def save_checkpoint(
1341
+ self,
1342
+ dir: PathOrStr,
1343
+ fsdp_model: FSDP,
1344
+ optim: Optimizer,
1345
+ trainer_state: Dict[str, Any],
1346
+ *,
1347
+ upload_to: Optional[str] = None,
1348
+ ) -> None:
1349
+ with self._temporary_wd(dir) as checkpoint_dir:
1350
+ # Gather local FSDP flat params data to save.
1351
+ # We also save some flat param metadata like the corresponding fully qualified names (fqns)
1352
+ # of each original parameter so we can validate that the sharding is the same when loading
1353
+ # one of these checkpoints.
1354
+ log.info("Saving local FSDP flat params data...")
1355
+ save_state_dict(
1356
+ checkpoint_dir,
1357
+ f"model/rank{get_global_rank()}.pt",
1358
+ self._get_flat_param_state_to_save(fsdp_model),
1359
+ upload_to=upload_to,
1360
+ save_overwrite=self.cfg.save_overwrite,
1361
+ )
1362
+
1363
+ # Save optimizer state.
1364
+ log.info("Saving local optimizer state...")
1365
+ save_state_dict(
1366
+ checkpoint_dir,
1367
+ f"optim/rank{get_global_rank()}.pt",
1368
+ optim.state_dict(),
1369
+ upload_to=upload_to,
1370
+ save_overwrite=self.cfg.save_overwrite,
1371
+ )
1372
+
1373
+ # Save trainer state.
1374
+ log.info("Saving trainer state...")
1375
+ save_state_dict(
1376
+ checkpoint_dir,
1377
+ f"train/rank{get_global_rank()}.pt",
1378
+ trainer_state,
1379
+ upload_to=upload_to,
1380
+ save_overwrite=self.cfg.save_overwrite,
1381
+ )
1382
+
1383
+ # Save metadata.
1384
+ self._save_metadata(checkpoint_dir, upload_to=upload_to)
1385
+
1386
+ # Save config. We do this last b/c the presence of a config in a remote checkpoint
1387
+ # "directory" indicates that the folder is valid, as a opposed to a partially
1388
+ # uploaded checkpoint directory that failed before completing.
1389
+ self._save_config(checkpoint_dir, upload_to=upload_to)
1390
+
1391
+ def restore_checkpoint(
1392
+ self,
1393
+ load_path: PathOrStr,
1394
+ fsdp_model: FSDP,
1395
+ optim: Optimizer,
1396
+ *,
1397
+ local_cache: Optional[PathOrStr] = None,
1398
+ load_optimizer_state: bool = True,
1399
+ ) -> Dict[str, Any]:
1400
+ # Load metadata and make sure checkpoint is compatible.
1401
+ metadata = self._load_metadata(load_path, local_cache=local_cache)
1402
+ assert metadata.world_size == get_world_size()
1403
+
1404
+ # Load local FSDP flat param data.
1405
+ log.info("Loading local FSDP flat params data...")
1406
+ model_state = load_state_dict(
1407
+ load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1408
+ )
1409
+ self._load_flat_param_state(fsdp_model, model_state)
1410
+ del model_state
1411
+
1412
+ # Load local optim state.
1413
+ if load_optimizer_state:
1414
+ log.info("Loading local optimizer state...")
1415
+ optim_state = load_state_dict(
1416
+ load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1417
+ )
1418
+ # HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param
1419
+ # in every rank, and keep this in the optimizer state. But this causes issues when loading the
1420
+ # state since torch sees the state is non-empty for some params which would normally be empty,
1421
+ # and then assumes it should have all of the other state tensors for that param, which is doesn't.
1422
+ # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
1423
+ # Not the end of the world but there's probably a better way around this without resetting
1424
+ # the metric.
1425
+ for param_id in list(optim_state["state"].keys()):
1426
+ state = optim_state["state"][param_id]
1427
+ if "grad_norm_exp_avg" in state:
1428
+ del state["grad_norm_exp_avg"]
1429
+ if len(state) == 0:
1430
+ del optim_state["state"][param_id]
1431
+ optim.load_state_dict(optim_state)
1432
+ del optim_state
1433
+
1434
+ # Load local trainer state.
1435
+ log.info("Loading local trainer state...")
1436
+ trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
1437
+ barrier()
1438
+ return trainer_state
1439
+
1440
+ def _iter_flat_param_shards(
1441
+ self, model_state: Dict[str, Any]
1442
+ ) -> Generator[Tuple[str, _FlatParamShard], None, None]:
1443
+ for module_data in model_state["modules"]:
1444
+ module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "")
1445
+ for handle in module_data["handles"]:
1446
+ flat_data = handle["flat_param.data"]
1447
+ if (num_padding := handle["flat_param._shard_numel_padded"]) > 0:
1448
+ # If there's padding in the flat param it should be on the right.
1449
+ assert (flat_data[-num_padding:] == 0).all()
1450
+ # NOTE: this changes depending on the torch version, but we don't do a version
1451
+ # check since we might be trying to unshard an old checkpoint that was stored
1452
+ # with a different torch version than we're currently running with.
1453
+ if "flat_param._shard_indices" in handle:
1454
+ # torch <=2.0.1
1455
+ param_start = handle["flat_param._shard_indices"][0]
1456
+ current_flat_index = 0
1457
+ for relative_fqn, full_shape, (offset_start, offset_end) in zip(
1458
+ handle["flat_param._fqns"][param_start:],
1459
+ handle["flat_param._shapes"][param_start:],
1460
+ handle["flat_param._shard_param_offsets"],
1461
+ ):
1462
+ root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1463
+ numel_shard = offset_end - offset_start + 1
1464
+ flat_param_shard = _FlatParamShard(
1465
+ full_shape=full_shape,
1466
+ shard_offsets=(offset_start, offset_end),
1467
+ shard_data=flat_data[current_flat_index : current_flat_index + numel_shard],
1468
+ )
1469
+ current_flat_index += numel_shard
1470
+ yield root_fqn, flat_param_shard
1471
+ else:
1472
+ # torch >=2.1.0
1473
+ for relative_fqn, full_shape, shard_param_info in zip(
1474
+ handle["flat_param._fqns"],
1475
+ handle["flat_param._shapes"],
1476
+ handle["flat_param._shard_param_infos"],
1477
+ ):
1478
+ if not shard_param_info.in_shard:
1479
+ continue
1480
+ root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1481
+ flat_param_shard = _FlatParamShard(
1482
+ full_shape=full_shape,
1483
+ shard_offsets=(
1484
+ shard_param_info.intra_param_start_idx,
1485
+ shard_param_info.intra_param_end_idx,
1486
+ ),
1487
+ shard_data=flat_data[
1488
+ shard_param_info.offset_in_shard : shard_param_info.offset_in_shard
1489
+ + shard_param_info.numel_in_shard
1490
+ ],
1491
+ )
1492
+ yield root_fqn, flat_param_shard
1493
+
1494
+ def unshard_checkpoint(
1495
+ self,
1496
+ load_path: PathOrStr,
1497
+ *,
1498
+ local_cache: Optional[PathOrStr] = None,
1499
+ load_optimizer_state: bool = True,
1500
+ load_trainer_state: bool = True,
1501
+ device: Optional[torch.device] = None,
1502
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
1503
+ device = device or torch.device("cpu")
1504
+ metadata = self._load_metadata(load_path, local_cache=local_cache)
1505
+
1506
+ # Gather paths model state, potentially downloading them.
1507
+ log.info("Gathering model state dicts...")
1508
+ model_state_paths = self._gather_state_dict_paths(
1509
+ load_path, "model", metadata.world_size, local_cache=local_cache
1510
+ )
1511
+
1512
+ # Load model state dicts one-by-one, materializing and populating the full parameters as we go.
1513
+ log.info("Materializing full parameters...")
1514
+ full_model_state: Dict[str, torch.Tensor] = {}
1515
+ # We keep a copy of the flat param metadata minus the actual tensors so we can reconstruct
1516
+ # the full optimizer state below without having to reload the model state dicts.
1517
+ flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict)
1518
+ for rank, path in enumerate(model_state_paths):
1519
+ log.info(f"Loading shards from rank {rank}...")
1520
+ model_state = torch.load(path, map_location="cpu")
1521
+ for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state):
1522
+ if root_fqn not in full_model_state:
1523
+ log.info(
1524
+ f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..."
1525
+ )
1526
+ assert flat_param_shard.shard_data is not None
1527
+ full_model_state[root_fqn] = torch.empty(
1528
+ flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device
1529
+ )
1530
+ # Fill with NaNs so we can validate that the whole parameter has been populated
1531
+ # afterwards.
1532
+ full_model_state[root_fqn].fill_(torch.nan)
1533
+ # Copy over the local shard to the relevant part of the full parameter.
1534
+ full_param = full_model_state[root_fqn]
1535
+ log.info(f"Loading rank {rank} shard for '{root_fqn}'...")
1536
+ flat_param_shard.copy_into(full_param)
1537
+ flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None)
1538
+
1539
+ log.info("Validating full parameters...")
1540
+ for key, tensor in full_model_state.items():
1541
+ if torch.isnan(tensor).any():
1542
+ raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder")
1543
+
1544
+ trainer_state: Optional[Dict[str, Any]] = None
1545
+ if load_trainer_state:
1546
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
1547
+
1548
+ if not load_optimizer_state:
1549
+ return full_model_state, None, trainer_state
1550
+
1551
+ log.info("Gathering optim state dicts...")
1552
+ optim_state_paths = self._gather_state_dict_paths(
1553
+ load_path, "optim", metadata.world_size, local_cache=local_cache
1554
+ )
1555
+
1556
+ log.info("Materializing full optim state...")
1557
+ full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
1558
+ fqn_to_id: Dict[str, int] = {}
1559
+ id_to_fqn: Dict[int, str] = {}
1560
+ for rank, path in enumerate(optim_state_paths):
1561
+ log.info(f"Loading sharded optim state from rank {rank}...")
1562
+ optim_state = torch.load(path, map_location="cpu")
1563
+
1564
+ # Initialize param groups.
1565
+ # We assume parameter groups are the same across all ranks.
1566
+ # The only thing that differs across ranks is the state for each local sharded param.
1567
+ if "param_groups" not in full_optim_state:
1568
+ full_optim_state["param_groups"] = optim_state["param_groups"]
1569
+ else:
1570
+ assert full_optim_state["param_groups"] == optim_state["param_groups"]
1571
+
1572
+ # Generate mapping of parameter FQNs to optimizer param IDs and vice-versa.
1573
+ if not fqn_to_id or not id_to_fqn:
1574
+ for group in full_optim_state["param_groups"]:
1575
+ for fqn, id in zip(group["param_names"], group["params"]):
1576
+ fqn = fqn.replace("_fsdp_wrapped_module.", "")
1577
+ fqn_to_id[fqn] = id
1578
+ id_to_fqn[id] = fqn
1579
+
1580
+ # Iterate over local shard state and copy into the full state.
1581
+ for id, shard_state in optim_state["state"].items():
1582
+ fqn = id_to_fqn[id]
1583
+ flat_param_shard = flat_params_data[rank].get(fqn) # type: ignore[assignment]
1584
+ full_state = full_optim_state["state"][id]
1585
+ for key, shard_value in shard_state.items():
1586
+ assert isinstance(shard_value, torch.Tensor)
1587
+ if shard_value.shape == torch.Size([]):
1588
+ # Add singleton tensors directly to full state. These should be the same across
1589
+ # all ranks.
1590
+ assert key in ("step", "grad_norm_exp_avg") # sanity check
1591
+ if key not in full_state:
1592
+ full_state[key] = shard_value.to(device)
1593
+ else:
1594
+ assert full_state[key] == shard_value
1595
+ else:
1596
+ # Otherwise we have a sharded param state.
1597
+ # If the corresponding full param state hasn't been materialized yet, do so now.
1598
+ assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}"
1599
+ if key not in full_state:
1600
+ log.info(
1601
+ f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..."
1602
+ )
1603
+ full_state[key] = torch.empty(
1604
+ flat_param_shard.full_shape, dtype=shard_value.dtype, device=device
1605
+ )
1606
+ full_state_value = full_state[key]
1607
+
1608
+ # Copy over the local shard state to the relevant part of the full parameter state.
1609
+ log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...")
1610
+ replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value)
1611
+
1612
+ # Lastly, clean up the parameter names in param groups.
1613
+ for group in full_optim_state["param_groups"]:
1614
+ group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]]
1615
+
1616
+ return full_model_state, full_optim_state, trainer_state
1617
+
1618
+ def _get_state_dict_path(
1619
+ self,
1620
+ load_path: PathOrStr,
1621
+ state_dict_type: str,
1622
+ rank: int,
1623
+ *,
1624
+ local_cache: Optional[PathOrStr] = None,
1625
+ progress=None,
1626
+ ) -> Tuple[int, Path]:
1627
+ fname = f"{state_dict_type}/rank{rank}.pt"
1628
+ return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress)
1629
+
1630
+ def _gather_state_dict_paths(
1631
+ self,
1632
+ load_path: PathOrStr,
1633
+ state_dict_type: str,
1634
+ world_size: int,
1635
+ *,
1636
+ local_cache: Optional[PathOrStr] = None,
1637
+ ) -> List[Path]:
1638
+ progress = get_progress_bar()
1639
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
1640
+ futures = []
1641
+ for rank in range(world_size):
1642
+ future = executor.submit(
1643
+ self._get_state_dict_path,
1644
+ load_path,
1645
+ state_dict_type,
1646
+ rank,
1647
+ local_cache=local_cache,
1648
+ progress=progress,
1649
+ )
1650
+ futures.append(future)
1651
+
1652
+ results: Dict[int, Path] = {}
1653
+ for future in as_completed(futures):
1654
+ rank, path = future.result()
1655
+ results[rank] = path
1656
+
1657
+ return [results[rank] for rank in range(world_size)]
1658
+
1659
+
1660
+ def build_sharded_checkpointer(
1661
+ cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None
1662
+ ) -> Checkpointer:
1663
+ name = name or cfg.sharded_checkpointer
1664
+ if name == ShardedCheckpointerType.torch_new:
1665
+ return TorchNewStyleShardedCheckpointer(cfg)
1666
+ elif name == ShardedCheckpointerType.torch_legacy:
1667
+ return TorchLegacyShardedCheckpointer(cfg)
1668
+ elif name == ShardedCheckpointerType.local:
1669
+ return LocalShardedCheckpointer(cfg)
1670
+ else:
1671
+ raise NotImplementedError(name)
OLMo_Bitnet_1B/config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_type": "swiglu",
3
+ "alibi": false,
4
+ "alibi_bias_max": 8.0,
5
+ "architectures": [
6
+ "OLMoModelForCausalLM"
7
+ ],
8
+ "attention_dropout": 0.0,
9
+ "attention_layer_norm": false,
10
+ "attention_layer_norm_with_affine": false,
11
+ "bias_for_layer_norm": false,
12
+ "block_group_size": 1,
13
+ "block_type": "sequential",
14
+ "clip_qkv": null,
15
+ "d_model": 2048,
16
+ "embedding_dropout": 0.0,
17
+ "embedding_size": 50304,
18
+ "eos_token_id": 50279,
19
+ "flash_attention": true,
20
+ "include_bias": false,
21
+ "init_cutoff_factor": null,
22
+ "init_device": "cpu",
23
+ "init_fn": "mitchell",
24
+ "init_std": 0.02,
25
+ "layer_norm_type": "rms",
26
+ "layer_norm_with_affine": true,
27
+ "max_sequence_length": 2048,
28
+ "mlp_hidden_size": null,
29
+ "mlp_ratio": 8,
30
+ "model_type": "olmo",
31
+ "multi_query_attention": false,
32
+ "n_heads": 16,
33
+ "n_layers": 16,
34
+ "pad_token_id": 1,
35
+ "precision": "amp_bf16",
36
+ "residual_dropout": 0.0,
37
+ "rope": true,
38
+ "rope_full_precision": true,
39
+ "scale_logits": false,
40
+ "ternary": true,
41
+ "transformers_version": "4.38.2",
42
+ "use_cache": true,
43
+ "vocab_size": 50280,
44
+ "inference_mode":false,
45
+ "weight_tying": true,
46
+ "auto_map": {
47
+ "AutoConfig": "configuration_olmo.OLMoConfig",
48
+ "AutoModelForCausalLM": "modeling_olmo.OLMoForCausalLM"
49
+ }
50
+ }
OLMo_Bitnet_1B/config.py ADDED
@@ -0,0 +1,1106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from glob import glob
5
+ from pathlib import Path
6
+ from typing import (
7
+ Any,
8
+ Dict,
9
+ Iterable,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Type,
14
+ TypeVar,
15
+ Union,
16
+ cast,
17
+ )
18
+
19
+ import torch
20
+ from omegaconf import DictConfig, ListConfig
21
+ from omegaconf import OmegaConf as om
22
+ from omegaconf.errors import OmegaConfBaseException
23
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
24
+
25
+ from .aliases import PathOrStr
26
+ from .beam_search import Sampler
27
+ from .exceptions import OLMoConfigurationError
28
+ from .util import StrEnum
29
+
30
+ __all__ = [
31
+ "ActivationType",
32
+ "ActivationCheckpointingStrategy",
33
+ "BlockType",
34
+ "LayerNormType",
35
+ "InitFnType",
36
+ "ModelConfig",
37
+ "OptimizerType",
38
+ "OptimizerConfig",
39
+ "SchedulerType",
40
+ "SchedulerConfig",
41
+ "DataConfig",
42
+ "EvaluatorConfig",
43
+ "TokenizerConfig",
44
+ "TrainConfig",
45
+ "PaddingDirection",
46
+ "TruncationDirection",
47
+ "SpeedMonitorConfig",
48
+ "WandbConfig",
49
+ "CompilerConfig",
50
+ "WandbConfig",
51
+ "FSDPPrecision",
52
+ "FSDPWrapStrategy",
53
+ "FSDPConfig",
54
+ "CheckpointType",
55
+ ]
56
+
57
+ C = TypeVar("C", bound="BaseConfig")
58
+ D = TypeVar("D", bound="DictConfig|ListConfig")
59
+
60
+
61
+ class BaseConfig:
62
+ @classmethod
63
+ def _register_resolvers(cls, validate_paths: bool = True):
64
+ # Expands path globs into a list.
65
+ def path_glob(*paths) -> List[str]:
66
+ out = []
67
+ for path in paths:
68
+ matches = sorted(glob(path))
69
+ if not matches and validate_paths:
70
+ raise FileNotFoundError(f"{path} does not match any files or dirs")
71
+ out.extend(matches)
72
+ return out
73
+
74
+ # Chooses the first path in the arguments that exists.
75
+ def path_choose(*paths) -> str:
76
+ from .util import is_url
77
+
78
+ for path in paths:
79
+ if is_url(path) or Path(path).exists():
80
+ return path
81
+ if validate_paths:
82
+ raise FileNotFoundError(", ".join(paths))
83
+ else:
84
+ return ""
85
+
86
+ # Finds the latest checkpoint in a folder.
87
+ def path_last_checkpoint(path) -> str:
88
+ from .util import find_latest_checkpoint
89
+
90
+ latest_checkpoint = find_latest_checkpoint(path)
91
+ if latest_checkpoint is None:
92
+ if validate_paths:
93
+ raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
94
+ else:
95
+ return ""
96
+ else:
97
+ return str(latest_checkpoint)
98
+
99
+ om.register_new_resolver("path.glob", path_glob, replace=True)
100
+ om.register_new_resolver("path.choose", path_choose, replace=True)
101
+ om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
102
+
103
+ @classmethod
104
+ def update_legacy_settings(cls, config: D) -> D:
105
+ """
106
+ Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
107
+ """
108
+ return config
109
+
110
+ @classmethod
111
+ def new(cls: Type[C], **kwargs) -> C:
112
+ cls._register_resolvers()
113
+ conf = om.structured(cls)
114
+ try:
115
+ if kwargs:
116
+ conf = om.merge(conf, kwargs)
117
+ return cast(C, om.to_object(conf))
118
+ except OmegaConfBaseException as e:
119
+ raise OLMoConfigurationError(str(e))
120
+
121
+ @classmethod
122
+ def load(
123
+ cls: Type[C],
124
+ path: PathOrStr,
125
+ overrides: Optional[List[str]] = None,
126
+ key: Optional[str] = None,
127
+ validate_paths: bool = True,
128
+ ) -> C:
129
+ """Load from a YAML file."""
130
+ cls._register_resolvers(validate_paths=validate_paths)
131
+ schema = om.structured(cls)
132
+ try:
133
+ raw = om.load(str(path))
134
+ if key is not None:
135
+ raw = raw[key] # type: ignore
136
+ raw = cls.update_legacy_settings(raw)
137
+ conf = om.merge(schema, raw)
138
+ if overrides:
139
+ conf = om.merge(conf, om.from_dotlist(overrides))
140
+ return cast(C, om.to_object(conf))
141
+ except OmegaConfBaseException as e:
142
+ raise OLMoConfigurationError(str(e))
143
+
144
+ def save(self, path: PathOrStr) -> None:
145
+ """Save to a YAML file."""
146
+ om.save(config=self, f=str(path))
147
+
148
+ def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
149
+ out = asdict(self) # type: ignore
150
+ if exclude is not None:
151
+ for name in exclude:
152
+ if name in out:
153
+ del out[name]
154
+ return out
155
+
156
+
157
+ class LayerNormType(StrEnum):
158
+ default = "default"
159
+ """
160
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
161
+ """
162
+
163
+ low_precision = "low_precision"
164
+ """
165
+ A low-precision version of the default LayerNorm.
166
+ """
167
+
168
+ rms = "rms"
169
+ """
170
+ An RMSNorm implementation. When using ``torch.compile`` this is
171
+ probably the fastest implementation.
172
+ """
173
+
174
+
175
+ class ActivationType(StrEnum):
176
+ gelu = "gelu"
177
+ relu = "relu"
178
+ swiglu = "swiglu"
179
+
180
+
181
+ class BlockType(StrEnum):
182
+ sequential = "sequential"
183
+
184
+ llama = "llama"
185
+ """
186
+ A block similar to the sequential block with slightly different
187
+ implementations of operations like attention to imitate the behavior of Llama.
188
+ """
189
+
190
+
191
+ class InitFnType(StrEnum):
192
+ mitchell = "mitchell"
193
+ """
194
+ The strategy suggested to us by Mitchell Wortsman from UW.
195
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
196
+ on the size of the weights as well as the depth of the layer.
197
+ """
198
+
199
+ normal = "normal"
200
+ """
201
+ All weights are initialized from the same normal distribution.
202
+ """
203
+
204
+ kaiming_normal = "kaiming_normal"
205
+ """
206
+ All weights are initialized with the Kaiming method from a normal distribution.
207
+ Note this currently won't work with FSDP.
208
+ """
209
+
210
+ fan_in = "fan_in"
211
+ """
212
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
213
+ is the input dimensionality of the kernel.
214
+ """
215
+
216
+ full_megatron = "full_megatron"
217
+ """
218
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
219
+ """
220
+
221
+
222
+ @dataclass
223
+ class ModelConfig(BaseConfig):
224
+ """
225
+ OLMo (model) configuration.
226
+ """
227
+
228
+ # Note that the defaults for these attributes are equivalent to the base GPT2 model.
229
+
230
+ d_model: int = 768
231
+ """
232
+ The hidden size of the model.
233
+ """
234
+
235
+ n_heads: int = 12
236
+ """
237
+ The number of self-attention heads.
238
+ """
239
+
240
+ n_kv_heads: Optional[int] = None
241
+ """
242
+ The number of heads to use for keys and values. Defaults to `n_heads`.
243
+ Set this to ``None`` or ``n_heads`` for normal multi-head attention.
244
+ Set this to 1 for multi-query attention.
245
+ Set it to some in-between value for Llama2-style grouped query attention.
246
+ """
247
+
248
+ clip_qkv: Optional[float] = None
249
+ """
250
+ Clip QKV to this value when set.
251
+ """
252
+
253
+ n_layers: int = 12
254
+ """
255
+ The number of layers/blocks.
256
+ """
257
+
258
+ mlp_ratio: int = 4
259
+ """
260
+ The ratio of the inner MLP dimensionality to ``d_model``.
261
+ This is only used when ``mlp_hidden_size`` is not set.
262
+ """
263
+
264
+ mlp_hidden_size: Optional[int] = None
265
+ """
266
+ Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
267
+ """
268
+
269
+ activation_type: ActivationType = ActivationType.swiglu
270
+ """
271
+ The activation function to use within the MLP layers.
272
+ """
273
+
274
+ block_type: BlockType = BlockType.sequential
275
+ """
276
+ The transformer block implementation.
277
+ """
278
+
279
+ block_group_size: int = 1
280
+ """
281
+ The number of blocks to group together into a single parent block.
282
+ This has no affect on the number of parameters in the model and is only used to wrap groups
283
+ of blocks together with a single FSDP wrapper during training.
284
+ """
285
+
286
+ alibi: bool = False
287
+ """
288
+ If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
289
+ """
290
+
291
+ alibi_bias_max: float = 8.0
292
+ """
293
+ Maximum absolute value of ALiBi bias.
294
+ """
295
+
296
+ rope: bool = False
297
+ """
298
+ Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
299
+ """
300
+
301
+ rope_full_precision: bool = True
302
+ """
303
+ If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
304
+ apply RoPE at the precision of the input.
305
+ """
306
+
307
+ flash_attention: bool = False
308
+ """
309
+ If ``True``, use ``FlashAttention``.
310
+ """
311
+
312
+ attention_dropout: float = 0.1
313
+ """
314
+ The dropout probability within the attention modules.
315
+ """
316
+
317
+ multi_query_attention: Optional[bool] = None
318
+ """
319
+ Deprecated. Use n_kv_heads instead.
320
+ """
321
+
322
+ attention_layer_norm: bool = False
323
+ """
324
+ Apply layer norm to the keys and queries within the attention mechanism.
325
+ This can help stabilize training.
326
+ """
327
+
328
+ residual_dropout: float = 0.1
329
+ """
330
+ The dropout probability for the MLP and attention output within each block.
331
+ """
332
+
333
+ embedding_dropout: float = 0.1
334
+ """
335
+ The dropout probability for embeddings.
336
+ """
337
+
338
+ layer_norm_type: LayerNormType = LayerNormType.default
339
+ """
340
+ The layernorm implementation to use.
341
+ """
342
+
343
+ layer_norm_with_affine: bool = True
344
+ """
345
+ Whether to include bias and weight parameters for the layer norms.
346
+ This only affects layer norms that are immediately followed by a linear layer in the forward pass,
347
+ so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
348
+ to ``False``.
349
+ """
350
+
351
+ attention_layer_norm_with_affine: bool = True
352
+ """
353
+ Toggle affine transform for the QK norms.
354
+ """
355
+
356
+ max_sequence_length: int = 1024
357
+ """
358
+ The maximum input sequence length supported by the model.
359
+ """
360
+
361
+ include_bias: bool = True
362
+ """
363
+ Whether or not to include bias parameters in linear layers.
364
+ In PaLM, they got rid of all bias terms because they found that large
365
+ models tend to have near 0 bias terms anyway.
366
+ """
367
+
368
+ bias_for_layer_norm: Optional[bool] = None
369
+ """
370
+ Whether or not to include bias parameters in layer norm.
371
+ This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
372
+ layer norm.
373
+ When this is None (the default), it inherits the setting from include_bias.
374
+ """
375
+
376
+ scale_logits: bool = False
377
+ """
378
+ If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
379
+ """
380
+
381
+ vocab_size: int = 50257
382
+ """
383
+ Vocabulary size of the model.
384
+ """
385
+
386
+ embedding_size: Optional[int] = 50304
387
+ """
388
+ The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
389
+ to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
390
+ next multiple of 128 that's greater than ``vocab_size`` can improve throughput
391
+ substantially.
392
+ """
393
+
394
+ weight_tying: bool = True
395
+ """
396
+ Whether to tie output linear weights to the input embedding.
397
+ """
398
+
399
+ eos_token_id: int = 50256
400
+ """
401
+ The ID of the end-of-sentence special token.
402
+ """
403
+
404
+ pad_token_id: int = 50256
405
+ """
406
+ The ID of the token to use for padding. Defaults to the ID of the EOS token.
407
+ """
408
+
409
+ init_device: Optional[str] = None
410
+ """
411
+ The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
412
+ """
413
+
414
+ init_fn: InitFnType = InitFnType.normal
415
+ """
416
+ The weight initialization strategy.
417
+ """
418
+
419
+ init_std: float = 0.02
420
+ """
421
+ The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
422
+ as "normal".
423
+ """
424
+
425
+ init_cutoff_factor: Optional[float] = None
426
+ """
427
+ A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
428
+ as "normal". Setting this to None means values are not cutoff.
429
+ """
430
+
431
+ precision: Optional[str] = None
432
+ """
433
+ Precision used to train/evaluate with. You shouldn't set this directly.
434
+ See :data:`TrainConfig.precision` instead.
435
+ """
436
+
437
+ ternary: bool = False
438
+ """
439
+ Use ternary BitLinear layer from "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" (https://arxiv.org/pdf/2402.17764.pdf)
440
+ """
441
+
442
+ @property
443
+ def effective_n_kv_heads(self) -> int:
444
+ if self.n_kv_heads is None:
445
+ if self.multi_query_attention is True:
446
+ return 1
447
+ else:
448
+ return self.n_heads
449
+ else:
450
+ if self.multi_query_attention is None:
451
+ return self.n_kv_heads
452
+ if self.multi_query_attention:
453
+ n_kv_heads_should_be = 1
454
+ else:
455
+ n_kv_heads_should_be = self.n_heads
456
+ if self.n_kv_heads == n_kv_heads_should_be:
457
+ return n_kv_heads_should_be
458
+ else:
459
+ raise OLMoConfigurationError(
460
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
461
+ )
462
+
463
+
464
+ class OptimizerType(StrEnum):
465
+ lionw = "lionw"
466
+ adamw = "adamw"
467
+
468
+
469
+ @dataclass
470
+ class OptimizerConfig(BaseConfig):
471
+ name: OptimizerType = OptimizerType.lionw
472
+ learning_rate: float = 1.0e-4
473
+ weight_decay: float = 0.01
474
+ betas: Tuple[float, float] = (0.9, 0.95)
475
+
476
+ no_decay_norm_and_bias: Optional[bool] = None
477
+ """
478
+ Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead.
479
+ """
480
+
481
+ decay_norm_and_bias: bool = False
482
+ decay_embeddings: bool = False
483
+ metrics_log_interval: Optional[int] = None
484
+ """
485
+ The interval with which to collect and log detailed parameter-specific metrics.
486
+ This only applies when logging to W&B, since these metrics won't be logged to the console.
487
+ If not set, defaults to the wandb `log_interval`.
488
+ """
489
+
490
+ def __post_init__(self):
491
+ self.betas = tuple(self.betas) # type: ignore[assignment]
492
+
493
+ @classmethod
494
+ def update_legacy_settings(cls, config: D) -> D:
495
+ new_config = config.copy()
496
+ if om.is_dict(new_config):
497
+ assert isinstance(new_config, DictConfig)
498
+
499
+ if hasattr(new_config, "name") and new_config.name == "decoupled_lionw":
500
+ new_config.name = "lionw"
501
+ if hasattr(new_config, "eps"):
502
+ del new_config.eps
503
+
504
+ return new_config
505
+
506
+
507
+ class SchedulerType(StrEnum):
508
+ cosine_with_warmup = "cosine_with_warmup"
509
+ linear_with_warmup = "linear_with_warmup"
510
+ inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
511
+ max_scheduler = "max_scheduler"
512
+ constant = "constant"
513
+
514
+
515
+ class SchedulerUnits(StrEnum):
516
+ steps = "steps"
517
+ tokens = "tokens"
518
+
519
+
520
+ @dataclass
521
+ class SchedulerConfig(BaseConfig):
522
+ name: SchedulerType = SchedulerType.cosine_with_warmup
523
+ units: SchedulerUnits = SchedulerUnits.steps
524
+ t_warmup: Union[int, float] = 100
525
+ t_max: Optional[Union[int, float]] = None
526
+ alpha_f: float = 0.1
527
+
528
+ grad_clip_warmup_steps: Optional[Union[int, float]] = None
529
+ """
530
+ The warmup period for which the max grad norm (or norm ratio) will be set to its
531
+ warmup value of `max_grad_norm * grad_clip_warmup_factor`.
532
+ """
533
+
534
+ grad_clip_warmup_factor: Optional[float] = None
535
+ """
536
+ The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period
537
+ vs after the warmup period.
538
+ """
539
+
540
+
541
+ class PaddingDirection(StrEnum):
542
+ right = "right"
543
+ left = "left"
544
+
545
+
546
+ @dataclass
547
+ class DataConfig(BaseConfig):
548
+ paths: Optional[List[str]] = None
549
+ datasets: Optional[Dict[str, List[str]]] = None
550
+ label_mask_paths: Optional[List[str]] = None
551
+ pad_direction: PaddingDirection = PaddingDirection.right
552
+ generate_attention_mask: bool = False
553
+ num_workers: int = 0
554
+ drop_last: bool = False
555
+ pin_memory: bool = False
556
+ prefetch_factor: Optional[int] = None
557
+ persistent_workers: bool = False
558
+ timeout: int = 0
559
+ seed: Optional[int] = None
560
+
561
+
562
+ class EvaluatorType(StrEnum):
563
+ downstream = "downstream"
564
+ lm = "lm"
565
+
566
+
567
+ @dataclass
568
+ class EvaluatorConfig(BaseConfig):
569
+ label: str
570
+ type: EvaluatorType = EvaluatorType.lm
571
+ data: DataConfig = field(default_factory=DataConfig)
572
+ device_eval_batch_size: Optional[int] = None
573
+ subset_num_batches: Optional[int] = None
574
+
575
+
576
+ class TruncationDirection(StrEnum):
577
+ right = "right"
578
+ left = "left"
579
+
580
+
581
+ @dataclass
582
+ class TokenizerConfig(BaseConfig):
583
+ identifier: str = "gpt2"
584
+ truncate_direction: TruncationDirection = TruncationDirection.right
585
+
586
+
587
+ @dataclass
588
+ class WandbConfig(BaseConfig):
589
+ project: Optional[str] = None
590
+ entity: Optional[str] = "ai2-llm"
591
+ group: Optional[str] = None
592
+ name: Optional[str] = None
593
+ tags: Optional[List[str]] = field(default_factory=lambda: ["watching"])
594
+ log_artifacts: bool = False
595
+ rank_zero_only: bool = True
596
+ log_interval: int = 1
597
+
598
+
599
+ @dataclass
600
+ class SpeedMonitorConfig(BaseConfig):
601
+ window_size: int = 100
602
+ gpu_flops_available: Optional[Union[float, int]] = None
603
+
604
+
605
+ @dataclass
606
+ class CompilerConfig(BaseConfig):
607
+ mode: Optional[str] = None
608
+ """
609
+ The mode to compile the model in. At the moment this can be "default",
610
+ "reduce-overhead" (useful for smaller models/batches), or "max-autotune"
611
+ (the fastest for larger models, but takes a long time to compile).
612
+ """
613
+
614
+ fullgraph: bool = False
615
+ """
616
+ Whether it is OK to break model into several subgraphs when compiling.
617
+ Note that this is not compatible with FSDP.
618
+ """
619
+
620
+ backend: str = "inductor"
621
+ """
622
+ The backend to use.
623
+ """
624
+
625
+
626
+ class FSDPWrapStrategy(StrEnum):
627
+ by_block = "by_block"
628
+ """
629
+ Wrap each OLMo block with its own FSDP instance.
630
+ """
631
+
632
+ by_block_and_size = "by_block_and_size"
633
+ """
634
+ Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well.
635
+ """
636
+
637
+ by_block_group = "by_block_group"
638
+ """
639
+ Wrap each block group together into its own FSDP instance.
640
+ This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
641
+ """
642
+
643
+ by_block_group_and_size = "by_block_group_and_size"
644
+ """
645
+ Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well.
646
+ """
647
+
648
+ size_based = "size_based"
649
+ """
650
+ Used PyTorch's default size-based auto wrap policy.
651
+ """
652
+
653
+ one_in_two = "one_in_two"
654
+ one_in_three = "one_in_three"
655
+ one_in_four = "one_in_four"
656
+ one_in_five = "one_in_five"
657
+
658
+
659
+ class FSDPPrecision(StrEnum):
660
+ pure = "pure"
661
+ """
662
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``,
663
+ and ``buffer_dtype`` all set to the autocast precision data type.
664
+ """
665
+
666
+ mixed = "mixed"
667
+ """
668
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype``
669
+ set to the autocast precision data type, while ``reduce_dtype`` is set to fp32.
670
+ """
671
+
672
+
673
+ @dataclass
674
+ class FSDPConfig(BaseConfig):
675
+ use_orig_params: bool = True
676
+ """
677
+ This must be ``True`` if using ``compile`` or you want to track the parameter norm during training.
678
+ """
679
+
680
+ sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
681
+
682
+ wrapping_strategy: Optional[FSDPWrapStrategy] = None
683
+ """
684
+ The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level
685
+ FSDP instance.
686
+ """
687
+
688
+ precision: FSDPPrecision = FSDPPrecision.pure
689
+
690
+
691
+ class CheckpointType(StrEnum):
692
+ sharded = "sharded"
693
+ unsharded = "unsharded"
694
+ sharded_ephemeral = "sharded_ephemeral"
695
+
696
+
697
+ class ShardedCheckpointerType(StrEnum):
698
+ torch_new = "torch_new"
699
+ torch_legacy = "torch_legacy"
700
+ local = "local"
701
+
702
+
703
+ class ActivationCheckpointingStrategy(StrEnum):
704
+ whole_layer = "whole_layer"
705
+ """
706
+ Checkpoint every transformer layer.
707
+ """
708
+
709
+ one_in_two = "one_in_two"
710
+ """
711
+ Checkpoint one in two transformer layers.
712
+ """
713
+
714
+ one_in_three = "one_in_three"
715
+ """
716
+ Checkpoint one in three transformer layers.
717
+ """
718
+
719
+ one_in_four = "one_in_four"
720
+ """
721
+ Checkpoint one in four transformer layers.
722
+ """
723
+
724
+ two_in_three = "two_in_three"
725
+ """
726
+ Checkpoint two out of every three transformer layers.
727
+ """
728
+
729
+ three_in_four = "three_in_four"
730
+ """
731
+ Checkpoint three out of four of every transformer layers.
732
+ """
733
+
734
+ fine_grained = "fine_grained"
735
+ """
736
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
737
+ """
738
+
739
+
740
+ @dataclass
741
+ class TrainConfig(BaseConfig):
742
+ """
743
+ OLMo training configuration.
744
+ """
745
+
746
+ run_name: Optional[str] = None
747
+ """
748
+ The name of the run.
749
+ """
750
+
751
+ seed: int = 6198
752
+ """
753
+ Used to seed all initial RNG states.
754
+ """
755
+
756
+ epoch: Optional[int] = None
757
+ """
758
+ Increment this when starting a new epoch.
759
+ """
760
+
761
+ dry_run: bool = False
762
+ """
763
+ If ``True``, don't actually train.
764
+ """
765
+
766
+ model: ModelConfig = field(default_factory=ModelConfig)
767
+ """
768
+ OLMo Model configuration.
769
+ """
770
+
771
+ optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
772
+ """
773
+ Optimizer configuration.
774
+ """
775
+
776
+ scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
777
+ """
778
+ Learning rate scheduler configuration.
779
+ """
780
+
781
+ data: DataConfig = field(default_factory=DataConfig)
782
+ """
783
+ Training data configuration.
784
+ """
785
+
786
+ restore_dataloader: bool = True
787
+ """
788
+ When restarting, restore the data loader to where it left off.
789
+ If you restarting in order to train on a different dataset, set this to ``False``.
790
+ """
791
+
792
+ fast_forward_batches: Optional[int] = None
793
+ """
794
+ When restarting, use this to fast-forward the dataloader beyond the last checkpoint.
795
+ This can be useful when restarting due to a loss spike in order to skip the data that
796
+ corresponded to the spike.
797
+ """
798
+
799
+ evaluators: List[EvaluatorConfig] = field(default_factory=list)
800
+ """
801
+ Evaluation configurations.
802
+ """
803
+
804
+ eval_interval: int = 1000
805
+ """
806
+ How often (in terms of batches) to run evaluations.
807
+ """
808
+
809
+ tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
810
+ """
811
+ Tokenizer configuration.
812
+ """
813
+
814
+ save_folder: str = "./"
815
+ """
816
+ The directory to save checkpoints to.
817
+ """
818
+
819
+ remote_save_folder: Optional[str] = None
820
+ """
821
+ A folder in a cloud bucket to upload saved checkpoints to.
822
+ """
823
+
824
+ canceled_check_interval: int = 50
825
+ """
826
+ How often (in batches) to check if the run has been canceled or reached its time limit.
827
+ """
828
+
829
+ save_interval: int = 1000
830
+ """
831
+ How often (in terms of steps) to save sharded training state checkpoints.
832
+ """
833
+
834
+ save_interval_unsharded: Optional[int] = None
835
+ """
836
+ How often (if at all) to save unsharded training state checkpoint.
837
+ For large models it can be costly to save these, so it usually makes sense to save
838
+ these less often than regular (sharded) training checkpoints.
839
+ """
840
+
841
+ save_interval_ephemeral: Optional[int] = None
842
+ """
843
+ How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
844
+ as those saved every `save_interval` except that at most only the most recent one of these is kept.
845
+ This is useful when you want to checkpoint often for restarts in case of failures, but don't
846
+ want to keep the majority of these checkpoints.
847
+
848
+ For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
849
+ a temporary checkpoint every 100 steps in case your job fails. In that case you would
850
+ set `save_interval=1000` and `save_interval_ephemeral=100`.
851
+ """
852
+
853
+ save_num_checkpoints_to_keep: int = -1
854
+ """
855
+ How many sharded checkpoints to keep.
856
+ """
857
+
858
+ save_num_unsharded_checkpoints_to_keep: int = -1
859
+ """
860
+ How many unsharded checkpoints to keep.
861
+ """
862
+
863
+ save_overwrite: bool = False
864
+ """
865
+ If ``True``, overwrite any conflicting checkpoint files.
866
+ """
867
+
868
+ force_save_unsharded: bool = False
869
+ """
870
+ Save an unsharded checkpoint before training (even during a dry run).
871
+ Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded
872
+ checkpoint into an unsharded checkpoint.
873
+ """
874
+
875
+ no_pre_train_checkpoint: bool = False
876
+ """
877
+ Skip saving pre-train checkpoint.
878
+ """
879
+
880
+ load_path: Optional[str] = None
881
+ """
882
+ The path to a training checkpoint to restore/resume from.
883
+
884
+ Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
885
+ a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
886
+ For example,
887
+
888
+ ```bash
889
+ --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
890
+ ```
891
+ """
892
+
893
+ load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
894
+ """
895
+ The sharded checkpointer type to use to load the initial checkpoint from ``load_path``.
896
+ """
897
+
898
+ reset_optimizer_state: bool = False
899
+ """
900
+ When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
901
+ We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
902
+ curve (according to the current learning rate schedule settings), and continues from there.
903
+ """
904
+
905
+ reset_trainer_state: bool = False
906
+ """
907
+ When this is set we don't restore the trainer state from a checkpoint.
908
+ """
909
+
910
+ sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy
911
+ """
912
+ The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training.
913
+ """
914
+
915
+ new_style_checkpoints: Optional[bool] = None
916
+ """
917
+ Deprecated. Use ``sharded_checkpointer`` instead.
918
+ """
919
+
920
+ max_duration: Union[int, str] = 10000
921
+ """
922
+ How long to train for.
923
+
924
+ If specified without a unit (the default), the units are assumed to be steps.
925
+ You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until
926
+ 2 trillion tokens.
927
+ """
928
+
929
+ global_train_batch_size: int = 512
930
+ """
931
+ The effective global batch size.
932
+ """
933
+
934
+ device_train_batch_size: Optional[int] = None # calculated automatically
935
+ """
936
+ Don't set this manually. This will be set to ``global_train_batch_size // world_size``.
937
+ """
938
+
939
+ device_train_microbatch_size: int = 16
940
+ """
941
+ The number of instances passed to the model in a single forward-backward pass. You should set
942
+ this as large as you can based on available GPU memory.
943
+ """
944
+
945
+ device_eval_batch_size: int = 16
946
+ """
947
+ The number of evaluation instances passed to the model in a single forward pass on each device.
948
+ """
949
+
950
+ eval_subset_num_batches: int = -1
951
+ """
952
+ The number of batches to use for downstream evaluation from each dataset.
953
+ """
954
+
955
+ eval_on_load: bool = False
956
+ """
957
+ When resuming from a checkpoint, run the evaluation loop right away.
958
+ """
959
+
960
+ device_train_grad_accum: Optional[int] = None # calculated automatically
961
+ """
962
+ Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``.
963
+ """
964
+
965
+ max_grad_norm: Optional[float] = None
966
+ """
967
+ Clip gradient norms to this value if set.
968
+ """
969
+
970
+ max_grad_norm_ratio: Optional[float] = None
971
+ """
972
+ If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`.
973
+ This takes priority over `max_grad_norm` when set.
974
+ """
975
+
976
+ precision: Optional[str] = None
977
+ """
978
+ Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32").
979
+ """
980
+
981
+ wandb: Optional[WandbConfig] = None
982
+ """
983
+ Weights & Biases configuration.
984
+ """
985
+
986
+ speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
987
+ """
988
+ Speed monitor configuration.
989
+ """
990
+
991
+ console_log_interval: int = 1
992
+ """
993
+ How often to log to the console.
994
+ """
995
+
996
+ compile: Optional[CompilerConfig] = None
997
+ """
998
+ Settings for compiling the model with ``torch.compile()``.
999
+ """
1000
+
1001
+ fsdp: FSDPConfig = field(default_factory=FSDPConfig)
1002
+ """
1003
+ Fully sharded data parallel settings.
1004
+ """
1005
+
1006
+ softmax_auxiliary_loss: bool = False
1007
+ """
1008
+ If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax
1009
+ normalizing term to be close to 0.
1010
+ """
1011
+
1012
+ time_limit: Optional[float] = 60 * 60 * 47.5
1013
+ """
1014
+ The maximum amount of time to train for before saving a checkpoint and ending early.
1015
+ On LUMI we have 48 hours max per job, so we default to just under 48 hours to give us time
1016
+ to write out a final checkpoint.
1017
+ """
1018
+
1019
+ extra_steps_after_cancel: int = 10
1020
+ """
1021
+ Under certain conditions when a run is canceled we train for a few extra steps after saving
1022
+ the final checkpoint so that when the run is restarted from the latest checkpoint we have some
1023
+ overlap in metrics.
1024
+ """
1025
+
1026
+ early_stopping_factor: Optional[float] = None
1027
+
1028
+ save_data_indices: bool = True
1029
+ """
1030
+ Save training data indices from each batch for each worker.
1031
+ """
1032
+
1033
+ python_profiling: bool = False
1034
+ """
1035
+ Whether to run the Python profiler on batches 6, 7, and 8.
1036
+ """
1037
+
1038
+ torch_profiling: bool = False
1039
+ """
1040
+ Whether to run the PyTorch profiler on batches 6, 7, and 8.
1041
+ """
1042
+
1043
+ stop_at: Optional[int] = None
1044
+ """
1045
+ Stop at a specific step.
1046
+ """
1047
+
1048
+ stop_after: Optional[int] = None
1049
+ """
1050
+ Stop after a specific number of steps.
1051
+ """
1052
+
1053
+ activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None
1054
+ """
1055
+ The activation checkpointing strategy to use.
1056
+ """
1057
+
1058
+ fused_loss: Optional[bool] = None
1059
+ """
1060
+ Whether to use the fused CE loss function from `flash-attn`.
1061
+ """
1062
+
1063
+ @property
1064
+ def autocast_precision(self) -> torch.dtype:
1065
+ if self.precision == "amp_bf16":
1066
+ return torch.bfloat16
1067
+ elif self.precision == "amp_fp16":
1068
+ return torch.float16
1069
+ elif self.precision == "fp32":
1070
+ return torch.float32
1071
+ else:
1072
+ raise ValueError(f"Unexpected precision type '{self.precision}'")
1073
+
1074
+ @property
1075
+ def fsdp_precision(self) -> MixedPrecision:
1076
+ if self.fsdp.precision == FSDPPrecision.pure:
1077
+ return MixedPrecision(
1078
+ param_dtype=self.autocast_precision,
1079
+ reduce_dtype=self.autocast_precision,
1080
+ buffer_dtype=self.autocast_precision,
1081
+ )
1082
+ elif self.fsdp.precision == FSDPPrecision.mixed:
1083
+ return MixedPrecision(
1084
+ param_dtype=self.autocast_precision,
1085
+ reduce_dtype=torch.float32,
1086
+ buffer_dtype=self.autocast_precision,
1087
+ )
1088
+ else:
1089
+ raise NotImplementedError(f"{self.fsdp.precision}")
1090
+
1091
+ @classmethod
1092
+ def update_legacy_settings(cls, config: D) -> D:
1093
+ new_config = config.copy()
1094
+ if om.is_dict(new_config):
1095
+ assert isinstance(new_config, DictConfig)
1096
+
1097
+ if hasattr(new_config, "activation_checkpointing"):
1098
+ if new_config.activation_checkpointing is False:
1099
+ new_config.activation_checkpointing = None
1100
+ if new_config.activation_checkpointing is True:
1101
+ new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer
1102
+
1103
+ if hasattr(new_config, "optimizer"):
1104
+ new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)
1105
+
1106
+ return new_config
OLMo_Bitnet_1B/configuration_olmo.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OLMo configuration
3
+ """
4
+
5
+ from transformers import AutoConfig, PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ from .config import ModelConfig
9
+ from .aliases import PathOrStr
10
+ from .beam_search import Sampler
11
+ from .exceptions import OLMoError
12
+ from .initialization import ModuleType
13
+ from .optim import Optimizer
14
+ from .util import StrEnum
15
+ from .safetensors_util import STKey
16
+ from .torch_util import seed_all
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class OLMoConfig(PretrainedConfig):
22
+ model_type = "olmo"
23
+ keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
24
+
25
+ def __init__(self, use_cache: bool = False, **kwargs):
26
+ model_config = ModelConfig()
27
+ all_kwargs = model_config.asdict()
28
+ all_kwargs.update(kwargs)
29
+ all_kwargs.update({"use_cache": use_cache})
30
+ all_kwargs.update(
31
+ {
32
+ "architectures": all_kwargs.get("architectures", ["OLMoModelForCausalLM"])
33
+ or ["OLMoModelForCausalLM"]
34
+ }
35
+ )
36
+ super().__init__(**all_kwargs)
37
+
38
+ @property
39
+ def num_attention_heads(self):
40
+ return self.n_heads
41
+
42
+ @property
43
+ def num_hidden_layers(self):
44
+ return self.n_layers
45
+
46
+ @property
47
+ def hidden_size(self):
48
+ return self.d_model
49
+
50
+
51
+ # Register the config class so that it is available for transformer pipelines, auto-loading etc.
52
+ # AutoConfig.register("olmo", OLMoConfig)
OLMo_Bitnet_1B/exceptions.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "OLMoError",
3
+ "OLMoConfigurationError",
4
+ "OLMoCliError",
5
+ "OLMoEnvironmentError",
6
+ "OLMoNetworkError",
7
+ "OLMoCheckpointError",
8
+ ]
9
+
10
+
11
+ class OLMoError(Exception):
12
+ """
13
+ Base class for all custom OLMo exceptions.
14
+ """
15
+
16
+
17
+ class OLMoConfigurationError(OLMoError):
18
+ """
19
+ An error with a configuration file.
20
+ """
21
+
22
+
23
+ class OLMoCliError(OLMoError):
24
+ """
25
+ An error from incorrect CLI usage.
26
+ """
27
+
28
+
29
+ class OLMoEnvironmentError(OLMoError):
30
+ """
31
+ An error from incorrect environment variables.
32
+ """
33
+
34
+
35
+ class OLMoNetworkError(OLMoError):
36
+ """
37
+ An error with a network request.
38
+ """
39
+
40
+
41
+ class OLMoCheckpointError(OLMoError):
42
+ """
43
+ An error occurred reading or writing from a checkpoint.
44
+ """
45
+
46
+
47
+ class OLMoThreadError(Exception):
48
+ """
49
+ Raised when a thread fails.
50
+ """
OLMo_Bitnet_1B/initialization.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .config import InitFnType, ModelConfig
8
+ from .util import StrEnum
9
+
10
+ __all__ = ["init_weights", "ModuleType"]
11
+
12
+
13
+ class ModuleType(StrEnum):
14
+ in_module = "in"
15
+ out_module = "out"
16
+ emb = "emb"
17
+ final_out = "final_out"
18
+
19
+
20
+ def init_weights(
21
+ config: ModelConfig,
22
+ module: Union[nn.Linear, nn.Embedding],
23
+ d: Optional[int] = None,
24
+ layer_id: Optional[int] = None,
25
+ std_factor: float = 1.0,
26
+ type_of_module: Optional[ModuleType] = None,
27
+ ) -> None:
28
+ """
29
+ Initialize weights of a linear or embedding module.
30
+
31
+ :param config: The model config.
32
+ :param module: The linear or embedding submodule to initialize.
33
+ :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
34
+ for fused layers.
35
+ :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
36
+ ``1 / sqrt(2 * (layer_id + 1))``.
37
+ """
38
+ d = d if d is not None else config.d_model
39
+ if config.init_fn == InitFnType.normal:
40
+ std = config.init_std * std_factor
41
+ if config.init_cutoff_factor is not None:
42
+ cutoff_value = config.init_cutoff_factor * std
43
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
44
+ else:
45
+ nn.init.normal_(module.weight, mean=0.0, std=std)
46
+ elif config.init_fn == InitFnType.mitchell:
47
+ std = std_factor / math.sqrt(d)
48
+ if layer_id is not None:
49
+ std = std / math.sqrt(2 * (layer_id + 1))
50
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
51
+ elif config.init_fn == InitFnType.kaiming_normal:
52
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
53
+ elif config.init_fn == InitFnType.fan_in:
54
+ std = std_factor / math.sqrt(d)
55
+ nn.init.normal_(module.weight, mean=0.0, std=std)
56
+ elif config.init_fn == InitFnType.full_megatron:
57
+ if type_of_module is None:
58
+ raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
59
+
60
+ cutoff_factor = config.init_cutoff_factor
61
+ if cutoff_factor is None:
62
+ cutoff_factor = 3
63
+
64
+ if type_of_module == ModuleType.in_module:
65
+ # for att_proj (same as QKV), ff_proj
66
+ std = config.init_std
67
+ elif type_of_module == ModuleType.out_module:
68
+ # for attn_out, ff_out
69
+ std = config.init_std / math.sqrt(2.0 * config.n_layers)
70
+ elif type_of_module == ModuleType.emb:
71
+ # positional embeddings (wpe)
72
+ # token embeddings (wte)
73
+ std = config.init_std
74
+ elif type_of_module == ModuleType.final_out:
75
+ # final output (ff_out)
76
+ std = config.d_model**-0.5
77
+ else:
78
+ raise RuntimeError(f"Unknown module type '{type_of_module}'")
79
+ nn.init.trunc_normal_(
80
+ module.weight,
81
+ mean=0.0,
82
+ std=std,
83
+ a=-cutoff_factor * std,
84
+ b=cutoff_factor * std,
85
+ )
86
+ else:
87
+ raise NotImplementedError(config.init_fn)
88
+
89
+ if isinstance(module, nn.Linear):
90
+ if module.bias is not None:
91
+ nn.init.zeros_(module.bias)
92
+
93
+ if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
94
+ with torch.no_grad():
95
+ module.weight.div_(math.sqrt(2 * config.n_layers))