naveensp commited on
Commit
a7c44ca
1 Parent(s): 29b160c

Delete folder olmo_bitnet_1b with huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. olmo_bitnet_1b/.gitattributes +0 -35
  2. olmo_bitnet_1b/README.md +0 -38
  3. olmo_bitnet_1b/__init__.py +0 -0
  4. olmo_bitnet_1b/__pycache__/__init__.cpython-310.pyc +0 -0
  5. olmo_bitnet_1b/__pycache__/__init__.cpython-311.pyc +0 -0
  6. olmo_bitnet_1b/__pycache__/__init__.cpython-312.pyc +0 -0
  7. olmo_bitnet_1b/__pycache__/aliases.cpython-310.pyc +0 -0
  8. olmo_bitnet_1b/__pycache__/aliases.cpython-311.pyc +0 -0
  9. olmo_bitnet_1b/__pycache__/aliases.cpython-312.pyc +0 -0
  10. olmo_bitnet_1b/__pycache__/beam_search.cpython-310.pyc +0 -0
  11. olmo_bitnet_1b/__pycache__/beam_search.cpython-311.pyc +0 -0
  12. olmo_bitnet_1b/__pycache__/beam_search.cpython-312.pyc +0 -0
  13. olmo_bitnet_1b/__pycache__/config.cpython-310.pyc +0 -0
  14. olmo_bitnet_1b/__pycache__/config.cpython-311.pyc +0 -0
  15. olmo_bitnet_1b/__pycache__/config.cpython-312.pyc +0 -0
  16. olmo_bitnet_1b/__pycache__/configuration_olmo.cpython-310.pyc +0 -0
  17. olmo_bitnet_1b/__pycache__/configuration_olmo.cpython-311.pyc +0 -0
  18. olmo_bitnet_1b/__pycache__/configuration_olmo.cpython-312.pyc +0 -0
  19. olmo_bitnet_1b/__pycache__/exceptions.cpython-310.pyc +0 -0
  20. olmo_bitnet_1b/__pycache__/exceptions.cpython-311.pyc +0 -0
  21. olmo_bitnet_1b/__pycache__/exceptions.cpython-312.pyc +0 -0
  22. olmo_bitnet_1b/__pycache__/initialization.cpython-310.pyc +0 -0
  23. olmo_bitnet_1b/__pycache__/initialization.cpython-311.pyc +0 -0
  24. olmo_bitnet_1b/__pycache__/initialization.cpython-312.pyc +0 -0
  25. olmo_bitnet_1b/__pycache__/model.cpython-310.pyc +0 -0
  26. olmo_bitnet_1b/__pycache__/model.cpython-311.pyc +0 -0
  27. olmo_bitnet_1b/__pycache__/model.cpython-312.pyc +0 -0
  28. olmo_bitnet_1b/__pycache__/modeling_olmo.cpython-310.pyc +0 -0
  29. olmo_bitnet_1b/__pycache__/modeling_olmo.cpython-311.pyc +0 -0
  30. olmo_bitnet_1b/__pycache__/modeling_olmo.cpython-312.pyc +0 -0
  31. olmo_bitnet_1b/__pycache__/optim.cpython-310.pyc +0 -0
  32. olmo_bitnet_1b/__pycache__/optim.cpython-311.pyc +0 -0
  33. olmo_bitnet_1b/__pycache__/optim.cpython-312.pyc +0 -0
  34. olmo_bitnet_1b/__pycache__/safetensors_util.cpython-310.pyc +0 -0
  35. olmo_bitnet_1b/__pycache__/safetensors_util.cpython-311.pyc +0 -0
  36. olmo_bitnet_1b/__pycache__/safetensors_util.cpython-312.pyc +0 -0
  37. olmo_bitnet_1b/__pycache__/torch_util.cpython-310.pyc +0 -0
  38. olmo_bitnet_1b/__pycache__/torch_util.cpython-311.pyc +0 -0
  39. olmo_bitnet_1b/__pycache__/torch_util.cpython-312.pyc +0 -0
  40. olmo_bitnet_1b/__pycache__/util.cpython-310.pyc +0 -0
  41. olmo_bitnet_1b/__pycache__/util.cpython-311.pyc +0 -0
  42. olmo_bitnet_1b/__pycache__/util.cpython-312.pyc +0 -0
  43. olmo_bitnet_1b/aliases.py +0 -7
  44. olmo_bitnet_1b/beam_search.py +0 -1078
  45. olmo_bitnet_1b/checkpoint.py +0 -1671
  46. olmo_bitnet_1b/config.json +0 -50
  47. olmo_bitnet_1b/config.py +0 -1106
  48. olmo_bitnet_1b/configuration_olmo.py +0 -52
  49. olmo_bitnet_1b/exceptions.py +0 -50
  50. olmo_bitnet_1b/initialization.py +0 -95
olmo_bitnet_1b/.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
olmo_bitnet_1b/README.md DELETED
@@ -1,38 +0,0 @@
1
- ---
2
- license: apache-2.0
3
- datasets:
4
- - allenai/dolma
5
- ---
6
- # OLMo-Bitnet-1B
7
-
8
- OLMo-Bitnet-1B is a 1B parameter model trained using the method described in [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764).
9
-
10
- It was trained on the first 60B tokens of the [Dolma](https://huggingface.co/datasets/allenai/dolma) dataset, so it is merely a research proof-of-concept to test out the methodolgy.
11
-
12
- A separate training run was run with the exact same hyperparameters, but using standard fp16 weights.
13
- The comparison can be found in [this wandb report](https://api.wandb.ai/links/emozilla/evltqiv7).
14
-
15
-
16
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6317aade83d8d2fd903192d9/NAw-hyWJl5ihVsAPqz3Xe.png)
17
-
18
- Sample inference code
19
-
20
- ```sh
21
- pip install ai2-olmo
22
- ```
23
-
24
- ```python
25
- import torch
26
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
27
-
28
- tokenizer = AutoTokenizer.from_pretrained("NousResearch/OLMo-Bitnet-1B")
29
- model = AutoModelForCausalLM.from_pretrained("NousResearch/OLMo-Bitnet-1B",
30
- torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
31
-
32
- streamer = TextStreamer(tokenizer)
33
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id,
34
- temperature=0.8, repetition_penalty=1.1, do_sample=True,streamer=streamer)
35
- pipe("The capitol of Paris is", max_new_tokens=256)
36
- ```
37
-
38
- Training was performed using [OLMo](https://github.com/allenai/OLMo).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
olmo_bitnet_1b/__init__.py DELETED
File without changes
olmo_bitnet_1b/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (146 Bytes)
 
olmo_bitnet_1b/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (165 Bytes)
 
olmo_bitnet_1b/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (153 Bytes)
 
olmo_bitnet_1b/__pycache__/aliases.cpython-310.pyc DELETED
Binary file (268 Bytes)
 
olmo_bitnet_1b/__pycache__/aliases.cpython-311.pyc DELETED
Binary file (342 Bytes)
 
olmo_bitnet_1b/__pycache__/aliases.cpython-312.pyc DELETED
Binary file (300 Bytes)
 
olmo_bitnet_1b/__pycache__/beam_search.cpython-310.pyc DELETED
Binary file (31.6 kB)
 
olmo_bitnet_1b/__pycache__/beam_search.cpython-311.pyc DELETED
Binary file (48 kB)
 
olmo_bitnet_1b/__pycache__/beam_search.cpython-312.pyc DELETED
Binary file (45.6 kB)
 
olmo_bitnet_1b/__pycache__/config.cpython-310.pyc DELETED
Binary file (18.1 kB)
 
olmo_bitnet_1b/__pycache__/config.cpython-311.pyc DELETED
Binary file (28.4 kB)
 
olmo_bitnet_1b/__pycache__/config.cpython-312.pyc DELETED
Binary file (25 kB)
 
olmo_bitnet_1b/__pycache__/configuration_olmo.cpython-310.pyc DELETED
Binary file (1.83 kB)
 
olmo_bitnet_1b/__pycache__/configuration_olmo.cpython-311.pyc DELETED
Binary file (2.74 kB)
 
olmo_bitnet_1b/__pycache__/configuration_olmo.cpython-312.pyc DELETED
Binary file (2.36 kB)
 
olmo_bitnet_1b/__pycache__/exceptions.cpython-310.pyc DELETED
Binary file (1.45 kB)
 
olmo_bitnet_1b/__pycache__/exceptions.cpython-311.pyc DELETED
Binary file (1.99 kB)
 
olmo_bitnet_1b/__pycache__/exceptions.cpython-312.pyc DELETED
Binary file (1.68 kB)
 
olmo_bitnet_1b/__pycache__/initialization.cpython-310.pyc DELETED
Binary file (2.71 kB)
 
olmo_bitnet_1b/__pycache__/initialization.cpython-311.pyc DELETED
Binary file (5.12 kB)
 
olmo_bitnet_1b/__pycache__/initialization.cpython-312.pyc DELETED
Binary file (5.09 kB)
 
olmo_bitnet_1b/__pycache__/model.cpython-310.pyc DELETED
Binary file (47.9 kB)
 
olmo_bitnet_1b/__pycache__/model.cpython-311.pyc DELETED
Binary file (90.1 kB)
 
olmo_bitnet_1b/__pycache__/model.cpython-312.pyc DELETED
Binary file (86.4 kB)
 
olmo_bitnet_1b/__pycache__/modeling_olmo.cpython-310.pyc DELETED
Binary file (6.53 kB)
 
olmo_bitnet_1b/__pycache__/modeling_olmo.cpython-311.pyc DELETED
Binary file (10.1 kB)
 
olmo_bitnet_1b/__pycache__/modeling_olmo.cpython-312.pyc DELETED
Binary file (9.86 kB)
 
olmo_bitnet_1b/__pycache__/optim.cpython-310.pyc DELETED
Binary file (19.6 kB)
 
olmo_bitnet_1b/__pycache__/optim.cpython-311.pyc DELETED
Binary file (41 kB)
 
olmo_bitnet_1b/__pycache__/optim.cpython-312.pyc DELETED
Binary file (36 kB)
 
olmo_bitnet_1b/__pycache__/safetensors_util.cpython-310.pyc DELETED
Binary file (2.8 kB)
 
olmo_bitnet_1b/__pycache__/safetensors_util.cpython-311.pyc DELETED
Binary file (5.01 kB)
 
olmo_bitnet_1b/__pycache__/safetensors_util.cpython-312.pyc DELETED
Binary file (4.4 kB)
 
olmo_bitnet_1b/__pycache__/torch_util.cpython-310.pyc DELETED
Binary file (5.04 kB)
 
olmo_bitnet_1b/__pycache__/torch_util.cpython-311.pyc DELETED
Binary file (9.11 kB)
 
olmo_bitnet_1b/__pycache__/torch_util.cpython-312.pyc DELETED
Binary file (8.08 kB)
 
olmo_bitnet_1b/__pycache__/util.cpython-310.pyc DELETED
Binary file (19.7 kB)
 
olmo_bitnet_1b/__pycache__/util.cpython-311.pyc DELETED
Binary file (37.3 kB)
 
olmo_bitnet_1b/__pycache__/util.cpython-312.pyc DELETED
Binary file (33.2 kB)
 
olmo_bitnet_1b/aliases.py DELETED
@@ -1,7 +0,0 @@
1
- from os import PathLike
2
- from typing import Union
3
-
4
- __all__ = ["PathOrStr"]
5
-
6
-
7
- PathOrStr = Union[str, PathLike]
 
 
 
 
 
 
 
 
olmo_bitnet_1b/beam_search.py DELETED
@@ -1,1078 +0,0 @@
1
- """
2
- This is a self-contained and flexible beam search implementation adapted from
3
- AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
4
- """
5
-
6
- import copy
7
- import warnings
8
- from abc import abstractmethod
9
- from inspect import signature
10
- from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
11
-
12
- import torch
13
-
14
- __all__ = [
15
- "Sampler",
16
- "DeterministicSampler",
17
- "MultinomialSampler",
18
- "TopKSampler",
19
- "TopPSampler",
20
- "GumbelSampler",
21
- "FinalSequenceScorer",
22
- "SequenceLogProbabilityScorer",
23
- "LengthNormalizedSequenceLogProbabilityScorer",
24
- "Constraint",
25
- "RepeatedNGramBlockingConstraint",
26
- "BeamSearch",
27
- ]
28
-
29
- StateType = Dict[str, torch.Tensor]
30
- StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
31
- StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
32
-
33
- StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
34
- """
35
- The type of step function that can be passed to [`BeamSearch.search`](#search).
36
-
37
- This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
38
- or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
39
- """
40
-
41
- ConstraintStateType = List[List[Dict[str, Any]]]
42
-
43
-
44
- class Sampler:
45
- """
46
- An abstract class that can be used to sample candidates (either nodes or beams)
47
- within `BeamSearch`.
48
-
49
- A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
50
-
51
- `init_state()` takes three arguments:
52
-
53
- - a tensor of starting log probs with shape `(batch_size,, num_classes)`,
54
- - the batch size, an int,
55
- - and the number of classes, also an int.
56
-
57
- It returns a state dictionary with any state tensors needed for subsequent
58
- calls to `sample_nodes()` and `sample_beams()`.
59
-
60
- By default this method just returns an empty dictionary.
61
-
62
- Both `sample_nodes()` and `sample_beams()` should take three arguments:
63
-
64
- - tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
65
- - an integer representing the number of samples to take for each example in the batch,
66
- - and a state dictionary which could contain any tensors needed for the `Sampler` to keep
67
- track of state.
68
-
69
- For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
70
- `num_examples = beam_size * per_node_beam_size`.
71
-
72
- The return value should be a tuple containing:
73
-
74
- - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
75
- - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
76
- - and the updated state dictionary.
77
-
78
- A default implementation of `sample_beams` is provided, which just deterministically
79
- picks the `k` examples with highest log probability.
80
- """
81
-
82
- def init_state(
83
- self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
84
- ) -> StateType:
85
- del start_class_log_probabilities, batch_size, num_classes
86
- return {}
87
-
88
- @abstractmethod
89
- def sample_nodes(
90
- self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
91
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
92
- raise NotImplementedError
93
-
94
- def sample_beams(
95
- self, log_probs: torch.Tensor, beam_size: int, state: StateType
96
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
97
- del state
98
- selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
99
- return selected_log_probs, selected_indices, {}
100
-
101
-
102
- class DeterministicSampler(Sampler):
103
- """
104
- A `Sampler` that just deterministically returns the `k` nodes or beams with highest
105
- log probability.
106
- """
107
-
108
- def sample_nodes(
109
- self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
110
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
111
- del state
112
- selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
113
- return selected_log_probs, selected_indices, {}
114
-
115
-
116
- class MultinomialSampler(Sampler):
117
- """
118
- A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
119
- in the default, non-deterministic way.
120
-
121
- :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
122
- above 1.0 produces a flatter probability distribution.
123
- :param with_replacement: Whether to sample with replacement.
124
-
125
- """
126
-
127
- def __init__(
128
- self,
129
- temperature: float = 1.0,
130
- with_replacement: bool = False,
131
- ) -> None:
132
- self.temperature = temperature
133
- self.with_replacement = with_replacement
134
-
135
- def sample_nodes(
136
- self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
137
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
138
- if self.temperature != 1.0:
139
- _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
140
- else:
141
- _probabilities = log_probs.exp()
142
-
143
- selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
144
-
145
- return torch.gather(log_probs, 1, selected_indices), selected_indices, state
146
-
147
-
148
- class TopKSampler(Sampler):
149
- """
150
- A `Sampler` which redistributes the probability mass function for nodes among the
151
- top `k` choices, then samples from that subset after re-normalizing the probabilities.
152
-
153
- Beams are sampled in the default, deterministic way.
154
-
155
- :param k: The number of top choices to be selected from.
156
- :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
157
- above 1.0 produces a flatter probability distribution.
158
- :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
159
- """
160
-
161
- def __init__(
162
- self,
163
- k: int = 1,
164
- temperature: float = 1.0,
165
- with_replacement: bool = False,
166
- ):
167
- self.k = k
168
- self.temperature = temperature or 1.0
169
- self.with_replacement = with_replacement
170
-
171
- def sample_nodes(
172
- self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
173
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
174
- if not per_node_beam_size <= self.k <= log_probs.size()[1]:
175
- raise ValueError(
176
- "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
177
- )
178
-
179
- # shape (both): (batch_size, k)
180
- top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
181
-
182
- # Apply temperature if necessary.
183
- # shape: (batch_size, k)
184
- if self.temperature != 1.0:
185
- top_k_log_probs = top_k_log_probs / self.temperature
186
-
187
- # Re-normalize the subset.
188
- # shape: (batch_size, k)
189
- normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
190
-
191
- # Sample from the re-normalized subset.
192
- # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
193
- # shape: (batch_size, per_node_beam_size)
194
- sampled_indices = torch.multinomial(
195
- normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
196
- )
197
-
198
- # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
199
- # shape: (batch_size, per_node_beam_size)
200
- indices = top_k_indices.gather(-1, sampled_indices)
201
-
202
- return log_probs.gather(1, indices), indices, state
203
-
204
-
205
- class TopPSampler(Sampler):
206
- """
207
- A `Sampler` which redistributes the probability mass function for nodes among
208
- the top choices with a cumulative probability of at least `p`, then samples from that subset
209
- after re-normalizing the probabilities.
210
-
211
- Beams are sampled in the default, deterministic way.
212
-
213
- :param p:
214
- The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
215
- examples to sample from. If `with_replacement` is `False` and the number of possible samples is
216
- insufficient to sample without replacement from when calling `sample_nodes`, then the top
217
- `per_node_beam_size` examples will be chosen.
218
- :param temperature:
219
- A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
220
- above 1.0 produces a flatter probability distribution.
221
- :param with_replacement:
222
- If set to `True`, samples will be selected with replacement from the top choices.
223
-
224
- """
225
-
226
- def __init__(
227
- self,
228
- p: float = 0.9,
229
- temperature: float = 1.0,
230
- with_replacement: bool = False,
231
- ):
232
- if p < 0.0 or p > 1.0:
233
- raise ValueError("p must be a positive float no greater than 1.0")
234
- self.p = p
235
- self.temperature = temperature or 1.0
236
- self.with_replacement = with_replacement
237
-
238
- def sample_nodes(
239
- self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
240
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
241
- if not per_node_beam_size <= log_probs.size()[1]:
242
- raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
243
-
244
- # First apply temperature coefficient:
245
- if self.temperature != 1.0:
246
- _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
247
- else:
248
- _log_probs = log_probs
249
-
250
- # Sort the probabilities in descending order to then find cumulative sum
251
- log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
252
-
253
- # shape: (batch_size, num_classes)
254
- probabilities_descending = log_probs_descending.exp()
255
- probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
256
-
257
- # Create a mask for filtering out probabilities that don't make the top `p`.
258
- # shape: (batch_size, num_classes)
259
- exclusion_mask = probabilities_summed >= self.p
260
-
261
- # We want to include the first index where probabilities_summed >= p, so we shift over one.
262
- exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
263
- exclusion_mask[..., 0] = False
264
-
265
- # Make sure there's at least `per_node_beam_size` options to be selected.
266
- if not self.with_replacement:
267
- exclusion_mask[..., :per_node_beam_size] = False
268
-
269
- log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
270
-
271
- # Now re-normalized the included log probs.
272
- # shape: (batch_size, num_classes)
273
- filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
274
-
275
- # Sample from the re-normalized subset.
276
- # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
277
- # shape: (batch_size, per_node_beam_size)
278
- sampled_indices = torch.multinomial(
279
- filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
280
- )
281
-
282
- # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
283
- # shape: (batch_size, per_node_beam_size)
284
- selected_indices = sorting_indices.gather(-1, sampled_indices)
285
-
286
- # Return (selected log probabilities, selected classes)
287
- # shape: (len(log_probs),1) , (len(log_probs), 1)
288
- return torch.gather(log_probs, 1, selected_indices), selected_indices, state
289
-
290
-
291
- class GumbelSampler(Sampler):
292
- """
293
- A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
294
- [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
295
- Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
296
- (https://api.semanticscholar.org/CorpusID:76662039).
297
-
298
- :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
299
- above 1.0 produces a flatter probability distribution.
300
- """
301
-
302
- def __init__(self, temperature: float = 1.0):
303
- self.temperature = temperature
304
-
305
- def init_state(
306
- self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
307
- ) -> StateType:
308
- # shape: (batch_size, num_classes)
309
- zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
310
-
311
- # shape: (batch_size, num_classes)
312
- G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
313
-
314
- return {"G_phi_S": G_phi_S}
315
-
316
- def sample_nodes(
317
- self,
318
- log_probs: torch.Tensor,
319
- per_node_beam_size: int,
320
- state: StateType,
321
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
322
- # First apply temperature coefficient:
323
- # shape: (batch_size * beam_size, num_classes)
324
- if self.temperature != 1.0:
325
- _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
326
- else:
327
- _log_probs = log_probs
328
-
329
- # shape: (group_size,)
330
- phi_S = state["phi_S"]
331
-
332
- # shape: (group_size, num_classes)
333
- phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
334
-
335
- # shape: (group_size, num_classes)
336
- phi_S_new = phi_S + _log_probs
337
-
338
- # shape: (group_size, 1)
339
- G_phi_S = state["G_phi_S"].unsqueeze(-1)
340
-
341
- # shape: (group_size, num_classes)
342
- G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
343
-
344
- # Replace NaNs with very negative number.
345
- # shape: (group_size, num_classes)
346
- # G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
347
-
348
- # shape (both): (group_size, per_node_beam_size)
349
- top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
350
-
351
- # shape: (group_size, per_node_beam_size)
352
- top_log_probs = log_probs.gather(1, top_indices)
353
-
354
- return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
355
-
356
- def sample_beams(
357
- self,
358
- log_probs: torch.Tensor,
359
- beam_size: int,
360
- state: StateType,
361
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
362
- """
363
- Returns the beams with the highest perturbed log probabilities.
364
- """
365
- # shape (log_probs): (batch_size, beam_size * per_node_beam_size)
366
-
367
- batch_size = log_probs.size()[0]
368
-
369
- # shape: (batch_size * beam_size, per_node_beam_size)
370
- G_phi_S = state["G_phi_S"]
371
-
372
- # shape: (batch_size, beam_size * per_node_beam_size)
373
- G_phi_S = G_phi_S.reshape_as(log_probs)
374
-
375
- # shape (both): (batch_size, beam_size)
376
- G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
377
-
378
- # shape: (batch_size, beam_size)
379
- selected_log_probs = log_probs.gather(1, selected_indices)
380
-
381
- # Now sort the selected beams by their true log prob.
382
- # shape (all): (batch_size, beam_size)
383
- selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
384
- selected_indices = selected_indices.gather(1, sort_indices)
385
- G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
386
-
387
- # shape: (batch_size * beam_size,)
388
- G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
389
-
390
- # shape: (batch_size * beam_size,)
391
- phi_S = selected_log_probs.reshape(batch_size * beam_size)
392
-
393
- return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
394
-
395
- def gumbel(self, phi) -> torch.Tensor:
396
- """
397
- Sample `Gumbel(phi)`.
398
-
399
- `phi` should have shape `(batch_size, num_classes)`.
400
- """
401
- return -torch.log(-torch.log(torch.rand_like(phi))) + phi
402
-
403
- def gumbel_with_max(self, phi, T) -> torch.Tensor:
404
- """
405
- Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
406
-
407
- `phi` should have shape `(batch_size, num_classes)` and `T` should have
408
- shape `(batch_size, 1)`.
409
- """
410
- # Shape: (batch_size, num_classes)
411
- G_phi = self.gumbel(phi)
412
-
413
- # Now we find the maximum from these samples.
414
- # Shape: (batch_size, )
415
- Z, _ = G_phi.max(dim=-1)
416
-
417
- # Shape: (batch_size, num_classes)
418
- v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
419
-
420
- # Shape: (batch_size, num_classes)
421
- return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
422
-
423
-
424
- class FinalSequenceScorer:
425
- """
426
- An abstract class that can be used to score the final generated sequences found
427
- by beam search. Given the predicted sequences and the corresponding log probabilities of
428
- those sequences, the class calculates and returns the final score of the sequences.
429
-
430
- The default implementation scores the sequences using the sum of the log probabilities of
431
- the sequence, which is passed as input.
432
- """
433
-
434
- @abstractmethod
435
- def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
436
- """
437
- Score the final predictions found by beam search.
438
- Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
439
-
440
- :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
441
- :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
442
- of the log probabilities per token, with shape `(batch_size, beam_size)`.
443
- :param end_index: The index of the end symbol.
444
-
445
- """
446
- raise NotImplementedError
447
-
448
-
449
- class SequenceLogProbabilityScorer(FinalSequenceScorer):
450
- """
451
- A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
452
- across the sequence's tokens.
453
- """
454
-
455
- def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
456
- del predictions, end_index
457
- # The sum of the sequence log probabilities is the input parameter, so just
458
- # return it.
459
- return log_probabilities
460
-
461
-
462
- class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
463
- """
464
- A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
465
- tokens in the sequence. It optionally includes a length penalty which promotes
466
- or demotes sequences based on their lengths. The final score for a sequence will
467
- be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
468
- here includes the end token.
469
-
470
- :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
471
- A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
472
- """
473
-
474
- def __init__(self, length_penalty: float = 1.0):
475
- super().__init__()
476
- self.length_penalty = length_penalty
477
-
478
- def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
479
- # shape: (batch_size, beam_size)
480
- lengths = (predictions != end_index).long().sum(dim=2)
481
-
482
- # If the sequence ended during beam search, the `log_probabilities` will include
483
- # the transition to the end token. Therefore, in such situations, `lengths` is
484
- # actually off by 1. This corrects for that.
485
- # shape: (batch_size, beam_size)
486
- is_end_token = predictions[:, :, -1] == end_index
487
- lengths += is_end_token.long()
488
-
489
- # shape: (batch_size, beam_size)
490
- average_log_probs = log_probabilities / (lengths**self.length_penalty)
491
- return average_log_probs
492
-
493
-
494
- class Constraint:
495
- """
496
- An abstract class that can be used to enforce constraints on the output predictions
497
- by manipulating the class log probabilities during beam search.
498
-
499
- A `Constraint` just has three methods that need to be implemented by subclasses:
500
- `init_state()`, `apply()` and `_update_state()`.
501
-
502
- `init_state()` takes one argument:
503
-
504
- - the batch size, an int
505
-
506
- It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
507
- calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
508
- Each inner list should be of length 1.
509
-
510
- `apply()` takes two arguments:
511
-
512
- - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
513
- and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
514
- - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
515
- log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
516
-
517
- The `apply()` method should return new `class_log_probabilities` that enforce the constraint
518
- for this step of beam search. For instance, it may prevent a specific class from being selected by setting
519
- the corresponding log probability to a negligible value such as `float("-inf")` or
520
- `torch.finfo(class_log_probabilities.dtype).min`.
521
-
522
- `_update_state()` takes two arguments:
523
-
524
- - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
525
- copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
526
- directly edited in-place without affecting the others.
527
- - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
528
- step of beam search.
529
-
530
- The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
531
- length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
532
-
533
- """
534
-
535
- @abstractmethod
536
- def init_state(
537
- self,
538
- batch_size: int,
539
- ) -> ConstraintStateType:
540
- raise NotImplementedError
541
-
542
- @abstractmethod
543
- def apply(
544
- self,
545
- state: ConstraintStateType,
546
- class_log_probabilities: torch.Tensor,
547
- ) -> torch.Tensor:
548
- raise NotImplementedError
549
-
550
- @staticmethod
551
- def _copy_state(
552
- state: ConstraintStateType,
553
- batch_size: int,
554
- beam_size: int,
555
- last_backpointer: Optional[torch.Tensor] = None,
556
- ) -> ConstraintStateType:
557
- """
558
- Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
559
- is not appropriate for your constraint, you will need to implement the copying yourself.
560
- """
561
- new_state = []
562
- for i in range(batch_size):
563
- batch_state = []
564
- for j in range(beam_size):
565
- if last_backpointer is None:
566
- # This is the first prediction, so the backpointer is 0
567
- backpointer = 0
568
- else:
569
- backpointer = last_backpointer[i, j].item()
570
- batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
571
- new_state.append(batch_state)
572
- return new_state
573
-
574
- def update_state(
575
- self,
576
- state: ConstraintStateType,
577
- last_prediction: torch.Tensor,
578
- last_backpointer: Optional[torch.Tensor] = None,
579
- ) -> ConstraintStateType:
580
- batch_size, beam_size = last_prediction.size()
581
- new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
582
- return self._update_state(new_state, last_prediction)
583
-
584
- @abstractmethod
585
- def _update_state(
586
- self,
587
- state: ConstraintStateType,
588
- last_prediction: torch.Tensor,
589
- ) -> ConstraintStateType:
590
- raise NotImplementedError
591
-
592
-
593
- class RepeatedNGramBlockingConstraint(Constraint):
594
- def __init__(self, ngram_size: int, **kwargs) -> None:
595
- super().__init__(**kwargs)
596
- self.ngram_size = ngram_size
597
-
598
- def init_state(
599
- self,
600
- batch_size: int,
601
- ) -> ConstraintStateType:
602
- return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
603
-
604
- def apply(
605
- self,
606
- state: ConstraintStateType,
607
- class_log_probabilities: torch.Tensor,
608
- ) -> torch.Tensor:
609
- for i, batch in enumerate(state):
610
- for j, beam in enumerate(batch):
611
- current_prefix = tuple(beam["current_prefix"])
612
- seen_ngrams = beam["seen_ngrams"]
613
- try:
614
- disallowed_indices = seen_ngrams[current_prefix]
615
- class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
616
- class_log_probabilities.dtype
617
- ).min
618
- except KeyError:
619
- # We have not seen this prefix before, so there is no index
620
- # that needs to be blocked
621
- pass
622
- return class_log_probabilities
623
-
624
- def _update_state(
625
- self,
626
- state: ConstraintStateType,
627
- last_prediction: torch.Tensor,
628
- ) -> ConstraintStateType:
629
- for i, batch in enumerate(state):
630
- for j, beam in enumerate(batch):
631
- prediction = last_prediction[i, j].item()
632
- prefix = beam["current_prefix"]
633
- seen_ngrams = beam["seen_ngrams"]
634
-
635
- if len(prefix) == self.ngram_size - 1:
636
- # This is a new ngram that we have to remember
637
- if tuple(prefix) not in seen_ngrams:
638
- seen_ngrams[tuple(prefix)] = []
639
- seen_ngrams[tuple(prefix)].append(prediction)
640
-
641
- # Create the new prefix, removing the oldest index if the prefix
642
- # is too long
643
- prefix.append(prediction)
644
- if len(prefix) == self.ngram_size:
645
- prefix.pop(0)
646
- return state
647
-
648
-
649
- class BeamSearch:
650
- """
651
- Implements the beam search algorithm for decoding the most likely sequences.
652
-
653
- :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
654
-
655
- :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
656
- of the predicted sequences.
657
-
658
- :param beam_size: The width of the beam used.
659
-
660
- :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
661
- If not given, this just defaults to `beam_size`. Setting this parameter
662
- to a number smaller than `beam_size` may give better results, as it can introduce
663
- more diversity into the search. See
664
- [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
665
- (https://api.semanticscholar.org/CorpusID:2229477).
666
-
667
- :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
668
- If not specified, `DeterministicSampler` will be used, which just takes the
669
- `per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
670
-
671
- Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
672
- [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
673
-
674
- :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
675
- the predicted sequences. This does not include the start or end tokens. If `None`,
676
- no minimum is enforced.
677
-
678
- :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
679
- The output from this module is what is returned by the `search` method. If not
680
- specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
681
- by the sum of the token log probabilities.
682
-
683
- :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
684
- provided, no constraints will be enforced.
685
-
686
- """
687
-
688
- def __init__(
689
- self,
690
- end_index: int,
691
- *,
692
- max_steps: int = 50,
693
- beam_size: int = 10,
694
- per_node_beam_size: Optional[int] = None,
695
- sampler: Optional[Sampler] = None,
696
- min_steps: Optional[int] = None,
697
- final_sequence_scorer: Optional[FinalSequenceScorer] = None,
698
- constraints: Optional[List[Constraint]] = None,
699
- ) -> None:
700
- if not max_steps > 0:
701
- raise ValueError("max_steps must be positive")
702
- if not beam_size > 0:
703
- raise ValueError("beam_size must be positive")
704
- if per_node_beam_size is not None and not per_node_beam_size > 0:
705
- raise ValueError("per_node_beam_size must be positive")
706
- if min_steps is not None:
707
- if not min_steps >= 0:
708
- raise ValueError("min_steps must be non-negative")
709
- if not min_steps <= max_steps:
710
- raise ValueError("min_steps must be less than or equal to max_steps")
711
-
712
- self._end_index = end_index
713
- self.max_steps = max_steps
714
- self.beam_size = beam_size
715
- self.per_node_beam_size = per_node_beam_size or beam_size
716
- self.sampler = sampler or DeterministicSampler()
717
- self.min_steps = min_steps or 0
718
- self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
719
- self.constraints = constraints or []
720
-
721
- @staticmethod
722
- def _reconstruct_sequences(predictions, backpointers):
723
- # Reconstruct the sequences.
724
- # shape: [(batch_size, beam_size, 1)]
725
- reconstructed_predictions = [predictions[-1].unsqueeze(2)]
726
-
727
- if not backpointers:
728
- return reconstructed_predictions
729
-
730
- # shape: (batch_size, beam_size)
731
- cur_backpointers = backpointers[-1]
732
-
733
- for timestep in range(len(predictions) - 2, 0, -1):
734
- # shape: (batch_size, beam_size, 1)
735
- cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
736
-
737
- reconstructed_predictions.append(cur_preds)
738
-
739
- # shape: (batch_size, beam_size)
740
- cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
741
-
742
- # shape: (batch_size, beam_size, 1)
743
- final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
744
-
745
- reconstructed_predictions.append(final_preds)
746
-
747
- return reconstructed_predictions
748
-
749
- def search(
750
- self,
751
- start_predictions: torch.Tensor,
752
- start_state: StateType,
753
- step: StepFunctionType,
754
- ) -> Tuple[torch.Tensor, torch.Tensor]:
755
- """
756
- Given a starting state and a step function, apply beam search to find the
757
- most likely target sequences.
758
-
759
- Returns a tuple of `(predictions, final_scores)`, where `predictions`
760
- has shape `(batch_size, beam_size, max_steps)` and `final_scores`
761
- has shape `(batch_size, beam_size)`.
762
-
763
- .. note::
764
- If your step function returns `-inf` for some log probabilities
765
- (like if you're using a masked log-softmax) then some of the "best"
766
- sequences returned may also have `-inf` log probability. Specifically
767
- this happens when the beam size is smaller than the number of actions
768
- with finite log probability (non-zero probability) returned by the step function.
769
- Therefore if you're using a mask you may want to check the results from `search`
770
- and potentially discard sequences with non-finite log probability.
771
-
772
- :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
773
- Usually the initial predictions are just the index of the "start" token
774
- in the target vocabulary.
775
-
776
- :param start_state: The initial state passed to the `step` function. Each value of the state dict
777
- should be a tensor of shape `(batch_size, *)`, where `*` means any other
778
- number of dimensions.
779
-
780
- :param step: A function that is responsible for computing the next most likely tokens,
781
- given the current state and the predictions from the last time step.
782
- The function should accept two or three arguments:
783
-
784
- - a tensor of shape `(group_size,)` or representing the index of the predicted
785
- tokens from the last time step,
786
- - the current state, a `StateType`, and
787
- - optionally, the timestep, an `int`.
788
-
789
- The `group_size` will be `batch_size * beam_size`, except in the initial
790
- step, for which it will just be `batch_size`.
791
-
792
- The function is expected to return a tuple, where the first element
793
- is a tensor of shape `(group_size, vocab_size)` containing
794
- the log probabilities of the tokens for the next step, and the second
795
- element is the updated state. The tensor in the state should have shape
796
- `(group_size, *)`, where `*` means any other number of dimensions.
797
-
798
- """
799
- step_signature = signature(step)
800
- if len(step_signature.parameters) < 3:
801
- # If the step function we're given does not take the time step argument, wrap it
802
- # in one that does.
803
- old_step = cast(StepFunctionTypeNoTimestep, step)
804
-
805
- def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
806
- del time_step
807
- return old_step(last_predictions, state)
808
-
809
- return self._search(start_predictions, start_state, new_step)
810
- else:
811
- return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
812
-
813
- def _search(
814
- self,
815
- start_predictions: torch.Tensor,
816
- start_state: StateType,
817
- step: StepFunctionTypeWithTimestep,
818
- ) -> Tuple[torch.Tensor, torch.Tensor]:
819
- batch_size = start_predictions.size()[0]
820
-
821
- # List of (batch_size, beam_size) tensors. One for each time step. Does not
822
- # include the start symbols, which are implicit.
823
- predictions: List[torch.Tensor] = []
824
-
825
- # List of (batch_size, beam_size) tensors. One for each time step. None for
826
- # the first. Stores the index n for the parent prediction, i.e.
827
- # predictions[t-1][i][n], that it came from.
828
- backpointers: List[torch.Tensor] = []
829
-
830
- constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
831
-
832
- # Calculate the first timestep. This is done outside the main loop
833
- # because we are going from a single decoder input (the output from the
834
- # encoder) to the top `beam_size` decoder outputs. On the other hand,
835
- # within the main loop we are going from the `beam_size` elements of the
836
- # beam to `beam_size`^2 candidates from which we will select the top
837
- # `beam_size` elements for the next iteration.
838
- # shape: (batch_size, num_classes)
839
- start_class_log_probabilities, state = step(start_predictions, start_state, 0)
840
-
841
- num_classes = start_class_log_probabilities.size()[1]
842
-
843
- # Make sure `per_node_beam_size` is not larger than `num_classes`.
844
- if self.per_node_beam_size > num_classes:
845
- raise ValueError(
846
- f"Vocab size ({num_classes:d}) too small "
847
- f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
848
- f"Please decrease beam_size or per_node_beam_size."
849
- )
850
-
851
- sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
852
-
853
- # Apply all constraints.
854
- if self.constraints:
855
- # shape: (batch_size, 1, num_classes)
856
- expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
857
- for constraint, constraint_state in zip(self.constraints, constraint_states):
858
- expanded_start_class_log_probabilities = constraint.apply(
859
- constraint_state, expanded_start_class_log_probabilities
860
- )
861
- start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
862
-
863
- # Prevent selecting the end symbol if there is any min_steps constraint
864
- if self.min_steps >= 1:
865
- start_class_log_probabilities[:, self._end_index] = torch.finfo(
866
- start_class_log_probabilities.dtype
867
- ).min
868
-
869
- # Get the initial predicted classed and their log probabilities.
870
- # shape: (batch_size, beam_size), (batch_size, beam_size)
871
- (
872
- start_top_log_probabilities,
873
- start_predicted_classes,
874
- sampler_state,
875
- ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
876
-
877
- if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
878
- warnings.warn(
879
- "Empty sequences predicted. You may want to increase the beam size or ensure "
880
- "your step function is working properly.",
881
- RuntimeWarning,
882
- )
883
- return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
884
-
885
- # The log probabilities for the last time step.
886
- # shape: (batch_size, beam_size)
887
- last_log_probabilities = start_top_log_probabilities
888
-
889
- # shape: [(batch_size, beam_size)]
890
- predictions.append(start_predicted_classes)
891
-
892
- # Log probability tensor that mandates that the end token is selected.
893
- # shape: (batch_size * beam_size, num_classes)
894
- log_probs_after_end = start_class_log_probabilities.new_full(
895
- (batch_size * self.beam_size, num_classes),
896
- torch.finfo(start_class_log_probabilities.dtype).min,
897
- )
898
- log_probs_after_end[:, self._end_index] = 0.0
899
-
900
- # Set the same state for each element in the beam.
901
- self._update_initial_state(state, batch_size)
902
-
903
- for i, constraint in enumerate(self.constraints):
904
- constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
905
-
906
- for timestep in range(self.max_steps - 1):
907
- # shape: (batch_size * beam_size,)
908
- last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
909
-
910
- # If every predicted token from the last step is `self._end_index`,
911
- # then we can stop early.
912
- if (last_predictions == self._end_index).all():
913
- break
914
- # Take a step. This get the predicted log probs of the next classes
915
- # and updates the state.
916
- # shape: (batch_size * beam_size, num_classes)
917
- class_log_probabilities, state = step(last_predictions, state, timestep + 1)
918
-
919
- # Apply all constraints.
920
- if self.constraints:
921
- # shape: (batch_size, beam_size, num_classes)
922
- reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
923
- for constraint, constraint_state in zip(self.constraints, constraint_states):
924
- reshaped_class_log_probabilities = constraint.apply(
925
- constraint_state, reshaped_class_log_probabilities
926
- )
927
- # shape: (batch_size * beam_size, num_classes)
928
- class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
929
-
930
- # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
931
- # of the sequence (because `timestep` is 0-indexed and we generated the first token
932
- # before the for loop). Here we block the end index if the search is not allowed to
933
- # terminate on this iteration.
934
- if timestep + 2 <= self.min_steps:
935
- class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
936
-
937
- # shape: (batch_size * beam_size, num_classes)
938
- last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
939
- batch_size * self.beam_size, num_classes
940
- )
941
-
942
- # Here we are finding any beams where we predicted the end token in
943
- # the previous timestep and replacing the distribution with a
944
- # one-hot distribution, forcing the beam to predict the end token
945
- # this timestep as well.
946
- # shape: (batch_size * beam_size, num_classes)
947
- cleaned_log_probabilities = torch.where(
948
- last_predictions_expanded == self._end_index,
949
- log_probs_after_end,
950
- class_log_probabilities,
951
- )
952
-
953
- # shape (both): (batch_size * beam_size, per_node_beam_size)
954
- top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
955
- cleaned_log_probabilities, self.per_node_beam_size, sampler_state
956
- )
957
-
958
- # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
959
- # so that we can add them to the current log probs for this timestep.
960
- # This lets us maintain the log probability of each element on the beam.
961
- # shape: (batch_size * beam_size, per_node_beam_size)
962
- expanded_last_log_probabilities = (
963
- last_log_probabilities.unsqueeze(2)
964
- .expand(batch_size, self.beam_size, self.per_node_beam_size)
965
- .reshape(batch_size * self.beam_size, self.per_node_beam_size)
966
- )
967
-
968
- # shape: (batch_size * beam_size, per_node_beam_size)
969
- summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
970
-
971
- # shape: (batch_size, beam_size * per_node_beam_size)
972
- reshaped_summed = summed_top_log_probabilities.reshape(
973
- batch_size, self.beam_size * self.per_node_beam_size
974
- )
975
-
976
- # shape: (batch_size, beam_size * per_node_beam_size)
977
- reshaped_predicted_classes = predicted_classes.reshape(
978
- batch_size, self.beam_size * self.per_node_beam_size
979
- )
980
-
981
- # Keep only the top `beam_size` beam indices.
982
- # shape (both): (batch_size, beam_size)
983
- (
984
- restricted_beam_log_probs,
985
- restricted_beam_indices,
986
- sampler_state,
987
- ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
988
-
989
- # Use the beam indices to extract the corresponding classes.
990
- # shape: (batch_size, beam_size)
991
- restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
992
-
993
- predictions.append(restricted_predicted_classes)
994
-
995
- # shape: (batch_size, beam_size)
996
- last_log_probabilities = restricted_beam_log_probs
997
-
998
- # The beam indices come from a `beam_size * per_node_beam_size` dimension where the
999
- # indices with a common ancestor are grouped together. Hence
1000
- # dividing by per_node_beam_size gives the ancestor. (Note that this is integer
1001
- # division as the tensor is a LongTensor.)
1002
- # shape: (batch_size, beam_size)
1003
- backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
1004
- backpointers.append(backpointer)
1005
-
1006
- # Keep only the pieces of the state tensors corresponding to the
1007
- # ancestors created this iteration.
1008
- self._update_state(state, backpointer)
1009
-
1010
- for i, constraint in enumerate(self.constraints):
1011
- constraint_states[i] = constraint.update_state(
1012
- constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
1013
- )
1014
-
1015
- # Warn about "-inf" log probabilities if not using any constraints (negligible
1016
- # log probabilities are expected when using constraints).
1017
- if not self.constraints and (
1018
- not torch.isfinite(last_log_probabilities).all()
1019
- or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
1020
- ):
1021
- warnings.warn(
1022
- "Negligible log probabilities encountered ('-inf' or equivalent). "
1023
- "Some final sequences may not make sense. "
1024
- "This can happen when the beam size is larger than the number of valid (non-zero "
1025
- "probability) transitions that the step function produces.",
1026
- RuntimeWarning,
1027
- )
1028
-
1029
- reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
1030
-
1031
- # shape: (batch_size, beam_size, max_steps)
1032
- all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
1033
-
1034
- # Calculate the final sequence scores
1035
- # shape: (batch_size, beam_size)
1036
- final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
1037
-
1038
- # Sort the sequences based on the final scores so the best scoring
1039
- # sequence is at index 0
1040
- sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
1041
- sorted_all_predictions = torch.gather(
1042
- all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
1043
- )
1044
-
1045
- return sorted_all_predictions, sorted_final_scores
1046
-
1047
- def _update_initial_state(self, state: StateType, batch_size: int):
1048
- """
1049
- Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
1050
- """
1051
- for key, state_tensor in state.items():
1052
- if state_tensor is None:
1053
- continue
1054
- # shape: (batch_size * beam_size, *)
1055
- _, *last_dims = state_tensor.size()
1056
- state[key] = (
1057
- state_tensor.unsqueeze(1)
1058
- .expand(batch_size, self.beam_size, *last_dims)
1059
- .reshape(batch_size * self.beam_size, *last_dims)
1060
- )
1061
-
1062
- def _update_state(self, state: StateType, backpointer: torch.Tensor):
1063
- batch_size = backpointer.size()[0]
1064
-
1065
- for key, state_tensor in state.items():
1066
- if state_tensor is None:
1067
- continue
1068
- _, *last_dims = state_tensor.size()
1069
- # shape: (batch_size, beam_size, *)
1070
- expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
1071
- batch_size, self.beam_size, *last_dims
1072
- )
1073
- # shape: (batch_size * beam_size, *)
1074
- state[key] = (
1075
- state_tensor.reshape(batch_size, self.beam_size, *last_dims)
1076
- .gather(1, expanded_backpointer)
1077
- .reshape(batch_size * self.beam_size, *last_dims)
1078
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
olmo_bitnet_1b/checkpoint.py DELETED
@@ -1,1671 +0,0 @@
1
- import gc
2
- import io
3
- import logging
4
- import pickle
5
- import shutil
6
- import traceback
7
- from abc import ABCMeta, abstractmethod
8
- from collections import defaultdict
9
- from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
10
- from contextlib import contextmanager
11
- from copy import deepcopy
12
- from dataclasses import dataclass, field, replace
13
- from functools import reduce
14
- from multiprocessing import shared_memory
15
- from pathlib import Path
16
- from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast
17
-
18
- import numpy as np
19
- import torch
20
- import torch.distributed.checkpoint as dist_cp
21
- import torch.multiprocessing as mp
22
- from packaging import version
23
- from torch.distributed import _remote_device
24
- from torch.distributed._shard._utils import narrow_tensor_by_index
25
- from torch.distributed._shard.metadata import ShardMetadata
26
- from torch.distributed._shard.sharded_tensor import ShardedTensor
27
- from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
28
- from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
29
- from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
30
- from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
31
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
32
- from torch.distributed.fsdp import StateDictType
33
- from torch.distributed.fsdp.api import (
34
- FullOptimStateDictConfig,
35
- FullStateDictConfig,
36
- ShardedOptimStateDictConfig,
37
- ShardedStateDictConfig,
38
- )
39
- from torch.futures import Future
40
-
41
- try:
42
- from torch.distributed.fsdp.flat_param import FlatParamHandle # type: ignore
43
- except ModuleNotFoundError:
44
- from torch.distributed.fsdp._flat_param import FlatParamHandle # type: ignore
45
-
46
- from . import util
47
-
48
- from .aliases import PathOrStr
49
- from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
50
- from .exceptions import OLMoCheckpointError
51
- from .optim import Optimizer, fix_optim_state_dict
52
- from .safetensors_util import safetensors_file_to_state_dict
53
- from .torch_util import (
54
- barrier,
55
- gc_cuda,
56
- get_fs_local_rank,
57
- get_global_rank,
58
- get_world_size,
59
- )
60
- from .util import (
61
- _get_s3_client,
62
- default_thread_count,
63
- dir_is_empty,
64
- get_bytes_range,
65
- get_progress_bar,
66
- resource_path,
67
- upload,
68
- wait_for,
69
- )
70
-
71
- __all__ = [
72
- "save_fsdp_model_and_optim_state",
73
- "load_fsdp_model_and_optim_state",
74
- "load_fsdp_optim_state",
75
- "save_state_dict",
76
- "load_state_dict",
77
- "load_model_state",
78
- "RemoteFileSystemWriter",
79
- "RemoteFileSystemReader",
80
- "Checkpointer",
81
- "FullCheckpointer",
82
- "TorchNewStyleShardedCheckpointer",
83
- "TorchLegacyShardedCheckpointer",
84
- "LocalShardedCheckpointer",
85
- "build_sharded_checkpointer",
86
- ]
87
-
88
-
89
- log = logging.getLogger(__name__)
90
-
91
- MODEL_AND_OPTIM_FOLDER = "model_and_optim"
92
-
93
-
94
- def save_fsdp_model_and_optim_state(
95
- checkpoint_dir: PathOrStr,
96
- fsdp_model: FSDP,
97
- optim: Optimizer,
98
- *,
99
- upload_to: Optional[str] = None,
100
- save_overwrite: bool = False,
101
- ):
102
- """
103
- Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
104
- functions. This should be used during distributed training and should be called by all ranks.
105
-
106
- :param checkpoint_dir: The directory to save to.
107
- :param fsdp_model: The FSDP model.
108
- :param optim: The FSDP model's optimizer.
109
- :param upload_to: Optional, a remote "directory" to upload the checkpoint files to.
110
- :param save_overwrite: Overwrite existing files.
111
-
112
- :raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``.
113
- """
114
- checkpoint_dir = Path(checkpoint_dir)
115
- target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
116
- if save_overwrite:
117
- if get_fs_local_rank() == 0:
118
- shutil.rmtree(target_dir, ignore_errors=True)
119
- elif not dir_is_empty(target_dir):
120
- raise FileExistsError(target_dir)
121
- barrier()
122
- if get_fs_local_rank() == 0:
123
- target_dir.mkdir(exist_ok=True, parents=True)
124
- barrier()
125
- with FSDP.state_dict_type(
126
- fsdp_model,
127
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
128
- state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
129
- optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
130
- ):
131
- model_and_optim_state = {
132
- "model": fsdp_model.state_dict(),
133
- "optim": FSDP.optim_state_dict(fsdp_model, optim),
134
- }
135
- dist_cp.save_state_dict(
136
- model_and_optim_state,
137
- RemoteFileSystemWriter(
138
- target_dir,
139
- upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
140
- save_overwrite=save_overwrite,
141
- ),
142
- )
143
-
144
-
145
- def load_fsdp_model_and_optim_state(
146
- checkpoint_dir: PathOrStr,
147
- fsdp_model: FSDP,
148
- optim: Optimizer,
149
- *,
150
- local_cache: Optional[PathOrStr] = None,
151
- load_optimizer_state: bool = True,
152
- ):
153
- """
154
- Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
155
- functions. This should be used during distributed training and should be called by all ranks.
156
-
157
- :param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory.
158
- :param fsdp_model: The FSDP model.
159
- :param optim: The FSDP model's optimizer.
160
- :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
161
- remote "directory" but there might be a cached version of the same artifacts.
162
- :param load_optimizer_state: Set to ``False`` to skip loading the optimizer state.
163
-
164
- :raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint.
165
- """
166
- load_path = str(checkpoint_dir).rstrip("/")
167
- local_cache = None if local_cache is None else Path(local_cache)
168
- with FSDP.state_dict_type(
169
- fsdp_model,
170
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
171
- state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
172
- optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
173
- ):
174
- # Load the model state dict in place.
175
- log.info("Loading model state...")
176
- model_state = {"model": fsdp_model.state_dict()}
177
- dist_cp.load_state_dict(
178
- model_state,
179
- RemoteFileSystemReader(
180
- f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
181
- local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
182
- ),
183
- )
184
- fsdp_model.load_state_dict(model_state["model"])
185
-
186
- if not load_optimizer_state:
187
- return
188
-
189
- # Load optim state dict in place.
190
- log.info("Loading sharded optimizer state...")
191
- optim_state = load_sharded_optimizer_state_dict(
192
- model_state_dict=model_state["model"],
193
- optimizer_key="optim",
194
- storage_reader=RemoteFileSystemReader(
195
- f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
196
- local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
197
- ),
198
- )
199
- del model_state
200
- gc_cuda()
201
- load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])
202
-
203
-
204
- def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
205
- log.info("Flattening sharded optimizer state...")
206
- # NOTE: Careful! The order of the these arguments has changed from 2.0 to 2.1... ¯\_(ツ)_/¯
207
- if version.parse(torch.__version__) < version.parse("2.1.0"):
208
- flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim) # type: ignore
209
- else:
210
- flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state) # type: ignore
211
- del optim_state
212
- gc.collect()
213
- log.info("Loading flattened optimizer state...")
214
- # Put optim state on CPU since `Optimizer.load_state_dict()` will create a deepcopy of the whole state dict,
215
- # which takes up unnecessary GPU memory.
216
- for state in flattened_osd["state"].values():
217
- for k in state.keys():
218
- v = state[k]
219
- if isinstance(v, torch.Tensor):
220
- state[k] = v.to(device="cpu")
221
- gc_cuda()
222
- optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))
223
-
224
-
225
- def save_state_dict(
226
- checkpoint_dir: PathOrStr,
227
- fname: str,
228
- state_dict: Dict[str, Any],
229
- *,
230
- upload_to: Optional[str] = None,
231
- save_overwrite: bool = False,
232
- synchronize: bool = True,
233
- ):
234
- """
235
- Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`.
236
- This can be used during distributed training or not. If during distributed training the ``fname`` should be unique
237
- for each rank.
238
-
239
- :param checkpoint_dir: The directory to save to.
240
- :param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``.
241
- :param state_dict: The state dict to save.
242
- :param upload_to: Optional, a remote "directory" to upload the file to.
243
- :param save_overwrite: Overwrite existing files.
244
- :param synchronize: If ``False``, don't do any distributed synchronization. Use this when only calling
245
- this function from a single rank.
246
-
247
- :raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``.
248
- """
249
- checkpoint_dir = Path(checkpoint_dir)
250
- target_path = checkpoint_dir / fname
251
- if save_overwrite:
252
- target_path.unlink(missing_ok=True)
253
- elif target_path.is_file():
254
- raise FileExistsError(target_path)
255
- if synchronize:
256
- barrier()
257
- target_path.parent.mkdir(exist_ok=True, parents=True)
258
- if synchronize:
259
- barrier()
260
- torch.save(state_dict, target_path)
261
- if upload_to is not None:
262
- upload_target = f"{upload_to.rstrip('/')}/{fname}"
263
- log.info(f"Uploading {target_path} to {upload_target}...")
264
- upload(target_path, upload_target, save_overwrite=save_overwrite)
265
-
266
-
267
- def load_state_dict(
268
- checkpoint_dir: PathOrStr,
269
- fname: str,
270
- *,
271
- local_cache: Optional[PathOrStr] = None,
272
- map_location: Optional[str] = None,
273
- ):
274
- """
275
- Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`.
276
- This can be used during distributed training or not.
277
-
278
- :param checkpoint_dir: A local or remote checkpoint directory.
279
- :param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``.
280
- :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
281
- remote "directory" but there might be a cached version of the same artifacts.
282
-
283
- :raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
284
- """
285
- if fname.endswith(".pt"):
286
- # Try safetensors version first.
287
- try:
288
- path = resource_path(
289
- str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
290
- )
291
- return safetensors_file_to_state_dict(path, map_location=map_location)
292
- except FileNotFoundError:
293
- pass
294
-
295
- path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
296
- return torch.load(path, map_location=map_location)
297
-
298
-
299
- def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
300
- """
301
- Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`.
302
- Note that ``model`` should not be wrapped with FSDP.
303
- """
304
- state_dict = {"model": model.state_dict()}
305
- dist_cp.load_state_dict(
306
- state_dict,
307
- RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
308
- no_dist=True,
309
- )
310
- model.load_state_dict(state_dict["model"])
311
-
312
-
313
- class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
314
- """
315
- A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files
316
- directly to a cloud bucket when ``upload_to`` is specified.
317
- """
318
-
319
- def __init__(
320
- self,
321
- path: PathOrStr,
322
- single_file_per_rank: bool = True,
323
- sync_files: bool = True,
324
- thread_count: Optional[int] = None,
325
- per_thread_copy_ahead: int = 10_000_000,
326
- upload_to: Optional[str] = None,
327
- save_overwrite: bool = False,
328
- ) -> None:
329
- if thread_count is not None and thread_count <= 0:
330
- raise ValueError("thread count must be at least 1")
331
- super().__init__(
332
- path,
333
- single_file_per_rank=single_file_per_rank,
334
- sync_files=sync_files,
335
- # NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
336
- # returns because uploading big checkpoint files with multiple threads causes
337
- # boto3 to fail in weird ways.
338
- thread_count=thread_count or 1,
339
- per_thread_copy_ahead=per_thread_copy_ahead,
340
- )
341
- self.upload_to = None if upload_to is None else upload_to.rstrip("/")
342
- self.save_overwrite = save_overwrite
343
-
344
- def write_data(
345
- self,
346
- plan: dist_cp.SavePlan,
347
- planner: dist_cp.SavePlanner,
348
- ) -> Future[List[WriteResult]]:
349
- fut = super().write_data(plan, planner)
350
- if self.upload_to is not None:
351
- files_to_upload = set()
352
- for write_result in fut.wait():
353
- files_to_upload.add(write_result.storage_data.relative_path)
354
-
355
- # Create the global S3 client up front to work around a threading issue in boto.
356
- if self.upload_to.startswith("s3://"):
357
- _get_s3_client("s3")
358
- elif self.upload_to.startswith("r2://"):
359
- _get_s3_client("r2")
360
-
361
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
362
- futures = []
363
- for fname in files_to_upload:
364
- source = self.path / fname
365
- target = f"{self.upload_to}/{fname}"
366
- log.info(f"Uploading {source} to {target}...")
367
- futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
368
- for f in as_completed(futures):
369
- try:
370
- f.result()
371
- except BaseException:
372
- # NOTE: we might get an error here that can't be pickled, which causes a different failure
373
- # later when PyTorch tries to reduce that error across ranks. So here we just make
374
- # sure we're raising a simple error type that can be pickled.
375
- raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
376
- return fut
377
-
378
- def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
379
- super().finish(metadata, results)
380
- if self.upload_to is not None:
381
- source = self.path / ".metadata"
382
- target = f"{self.upload_to}/.metadata"
383
- log.info(f"Uploading {source} to {target}...")
384
- upload(source, target, save_overwrite=self.save_overwrite)
385
-
386
-
387
- class RemoteFileSystemReader(dist_cp.StorageReader):
388
- """
389
- A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader`
390
- that can read data directly from cloud storage as well as a local directory.
391
- """
392
-
393
- def __init__(
394
- self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
395
- ):
396
- super().__init__()
397
- if thread_count is not None and thread_count <= 0:
398
- raise ValueError("thread count must be at least 1")
399
- self.path = str(path).rstrip("/")
400
- self.cache = None if local_cache is None else Path(local_cache)
401
- self.thread_count = thread_count or default_thread_count()
402
- self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
403
- self._metadata: Optional[Metadata] = None
404
-
405
- def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
406
- if self.cache is not None and (path := self.cache / relative_path).is_file():
407
- return get_bytes_range(path, offset, length)
408
- else:
409
- return get_bytes_range(f"{self.path}/{relative_path}", offset, length)
410
-
411
- def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
412
- sinfo = self.storage_data[read_item.storage_index]
413
- content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
414
- return (read_item, content)
415
-
416
- def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
417
- # Create the global S3 client up front to work around a threading issue in boto.
418
- if isinstance(self.path, str):
419
- if self.path.startswith("s3://"):
420
- _get_s3_client("s3")
421
- elif self.path.startswith("r2://"):
422
- _get_s3_client("r2")
423
-
424
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
425
- read_item_content_futures = []
426
- for read_item in plan.items:
427
- read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
428
- read_item_content_results = []
429
- for f in as_completed(read_item_content_futures):
430
- try:
431
- read_item_content_results.append(f.result())
432
- except BaseException:
433
- # NOTE: we might get an error here that can't be pickled, which causes a different failure
434
- # later when PyTorch tries to reduce that error across ranks. So here we just make
435
- # sure we're raising a simple error type that can be pickled.
436
- raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
437
-
438
- # Modified from `FileSystemReader.read_data()`
439
- for read_item, content in read_item_content_results:
440
- bytes = io.BytesIO(content)
441
- bytes.seek(0)
442
- if read_item.type == LoadItemType.BYTE_IO:
443
- planner.load_bytes(read_item, bytes)
444
- else:
445
- tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
446
- tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
447
- target_tensor = planner.resolve_tensor(read_item).detach()
448
-
449
- assert (
450
- target_tensor.size() == tensor.size()
451
- ), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
452
- target_tensor.copy_(tensor)
453
- planner.commit_tensor(read_item, target_tensor)
454
-
455
- fut: Future = Future()
456
- fut.set_result(None)
457
- return fut
458
-
459
- def read_metadata(self) -> Metadata:
460
- if self._metadata is None:
461
- with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
462
- self._metadata = pickle.load(metadata_file)
463
- return self._metadata
464
-
465
- def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
466
- del is_coordinator
467
- self.storage_data = metadata.storage_data
468
- assert self.storage_data is not None
469
-
470
- def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
471
- return plan
472
-
473
- def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
474
- return global_plan
475
-
476
-
477
- class Checkpointer(metaclass=ABCMeta):
478
- def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None):
479
- self.cfg = cfg
480
- self.thread_count = thread_count or default_thread_count()
481
-
482
- @abstractmethod
483
- def save_checkpoint(
484
- self,
485
- dir: PathOrStr,
486
- fsdp_model: FSDP,
487
- optim: Optimizer,
488
- train_state: Dict[str, Any],
489
- *,
490
- upload_to: Optional[str] = None,
491
- ) -> None:
492
- raise NotImplementedError
493
-
494
- @abstractmethod
495
- def restore_checkpoint(
496
- self,
497
- load_path: PathOrStr,
498
- fsdp_model: FSDP,
499
- optim: Optimizer,
500
- *,
501
- local_cache: Optional[PathOrStr] = None,
502
- load_optimizer_state: bool = True,
503
- ) -> Dict[str, Any]:
504
- """
505
- Restores a checkpoint to the model and optimizer. Returns the remaining trainer state.
506
- """
507
- raise NotImplementedError
508
-
509
- def unshard_checkpoint(
510
- self,
511
- load_path: PathOrStr,
512
- *,
513
- local_cache: Optional[PathOrStr] = None,
514
- load_optimizer_state: bool = True,
515
- load_trainer_state: bool = True,
516
- device: Optional[torch.device] = None,
517
- ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
518
- """
519
- Unshard a checkpoint.
520
-
521
- Note this is not marked abstract because child classes are not required to implemented this.
522
- """
523
- del load_path, local_cache, load_optimizer_state, load_trainer_state, device
524
- raise NotImplementedError
525
-
526
- @contextmanager
527
- def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
528
- # Make sure checkpoint directory doesn't exist unless it's okay to overwrite it.
529
- checkpoint_dir = Path(dir)
530
- if not dir_is_empty(checkpoint_dir):
531
- if self.cfg.save_overwrite:
532
- if get_fs_local_rank() == 0:
533
- shutil.rmtree(checkpoint_dir, ignore_errors=True)
534
- else:
535
- raise FileExistsError(checkpoint_dir)
536
- # No need to mkdir here since we'll directly replace the temporary directory with
537
- # this directory below.
538
- barrier()
539
-
540
- # Prepare temporary directory. We don't have to be as careful here, we can
541
- # just remove it if it already exists.
542
- checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
543
- if get_fs_local_rank() == 0:
544
- shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
545
- checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)
546
-
547
- barrier()
548
-
549
- # Yield temporary directory for `.save_checkpoint()` to use.
550
- yield checkpoint_dir_tmp
551
-
552
- barrier()
553
-
554
- # Finally if all went well replace the temporary directory with the actual
555
- # checkpoint directory.
556
- if get_fs_local_rank() == 0:
557
- # Replace temp directory with target checkpoint directory.
558
- try:
559
- checkpoint_dir_tmp.replace(checkpoint_dir)
560
- except FileNotFoundError:
561
- # Caught when another (file-system) local rank 0 has already replaced the tmp directory.
562
- # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
563
- # file-systems.
564
- if not checkpoint_dir.exists():
565
- raise
566
-
567
- # In the cases where we're using a shared NFS drive between ranks to save checkpoints,
568
- # replacing the temp directory with the final directory from rank 0 might not be immediately
569
- # realized in the file systems of the other ranks.
570
- # So we wait here across all ranks until that final checkpoint directory is visible.
571
- wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)
572
-
573
- barrier()
574
-
575
- def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
576
- if get_global_rank() == 0:
577
- log.info("Saving config...")
578
- self.cfg.save(config_path := Path(dir) / "config.yaml")
579
- if upload_to is not None:
580
- upload_target = f"{upload_to}/config.yaml"
581
- log.info(f"Uploading {config_path} to {upload_target}")
582
- upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite)
583
-
584
-
585
- class FullCheckpointer(Checkpointer):
586
- """
587
- A :class:`Checkpointer` that saves a single full model and optimizer state dictionary.
588
- """
589
-
590
- def save_checkpoint(
591
- self,
592
- dir: PathOrStr,
593
- fsdp_model: FSDP,
594
- optim: Optimizer,
595
- trainer_state: Dict[str, Any],
596
- *,
597
- upload_to: Optional[str] = None,
598
- ) -> None:
599
- with self._temporary_wd(dir) as checkpoint_dir:
600
- with FSDP.state_dict_type(
601
- fsdp_model,
602
- state_dict_type=StateDictType.FULL_STATE_DICT,
603
- state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
604
- optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
605
- ):
606
- # We'll write the model and optimizer state dicts individually to reduce (CPU) memory consumption.
607
- # First the model state.
608
- model_state_dict = fsdp_model.state_dict()
609
- if get_global_rank() == 0:
610
- log.info("Saving model state...")
611
- save_state_dict(
612
- checkpoint_dir,
613
- "model.pt",
614
- model_state_dict,
615
- upload_to=upload_to,
616
- save_overwrite=self.cfg.save_overwrite,
617
- synchronize=False,
618
- )
619
- del model_state_dict
620
- barrier()
621
-
622
- # Then the optimizer state.
623
- optim_state_dict = FSDP.optim_state_dict(fsdp_model, optim)
624
- if get_global_rank() == 0:
625
- log.info("Saving optim state...")
626
- save_state_dict(
627
- checkpoint_dir,
628
- "optim.pt",
629
- optim_state_dict,
630
- upload_to=upload_to,
631
- save_overwrite=self.cfg.save_overwrite,
632
- synchronize=False,
633
- )
634
- del optim_state_dict
635
- barrier()
636
-
637
- # Save trainer state.
638
- if get_global_rank() == 0:
639
- log.info("Saving trainer state...")
640
- save_state_dict(
641
- checkpoint_dir,
642
- "train.pt",
643
- trainer_state,
644
- upload_to=upload_to,
645
- save_overwrite=self.cfg.save_overwrite,
646
- synchronize=False,
647
- )
648
- # Save config.
649
- self._save_config(checkpoint_dir, upload_to=upload_to)
650
-
651
- def restore_checkpoint(
652
- self,
653
- load_path: PathOrStr,
654
- fsdp_model: FSDP,
655
- optim: Optimizer,
656
- *,
657
- local_cache: Optional[PathOrStr] = None,
658
- load_optimizer_state: bool = True,
659
- ) -> Dict[str, Any]:
660
- with FSDP.state_dict_type(
661
- fsdp_model,
662
- state_dict_type=StateDictType.FULL_STATE_DICT,
663
- state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
664
- optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
665
- ):
666
- with torch.no_grad():
667
- # fill everything with NaN, so we can check afterwards that every parameter has been restored
668
- for module_name, module in fsdp_model.named_modules():
669
- if not isinstance(module, FSDP):
670
- continue
671
- for param in module.params:
672
- param.fill_(torch.nan)
673
-
674
- # restore params from checkpoint
675
- state_dict_to_load = load_state_dict(
676
- load_path, "model.pt", local_cache=local_cache, map_location="cpu"
677
- )
678
- (
679
- state_dict_to_load,
680
- og_keys_to_new,
681
- ) = fsdp_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load)
682
-
683
- for module_name, module in fsdp_model.named_modules():
684
- if not isinstance(module, FSDP):
685
- continue
686
- for param in module.params:
687
- assert param._is_flat_param
688
- for fqn, spi in zip(param._fqns, param._shard_param_infos):
689
- if not spi.in_shard:
690
- continue
691
- key = f"{module_name}.{fqn}"
692
- key = key.replace("_fsdp_wrapped_module.", "")
693
- key = key.lstrip(".")
694
- t = state_dict_to_load[key]
695
- t = t.flatten()
696
- param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_(
697
- t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1]
698
- )
699
-
700
- # make sure that every parameter has been restored
701
- for module_name, module in fsdp_model.named_modules():
702
- if not isinstance(module, FSDP):
703
- continue
704
- for param in module.params:
705
- if torch.isnan(param).any():
706
- raise ValueError(
707
- f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints"
708
- )
709
-
710
- # Load optimizer state.
711
- if load_optimizer_state:
712
- optim_state_dict_to_load = load_state_dict(
713
- load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
714
- )
715
- optim_state_dict_to_load = self._make_optim_state_dict_compatible(
716
- optim_state_dict_to_load,
717
- og_keys_to_new,
718
- )
719
- load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load)
720
- del optim_state_dict_to_load
721
-
722
- # Load other state.
723
- try:
724
- trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
725
- except FileNotFoundError:
726
- # for backwards compatibility
727
- trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
728
- barrier()
729
- return trainer_state
730
-
731
- def _make_optim_state_dict_compatible(
732
- self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
733
- ) -> Dict[str, Any]:
734
- # This state dict comes in two forms: one where the state keys are integers and one where the
735
- # keys are fully qualified parameter names. The latter case is easier to deal with here so we
736
- # first transform the integer key form into the FQN key form.
737
- if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
738
- id_to_fqn: Dict[int, str] = {}
739
- for group in optim_state_dict["param_groups"]:
740
- new_param_names = []
741
- for fqn, id in zip(group["param_names"], group["params"]):
742
- fqn = fqn.replace("_fsdp_wrapped_module.", "")
743
- id_to_fqn[id] = fqn
744
- new_param_names.append(fqn)
745
- group["param_names"] = new_param_names
746
- group["params"] = new_param_names
747
- for id in list(optim_state_dict["state"].keys()):
748
- optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
749
- else:
750
- # Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix.
751
- for group in optim_state_dict["param_groups"]:
752
- group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
753
- group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
754
- assert group["param_names"] == group["params"]
755
- for key in list(optim_state_dict["state"].keys()):
756
- optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
757
- "state"
758
- ].pop(key)
759
-
760
- # Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
761
- # First fix param names in the state.
762
- for og_key, new_keys in og_keys_to_new.items():
763
- og_state = optim_state_dict["state"].pop(og_key, None)
764
- if og_state is None:
765
- continue
766
- for i, new_key in enumerate(new_keys):
767
- if i == len(new_keys) - 1:
768
- optim_state_dict["state"][new_key] = og_state
769
- else:
770
- optim_state_dict["state"][new_key] = deepcopy(og_state)
771
- # Now fix param names in the param groups.
772
- for group in optim_state_dict["param_groups"]:
773
- og_names = group["params"]
774
- new_names = []
775
- for og_key in og_names:
776
- for new_key in og_keys_to_new[og_key]:
777
- new_names.append(new_key)
778
- group["params"] = new_names
779
- group["param_names"] = new_names
780
-
781
- return optim_state_dict
782
-
783
- def load_checkpoint(
784
- self,
785
- load_path: PathOrStr,
786
- *,
787
- local_cache: Optional[PathOrStr] = None,
788
- load_optimizer_state: bool = True,
789
- device: Optional[torch.device] = None,
790
- ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
791
- device = device if device is not None else torch.device("cpu")
792
- model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device) # type: ignore
793
- optim_state = None
794
- if load_optimizer_state:
795
- optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device) # type: ignore
796
- return model_state, optim_state
797
-
798
-
799
- class TorchNewStyleShardedCheckpointer(Checkpointer):
800
- """
801
- A sharded :class:`Checkpointer` that uses PyTorch's new distributed checkpointing functionality.
802
- """
803
-
804
- def save_checkpoint(
805
- self,
806
- dir: PathOrStr,
807
- fsdp_model: FSDP,
808
- optim: Optimizer,
809
- trainer_state: Dict[str, Any],
810
- *,
811
- upload_to: Optional[str] = None,
812
- ) -> None:
813
- with self._temporary_wd(dir) as checkpoint_dir:
814
- # Save model and optim state.
815
- save_fsdp_model_and_optim_state(
816
- checkpoint_dir,
817
- fsdp_model,
818
- optim,
819
- upload_to=upload_to,
820
- save_overwrite=self.cfg.save_overwrite,
821
- )
822
-
823
- # Save trainer state.
824
- log.info("Saving trainer state...")
825
- save_state_dict(
826
- checkpoint_dir,
827
- f"train/rank{get_global_rank()}.pt",
828
- trainer_state,
829
- upload_to=upload_to,
830
- save_overwrite=self.cfg.save_overwrite,
831
- )
832
-
833
- # Save config.
834
- self._save_config(checkpoint_dir, upload_to=upload_to)
835
-
836
- def restore_checkpoint(
837
- self,
838
- load_path: PathOrStr,
839
- fsdp_model: FSDP,
840
- optim: Optimizer,
841
- *,
842
- local_cache: Optional[PathOrStr] = None,
843
- load_optimizer_state: bool = True,
844
- ) -> Dict[str, Any]:
845
- # Load model and optimizer state in place.
846
- log.info("Loading model and optimizer state...")
847
- load_fsdp_model_and_optim_state(
848
- load_path,
849
- fsdp_model,
850
- optim,
851
- local_cache=local_cache,
852
- load_optimizer_state=load_optimizer_state,
853
- )
854
-
855
- # Load trainer state dict.
856
- log.info("Loading trainer state...")
857
- try:
858
- trainer_state = load_state_dict(
859
- load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
860
- )
861
- except FileNotFoundError:
862
- # Fall back to rank 0 train state.
863
- # This can happen when we're restoring a checkpoint with a different world size.
864
- trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
865
- barrier()
866
- return trainer_state
867
-
868
-
869
- class TorchLegacyShardedCheckpointer(Checkpointer):
870
- """
871
- A sharded :class:`Checkpointer` that just uses `torch.save()` with extra logic for handling FSDP model
872
- and optim state.
873
-
874
- The world size must be kept consistent when using this checkpointer.
875
- """
876
-
877
- def save_checkpoint(
878
- self,
879
- dir: PathOrStr,
880
- fsdp_model: FSDP,
881
- optim: Optimizer,
882
- trainer_state: Dict[str, Any],
883
- *,
884
- upload_to: Optional[str] = None,
885
- ) -> None:
886
- with self._temporary_wd(dir) as checkpoint_dir:
887
- with FSDP.state_dict_type(
888
- fsdp_model,
889
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
890
- state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
891
- optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
892
- ):
893
- state_dict = {
894
- "model": fsdp_model.state_dict(),
895
- "optim": FSDP.optim_state_dict(fsdp_model, optim),
896
- **trainer_state,
897
- }
898
- save_state_dict(
899
- checkpoint_dir,
900
- f"rank{get_global_rank()}.pt",
901
- state_dict,
902
- upload_to=upload_to,
903
- save_overwrite=self.cfg.save_overwrite,
904
- )
905
-
906
- # Save config.
907
- self._save_config(checkpoint_dir, upload_to=upload_to)
908
-
909
- def restore_checkpoint(
910
- self,
911
- load_path: PathOrStr,
912
- fsdp_model: FSDP,
913
- optim: Optimizer,
914
- *,
915
- local_cache: Optional[PathOrStr] = None,
916
- load_optimizer_state: bool = True,
917
- ) -> Dict[str, Any]:
918
- with FSDP.state_dict_type(
919
- fsdp_model,
920
- state_dict_type=StateDictType.SHARDED_STATE_DICT,
921
- state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
922
- optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
923
- ):
924
- # Deserialize state dict.
925
- state_dict = load_state_dict(
926
- load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
927
- )
928
-
929
- # Load model and optimizer state.
930
- log.info("Loading model state...")
931
- fsdp_model.load_state_dict(state_dict["model"])
932
- del state_dict["model"]
933
- if load_optimizer_state:
934
- log.info("Loading optimizer state...")
935
- load_fsdp_optim_state(fsdp_model, optim, state_dict["optim"])
936
- del state_dict["optim"]
937
-
938
- barrier()
939
- return state_dict
940
-
941
- def unshard_checkpoint(
942
- self,
943
- load_path: PathOrStr,
944
- *,
945
- local_cache: Optional[PathOrStr] = None,
946
- load_optimizer_state: bool = True,
947
- load_trainer_state: bool = True,
948
- device: Optional[torch.device] = None,
949
- ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
950
- assert local_cache is None, "this method currently only supports local files"
951
- full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"})
952
- model_state = full_state_dict.pop("model")
953
- optim_state = full_state_dict.pop("optim")
954
- return (
955
- model_state,
956
- optim_state if load_optimizer_state else None,
957
- full_state_dict if load_trainer_state else None,
958
- )
959
-
960
- def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
961
- key = tuple() if key is None else key
962
- if isinstance(state, (list, tuple, set)):
963
- for i, sub_state in enumerate(state):
964
- self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
965
- elif isinstance(state, dict):
966
- for name in state.keys():
967
- self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
968
- elif isinstance(state, ShardedTensor):
969
- self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
970
- return
971
- else:
972
- return
973
-
974
- def _get_shard_placement_and_rank_sizes(
975
- self, shards_metadata: List[ShardMetadata], world_size: int
976
- ) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
977
- def shard_size(shard_md):
978
- return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
979
-
980
- rank_sizes = [0 for _ in range(world_size)]
981
- shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
982
- for shard_md in shards_metadata:
983
- shard_rank = cast(_remote_device, shard_md.placement).rank()
984
- assert shard_rank is not None
985
- if shard_rank >= world_size:
986
- raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")
987
-
988
- shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
989
- rank_sizes[shard_rank] += shard_size(shard_md)
990
-
991
- return shard_placement, rank_sizes
992
-
993
- def _copy_sharded_tensor_to_shared_mem(
994
- self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
995
- ) -> Any:
996
- shard0_md = sharded_tensor.metadata()
997
- shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
998
- shard0_md.shards_metadata, world_size
999
- )
1000
-
1001
- rank_size = rank_sizes[rank]
1002
- assert rank_size >= 0
1003
- if rank_size == 0:
1004
- return
1005
-
1006
- assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1007
- numpy_type = np.float32
1008
-
1009
- sharded_memory_name = "-".join(key + (str(rank),))
1010
-
1011
- shm = shared_memory.SharedMemory(
1012
- create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
1013
- )
1014
- np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1015
-
1016
- for local_shard in sharded_tensor.local_shards():
1017
- shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
1018
- assert shard_rank == rank
1019
-
1020
- src = local_shard.tensor.flatten()
1021
- shard_offset = shard_placement[local_shard.metadata][1]
1022
-
1023
- np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()
1024
-
1025
- shm.close()
1026
-
1027
- def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
1028
- shard_number = int(shard_filepath.name[4:-3])
1029
- log.info("Starting unsharding shard number %d to shared memory", shard_number)
1030
-
1031
- with self._patch_sharded_tensor_load():
1032
- shard = torch.load(shard_filepath, map_location="cpu")
1033
- log.debug("Done loading shard number %d", shard_number)
1034
-
1035
- self._copy_sharded_tensors_to_shared_mem(
1036
- shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
1037
- )
1038
- log.info("Done unsharding shard number %d to shared memory", shard_number)
1039
-
1040
- def _unshard_using_sharded_mem(
1041
- self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
1042
- ) -> Any:
1043
- return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))
1044
-
1045
- def _unshard_state_using_shared_mem(
1046
- self, state: Any, world_size: int, device: torch.device, key: Tuple
1047
- ) -> Any:
1048
- if isinstance(state, (list, tuple, set)):
1049
- return state.__class__(
1050
- self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
1051
- for i, sub_state in enumerate(state)
1052
- )
1053
- elif isinstance(state, dict):
1054
- return {
1055
- name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
1056
- for name in state.keys()
1057
- }
1058
- elif isinstance(state, ShardedTensor):
1059
- return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
1060
- elif isinstance(state, torch.Tensor):
1061
- return state.to(device=device)
1062
- else:
1063
- return state
1064
-
1065
- def _unshard_tensor_using_shared_mem(
1066
- self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
1067
- ) -> torch.Tensor:
1068
- shard0_md = sharded_tensor.metadata()
1069
-
1070
- def shard_size(shard_md):
1071
- return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
1072
-
1073
- shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
1074
- shard0_md.shards_metadata, world_size
1075
- )
1076
-
1077
- assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1078
- numpy_type = np.float32
1079
-
1080
- out = torch.empty(
1081
- *sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
1082
- )
1083
- dims = len(sharded_tensor.metadata().size)
1084
- for shard_md, (rank, rank_offset) in shard_placement.items():
1085
- if rank >= world_size:
1086
- raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")
1087
-
1088
- sharded_memory_name = "-".join(key + (str(rank),))
1089
- shm = shared_memory.SharedMemory(name=sharded_memory_name)
1090
-
1091
- rank_size = rank_sizes[rank]
1092
- assert rank_size >= 0
1093
- if rank_size == 0:
1094
- continue
1095
-
1096
- np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1097
-
1098
- tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
1099
- tensor = tensor.view(shard_md.shard_sizes)
1100
-
1101
- out_narrow_view = out
1102
- for dim in range(dims):
1103
- out_narrow_view = out_narrow_view.narrow(
1104
- dim,
1105
- shard_md.shard_offsets[dim],
1106
- shard_md.shard_sizes[dim],
1107
- )
1108
-
1109
- out_narrow_view.copy_(tensor)
1110
-
1111
- shm.close()
1112
- shm.unlink()
1113
-
1114
- return out
1115
-
1116
- @contextmanager
1117
- def _patch_sharded_tensor_load(self):
1118
- """
1119
- Monkeypatch for torch's ShardedTensor, so we can unpickle without having torch.distributed set up.
1120
- """
1121
-
1122
- def _rebuild_from_type_v2_monkey(func, new_type, args, state):
1123
- ret = func(*args)
1124
- if type(ret) is not new_type:
1125
- ret = ret.as_subclass(new_type)
1126
-
1127
- # Shortcut the construction of ShardedTensor
1128
- # This is in the top 5 of my worst hacks.
1129
- if isinstance(ret, ShardedTensor):
1130
- ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state
1131
- return ret
1132
-
1133
- # The rest of this function ought to be in the top 5 of somebody else's worst hacks.
1134
- # Tensor does define __setstate__ even though it doesn't define
1135
- # __getstate__. So only use __setstate__ if it is NOT the one defined
1136
- # on Tensor
1137
- if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__:
1138
- ret.__setstate__(state)
1139
- else:
1140
- ret = torch._utils._set_obj_state(ret, state)
1141
- return ret
1142
-
1143
- original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2
1144
- try:
1145
- torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey
1146
- yield
1147
- finally:
1148
- torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2
1149
-
1150
- def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
1151
- """
1152
- The current unsharding implementation consists of:
1153
-
1154
- 1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
1155
- 2. Loading 1 shard on the main process as a base unsharded object.
1156
- 3. Using the sharded tensors in shared memory to populate the base unsharded object.
1157
-
1158
- This implementation replaced a prior implementation that instead loaded
1159
- all shards using threads, because that implementation turned out to
1160
- be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
1161
- The current implementation is slower than the old one in many scenarios,
1162
- but is significantly faster in the above mentioned case (e.g. 30 minutes)
1163
- if there are enough CPUs.
1164
- """
1165
-
1166
- input_dir = Path(input_dir)
1167
- skip_keys = skip_keys or set()
1168
-
1169
- shard_filepaths = list(input_dir.glob("rank*.pt"))
1170
- world_size = len(shard_filepaths)
1171
- if world_size == 0:
1172
- raise RuntimeError("No shards found for unsharding")
1173
-
1174
- log.info("Number of shards: %d", world_size)
1175
- shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
1176
- min_ram_required_estimate_gb = shard_size_gb * world_size
1177
- log.info(
1178
- "Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
1179
- )
1180
-
1181
- log.info("Copying sharded tensors to shared memory using multiple processes")
1182
- # Copy sharded data to shared memory using multiple processes, so this process can load
1183
- # from memory rather than disk. We spawn a new process instead of forking since shared memory
1184
- # appears to get deleted when forked processes end for some reason.
1185
- executor = ProcessPoolExecutor(
1186
- mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
1187
- )
1188
- futures = []
1189
- for shard_filepath in shard_filepaths:
1190
- shard_rank = int(shard_filepath.name[4:-3])
1191
-
1192
- if shard_rank >= world_size:
1193
- raise RuntimeError(
1194
- f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
1195
- )
1196
-
1197
- futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))
1198
-
1199
- for f in as_completed(futures):
1200
- f.result()
1201
- executor.shutdown()
1202
-
1203
- log.info("Loading a shard on the main process to be unsharded state")
1204
- with self._patch_sharded_tensor_load():
1205
- state = torch.load(shard_filepaths[0], map_location="cpu")
1206
-
1207
- for key in skip_keys:
1208
- if key in state:
1209
- del state[key]
1210
-
1211
- log.info("Unsharding from %d shards ...", world_size)
1212
- return self._unshard_using_sharded_mem(state, world_size, device, input_dir)
1213
-
1214
-
1215
- @dataclass
1216
- class _LocalShardedCheckpointerMetadata(BaseConfig):
1217
- world_size: int = field(default_factory=get_world_size)
1218
-
1219
-
1220
- @dataclass
1221
- class _FlatParamShard:
1222
- full_shape: torch.Size
1223
- shard_offsets: Tuple[int, int]
1224
- shard_data: Optional[torch.Tensor]
1225
-
1226
- def copy_into(self, full_tensor: torch.Tensor) -> None:
1227
- assert self.shard_data is not None
1228
- full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1]
1229
- assert self.shard_data.shape == full_tensor_shard_view.shape
1230
- full_tensor_shard_view.copy_(self.shard_data)
1231
-
1232
-
1233
- class LocalShardedCheckpointer(Checkpointer):
1234
- """
1235
- A sharded :class:`Checkpointer` that directly saves the local FSDP flat params data.
1236
- The optimizer state is saved directly with `torch.save()` without reformatting via FSDP methods.
1237
-
1238
- The world size must be kept consistent when using this checkpointer. However, you can easily
1239
- reconstruct a full unsharded model and/or optimizer state dictionary from a single Python process
1240
- using :meth:`unshard_checkpoint()` (no distributed initialization required).
1241
- """
1242
-
1243
- # These correspond to metadata attributes on `torch.distributed.fsdp.flat_param.FlatParameter`.
1244
- _FLAT_PARAM_METADATA_TO_SAVE = (
1245
- "_fqns",
1246
- "_shard_param_offsets",
1247
- "_shard_indices",
1248
- "_numels",
1249
- "_numels_with_padding",
1250
- "_shapes",
1251
- "_shard_numel_padded",
1252
- "_shard_param_infos",
1253
- )
1254
-
1255
- def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]:
1256
- """
1257
- Returns a list of FSDP modules with their FQN.
1258
- """
1259
- modules = []
1260
- for name, module in fsdp_model.named_modules():
1261
- if isinstance(module, FSDP):
1262
- modules.append((name, module))
1263
- return modules
1264
-
1265
- def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None:
1266
- from torch.distributed.fsdp._runtime_utils import _lazy_init
1267
-
1268
- # TODO (epwalsh): I'm not sure if this is necessary, but this is what PyTorch does before saving/loading
1269
- # an FSDP state dict through the built-in methods.
1270
- if torch.cuda.is_available():
1271
- torch.cuda.synchronize()
1272
- _lazy_init(fsdp_model, fsdp_model)
1273
-
1274
- def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
1275
- if version.parse(torch.__version__) < version.parse("2.1.0"):
1276
- return fsdp_model._handles # type: ignore
1277
- elif version.parse(torch.__version__) < version.parse("2.3.0"):
1278
- # Handle could be None if the FSDP wrapper doesn't manage any parameters.
1279
- if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
1280
- return [fsdp_model._handle] # type: ignore
1281
- else:
1282
- return []
1283
- else:
1284
- # Need to verify FSDP internals with newer versions.
1285
- raise NotImplementedError
1286
-
1287
- @torch.no_grad()
1288
- def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]:
1289
- self._prepare_fsdp_model(fsdp_model)
1290
- module_data = []
1291
- for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model):
1292
- handle_data = []
1293
- for handle in self._fsdp_handles(fsdp_module):
1294
- data: Dict[str, Any] = {}
1295
- # This is a `FlatParameter` instance.
1296
- # See `torch.distributed.fsdp.flat_param` for the API.
1297
- flat_param = handle.flat_param
1298
- data["flat_param.data"] = flat_param.detach()
1299
- for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1300
- if hasattr(flat_param, key):
1301
- data[f"flat_param.{key}"] = getattr(flat_param, key)
1302
- handle_data.append(data)
1303
- module_data.append({"handles": handle_data, "name": module_fqn})
1304
- return {"modules": module_data}
1305
-
1306
- @torch.no_grad()
1307
- def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]):
1308
- """Load the state produced from `self._get_flat_param_state_to_save()`."""
1309
- self._prepare_fsdp_model(fsdp_model)
1310
- fsdp_modules = self._fsdp_modules(fsdp_model)
1311
- assert len(model_state["modules"]) == len(fsdp_modules)
1312
- for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]):
1313
- handles = self._fsdp_handles(fsdp_module)
1314
- assert len(handles) == len(module_data["handles"])
1315
- for handle, data in zip(handles, module_data["handles"]):
1316
- flat_param = handle.flat_param
1317
- # Make sure metadata matches.
1318
- for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1319
- if hasattr(flat_param, key):
1320
- assert getattr(flat_param, key) == data[f"flat_param.{key}"]
1321
- # Load the flat sharded data.
1322
- flat_param.copy_(data["flat_param.data"])
1323
-
1324
- def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
1325
- if get_fs_local_rank() == 0:
1326
- log.info("Saving metadata...")
1327
- metadata = _LocalShardedCheckpointerMetadata()
1328
- metadata.save(metadata_path := Path(dir) / "metadata.yaml")
1329
- if upload_to is not None and get_global_rank() == 0:
1330
- upload_target = f"{upload_to}/metadata.yaml"
1331
- log.info(f"Uploading {metadata_path} to {upload_target}")
1332
- upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite)
1333
-
1334
- def _load_metadata(
1335
- self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None
1336
- ) -> _LocalShardedCheckpointerMetadata:
1337
- metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache)
1338
- return _LocalShardedCheckpointerMetadata.load(metadata_path)
1339
-
1340
- def save_checkpoint(
1341
- self,
1342
- dir: PathOrStr,
1343
- fsdp_model: FSDP,
1344
- optim: Optimizer,
1345
- trainer_state: Dict[str, Any],
1346
- *,
1347
- upload_to: Optional[str] = None,
1348
- ) -> None:
1349
- with self._temporary_wd(dir) as checkpoint_dir:
1350
- # Gather local FSDP flat params data to save.
1351
- # We also save some flat param metadata like the corresponding fully qualified names (fqns)
1352
- # of each original parameter so we can validate that the sharding is the same when loading
1353
- # one of these checkpoints.
1354
- log.info("Saving local FSDP flat params data...")
1355
- save_state_dict(
1356
- checkpoint_dir,
1357
- f"model/rank{get_global_rank()}.pt",
1358
- self._get_flat_param_state_to_save(fsdp_model),
1359
- upload_to=upload_to,
1360
- save_overwrite=self.cfg.save_overwrite,
1361
- )
1362
-
1363
- # Save optimizer state.
1364
- log.info("Saving local optimizer state...")
1365
- save_state_dict(
1366
- checkpoint_dir,
1367
- f"optim/rank{get_global_rank()}.pt",
1368
- optim.state_dict(),
1369
- upload_to=upload_to,
1370
- save_overwrite=self.cfg.save_overwrite,
1371
- )
1372
-
1373
- # Save trainer state.
1374
- log.info("Saving trainer state...")
1375
- save_state_dict(
1376
- checkpoint_dir,
1377
- f"train/rank{get_global_rank()}.pt",
1378
- trainer_state,
1379
- upload_to=upload_to,
1380
- save_overwrite=self.cfg.save_overwrite,
1381
- )
1382
-
1383
- # Save metadata.
1384
- self._save_metadata(checkpoint_dir, upload_to=upload_to)
1385
-
1386
- # Save config. We do this last b/c the presence of a config in a remote checkpoint
1387
- # "directory" indicates that the folder is valid, as a opposed to a partially
1388
- # uploaded checkpoint directory that failed before completing.
1389
- self._save_config(checkpoint_dir, upload_to=upload_to)
1390
-
1391
- def restore_checkpoint(
1392
- self,
1393
- load_path: PathOrStr,
1394
- fsdp_model: FSDP,
1395
- optim: Optimizer,
1396
- *,
1397
- local_cache: Optional[PathOrStr] = None,
1398
- load_optimizer_state: bool = True,
1399
- ) -> Dict[str, Any]:
1400
- # Load metadata and make sure checkpoint is compatible.
1401
- metadata = self._load_metadata(load_path, local_cache=local_cache)
1402
- assert metadata.world_size == get_world_size()
1403
-
1404
- # Load local FSDP flat param data.
1405
- log.info("Loading local FSDP flat params data...")
1406
- model_state = load_state_dict(
1407
- load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1408
- )
1409
- self._load_flat_param_state(fsdp_model, model_state)
1410
- del model_state
1411
-
1412
- # Load local optim state.
1413
- if load_optimizer_state:
1414
- log.info("Loading local optimizer state...")
1415
- optim_state = load_state_dict(
1416
- load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1417
- )
1418
- # HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param
1419
- # in every rank, and keep this in the optimizer state. But this causes issues when loading the
1420
- # state since torch sees the state is non-empty for some params which would normally be empty,
1421
- # and then assumes it should have all of the other state tensors for that param, which is doesn't.
1422
- # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
1423
- # Not the end of the world but there's probably a better way around this without resetting
1424
- # the metric.
1425
- for param_id in list(optim_state["state"].keys()):
1426
- state = optim_state["state"][param_id]
1427
- if "grad_norm_exp_avg" in state:
1428
- del state["grad_norm_exp_avg"]
1429
- if len(state) == 0:
1430
- del optim_state["state"][param_id]
1431
- optim.load_state_dict(optim_state)
1432
- del optim_state
1433
-
1434
- # Load local trainer state.
1435
- log.info("Loading local trainer state...")
1436
- trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
1437
- barrier()
1438
- return trainer_state
1439
-
1440
- def _iter_flat_param_shards(
1441
- self, model_state: Dict[str, Any]
1442
- ) -> Generator[Tuple[str, _FlatParamShard], None, None]:
1443
- for module_data in model_state["modules"]:
1444
- module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "")
1445
- for handle in module_data["handles"]:
1446
- flat_data = handle["flat_param.data"]
1447
- if (num_padding := handle["flat_param._shard_numel_padded"]) > 0:
1448
- # If there's padding in the flat param it should be on the right.
1449
- assert (flat_data[-num_padding:] == 0).all()
1450
- # NOTE: this changes depending on the torch version, but we don't do a version
1451
- # check since we might be trying to unshard an old checkpoint that was stored
1452
- # with a different torch version than we're currently running with.
1453
- if "flat_param._shard_indices" in handle:
1454
- # torch <=2.0.1
1455
- param_start = handle["flat_param._shard_indices"][0]
1456
- current_flat_index = 0
1457
- for relative_fqn, full_shape, (offset_start, offset_end) in zip(
1458
- handle["flat_param._fqns"][param_start:],
1459
- handle["flat_param._shapes"][param_start:],
1460
- handle["flat_param._shard_param_offsets"],
1461
- ):
1462
- root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1463
- numel_shard = offset_end - offset_start + 1
1464
- flat_param_shard = _FlatParamShard(
1465
- full_shape=full_shape,
1466
- shard_offsets=(offset_start, offset_end),
1467
- shard_data=flat_data[current_flat_index : current_flat_index + numel_shard],
1468
- )
1469
- current_flat_index += numel_shard
1470
- yield root_fqn, flat_param_shard
1471
- else:
1472
- # torch >=2.1.0
1473
- for relative_fqn, full_shape, shard_param_info in zip(
1474
- handle["flat_param._fqns"],
1475
- handle["flat_param._shapes"],
1476
- handle["flat_param._shard_param_infos"],
1477
- ):
1478
- if not shard_param_info.in_shard:
1479
- continue
1480
- root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1481
- flat_param_shard = _FlatParamShard(
1482
- full_shape=full_shape,
1483
- shard_offsets=(
1484
- shard_param_info.intra_param_start_idx,
1485
- shard_param_info.intra_param_end_idx,
1486
- ),
1487
- shard_data=flat_data[
1488
- shard_param_info.offset_in_shard : shard_param_info.offset_in_shard
1489
- + shard_param_info.numel_in_shard
1490
- ],
1491
- )
1492
- yield root_fqn, flat_param_shard
1493
-
1494
- def unshard_checkpoint(
1495
- self,
1496
- load_path: PathOrStr,
1497
- *,
1498
- local_cache: Optional[PathOrStr] = None,
1499
- load_optimizer_state: bool = True,
1500
- load_trainer_state: bool = True,
1501
- device: Optional[torch.device] = None,
1502
- ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
1503
- device = device or torch.device("cpu")
1504
- metadata = self._load_metadata(load_path, local_cache=local_cache)
1505
-
1506
- # Gather paths model state, potentially downloading them.
1507
- log.info("Gathering model state dicts...")
1508
- model_state_paths = self._gather_state_dict_paths(
1509
- load_path, "model", metadata.world_size, local_cache=local_cache
1510
- )
1511
-
1512
- # Load model state dicts one-by-one, materializing and populating the full parameters as we go.
1513
- log.info("Materializing full parameters...")
1514
- full_model_state: Dict[str, torch.Tensor] = {}
1515
- # We keep a copy of the flat param metadata minus the actual tensors so we can reconstruct
1516
- # the full optimizer state below without having to reload the model state dicts.
1517
- flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict)
1518
- for rank, path in enumerate(model_state_paths):
1519
- log.info(f"Loading shards from rank {rank}...")
1520
- model_state = torch.load(path, map_location="cpu")
1521
- for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state):
1522
- if root_fqn not in full_model_state:
1523
- log.info(
1524
- f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..."
1525
- )
1526
- assert flat_param_shard.shard_data is not None
1527
- full_model_state[root_fqn] = torch.empty(
1528
- flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device
1529
- )
1530
- # Fill with NaNs so we can validate that the whole parameter has been populated
1531
- # afterwards.
1532
- full_model_state[root_fqn].fill_(torch.nan)
1533
- # Copy over the local shard to the relevant part of the full parameter.
1534
- full_param = full_model_state[root_fqn]
1535
- log.info(f"Loading rank {rank} shard for '{root_fqn}'...")
1536
- flat_param_shard.copy_into(full_param)
1537
- flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None)
1538
-
1539
- log.info("Validating full parameters...")
1540
- for key, tensor in full_model_state.items():
1541
- if torch.isnan(tensor).any():
1542
- raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder")
1543
-
1544
- trainer_state: Optional[Dict[str, Any]] = None
1545
- if load_trainer_state:
1546
- trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
1547
-
1548
- if not load_optimizer_state:
1549
- return full_model_state, None, trainer_state
1550
-
1551
- log.info("Gathering optim state dicts...")
1552
- optim_state_paths = self._gather_state_dict_paths(
1553
- load_path, "optim", metadata.world_size, local_cache=local_cache
1554
- )
1555
-
1556
- log.info("Materializing full optim state...")
1557
- full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
1558
- fqn_to_id: Dict[str, int] = {}
1559
- id_to_fqn: Dict[int, str] = {}
1560
- for rank, path in enumerate(optim_state_paths):
1561
- log.info(f"Loading sharded optim state from rank {rank}...")
1562
- optim_state = torch.load(path, map_location="cpu")
1563
-
1564
- # Initialize param groups.
1565
- # We assume parameter groups are the same across all ranks.
1566
- # The only thing that differs across ranks is the state for each local sharded param.
1567
- if "param_groups" not in full_optim_state:
1568
- full_optim_state["param_groups"] = optim_state["param_groups"]
1569
- else:
1570
- assert full_optim_state["param_groups"] == optim_state["param_groups"]
1571
-
1572
- # Generate mapping of parameter FQNs to optimizer param IDs and vice-versa.
1573
- if not fqn_to_id or not id_to_fqn:
1574
- for group in full_optim_state["param_groups"]:
1575
- for fqn, id in zip(group["param_names"], group["params"]):
1576
- fqn = fqn.replace("_fsdp_wrapped_module.", "")
1577
- fqn_to_id[fqn] = id
1578
- id_to_fqn[id] = fqn
1579
-
1580
- # Iterate over local shard state and copy into the full state.
1581
- for id, shard_state in optim_state["state"].items():
1582
- fqn = id_to_fqn[id]
1583
- flat_param_shard = flat_params_data[rank].get(fqn) # type: ignore[assignment]
1584
- full_state = full_optim_state["state"][id]
1585
- for key, shard_value in shard_state.items():
1586
- assert isinstance(shard_value, torch.Tensor)
1587
- if shard_value.shape == torch.Size([]):
1588
- # Add singleton tensors directly to full state. These should be the same across
1589
- # all ranks.
1590
- assert key in ("step", "grad_norm_exp_avg") # sanity check
1591
- if key not in full_state:
1592
- full_state[key] = shard_value.to(device)
1593
- else:
1594
- assert full_state[key] == shard_value
1595
- else:
1596
- # Otherwise we have a sharded param state.
1597
- # If the corresponding full param state hasn't been materialized yet, do so now.
1598
- assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}"
1599
- if key not in full_state:
1600
- log.info(
1601
- f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..."
1602
- )
1603
- full_state[key] = torch.empty(
1604
- flat_param_shard.full_shape, dtype=shard_value.dtype, device=device
1605
- )
1606
- full_state_value = full_state[key]
1607
-
1608
- # Copy over the local shard state to the relevant part of the full parameter state.
1609
- log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...")
1610
- replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value)
1611
-
1612
- # Lastly, clean up the parameter names in param groups.
1613
- for group in full_optim_state["param_groups"]:
1614
- group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]]
1615
-
1616
- return full_model_state, full_optim_state, trainer_state
1617
-
1618
- def _get_state_dict_path(
1619
- self,
1620
- load_path: PathOrStr,
1621
- state_dict_type: str,
1622
- rank: int,
1623
- *,
1624
- local_cache: Optional[PathOrStr] = None,
1625
- progress=None,
1626
- ) -> Tuple[int, Path]:
1627
- fname = f"{state_dict_type}/rank{rank}.pt"
1628
- return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress)
1629
-
1630
- def _gather_state_dict_paths(
1631
- self,
1632
- load_path: PathOrStr,
1633
- state_dict_type: str,
1634
- world_size: int,
1635
- *,
1636
- local_cache: Optional[PathOrStr] = None,
1637
- ) -> List[Path]:
1638
- progress = get_progress_bar()
1639
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
1640
- futures = []
1641
- for rank in range(world_size):
1642
- future = executor.submit(
1643
- self._get_state_dict_path,
1644
- load_path,
1645
- state_dict_type,
1646
- rank,
1647
- local_cache=local_cache,
1648
- progress=progress,
1649
- )
1650
- futures.append(future)
1651
-
1652
- results: Dict[int, Path] = {}
1653
- for future in as_completed(futures):
1654
- rank, path = future.result()
1655
- results[rank] = path
1656
-
1657
- return [results[rank] for rank in range(world_size)]
1658
-
1659
-
1660
- def build_sharded_checkpointer(
1661
- cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None
1662
- ) -> Checkpointer:
1663
- name = name or cfg.sharded_checkpointer
1664
- if name == ShardedCheckpointerType.torch_new:
1665
- return TorchNewStyleShardedCheckpointer(cfg)
1666
- elif name == ShardedCheckpointerType.torch_legacy:
1667
- return TorchLegacyShardedCheckpointer(cfg)
1668
- elif name == ShardedCheckpointerType.local:
1669
- return LocalShardedCheckpointer(cfg)
1670
- else:
1671
- raise NotImplementedError(name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
olmo_bitnet_1b/config.json DELETED
@@ -1,50 +0,0 @@
1
- {
2
- "activation_type": "swiglu",
3
- "alibi": false,
4
- "alibi_bias_max": 8.0,
5
- "architectures": [
6
- "OLMoModelForCausalLM"
7
- ],
8
- "attention_dropout": 0.0,
9
- "attention_layer_norm": false,
10
- "attention_layer_norm_with_affine": false,
11
- "bias_for_layer_norm": false,
12
- "block_group_size": 1,
13
- "block_type": "sequential",
14
- "clip_qkv": null,
15
- "d_model": 2048,
16
- "embedding_dropout": 0.0,
17
- "embedding_size": 50304,
18
- "eos_token_id": 50279,
19
- "flash_attention": true,
20
- "include_bias": false,
21
- "init_cutoff_factor": null,
22
- "init_device": "cpu",
23
- "init_fn": "mitchell",
24
- "init_std": 0.02,
25
- "layer_norm_type": "rms",
26
- "layer_norm_with_affine": true,
27
- "max_sequence_length": 2048,
28
- "mlp_hidden_size": null,
29
- "mlp_ratio": 8,
30
- "model_type": "olmo",
31
- "multi_query_attention": false,
32
- "n_heads": 16,
33
- "n_layers": 16,
34
- "pad_token_id": 1,
35
- "precision": "amp_bf16",
36
- "residual_dropout": 0.0,
37
- "rope": true,
38
- "rope_full_precision": true,
39
- "scale_logits": false,
40
- "ternary": true,
41
- "transformers_version": "4.38.2",
42
- "use_cache": true,
43
- "vocab_size": 50280,
44
- "inference_mode":false,
45
- "weight_tying": true,
46
- "auto_map": {
47
- "AutoConfig": "configuration_olmo.OLMoConfig",
48
- "AutoModelForCausalLM": "modeling_olmo.OLMoForCausalLM"
49
- }
50
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
olmo_bitnet_1b/config.py DELETED
@@ -1,1106 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import asdict, dataclass, field
4
- from glob import glob
5
- from pathlib import Path
6
- from typing import (
7
- Any,
8
- Dict,
9
- Iterable,
10
- List,
11
- Optional,
12
- Tuple,
13
- Type,
14
- TypeVar,
15
- Union,
16
- cast,
17
- )
18
-
19
- import torch
20
- from omegaconf import DictConfig, ListConfig
21
- from omegaconf import OmegaConf as om
22
- from omegaconf.errors import OmegaConfBaseException
23
- from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
24
-
25
- from .aliases import PathOrStr
26
- from .beam_search import Sampler
27
- from .exceptions import OLMoConfigurationError
28
- from .util import StrEnum
29
-
30
- __all__ = [
31
- "ActivationType",
32
- "ActivationCheckpointingStrategy",
33
- "BlockType",
34
- "LayerNormType",
35
- "InitFnType",
36
- "ModelConfig",
37
- "OptimizerType",
38
- "OptimizerConfig",
39
- "SchedulerType",
40
- "SchedulerConfig",
41
- "DataConfig",
42
- "EvaluatorConfig",
43
- "TokenizerConfig",
44
- "TrainConfig",
45
- "PaddingDirection",
46
- "TruncationDirection",
47
- "SpeedMonitorConfig",
48
- "WandbConfig",
49
- "CompilerConfig",
50
- "WandbConfig",
51
- "FSDPPrecision",
52
- "FSDPWrapStrategy",
53
- "FSDPConfig",
54
- "CheckpointType",
55
- ]
56
-
57
- C = TypeVar("C", bound="BaseConfig")
58
- D = TypeVar("D", bound="DictConfig|ListConfig")
59
-
60
-
61
- class BaseConfig:
62
- @classmethod
63
- def _register_resolvers(cls, validate_paths: bool = True):
64
- # Expands path globs into a list.
65
- def path_glob(*paths) -> List[str]:
66
- out = []
67
- for path in paths:
68
- matches = sorted(glob(path))
69
- if not matches and validate_paths:
70
- raise FileNotFoundError(f"{path} does not match any files or dirs")
71
- out.extend(matches)
72
- return out
73
-
74
- # Chooses the first path in the arguments that exists.
75
- def path_choose(*paths) -> str:
76
- from .util import is_url
77
-
78
- for path in paths:
79
- if is_url(path) or Path(path).exists():
80
- return path
81
- if validate_paths:
82
- raise FileNotFoundError(", ".join(paths))
83
- else:
84
- return ""
85
-
86
- # Finds the latest checkpoint in a folder.
87
- def path_last_checkpoint(path) -> str:
88
- from .util import find_latest_checkpoint
89
-
90
- latest_checkpoint = find_latest_checkpoint(path)
91
- if latest_checkpoint is None:
92
- if validate_paths:
93
- raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
94
- else:
95
- return ""
96
- else:
97
- return str(latest_checkpoint)
98
-
99
- om.register_new_resolver("path.glob", path_glob, replace=True)
100
- om.register_new_resolver("path.choose", path_choose, replace=True)
101
- om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
102
-
103
- @classmethod
104
- def update_legacy_settings(cls, config: D) -> D:
105
- """
106
- Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
107
- """
108
- return config
109
-
110
- @classmethod
111
- def new(cls: Type[C], **kwargs) -> C:
112
- cls._register_resolvers()
113
- conf = om.structured(cls)
114
- try:
115
- if kwargs:
116
- conf = om.merge(conf, kwargs)
117
- return cast(C, om.to_object(conf))
118
- except OmegaConfBaseException as e:
119
- raise OLMoConfigurationError(str(e))
120
-
121
- @classmethod
122
- def load(
123
- cls: Type[C],
124
- path: PathOrStr,
125
- overrides: Optional[List[str]] = None,
126
- key: Optional[str] = None,
127
- validate_paths: bool = True,
128
- ) -> C:
129
- """Load from a YAML file."""
130
- cls._register_resolvers(validate_paths=validate_paths)
131
- schema = om.structured(cls)
132
- try:
133
- raw = om.load(str(path))
134
- if key is not None:
135
- raw = raw[key] # type: ignore
136
- raw = cls.update_legacy_settings(raw)
137
- conf = om.merge(schema, raw)
138
- if overrides:
139
- conf = om.merge(conf, om.from_dotlist(overrides))
140
- return cast(C, om.to_object(conf))
141
- except OmegaConfBaseException as e:
142
- raise OLMoConfigurationError(str(e))
143
-
144
- def save(self, path: PathOrStr) -> None:
145
- """Save to a YAML file."""
146
- om.save(config=self, f=str(path))
147
-
148
- def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
149
- out = asdict(self) # type: ignore
150
- if exclude is not None:
151
- for name in exclude:
152
- if name in out:
153
- del out[name]
154
- return out
155
-
156
-
157
- class LayerNormType(StrEnum):
158
- default = "default"
159
- """
160
- The default LayerNorm implementation, equivalent to PyTorch's built-in version.
161
- """
162
-
163
- low_precision = "low_precision"
164
- """
165
- A low-precision version of the default LayerNorm.
166
- """
167
-
168
- rms = "rms"
169
- """
170
- An RMSNorm implementation. When using ``torch.compile`` this is
171
- probably the fastest implementation.
172
- """
173
-
174
-
175
- class ActivationType(StrEnum):
176
- gelu = "gelu"
177
- relu = "relu"
178
- swiglu = "swiglu"
179
-
180
-
181
- class BlockType(StrEnum):
182
- sequential = "sequential"
183
-
184
- llama = "llama"
185
- """
186
- A block similar to the sequential block with slightly different
187
- implementations of operations like attention to imitate the behavior of Llama.
188
- """
189
-
190
-
191
- class InitFnType(StrEnum):
192
- mitchell = "mitchell"
193
- """
194
- The strategy suggested to us by Mitchell Wortsman from UW.
195
- This uses a truncated normal distribution with an adaptive standard deviation that depends
196
- on the size of the weights as well as the depth of the layer.
197
- """
198
-
199
- normal = "normal"
200
- """
201
- All weights are initialized from the same normal distribution.
202
- """
203
-
204
- kaiming_normal = "kaiming_normal"
205
- """
206
- All weights are initialized with the Kaiming method from a normal distribution.
207
- Note this currently won't work with FSDP.
208
- """
209
-
210
- fan_in = "fan_in"
211
- """
212
- "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
213
- is the input dimensionality of the kernel.
214
- """
215
-
216
- full_megatron = "full_megatron"
217
- """
218
- This is what metaseq calls "full megatron init". It is the init used for Llama 2.
219
- """
220
-
221
-
222
- @dataclass
223
- class ModelConfig(BaseConfig):
224
- """
225
- OLMo (model) configuration.
226
- """
227
-
228
- # Note that the defaults for these attributes are equivalent to the base GPT2 model.
229
-
230
- d_model: int = 768
231
- """
232
- The hidden size of the model.
233
- """
234
-
235
- n_heads: int = 12
236
- """
237
- The number of self-attention heads.
238
- """
239
-
240
- n_kv_heads: Optional[int] = None
241
- """
242
- The number of heads to use for keys and values. Defaults to `n_heads`.
243
- Set this to ``None`` or ``n_heads`` for normal multi-head attention.
244
- Set this to 1 for multi-query attention.
245
- Set it to some in-between value for Llama2-style grouped query attention.
246
- """
247
-
248
- clip_qkv: Optional[float] = None
249
- """
250
- Clip QKV to this value when set.
251
- """
252
-
253
- n_layers: int = 12
254
- """
255
- The number of layers/blocks.
256
- """
257
-
258
- mlp_ratio: int = 4
259
- """
260
- The ratio of the inner MLP dimensionality to ``d_model``.
261
- This is only used when ``mlp_hidden_size`` is not set.
262
- """
263
-
264
- mlp_hidden_size: Optional[int] = None
265
- """
266
- Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
267
- """
268
-
269
- activation_type: ActivationType = ActivationType.swiglu
270
- """
271
- The activation function to use within the MLP layers.
272
- """
273
-
274
- block_type: BlockType = BlockType.sequential
275
- """
276
- The transformer block implementation.
277
- """
278
-
279
- block_group_size: int = 1
280
- """
281
- The number of blocks to group together into a single parent block.
282
- This has no affect on the number of parameters in the model and is only used to wrap groups
283
- of blocks together with a single FSDP wrapper during training.
284
- """
285
-
286
- alibi: bool = False
287
- """
288
- If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
289
- """
290
-
291
- alibi_bias_max: float = 8.0
292
- """
293
- Maximum absolute value of ALiBi bias.
294
- """
295
-
296
- rope: bool = False
297
- """
298
- Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
299
- """
300
-
301
- rope_full_precision: bool = True
302
- """
303
- If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
304
- apply RoPE at the precision of the input.
305
- """
306
-
307
- flash_attention: bool = False
308
- """
309
- If ``True``, use ``FlashAttention``.
310
- """
311
-
312
- attention_dropout: float = 0.1
313
- """
314
- The dropout probability within the attention modules.
315
- """
316
-
317
- multi_query_attention: Optional[bool] = None
318
- """
319
- Deprecated. Use n_kv_heads instead.
320
- """
321
-
322
- attention_layer_norm: bool = False
323
- """
324
- Apply layer norm to the keys and queries within the attention mechanism.
325
- This can help stabilize training.
326
- """
327
-
328
- residual_dropout: float = 0.1
329
- """
330
- The dropout probability for the MLP and attention output within each block.
331
- """
332
-
333
- embedding_dropout: float = 0.1
334
- """
335
- The dropout probability for embeddings.
336
- """
337
-
338
- layer_norm_type: LayerNormType = LayerNormType.default
339
- """
340
- The layernorm implementation to use.
341
- """
342
-
343
- layer_norm_with_affine: bool = True
344
- """
345
- Whether to include bias and weight parameters for the layer norms.
346
- This only affects layer norms that are immediately followed by a linear layer in the forward pass,
347
- so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
348
- to ``False``.
349
- """
350
-
351
- attention_layer_norm_with_affine: bool = True
352
- """
353
- Toggle affine transform for the QK norms.
354
- """
355
-
356
- max_sequence_length: int = 1024
357
- """
358
- The maximum input sequence length supported by the model.
359
- """
360
-
361
- include_bias: bool = True
362
- """
363
- Whether or not to include bias parameters in linear layers.
364
- In PaLM, they got rid of all bias terms because they found that large
365
- models tend to have near 0 bias terms anyway.
366
- """
367
-
368
- bias_for_layer_norm: Optional[bool] = None
369
- """
370
- Whether or not to include bias parameters in layer norm.
371
- This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
372
- layer norm.
373
- When this is None (the default), it inherits the setting from include_bias.
374
- """
375
-
376
- scale_logits: bool = False
377
- """
378
- If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
379
- """
380
-
381
- vocab_size: int = 50257
382
- """
383
- Vocabulary size of the model.
384
- """
385
-
386
- embedding_size: Optional[int] = 50304
387
- """
388
- The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
389
- to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
390
- next multiple of 128 that's greater than ``vocab_size`` can improve throughput
391
- substantially.
392
- """
393
-
394
- weight_tying: bool = True
395
- """
396
- Whether to tie output linear weights to the input embedding.
397
- """
398
-
399
- eos_token_id: int = 50256
400
- """
401
- The ID of the end-of-sentence special token.
402
- """
403
-
404
- pad_token_id: int = 50256
405
- """
406
- The ID of the token to use for padding. Defaults to the ID of the EOS token.
407
- """
408
-
409
- init_device: Optional[str] = None
410
- """
411
- The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
412
- """
413
-
414
- init_fn: InitFnType = InitFnType.normal
415
- """
416
- The weight initialization strategy.
417
- """
418
-
419
- init_std: float = 0.02
420
- """
421
- The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
422
- as "normal".
423
- """
424
-
425
- init_cutoff_factor: Optional[float] = None
426
- """
427
- A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
428
- as "normal". Setting this to None means values are not cutoff.
429
- """
430
-
431
- precision: Optional[str] = None
432
- """
433
- Precision used to train/evaluate with. You shouldn't set this directly.
434
- See :data:`TrainConfig.precision` instead.
435
- """
436
-
437
- ternary: bool = False
438
- """
439
- Use ternary BitLinear layer from "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" (https://arxiv.org/pdf/2402.17764.pdf)
440
- """
441
-
442
- @property
443
- def effective_n_kv_heads(self) -> int:
444
- if self.n_kv_heads is None:
445
- if self.multi_query_attention is True:
446
- return 1
447
- else:
448
- return self.n_heads
449
- else:
450
- if self.multi_query_attention is None:
451
- return self.n_kv_heads
452
- if self.multi_query_attention:
453
- n_kv_heads_should_be = 1
454
- else:
455
- n_kv_heads_should_be = self.n_heads
456
- if self.n_kv_heads == n_kv_heads_should_be:
457
- return n_kv_heads_should_be
458
- else:
459
- raise OLMoConfigurationError(
460
- "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
461
- )
462
-
463
-
464
- class OptimizerType(StrEnum):
465
- lionw = "lionw"
466
- adamw = "adamw"
467
-
468
-
469
- @dataclass
470
- class OptimizerConfig(BaseConfig):
471
- name: OptimizerType = OptimizerType.lionw
472
- learning_rate: float = 1.0e-4
473
- weight_decay: float = 0.01
474
- betas: Tuple[float, float] = (0.9, 0.95)
475
-
476
- no_decay_norm_and_bias: Optional[bool] = None
477
- """
478
- Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead.
479
- """
480
-
481
- decay_norm_and_bias: bool = False
482
- decay_embeddings: bool = False
483
- metrics_log_interval: Optional[int] = None
484
- """
485
- The interval with which to collect and log detailed parameter-specific metrics.
486
- This only applies when logging to W&B, since these metrics won't be logged to the console.
487
- If not set, defaults to the wandb `log_interval`.
488
- """
489
-
490
- def __post_init__(self):
491
- self.betas = tuple(self.betas) # type: ignore[assignment]
492
-
493
- @classmethod
494
- def update_legacy_settings(cls, config: D) -> D:
495
- new_config = config.copy()
496
- if om.is_dict(new_config):
497
- assert isinstance(new_config, DictConfig)
498
-
499
- if hasattr(new_config, "name") and new_config.name == "decoupled_lionw":
500
- new_config.name = "lionw"
501
- if hasattr(new_config, "eps"):
502
- del new_config.eps
503
-
504
- return new_config
505
-
506
-
507
- class SchedulerType(StrEnum):
508
- cosine_with_warmup = "cosine_with_warmup"
509
- linear_with_warmup = "linear_with_warmup"
510
- inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
511
- max_scheduler = "max_scheduler"
512
- constant = "constant"
513
-
514
-
515
- class SchedulerUnits(StrEnum):
516
- steps = "steps"
517
- tokens = "tokens"
518
-
519
-
520
- @dataclass
521
- class SchedulerConfig(BaseConfig):
522
- name: SchedulerType = SchedulerType.cosine_with_warmup
523
- units: SchedulerUnits = SchedulerUnits.steps
524
- t_warmup: Union[int, float] = 100
525
- t_max: Optional[Union[int, float]] = None
526
- alpha_f: float = 0.1
527
-
528
- grad_clip_warmup_steps: Optional[Union[int, float]] = None
529
- """
530
- The warmup period for which the max grad norm (or norm ratio) will be set to its
531
- warmup value of `max_grad_norm * grad_clip_warmup_factor`.
532
- """
533
-
534
- grad_clip_warmup_factor: Optional[float] = None
535
- """
536
- The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period
537
- vs after the warmup period.
538
- """
539
-
540
-
541
- class PaddingDirection(StrEnum):
542
- right = "right"
543
- left = "left"
544
-
545
-
546
- @dataclass
547
- class DataConfig(BaseConfig):
548
- paths: Optional[List[str]] = None
549
- datasets: Optional[Dict[str, List[str]]] = None
550
- label_mask_paths: Optional[List[str]] = None
551
- pad_direction: PaddingDirection = PaddingDirection.right
552
- generate_attention_mask: bool = False
553
- num_workers: int = 0
554
- drop_last: bool = False
555
- pin_memory: bool = False
556
- prefetch_factor: Optional[int] = None
557
- persistent_workers: bool = False
558
- timeout: int = 0
559
- seed: Optional[int] = None
560
-
561
-
562
- class EvaluatorType(StrEnum):
563
- downstream = "downstream"
564
- lm = "lm"
565
-
566
-
567
- @dataclass
568
- class EvaluatorConfig(BaseConfig):
569
- label: str
570
- type: EvaluatorType = EvaluatorType.lm
571
- data: DataConfig = field(default_factory=DataConfig)
572
- device_eval_batch_size: Optional[int] = None
573
- subset_num_batches: Optional[int] = None
574
-
575
-
576
- class TruncationDirection(StrEnum):
577
- right = "right"
578
- left = "left"
579
-
580
-
581
- @dataclass
582
- class TokenizerConfig(BaseConfig):
583
- identifier: str = "gpt2"
584
- truncate_direction: TruncationDirection = TruncationDirection.right
585
-
586
-
587
- @dataclass
588
- class WandbConfig(BaseConfig):
589
- project: Optional[str] = None
590
- entity: Optional[str] = "ai2-llm"
591
- group: Optional[str] = None
592
- name: Optional[str] = None
593
- tags: Optional[List[str]] = field(default_factory=lambda: ["watching"])
594
- log_artifacts: bool = False
595
- rank_zero_only: bool = True
596
- log_interval: int = 1
597
-
598
-
599
- @dataclass
600
- class SpeedMonitorConfig(BaseConfig):
601
- window_size: int = 100
602
- gpu_flops_available: Optional[Union[float, int]] = None
603
-
604
-
605
- @dataclass
606
- class CompilerConfig(BaseConfig):
607
- mode: Optional[str] = None
608
- """
609
- The mode to compile the model in. At the moment this can be "default",
610
- "reduce-overhead" (useful for smaller models/batches), or "max-autotune"
611
- (the fastest for larger models, but takes a long time to compile).
612
- """
613
-
614
- fullgraph: bool = False
615
- """
616
- Whether it is OK to break model into several subgraphs when compiling.
617
- Note that this is not compatible with FSDP.
618
- """
619
-
620
- backend: str = "inductor"
621
- """
622
- The backend to use.
623
- """
624
-
625
-
626
- class FSDPWrapStrategy(StrEnum):
627
- by_block = "by_block"
628
- """
629
- Wrap each OLMo block with its own FSDP instance.
630
- """
631
-
632
- by_block_and_size = "by_block_and_size"
633
- """
634
- Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well.
635
- """
636
-
637
- by_block_group = "by_block_group"
638
- """
639
- Wrap each block group together into its own FSDP instance.
640
- This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
641
- """
642
-
643
- by_block_group_and_size = "by_block_group_and_size"
644
- """
645
- Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well.
646
- """
647
-
648
- size_based = "size_based"
649
- """
650
- Used PyTorch's default size-based auto wrap policy.
651
- """
652
-
653
- one_in_two = "one_in_two"
654
- one_in_three = "one_in_three"
655
- one_in_four = "one_in_four"
656
- one_in_five = "one_in_five"
657
-
658
-
659
- class FSDPPrecision(StrEnum):
660
- pure = "pure"
661
- """
662
- Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``,
663
- and ``buffer_dtype`` all set to the autocast precision data type.
664
- """
665
-
666
- mixed = "mixed"
667
- """
668
- Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype``
669
- set to the autocast precision data type, while ``reduce_dtype`` is set to fp32.
670
- """
671
-
672
-
673
- @dataclass
674
- class FSDPConfig(BaseConfig):
675
- use_orig_params: bool = True
676
- """
677
- This must be ``True`` if using ``compile`` or you want to track the parameter norm during training.
678
- """
679
-
680
- sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
681
-
682
- wrapping_strategy: Optional[FSDPWrapStrategy] = None
683
- """
684
- The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level
685
- FSDP instance.
686
- """
687
-
688
- precision: FSDPPrecision = FSDPPrecision.pure
689
-
690
-
691
- class CheckpointType(StrEnum):
692
- sharded = "sharded"
693
- unsharded = "unsharded"
694
- sharded_ephemeral = "sharded_ephemeral"
695
-
696
-
697
- class ShardedCheckpointerType(StrEnum):
698
- torch_new = "torch_new"
699
- torch_legacy = "torch_legacy"
700
- local = "local"
701
-
702
-
703
- class ActivationCheckpointingStrategy(StrEnum):
704
- whole_layer = "whole_layer"
705
- """
706
- Checkpoint every transformer layer.
707
- """
708
-
709
- one_in_two = "one_in_two"
710
- """
711
- Checkpoint one in two transformer layers.
712
- """
713
-
714
- one_in_three = "one_in_three"
715
- """
716
- Checkpoint one in three transformer layers.
717
- """
718
-
719
- one_in_four = "one_in_four"
720
- """
721
- Checkpoint one in four transformer layers.
722
- """
723
-
724
- two_in_three = "two_in_three"
725
- """
726
- Checkpoint two out of every three transformer layers.
727
- """
728
-
729
- three_in_four = "three_in_four"
730
- """
731
- Checkpoint three out of four of every transformer layers.
732
- """
733
-
734
- fine_grained = "fine_grained"
735
- """
736
- Focus checkpointing on where it is cheap to recompute and saves most memory.
737
- """
738
-
739
-
740
- @dataclass
741
- class TrainConfig(BaseConfig):
742
- """
743
- OLMo training configuration.
744
- """
745
-
746
- run_name: Optional[str] = None
747
- """
748
- The name of the run.
749
- """
750
-
751
- seed: int = 6198
752
- """
753
- Used to seed all initial RNG states.
754
- """
755
-
756
- epoch: Optional[int] = None
757
- """
758
- Increment this when starting a new epoch.
759
- """
760
-
761
- dry_run: bool = False
762
- """
763
- If ``True``, don't actually train.
764
- """
765
-
766
- model: ModelConfig = field(default_factory=ModelConfig)
767
- """
768
- OLMo Model configuration.
769
- """
770
-
771
- optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
772
- """
773
- Optimizer configuration.
774
- """
775
-
776
- scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
777
- """
778
- Learning rate scheduler configuration.
779
- """
780
-
781
- data: DataConfig = field(default_factory=DataConfig)
782
- """
783
- Training data configuration.
784
- """
785
-
786
- restore_dataloader: bool = True
787
- """
788
- When restarting, restore the data loader to where it left off.
789
- If you restarting in order to train on a different dataset, set this to ``False``.
790
- """
791
-
792
- fast_forward_batches: Optional[int] = None
793
- """
794
- When restarting, use this to fast-forward the dataloader beyond the last checkpoint.
795
- This can be useful when restarting due to a loss spike in order to skip the data that
796
- corresponded to the spike.
797
- """
798
-
799
- evaluators: List[EvaluatorConfig] = field(default_factory=list)
800
- """
801
- Evaluation configurations.
802
- """
803
-
804
- eval_interval: int = 1000
805
- """
806
- How often (in terms of batches) to run evaluations.
807
- """
808
-
809
- tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
810
- """
811
- Tokenizer configuration.
812
- """
813
-
814
- save_folder: str = "./"
815
- """
816
- The directory to save checkpoints to.
817
- """
818
-
819
- remote_save_folder: Optional[str] = None
820
- """
821
- A folder in a cloud bucket to upload saved checkpoints to.
822
- """
823
-
824
- canceled_check_interval: int = 50
825
- """
826
- How often (in batches) to check if the run has been canceled or reached its time limit.
827
- """
828
-
829
- save_interval: int = 1000
830
- """
831
- How often (in terms of steps) to save sharded training state checkpoints.
832
- """
833
-
834
- save_interval_unsharded: Optional[int] = None
835
- """
836
- How often (if at all) to save unsharded training state checkpoint.
837
- For large models it can be costly to save these, so it usually makes sense to save
838
- these less often than regular (sharded) training checkpoints.
839
- """
840
-
841
- save_interval_ephemeral: Optional[int] = None
842
- """
843
- How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
844
- as those saved every `save_interval` except that at most only the most recent one of these is kept.
845
- This is useful when you want to checkpoint often for restarts in case of failures, but don't
846
- want to keep the majority of these checkpoints.
847
-
848
- For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
849
- a temporary checkpoint every 100 steps in case your job fails. In that case you would
850
- set `save_interval=1000` and `save_interval_ephemeral=100`.
851
- """
852
-
853
- save_num_checkpoints_to_keep: int = -1
854
- """
855
- How many sharded checkpoints to keep.
856
- """
857
-
858
- save_num_unsharded_checkpoints_to_keep: int = -1
859
- """
860
- How many unsharded checkpoints to keep.
861
- """
862
-
863
- save_overwrite: bool = False
864
- """
865
- If ``True``, overwrite any conflicting checkpoint files.
866
- """
867
-
868
- force_save_unsharded: bool = False
869
- """
870
- Save an unsharded checkpoint before training (even during a dry run).
871
- Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded
872
- checkpoint into an unsharded checkpoint.
873
- """
874
-
875
- no_pre_train_checkpoint: bool = False
876
- """
877
- Skip saving pre-train checkpoint.
878
- """
879
-
880
- load_path: Optional[str] = None
881
- """
882
- The path to a training checkpoint to restore/resume from.
883
-
884
- Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
885
- a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
886
- For example,
887
-
888
- ```bash
889
- --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
890
- ```
891
- """
892
-
893
- load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
894
- """
895
- The sharded checkpointer type to use to load the initial checkpoint from ``load_path``.
896
- """
897
-
898
- reset_optimizer_state: bool = False
899
- """
900
- When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
901
- We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
902
- curve (according to the current learning rate schedule settings), and continues from there.
903
- """
904
-
905
- reset_trainer_state: bool = False
906
- """
907
- When this is set we don't restore the trainer state from a checkpoint.
908
- """
909
-
910
- sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy
911
- """
912
- The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training.
913
- """
914
-
915
- new_style_checkpoints: Optional[bool] = None
916
- """
917
- Deprecated. Use ``sharded_checkpointer`` instead.
918
- """
919
-
920
- max_duration: Union[int, str] = 10000
921
- """
922
- How long to train for.
923
-
924
- If specified without a unit (the default), the units are assumed to be steps.
925
- You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until
926
- 2 trillion tokens.
927
- """
928
-
929
- global_train_batch_size: int = 512
930
- """
931
- The effective global batch size.
932
- """
933
-
934
- device_train_batch_size: Optional[int] = None # calculated automatically
935
- """
936
- Don't set this manually. This will be set to ``global_train_batch_size // world_size``.
937
- """
938
-
939
- device_train_microbatch_size: int = 16
940
- """
941
- The number of instances passed to the model in a single forward-backward pass. You should set
942
- this as large as you can based on available GPU memory.
943
- """
944
-
945
- device_eval_batch_size: int = 16
946
- """
947
- The number of evaluation instances passed to the model in a single forward pass on each device.
948
- """
949
-
950
- eval_subset_num_batches: int = -1
951
- """
952
- The number of batches to use for downstream evaluation from each dataset.
953
- """
954
-
955
- eval_on_load: bool = False
956
- """
957
- When resuming from a checkpoint, run the evaluation loop right away.
958
- """
959
-
960
- device_train_grad_accum: Optional[int] = None # calculated automatically
961
- """
962
- Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``.
963
- """
964
-
965
- max_grad_norm: Optional[float] = None
966
- """
967
- Clip gradient norms to this value if set.
968
- """
969
-
970
- max_grad_norm_ratio: Optional[float] = None
971
- """
972
- If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`.
973
- This takes priority over `max_grad_norm` when set.
974
- """
975
-
976
- precision: Optional[str] = None
977
- """
978
- Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32").
979
- """
980
-
981
- wandb: Optional[WandbConfig] = None
982
- """
983
- Weights & Biases configuration.
984
- """
985
-
986
- speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
987
- """
988
- Speed monitor configuration.
989
- """
990
-
991
- console_log_interval: int = 1
992
- """
993
- How often to log to the console.
994
- """
995
-
996
- compile: Optional[CompilerConfig] = None
997
- """
998
- Settings for compiling the model with ``torch.compile()``.
999
- """
1000
-
1001
- fsdp: FSDPConfig = field(default_factory=FSDPConfig)
1002
- """
1003
- Fully sharded data parallel settings.
1004
- """
1005
-
1006
- softmax_auxiliary_loss: bool = False
1007
- """
1008
- If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax
1009
- normalizing term to be close to 0.
1010
- """
1011
-
1012
- time_limit: Optional[float] = 60 * 60 * 47.5
1013
- """
1014
- The maximum amount of time to train for before saving a checkpoint and ending early.
1015
- On LUMI we have 48 hours max per job, so we default to just under 48 hours to give us time
1016
- to write out a final checkpoint.
1017
- """
1018
-
1019
- extra_steps_after_cancel: int = 10
1020
- """
1021
- Under certain conditions when a run is canceled we train for a few extra steps after saving
1022
- the final checkpoint so that when the run is restarted from the latest checkpoint we have some
1023
- overlap in metrics.
1024
- """
1025
-
1026
- early_stopping_factor: Optional[float] = None
1027
-
1028
- save_data_indices: bool = True
1029
- """
1030
- Save training data indices from each batch for each worker.
1031
- """
1032
-
1033
- python_profiling: bool = False
1034
- """
1035
- Whether to run the Python profiler on batches 6, 7, and 8.
1036
- """
1037
-
1038
- torch_profiling: bool = False
1039
- """
1040
- Whether to run the PyTorch profiler on batches 6, 7, and 8.
1041
- """
1042
-
1043
- stop_at: Optional[int] = None
1044
- """
1045
- Stop at a specific step.
1046
- """
1047
-
1048
- stop_after: Optional[int] = None
1049
- """
1050
- Stop after a specific number of steps.
1051
- """
1052
-
1053
- activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None
1054
- """
1055
- The activation checkpointing strategy to use.
1056
- """
1057
-
1058
- fused_loss: Optional[bool] = None
1059
- """
1060
- Whether to use the fused CE loss function from `flash-attn`.
1061
- """
1062
-
1063
- @property
1064
- def autocast_precision(self) -> torch.dtype:
1065
- if self.precision == "amp_bf16":
1066
- return torch.bfloat16
1067
- elif self.precision == "amp_fp16":
1068
- return torch.float16
1069
- elif self.precision == "fp32":
1070
- return torch.float32
1071
- else:
1072
- raise ValueError(f"Unexpected precision type '{self.precision}'")
1073
-
1074
- @property
1075
- def fsdp_precision(self) -> MixedPrecision:
1076
- if self.fsdp.precision == FSDPPrecision.pure:
1077
- return MixedPrecision(
1078
- param_dtype=self.autocast_precision,
1079
- reduce_dtype=self.autocast_precision,
1080
- buffer_dtype=self.autocast_precision,
1081
- )
1082
- elif self.fsdp.precision == FSDPPrecision.mixed:
1083
- return MixedPrecision(
1084
- param_dtype=self.autocast_precision,
1085
- reduce_dtype=torch.float32,
1086
- buffer_dtype=self.autocast_precision,
1087
- )
1088
- else:
1089
- raise NotImplementedError(f"{self.fsdp.precision}")
1090
-
1091
- @classmethod
1092
- def update_legacy_settings(cls, config: D) -> D:
1093
- new_config = config.copy()
1094
- if om.is_dict(new_config):
1095
- assert isinstance(new_config, DictConfig)
1096
-
1097
- if hasattr(new_config, "activation_checkpointing"):
1098
- if new_config.activation_checkpointing is False:
1099
- new_config.activation_checkpointing = None
1100
- if new_config.activation_checkpointing is True:
1101
- new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer
1102
-
1103
- if hasattr(new_config, "optimizer"):
1104
- new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)
1105
-
1106
- return new_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
olmo_bitnet_1b/configuration_olmo.py DELETED
@@ -1,52 +0,0 @@
1
- """
2
- OLMo configuration
3
- """
4
-
5
- from transformers import AutoConfig, PretrainedConfig
6
- from transformers.utils import logging
7
-
8
- from .config import ModelConfig
9
- from .aliases import PathOrStr
10
- from .beam_search import Sampler
11
- from .exceptions import OLMoError
12
- from .initialization import ModuleType
13
- from .optim import Optimizer
14
- from .util import StrEnum
15
- from .safetensors_util import STKey
16
- from .torch_util import seed_all
17
-
18
- logger = logging.get_logger(__name__)
19
-
20
-
21
- class OLMoConfig(PretrainedConfig):
22
- model_type = "olmo"
23
- keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
24
-
25
- def __init__(self, use_cache: bool = False, **kwargs):
26
- model_config = ModelConfig()
27
- all_kwargs = model_config.asdict()
28
- all_kwargs.update(kwargs)
29
- all_kwargs.update({"use_cache": use_cache})
30
- all_kwargs.update(
31
- {
32
- "architectures": all_kwargs.get("architectures", ["OLMoModelForCausalLM"])
33
- or ["OLMoModelForCausalLM"]
34
- }
35
- )
36
- super().__init__(**all_kwargs)
37
-
38
- @property
39
- def num_attention_heads(self):
40
- return self.n_heads
41
-
42
- @property
43
- def num_hidden_layers(self):
44
- return self.n_layers
45
-
46
- @property
47
- def hidden_size(self):
48
- return self.d_model
49
-
50
-
51
- # Register the config class so that it is available for transformer pipelines, auto-loading etc.
52
- # AutoConfig.register("olmo", OLMoConfig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
olmo_bitnet_1b/exceptions.py DELETED
@@ -1,50 +0,0 @@
1
- __all__ = [
2
- "OLMoError",
3
- "OLMoConfigurationError",
4
- "OLMoCliError",
5
- "OLMoEnvironmentError",
6
- "OLMoNetworkError",
7
- "OLMoCheckpointError",
8
- ]
9
-
10
-
11
- class OLMoError(Exception):
12
- """
13
- Base class for all custom OLMo exceptions.
14
- """
15
-
16
-
17
- class OLMoConfigurationError(OLMoError):
18
- """
19
- An error with a configuration file.
20
- """
21
-
22
-
23
- class OLMoCliError(OLMoError):
24
- """
25
- An error from incorrect CLI usage.
26
- """
27
-
28
-
29
- class OLMoEnvironmentError(OLMoError):
30
- """
31
- An error from incorrect environment variables.
32
- """
33
-
34
-
35
- class OLMoNetworkError(OLMoError):
36
- """
37
- An error with a network request.
38
- """
39
-
40
-
41
- class OLMoCheckpointError(OLMoError):
42
- """
43
- An error occurred reading or writing from a checkpoint.
44
- """
45
-
46
-
47
- class OLMoThreadError(Exception):
48
- """
49
- Raised when a thread fails.
50
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
olmo_bitnet_1b/initialization.py DELETED
@@ -1,95 +0,0 @@
1
- import math
2
- from typing import Optional, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from .config import InitFnType, ModelConfig
8
- from .util import StrEnum
9
-
10
- __all__ = ["init_weights", "ModuleType"]
11
-
12
-
13
- class ModuleType(StrEnum):
14
- in_module = "in"
15
- out_module = "out"
16
- emb = "emb"
17
- final_out = "final_out"
18
-
19
-
20
- def init_weights(
21
- config: ModelConfig,
22
- module: Union[nn.Linear, nn.Embedding],
23
- d: Optional[int] = None,
24
- layer_id: Optional[int] = None,
25
- std_factor: float = 1.0,
26
- type_of_module: Optional[ModuleType] = None,
27
- ) -> None:
28
- """
29
- Initialize weights of a linear or embedding module.
30
-
31
- :param config: The model config.
32
- :param module: The linear or embedding submodule to initialize.
33
- :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
34
- for fused layers.
35
- :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
36
- ``1 / sqrt(2 * (layer_id + 1))``.
37
- """
38
- d = d if d is not None else config.d_model
39
- if config.init_fn == InitFnType.normal:
40
- std = config.init_std * std_factor
41
- if config.init_cutoff_factor is not None:
42
- cutoff_value = config.init_cutoff_factor * std
43
- nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
44
- else:
45
- nn.init.normal_(module.weight, mean=0.0, std=std)
46
- elif config.init_fn == InitFnType.mitchell:
47
- std = std_factor / math.sqrt(d)
48
- if layer_id is not None:
49
- std = std / math.sqrt(2 * (layer_id + 1))
50
- nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
51
- elif config.init_fn == InitFnType.kaiming_normal:
52
- nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
53
- elif config.init_fn == InitFnType.fan_in:
54
- std = std_factor / math.sqrt(d)
55
- nn.init.normal_(module.weight, mean=0.0, std=std)
56
- elif config.init_fn == InitFnType.full_megatron:
57
- if type_of_module is None:
58
- raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
59
-
60
- cutoff_factor = config.init_cutoff_factor
61
- if cutoff_factor is None:
62
- cutoff_factor = 3
63
-
64
- if type_of_module == ModuleType.in_module:
65
- # for att_proj (same as QKV), ff_proj
66
- std = config.init_std
67
- elif type_of_module == ModuleType.out_module:
68
- # for attn_out, ff_out
69
- std = config.init_std / math.sqrt(2.0 * config.n_layers)
70
- elif type_of_module == ModuleType.emb:
71
- # positional embeddings (wpe)
72
- # token embeddings (wte)
73
- std = config.init_std
74
- elif type_of_module == ModuleType.final_out:
75
- # final output (ff_out)
76
- std = config.d_model**-0.5
77
- else:
78
- raise RuntimeError(f"Unknown module type '{type_of_module}'")
79
- nn.init.trunc_normal_(
80
- module.weight,
81
- mean=0.0,
82
- std=std,
83
- a=-cutoff_factor * std,
84
- b=cutoff_factor * std,
85
- )
86
- else:
87
- raise NotImplementedError(config.init_fn)
88
-
89
- if isinstance(module, nn.Linear):
90
- if module.bias is not None:
91
- nn.init.zeros_(module.bias)
92
-
93
- if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
94
- with torch.no_grad():
95
- module.weight.div_(math.sqrt(2 * config.n_layers))