English
naveensp commited on
Commit
a816a73
·
verified ·
1 Parent(s): 7496fe7

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +1824 -0
model.py ADDED
@@ -0,0 +1,1824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from
3
+ [MosaiclML](https://github.com/mosaicml/examples.git) and
4
+ [minGPT](https://github.com/karpathy/minGPT.git)
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import math
11
+ import sys
12
+ from abc import abstractmethod
13
+ from collections import defaultdict
14
+ from functools import partial
15
+ from typing import (
16
+ Callable,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ NamedTuple,
21
+ Optional,
22
+ Sequence,
23
+ Set,
24
+ Tuple,
25
+ cast,
26
+ )
27
+
28
+ import torch
29
+ import torch.backends.cuda
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch import einsum
33
+
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast
35
+
36
+ from .aliases import PathOrStr
37
+ from .beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
38
+ from .config import (
39
+ ActivationCheckpointingStrategy,
40
+ ActivationType,
41
+ BlockType,
42
+ CheckpointType,
43
+ FSDPWrapStrategy,
44
+ LayerNormType,
45
+ ModelConfig,
46
+ )
47
+ from .exceptions import OLMoConfigurationError
48
+ from .initialization import ModuleType, init_weights
49
+ from .torch_util import ensure_finite_
50
+
51
+ import copy
52
+ if sys.version_info.minor > 8:
53
+ from collections.abc import MutableMapping
54
+ elif sys.version_info.minor == 8:
55
+ from typing import MutableMapping
56
+ else:
57
+ raise SystemExit("This script supports Python 3.8 or higher")
58
+
59
+ __all__ = [
60
+ "LayerNormBase",
61
+ "LayerNorm",
62
+ "RMSLayerNorm",
63
+ "RotaryEmbedding",
64
+ "Activation",
65
+ "GELU",
66
+ "ReLU",
67
+ "SwiGLU",
68
+ "BitLinear158",
69
+ "OLMoBlock",
70
+ "OLMoSequentialBlock",
71
+ "OLMoParallelBlock",
72
+ "OLMo",
73
+ "OLMoOutput",
74
+ "OLMoGenerateOutput",
75
+ ]
76
+
77
+
78
+ log = logging.getLogger(__name__)
79
+
80
+
81
+ def activation_checkpoint_function(cfg: ModelConfig):
82
+ preserve_rng_state = (
83
+ (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
84
+ )
85
+ from torch.utils.checkpoint import checkpoint
86
+
87
+ return partial(
88
+ checkpoint,
89
+ preserve_rng_state=preserve_rng_state,
90
+ use_reentrant=False,
91
+ )
92
+
93
+
94
+ class BufferCache(dict, MutableMapping[str, torch.Tensor]):
95
+ """
96
+ Cache for attention biases and other things that would normally be stored as buffers.
97
+ We avoid using buffers because we've run into various issues doing so with FSDP.
98
+ In general it appears the way FSDP handles buffers is not well-defined.
99
+ It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
100
+ since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
101
+ NaNs when they're synchronized due to casting or some other issue.
102
+ """
103
+
104
+
105
+ def _non_meta_init_device(config: ModelConfig) -> torch.device:
106
+ if config.init_device is not None and config.init_device != "meta":
107
+ return torch.device(config.init_device)
108
+ else:
109
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
110
+
111
+
112
+ class Dropout(nn.Dropout):
113
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
114
+ if self.p == 0.0:
115
+ return input
116
+ else:
117
+ return F.dropout(input, self.p, self.training, self.inplace)
118
+
119
+
120
+ class LayerNormBase(nn.Module):
121
+ def __init__(
122
+ self,
123
+ config: ModelConfig,
124
+ *,
125
+ size: Optional[int] = None,
126
+ elementwise_affine: Optional[bool] = True,
127
+ eps: float = 1e-05,
128
+ ):
129
+ super().__init__()
130
+ self.config = config
131
+ self.eps = eps
132
+ self.normalized_shape = (size or config.d_model,)
133
+ if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
134
+ self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
135
+ use_bias = self.config.bias_for_layer_norm
136
+ if use_bias is None:
137
+ use_bias = self.config.include_bias
138
+ if use_bias:
139
+ self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
140
+ else:
141
+ self.register_parameter("bias", None)
142
+ else:
143
+ self.register_parameter("bias", None)
144
+ self.register_parameter("weight", None)
145
+
146
+ @abstractmethod
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ raise NotImplementedError
149
+
150
+ @classmethod
151
+ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
152
+ if config.layer_norm_type == LayerNormType.default:
153
+ return LayerNorm(config, size=size, low_precision=False, **kwargs)
154
+ elif config.layer_norm_type == LayerNormType.low_precision:
155
+ return LayerNorm(config, size=size, low_precision=True, **kwargs)
156
+ elif config.layer_norm_type == LayerNormType.rms:
157
+ return RMSLayerNorm(config, size=size, **kwargs)
158
+ else:
159
+ raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
160
+
161
+ def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
162
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
163
+ # `is_autocast_cpu_enabled()` for CPU autocast.
164
+ # See https://github.com/pytorch/pytorch/issues/110966.
165
+ if tensor.device.type == "cuda" and torch.is_autocast_enabled():
166
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
167
+ elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
168
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
169
+ else:
170
+ return tensor
171
+
172
+ def reset_parameters(self):
173
+ if self.weight is not None:
174
+ torch.nn.init.ones_(self.weight) # type: ignore
175
+ if self.bias is not None:
176
+ torch.nn.init.zeros_(self.bias) # type: ignore
177
+
178
+
179
+ class LayerNorm(LayerNormBase):
180
+ """
181
+ The default :class:`LayerNorm` implementation which can optionally run in low precision.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ config: ModelConfig,
187
+ size: Optional[int] = None,
188
+ low_precision: bool = False,
189
+ elementwise_affine: Optional[bool] = None,
190
+ eps: float = 1e-05,
191
+ ):
192
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
193
+ self.low_precision = low_precision
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ if self.low_precision:
197
+ module_device = x.device
198
+ downcast_x = self._cast_if_autocast_enabled(x)
199
+ downcast_weight = (
200
+ self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
201
+ )
202
+ downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
203
+ with torch.autocast(enabled=False, device_type=module_device.type):
204
+ return F.layer_norm(
205
+ downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
206
+ )
207
+ else:
208
+ return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
209
+
210
+
211
+ class RMSLayerNorm(LayerNormBase):
212
+ """
213
+ RMS layer norm, a simplified :class:`LayerNorm` implementation
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ config: ModelConfig,
219
+ size: Optional[int] = None,
220
+ elementwise_affine: Optional[bool] = None,
221
+ eps: float = 1e-5,
222
+ ):
223
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
224
+
225
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
226
+ with torch.autocast(enabled=False, device_type=x.device.type):
227
+ og_dtype = x.dtype
228
+ x = x.to(torch.float32)
229
+ variance = x.pow(2).mean(-1, keepdim=True)
230
+ x = x * torch.rsqrt(variance + self.eps)
231
+ x = x.to(og_dtype)
232
+
233
+ if self.weight is not None:
234
+ if self.bias is not None:
235
+ return self.weight * x + self.bias
236
+ else:
237
+ return self.weight * x
238
+ else:
239
+ return x
240
+
241
+
242
+ class RotaryEmbedding(nn.Module):
243
+ """
244
+ [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
245
+ """
246
+
247
+ def __init__(self, config: ModelConfig, cache: BufferCache):
248
+ super().__init__()
249
+ self.config = config
250
+ self.__cache = cache
251
+ # Warm up cache.
252
+ self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
253
+
254
+ def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
255
+ if (
256
+ (pos_sin := self.__cache.get("rope_pos_sin")) is not None
257
+ and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
258
+ and pos_sin.shape[-2] >= seq_len
259
+ and pos_cos.shape[-2] >= seq_len
260
+ ):
261
+ if pos_sin.device != device:
262
+ pos_sin = pos_sin.to(device)
263
+ self.__cache["rope_pos_sin"] = pos_sin
264
+ if pos_cos.device != device:
265
+ pos_cos = pos_cos.to(device)
266
+ self.__cache["rope_pos_cos"] = pos_cos
267
+ return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
268
+
269
+ with torch.autocast(device.type, enabled=False):
270
+ dim = self.config.d_model // self.config.n_heads
271
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
272
+ seq = torch.arange(seq_len, device=device, dtype=torch.float)
273
+ freqs = einsum("i , j -> i j", seq, inv_freq)
274
+ positions = torch.cat((freqs, freqs), dim=-1)
275
+ pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
276
+ self.__cache["rope_pos_sin"] = pos_sin
277
+ self.__cache["rope_pos_cos"] = pos_cos
278
+ return pos_sin, pos_cos
279
+
280
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
281
+ B, nh, T, hs = x.size()
282
+ x = x.view(B, nh, T, 2, hs // 2)
283
+ x1, x2 = x.unbind(dim=-2)
284
+ return torch.cat((-x2, x1), dim=-1)
285
+
286
+ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
287
+ return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
288
+
289
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
290
+ if self.config.rope_full_precision:
291
+ q_, k_ = q.float(), k.float()
292
+ else:
293
+ q_, k_ = q, k
294
+
295
+ with torch.autocast(q.device.type, enabled=False):
296
+ query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
297
+ pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
298
+ pos_sin = pos_sin.type_as(q_)
299
+ pos_cos = pos_cos.type_as(q_)
300
+ q_ = self.apply_rotary_pos_emb(
301
+ pos_sin[:, :, key_len - query_len : key_len, :],
302
+ pos_cos[:, :, key_len - query_len : key_len, :],
303
+ q_,
304
+ )
305
+ k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
306
+ return q_.type_as(q), k_.type_as(k)
307
+
308
+
309
+ class Activation(nn.Module):
310
+ def __init__(self, config: ModelConfig):
311
+ super().__init__()
312
+ self.config = config
313
+
314
+ @abstractmethod
315
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
316
+ raise NotImplementedError
317
+
318
+ @property
319
+ @abstractmethod
320
+ def output_multiplier(self) -> float:
321
+ raise NotImplementedError
322
+
323
+ @classmethod
324
+ def build(cls, config: ModelConfig) -> Activation:
325
+ if config.activation_type == ActivationType.gelu:
326
+ return cast(Activation, GELU(approximate="none"))
327
+ elif config.activation_type == ActivationType.relu:
328
+ return cast(Activation, ReLU(inplace=False))
329
+ elif config.activation_type == ActivationType.swiglu:
330
+ return SwiGLU(config)
331
+ else:
332
+ raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
333
+
334
+
335
+ class GELU(nn.GELU):
336
+ @property
337
+ def output_multiplier(self) -> float:
338
+ return 1.0
339
+
340
+
341
+ class ReLU(nn.ReLU):
342
+ @property
343
+ def output_multiplier(self) -> float:
344
+ return 1.0
345
+
346
+
347
+ class SwiGLU(Activation):
348
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
349
+ x, gate = x.chunk(2, dim=-1)
350
+ return F.silu(gate) * x
351
+
352
+ @property
353
+ def output_multiplier(self) -> float:
354
+ return 0.5
355
+
356
+
357
+ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
358
+ att_bias = torch.triu(
359
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
360
+ diagonal=1,
361
+ )
362
+ att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
363
+ return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
364
+
365
+
366
+ def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
367
+ if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
368
+ if causal_bias.device != device:
369
+ causal_bias = causal_bias.to(device)
370
+ cache["causal_attention_bias"] = causal_bias
371
+ return causal_bias
372
+ with torch.autocast(device.type, enabled=False):
373
+ causal_bias = causal_attention_bias(seq_len, device)
374
+ cache["causal_attention_bias"] = causal_bias
375
+ return causal_bias
376
+
377
+
378
+ def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
379
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
380
+
381
+ # shape: (1, 1, seq_len, seq_len)
382
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
383
+ alibi_bias.abs_().mul_(-1)
384
+
385
+ # shape: (n_heads,)
386
+ m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
387
+ m.mul_(config.alibi_bias_max / config.n_heads)
388
+
389
+ # shape: (1, n_heads, seq_len, seq_len)
390
+ return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
391
+
392
+ def activation_quant(x):
393
+ """Per−token quantization to 8 bits. No grouping is needed for quantization.
394
+ Args:
395
+ x: an activation tensor with shape [n, d]
396
+ Returns:
397
+ y: a quantized activation tensor with shape [n, d]
398
+ """
399
+ scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
400
+ y = (x * scale).round().clamp_(-128, 127) / scale
401
+ return y
402
+
403
+ def weight_quant(w):
404
+ """Per−tensor quantization to 1.58 bits. No grouping is needed for quantization.
405
+ Args:
406
+ w: a weight tensor with shape [d, k]
407
+ Returns:
408
+ u: a quantized weight with shape [d, k]
409
+ """
410
+ scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
411
+ u = (w * scale).round().clamp_(-1, 1) / scale
412
+ return u
413
+
414
+ def activation_norm_quant(x):
415
+ """
416
+ same as activation_quant definition - but returning y and scale seperately
417
+ Args:
418
+ x: an activation tensor with shape [n, d]
419
+ Returns:
420
+ y: a quantized activation tensor with shape [n, d]
421
+ scale: a scalar for dequantization with shape [1]
422
+ """
423
+ scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
424
+ y = (x * scale).round().clamp_(-128, 127)
425
+ return y, scale
426
+
427
+ def gemm_lowbit_kernel(x, w):
428
+ y = F.linear(x, w)
429
+ return y
430
+
431
+ class BitLinear158(nn.Linear):
432
+ """
433
+ This is only for training, and kernel optimization is needed for efficiency.
434
+ """
435
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
436
+ device=None, dtype=None, config=None):
437
+ super().__init__(in_features, out_features, bias, device, dtype)
438
+ self.norm = RMSLayerNorm(config, elementwise_affine=False)
439
+
440
+ def forward(self, x):
441
+ """
442
+ Args:
443
+ x: an input tensor with shape [n, d]
444
+ Returns:
445
+ y: an output tensor with shape [n, d]
446
+ """
447
+ w = self.weight # a weight tensor with shape [d, k]
448
+ x_norm = self.norm(x)
449
+ # Atrick for implementing Straight−Through−Estimator (STE) using detach()
450
+ x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
451
+ w_quant = w + (weight_quant(w) - w).detach()
452
+ y = F.linear(x_quant, w_quant)
453
+ return y
454
+
455
+ class BitLinear158_inference(nn.Linear):
456
+ """
457
+ Use quantized weights for inference .
458
+ """
459
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
460
+ device=None, dtype=None, config=None):
461
+ super().__init__(in_features, out_features, bias, device, dtype)
462
+ self.norm = RMSLayerNorm(config, elementwise_affine=False)
463
+ self.weight_scale = nn.Parameter(torch.ones(1))
464
+
465
+ def forward(self, x):
466
+ """
467
+ Args:
468
+ x: an input tensor with shape [n, d]
469
+ Returns:
470
+ y: an output tensor with shape [n, d]
471
+ """
472
+ w = self.weight # a 1.58−bit weight tensor with shape [d, k]
473
+ w_scale = self.weight_scale # a full−precision weight scale tensor with shape [1]
474
+ x_norm = self.norm(x)
475
+ x_quant, x_scale = activation_norm_quant(x_norm)
476
+ y = gemm_lowbit_kernel(x_quant, w) / w_scale / x_scale
477
+ return y
478
+
479
+
480
+ class OLMoBlock(nn.Module):
481
+ """
482
+ A base class for transformer block implementations.
483
+ """
484
+
485
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
486
+ super().__init__()
487
+ self.layer_id = layer_id
488
+ self.config = config
489
+ self.hidden_size = (
490
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
491
+ )
492
+ self.__cache = cache
493
+ assert config.d_model % config.n_heads == 0
494
+
495
+ self._activation_checkpoint_fn = None
496
+
497
+ Linear = BitLinear158_inference if config.inference_mode else BitLinear158 if config.ternary else nn.Linear
498
+
499
+ # Dropout.
500
+ self.dropout = Dropout(config.residual_dropout)
501
+
502
+ # Layer norms.
503
+ self.k_norm: Optional[LayerNormBase] = None
504
+ self.q_norm: Optional[LayerNormBase] = None
505
+ if config.attention_layer_norm:
506
+ self.k_norm = LayerNormBase.build(
507
+ config,
508
+ size=config.d_model // config.n_heads if config.multi_query_attention else None,
509
+ elementwise_affine=config.attention_layer_norm_with_affine,
510
+ )
511
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
512
+
513
+ # Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
514
+ if config.clip_qkv is not None:
515
+ assert config.clip_qkv > 0
516
+
517
+ # Activation function.
518
+ self.act = Activation.build(config)
519
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
520
+
521
+ # Attention output projection.
522
+ self.attn_out = Linear(
523
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device,
524
+ config=config
525
+ )
526
+
527
+ # Feed-forward output projection.
528
+ self.ff_out = Linear(
529
+ int(self.act.output_multiplier * self.hidden_size),
530
+ config.d_model,
531
+ bias=config.include_bias,
532
+ device=config.init_device,
533
+ config=config,
534
+ )
535
+ self.ff_out._is_residual = True # type: ignore
536
+
537
+ # Rotary embeddings.
538
+ if self.config.rope:
539
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
540
+
541
+ def reset_parameters(self):
542
+ if self.k_norm is not None:
543
+ self.k_norm.reset_parameters()
544
+ if self.q_norm is not None:
545
+ self.q_norm.reset_parameters()
546
+ init_weights(
547
+ self.config,
548
+ self.attn_out,
549
+ d=self.config.d_model,
550
+ layer_id=self.layer_id,
551
+ type_of_module=ModuleType.out_module,
552
+ )
553
+ init_weights(
554
+ self.config,
555
+ self.ff_out,
556
+ d=self.ff_out.in_features,
557
+ layer_id=self.layer_id,
558
+ type_of_module=ModuleType.out_module,
559
+ )
560
+
561
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
562
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
563
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
564
+ else:
565
+ self._activation_checkpoint_fn = None
566
+
567
+ @classmethod
568
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
569
+ target_dtype = input_dtype
570
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
571
+ # `is_autocast_cpu_enabled()` for CPU autocast.
572
+ # See https://github.com/pytorch/pytorch/issues/110966.
573
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
574
+ target_dtype = torch.get_autocast_gpu_dtype()
575
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
576
+ target_dtype = torch.get_autocast_cpu_dtype()
577
+ if bias.dtype != target_dtype:
578
+ bias = bias.to(target_dtype)
579
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
580
+ return bias
581
+
582
+ def _scaled_dot_product_attention(
583
+ self,
584
+ q: torch.Tensor,
585
+ k: torch.Tensor,
586
+ v: torch.Tensor,
587
+ attn_mask: Optional[torch.Tensor] = None,
588
+ dropout_p: float = 0.0,
589
+ is_causal: bool = False,
590
+ ) -> torch.Tensor:
591
+ """
592
+ Computes scaled dot product attention on query, key and value tensors, using an optional
593
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
594
+
595
+ This method is based on PyTorch's `scaled_dot_product_attention`.
596
+ """
597
+ return F.scaled_dot_product_attention(
598
+ q,
599
+ k,
600
+ v,
601
+ attn_mask=attn_mask,
602
+ dropout_p=dropout_p,
603
+ is_causal=is_causal,
604
+ )
605
+
606
+ def attention(
607
+ self,
608
+ q: torch.Tensor,
609
+ k: torch.Tensor,
610
+ v: torch.Tensor,
611
+ attention_bias: Optional[torch.Tensor] = None,
612
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
613
+ use_cache: bool = False,
614
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
615
+ B, T, C = q.size() # batch size, sequence length, d_model
616
+ dtype = k.dtype
617
+
618
+ # Optionally apply layer norm to keys and queries.
619
+ if self.q_norm is not None and self.k_norm is not None:
620
+ q = self.q_norm(q).to(dtype=dtype)
621
+ k = self.k_norm(k).to(dtype=dtype)
622
+
623
+ # Move head forward to be next to the batch dim.
624
+ # shape: (B, nh, T, hs)
625
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
626
+ if self.config.multi_query_attention:
627
+ # shape: (B, 1, T, hs)
628
+ k = k.view(B, T, 1, C // self.config.n_heads).transpose(1, 2)
629
+ # shape: (B, 1, T, hs)
630
+ v = v.view(B, T, 1, C // self.config.n_heads).transpose(1, 2)
631
+ else:
632
+ # shape: (B, nh, T, hs)
633
+ k = k.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
634
+ # shape: (B, nh, T, hs)
635
+ v = v.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
636
+
637
+ if layer_past is not None:
638
+ past_key, past_value = layer_past
639
+ k = torch.cat((past_key, k), dim=-2)
640
+ v = torch.cat((past_value, v), dim=-2)
641
+
642
+ present = (k, v) if use_cache else None
643
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
644
+
645
+ if self.config.rope:
646
+ # Apply rotary embeddings.
647
+ q, k = self.rotary_emb(q, k)
648
+
649
+ if attention_bias is not None:
650
+ # Resize and cast attention bias.
651
+ # The current dtype of the attention bias might not match the dtype that the SDP attn function will
652
+ # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
653
+ # as down-casting the attention bias to the autocast precision will result in -infs, which will
654
+ # cause the SDP attn function to produce NaNs.
655
+ attention_bias = self._cast_attn_bias(
656
+ attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
657
+ )
658
+
659
+ # Get the attention scores.
660
+ # shape: (B, nh, T, hs)
661
+ att = self._scaled_dot_product_attention(
662
+ q,
663
+ k,
664
+ v,
665
+ attn_mask=attention_bias,
666
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
667
+ is_causal=attention_bias is None,
668
+ )
669
+
670
+ # Re-assemble all head outputs side-by-side.
671
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
672
+
673
+ # Apply output projection.
674
+ return self.attn_out(att), present
675
+
676
+ @abstractmethod
677
+ def forward(
678
+ self,
679
+ x: torch.Tensor,
680
+ attention_bias: Optional[torch.FloatTensor] = None,
681
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
682
+ use_cache: bool = False,
683
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
684
+ raise NotImplementedError
685
+
686
+ @classmethod
687
+ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBlock:
688
+ if config.block_type == BlockType.sequential:
689
+ return OLMoSequentialBlock(layer_id, config, cache)
690
+ elif config.block_type == BlockType.parallel:
691
+ return OLMoParallelBlock(layer_id, config, cache)
692
+ elif config.block_type == BlockType.llama:
693
+ return OLMoLlamaBlock(layer_id, config, cache)
694
+ else:
695
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
696
+
697
+
698
+ class OLMoSequentialBlock(OLMoBlock):
699
+ """
700
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
701
+ (plus another skip connection).
702
+ """
703
+
704
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
705
+ super().__init__(layer_id, config, cache)
706
+ # Layer norms.
707
+ self.attn_norm = LayerNorm.build(config)
708
+ self.ff_norm = LayerNorm.build(config)
709
+ Linear = BitLinear158_inference if config.inference_mode else BitLinear158 if config.ternary else nn.Linear
710
+ # Attention input projection. Projects x -> (q, k, v)
711
+ if config.multi_query_attention:
712
+ self.fused_dims = (config.d_model, config.d_model // config.n_heads, config.d_model // config.n_heads)
713
+ else:
714
+ self.fused_dims = (config.d_model, config.d_model, config.d_model)
715
+ self.att_proj = Linear(
716
+ config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device,
717
+ config=config
718
+ )
719
+ # Feed-forward input projection.
720
+ self.ff_proj = Linear(
721
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device,
722
+ config=config
723
+ )
724
+
725
+ def reset_parameters(self):
726
+ super().reset_parameters()
727
+ self.attn_norm.reset_parameters()
728
+ self.ff_norm.reset_parameters()
729
+ # NOTE: the standard deviation for these weights does not depend on the layer.
730
+ init_weights(
731
+ self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
732
+ )
733
+ init_weights(
734
+ self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
735
+ )
736
+
737
+ def forward(
738
+ self,
739
+ x: torch.Tensor,
740
+ attention_bias: Optional[torch.Tensor] = None,
741
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
742
+ use_cache: bool = False,
743
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
744
+ # Get query, key, value projections.
745
+ # shape:
746
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
747
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
748
+ # k, v: (batch_size, seq_len, d_model // n_heads)
749
+ if self._activation_checkpoint_fn is not None:
750
+ qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x))
751
+ else:
752
+ qkv = self.att_proj(self.attn_norm(x))
753
+
754
+ if self.config.clip_qkv is not None:
755
+ qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
756
+
757
+ q, k, v = qkv.split(self.fused_dims, dim=-1)
758
+
759
+ # Get attention scores.
760
+ if self._activation_checkpoint_fn is not None:
761
+ att, cache = self._activation_checkpoint_fn( # type: ignore
762
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
763
+ )
764
+ else:
765
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
766
+
767
+ # Add attention scores.
768
+ # shape: (B, T, C)
769
+ x = x + self.dropout(att)
770
+
771
+ # Add feed-forward projection.
772
+ # shape: (batch_size, seq_len, d_model)
773
+ og_x = x
774
+ if self._activation_checkpoint_fn is not None:
775
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
776
+ else:
777
+ x = self.ff_norm(x)
778
+ x = self.ff_proj(x)
779
+ if self._activation_checkpoint_fn is not None:
780
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
781
+ else:
782
+ x = self.act(x)
783
+ x = self.ff_out(x)
784
+ x = self.dropout(x)
785
+ x = og_x + x
786
+
787
+ return x, cache
788
+
789
+
790
+ class OLMoParallelBlock(OLMoBlock):
791
+ """
792
+ This is a transformer block where the output is computed as ``MLP(LN(x)) + Attention(LN(x))``
793
+ as in the PaLM architecture, as opposed to the typical ``MLP(LN(x + Attention(LN(x))))``
794
+ as in :class:`OLMoSequentialBlock` (ignoring some skip connections).
795
+
796
+ The decoupling of the MLP and Attention functions allow us to fuse the separate input projections
797
+ into a single linear layer to increase throughput. In this configuration it's also straight-forward
798
+ to fuse the output projections, but we found that didn't help.
799
+ """
800
+
801
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
802
+ super().__init__(layer_id, config, cache)
803
+ self.norm = LayerNorm.build(config)
804
+ Linear = BitLinear158_inference if config.inference_mode else BitLinear158 if config.ternary else nn.Linear
805
+
806
+ # Fused attention and feed-forward projection.
807
+ # NOTE: we could also fuse the attention and feed-forward output projections but we
808
+ # found that didn't help, possibly because of the overhead of joining the `att` and
809
+ # `ff` activations together. See https://github.com/allenai/LLM/pull/79 for details.
810
+ if config.multi_query_attention:
811
+ self.fused_dims = (
812
+ config.d_model,
813
+ config.d_model // config.n_heads,
814
+ config.d_model // config.n_heads,
815
+ self.hidden_size,
816
+ )
817
+ else:
818
+ self.fused_dims = (config.d_model, config.d_model, config.d_model, self.hidden_size)
819
+ self.fused_attn_ff_proj = Linear(
820
+ config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device,
821
+ config=config
822
+ )
823
+
824
+ def reset_parameters(self):
825
+ super().reset_parameters()
826
+ self.norm.reset_parameters()
827
+ # NOTE: the standard deviation for these weights does not depend on the layer.
828
+ init_weights(
829
+ self.config,
830
+ self.fused_attn_ff_proj,
831
+ d=self.config.d_model,
832
+ layer_id=None,
833
+ type_of_module=ModuleType.in_module,
834
+ )
835
+
836
+ def forward(
837
+ self,
838
+ x: torch.Tensor,
839
+ attention_bias: Optional[torch.Tensor] = None,
840
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
841
+ use_cache: bool = False,
842
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
843
+ # Get query, key, value, and feed-forward projections.
844
+ # shape of q, k, v:
845
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
846
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
847
+ # k, v: (batch_size, seq_len, d_model // n_heads)
848
+ # shape of ff: (batch_size, seq_len, hidden_size)
849
+ if self._activation_checkpoint_fn is not None:
850
+ q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split(
851
+ self.fused_dims, dim=-1
852
+ )
853
+ else:
854
+ q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1)
855
+
856
+ if self.config.clip_qkv is not None:
857
+ q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
858
+ k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
859
+ v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
860
+
861
+ # Get attention scores.
862
+ # shape: (B, T, C)
863
+ if self._activation_checkpoint_fn is not None:
864
+ att, cache = self._activation_checkpoint_fn( # type: ignore
865
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
866
+ )
867
+ else:
868
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
869
+
870
+ # Apply output projections (and activation function) and sum the results.
871
+ # We keep these projections separate because we found that we got better throughput this
872
+ # way compared to fusing them.
873
+ if self._activation_checkpoint_fn is not None:
874
+ return (
875
+ x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att),
876
+ cache,
877
+ )
878
+ else:
879
+ return (
880
+ x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att),
881
+ cache,
882
+ )
883
+
884
+
885
+ class OLMoLlamaBlock(OLMoBlock):
886
+ """
887
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
888
+ (plus another skip connection). This block is similar to `OLMoSequentialBlock`
889
+ but some operations have slightly different implementations to imitate the
890
+ behavior of Llama.
891
+ """
892
+
893
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
894
+ super().__init__(layer_id, config, cache)
895
+ # Layer norms.
896
+ self.attn_norm = LayerNorm.build(config)
897
+ self.ff_norm = LayerNorm.build(config)
898
+ self.__cache = cache
899
+ Linear = BitLinear158_inference if config.inference_mode else BitLinear158 if config.ternary else nn.Linear
900
+
901
+
902
+ # Attention input projection. Projects x -> (q, k, v)
903
+ if config.multi_query_attention:
904
+ q_proj_out_dim = config.d_model
905
+ k_proj_out_dim = config.d_model // config.n_heads
906
+ v_proj_out_dim = config.d_model // config.n_heads
907
+ else:
908
+ q_proj_out_dim = config.d_model
909
+ k_proj_out_dim = config.d_model
910
+ v_proj_out_dim = config.d_model
911
+ self.q_proj = Linear(
912
+ config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device,
913
+ config=config
914
+ )
915
+ self.k_proj = Linear(
916
+ config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device,
917
+ config=config
918
+ )
919
+ self.v_proj = Linear(
920
+ config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device,
921
+ config=config
922
+ )
923
+
924
+ # Feed-forward input projection.
925
+ self.ff_proj = Linear(
926
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device,
927
+ config=config
928
+ )
929
+
930
+ def reset_parameters(self):
931
+ super().reset_parameters()
932
+ if self.attn_norm:
933
+ self.attn_norm.reset_parameters()
934
+ self.ff_norm.reset_parameters()
935
+ # NOTE: the standard deviation for these weights does not depend on the layer.
936
+ init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
937
+ init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
938
+ init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
939
+ init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
940
+
941
+ def _scaled_dot_product_attention(
942
+ self,
943
+ q: torch.Tensor,
944
+ k: torch.Tensor,
945
+ v: torch.Tensor,
946
+ attn_mask: Optional[torch.Tensor] = None,
947
+ dropout_p: float = 0.0,
948
+ is_causal: bool = False,
949
+ ) -> torch.Tensor:
950
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
951
+
952
+ if is_causal:
953
+ assert attn_mask is None
954
+
955
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
956
+ attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
957
+ elif attn_mask is not None:
958
+ attn_bias = attn_mask.to(q.dtype)
959
+ else:
960
+ attn_bias = torch.zeros_like(attn_weights)
961
+
962
+ attn_weights += attn_bias
963
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
964
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
965
+ return torch.matmul(attn_weights, v)
966
+
967
+ def forward(
968
+ self,
969
+ x: torch.Tensor,
970
+ attention_bias: Optional[torch.Tensor] = None,
971
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
972
+ use_cache: bool = False,
973
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
974
+ # Get query, key, value projections.
975
+ # shape:
976
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
977
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
978
+ # k, v: (batch_size, seq_len, d_model // n_heads)
979
+ x_normed = self.attn_norm(x)
980
+ q = self.q_proj(x_normed)
981
+ k = self.k_proj(x_normed)
982
+ v = self.v_proj(x_normed)
983
+
984
+ if self.config.clip_qkv is not None:
985
+ q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
986
+ k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
987
+ v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
988
+
989
+ # Get attention scores.
990
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
991
+
992
+ # Add attention scores.
993
+ # shape: (B, T, C)
994
+ x = x + self.dropout(att)
995
+
996
+ # Add feed-forward projection.
997
+ # shape: (batch_size, seq_len, d_model)
998
+ og_x = x
999
+ if self._activation_checkpoint_fn is not None:
1000
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
1001
+ else:
1002
+ x = self.ff_norm(x)
1003
+ x = self.ff_proj(x)
1004
+ if self._activation_checkpoint_fn is not None:
1005
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
1006
+ else:
1007
+ x = self.act(x)
1008
+ x = self.ff_out(x)
1009
+ x = self.dropout(x)
1010
+ x = og_x + x
1011
+
1012
+ return x, cache
1013
+
1014
+
1015
+ class OLMoOutput(NamedTuple):
1016
+ logits: torch.FloatTensor
1017
+ """
1018
+ A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
1019
+ for the next token *before* normalization via (log) softmax.
1020
+ """
1021
+
1022
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
1023
+ """
1024
+ Attention keys and values from each block.
1025
+ """
1026
+
1027
+ hidden_states: Optional[Tuple[torch.Tensor]]
1028
+ """
1029
+ Hidden states from each block.
1030
+ """
1031
+
1032
+
1033
+ class OLMoGenerateOutput(NamedTuple):
1034
+ token_ids: torch.LongTensor
1035
+ """
1036
+ The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
1037
+ These do *not* include the original input IDs.
1038
+ """
1039
+
1040
+ scores: torch.FloatTensor
1041
+ """
1042
+ The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
1043
+ """
1044
+
1045
+
1046
+ class OLMoBlockGroup(nn.ModuleList):
1047
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
1048
+ super().__init__(modules)
1049
+ self.config = config
1050
+ self.layer_offset = layer_offset
1051
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1052
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
1053
+
1054
+ def forward(
1055
+ self,
1056
+ x: torch.Tensor,
1057
+ attention_bias: Optional[torch.FloatTensor] = None,
1058
+ layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
1059
+ use_cache: bool = False,
1060
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
1061
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1062
+ for block_idx, block in enumerate(self):
1063
+ layer_past = None if layers_past is None else layers_past[block_idx]
1064
+ block_idx += self.layer_offset
1065
+ if (
1066
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1067
+ or (
1068
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1069
+ and block_idx % 2 == 0
1070
+ )
1071
+ or (
1072
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1073
+ and block_idx % 3 == 0
1074
+ )
1075
+ or (
1076
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1077
+ and block_idx % 4 == 0
1078
+ )
1079
+ ):
1080
+ # shape: (batch_size, seq_len, d_model)
1081
+ x, cache = self._activation_checkpoint_fn( # type: ignore
1082
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1083
+ )
1084
+ else:
1085
+ # shape: (batch_size, seq_len, d_model)
1086
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1087
+ if attn_key_values is not None:
1088
+ assert cache is not None
1089
+ attn_key_values.append(cache)
1090
+ return x, attn_key_values
1091
+
1092
+ def reset_parameters(self):
1093
+ for block in self:
1094
+ block.reset_parameters()
1095
+
1096
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1097
+ self.activation_checkpointing_strategy = strategy
1098
+ for block in self:
1099
+ block.set_activation_checkpointing(strategy)
1100
+
1101
+
1102
+ class OLMo(nn.Module):
1103
+ def __init__(self, config: ModelConfig, init_params: bool = True):
1104
+ super().__init__()
1105
+ self.config = config
1106
+ self.__cache = BufferCache()
1107
+
1108
+ # Validate config.
1109
+ if self.config.alibi and self.config.flash_attention:
1110
+ raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention")
1111
+
1112
+ if self.config.alibi and self.config.rope:
1113
+ raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive")
1114
+
1115
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
1116
+ if self.config.embedding_size < self.config.vocab_size:
1117
+ raise OLMoConfigurationError("embedding size should be at least as big as vocab size")
1118
+ elif self.config.embedding_size % 128 != 0:
1119
+ import warnings
1120
+
1121
+ warnings.warn(
1122
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1123
+ )
1124
+
1125
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1126
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
1127
+
1128
+ if not (
1129
+ 0 < self.config.block_group_size <= self.config.n_layers
1130
+ and self.config.n_layers % self.config.block_group_size == 0
1131
+ ):
1132
+ raise OLMoConfigurationError("n layers must be divisible by block group size")
1133
+
1134
+ torch.backends.cuda.enable_flash_sdp(self.config.flash_attention)
1135
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
1136
+
1137
+ self.transformer = nn.ModuleDict(
1138
+ dict(
1139
+ wte=nn.Embedding(
1140
+ config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
1141
+ ),
1142
+ emb_drop=Dropout(config.embedding_dropout),
1143
+ ln_f=LayerNorm.build(config),
1144
+ )
1145
+ )
1146
+
1147
+ blocks = [OLMoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
1148
+ if self.config.block_group_size > 1:
1149
+ block_groups = [
1150
+ OLMoBlockGroup(config, i, blocks[i : i + config.block_group_size])
1151
+ for i in range(0, config.n_layers, config.block_group_size)
1152
+ ]
1153
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
1154
+ else:
1155
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
1156
+
1157
+ if not (self.config.alibi or self.config.rope):
1158
+ self.transformer.update(
1159
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1160
+ )
1161
+ if not config.weight_tying:
1162
+ self.transformer.update(
1163
+ {
1164
+ "ff_out": nn.Linear(
1165
+ config.d_model,
1166
+ config.embedding_size or config.vocab_size,
1167
+ bias=config.include_bias,
1168
+ device=config.init_device,
1169
+ )
1170
+ }
1171
+ )
1172
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1173
+ if init_params and self.config.init_device != "meta":
1174
+ self.reset_parameters()
1175
+ self.__num_fwd_flops: Optional[int] = None
1176
+
1177
+ # Warm up cache.
1178
+ if self.config.alibi:
1179
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
1180
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
1181
+
1182
+ def embed_tokens(self, input_ids):
1183
+ return self.transformer.wte(input_ids)
1184
+
1185
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1186
+ self.activation_checkpointing_strategy = strategy
1187
+ if self.config.block_group_size != 1:
1188
+ for block_group in self.transformer.block_groups:
1189
+ block_group.set_activation_checkpointing(strategy)
1190
+ else:
1191
+ for block in self.transformer.blocks:
1192
+ block.set_activation_checkpointing(strategy)
1193
+
1194
+ @property
1195
+ def device(self) -> torch.device:
1196
+ device: torch.device = self.transformer.wte.weight.device # type: ignore
1197
+ if device.type == "meta":
1198
+ return _non_meta_init_device(self.config)
1199
+ else:
1200
+ return device
1201
+
1202
+ def reset_parameters(self):
1203
+ log.info("Initializing model parameters...")
1204
+ # Top-level embeddings / linear layers.
1205
+ init_weights(
1206
+ self.config,
1207
+ self.transformer.wte, # type: ignore
1208
+ std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
1209
+ type_of_module=ModuleType.emb,
1210
+ )
1211
+ if hasattr(self.transformer, "wpe"):
1212
+ init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
1213
+
1214
+ # Top-level layer norm.
1215
+ self.transformer.ln_f.reset_parameters() # type: ignore
1216
+
1217
+ # Output weights.
1218
+ if hasattr(self.transformer, "ff_out"):
1219
+ init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
1220
+
1221
+ # Let the blocks handle themselves.
1222
+ if self.config.block_group_size == 1:
1223
+ for block in self.transformer.blocks:
1224
+ block.reset_parameters()
1225
+ else:
1226
+ for block_group in self.transformer.block_groups:
1227
+ block_group.reset_parameters()
1228
+
1229
+ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1230
+ if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
1231
+ -1
1232
+ ] >= seq_len:
1233
+ if alibi_bias.device != device:
1234
+ alibi_bias = alibi_bias.to(device)
1235
+ self.__cache["alibi_attention_bias"] = alibi_bias
1236
+ return alibi_bias
1237
+ with torch.autocast(device.type, enabled=False):
1238
+ alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1239
+ self.__cache["alibi_attention_bias"] = alibi_bias
1240
+ return alibi_bias
1241
+
1242
+ def forward(
1243
+ self,
1244
+ input_ids: torch.LongTensor,
1245
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1246
+ attention_mask: Optional[torch.Tensor] = None,
1247
+ attention_bias: Optional[torch.Tensor] = None,
1248
+ past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1249
+ use_cache: bool = False,
1250
+ last_logits_only: bool = False,
1251
+ output_hidden_states: Optional[bool] = None,
1252
+ ) -> OLMoOutput:
1253
+ """
1254
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1255
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
1256
+ embeddings. When provided, it is treated as the output of the input embedding layer.
1257
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
1258
+ which input IDs are masked. A `1` value in the mask means that
1259
+ the corresponding input ID should *not* be ignored. A `0` means
1260
+ that the corresponding input ID is masked.
1261
+
1262
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1263
+ library.
1264
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1265
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1266
+ to introduce causal or other biases.
1267
+
1268
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1269
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
1270
+ element in the sequence.
1271
+
1272
+ If the tensor is a float tensor, it will just be added to the attention
1273
+ scores before the softmax.
1274
+
1275
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1276
+ :param past_key_values: Pre-computed keys and values for each attention block.
1277
+ Can be used to speed up sequential decoding. The `input_ids` which have
1278
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
1279
+ :param use_cache: If `True`, return key and value tensors for each block.
1280
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
1281
+ This can speed up decoding when you only care about the next token.
1282
+ """
1283
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
1284
+
1285
+ if past_key_values:
1286
+ assert len(past_key_values) == self.config.n_layers
1287
+
1288
+ batch_size, seq_len = input_ids.size() if inputs_embeds is None else inputs_embeds.size()[:2]
1289
+ if past_key_values is None:
1290
+ past_length = 0
1291
+ else:
1292
+ past_length = past_key_values[0][0].size(-2)
1293
+
1294
+ # Get embeddings of input.
1295
+ # shape: (batch_size, seq_len, d_model)
1296
+ x = self.transformer.wte(input_ids) if inputs_embeds is None else inputs_embeds # type: ignore
1297
+
1298
+ if not (self.config.alibi or self.config.rope):
1299
+ # Get positional embeddings.
1300
+ # shape: (1, seq_len)
1301
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
1302
+ # shape: (1, seq_len, d_model)
1303
+ pos_emb = self.transformer.wpe(pos) # type: ignore
1304
+ x = pos_emb + x
1305
+
1306
+ # Add input + positional embeddings and apply dropout.
1307
+ # shape: (batch_size, seq_len, d_model)
1308
+ x = self.transformer.emb_drop(x) # type: ignore
1309
+
1310
+ # Transform the attention mask into what the blocks expect.
1311
+ if attention_mask is not None:
1312
+ # shape: (batch_size, 1, 1, seq_len)
1313
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1314
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
1315
+
1316
+ # Merge attention mask with attention bias.
1317
+ if (
1318
+ attention_bias is not None
1319
+ or attention_mask is not None
1320
+ or self.config.alibi
1321
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1322
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1323
+ # scores correctly.
1324
+ or past_key_values is not None
1325
+ ):
1326
+ if attention_bias is None and self.config.alibi:
1327
+ attention_bias = get_causal_attention_bias(
1328
+ self.__cache, past_length + seq_len, x.device
1329
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1330
+ elif attention_bias is None:
1331
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1332
+ elif attention_bias.dtype in (torch.int8, torch.bool):
1333
+ attention_bias = attention_bias.to(dtype=torch.float)
1334
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
1335
+
1336
+ # Transform to the right shape and data type.
1337
+ mask_len = seq_len
1338
+ if attention_mask is not None:
1339
+ mask_len = attention_mask.shape[-1]
1340
+ elif past_key_values is not None:
1341
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
1342
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
1343
+
1344
+ # Add in the masking bias.
1345
+ if attention_mask is not None:
1346
+ attention_bias = attention_bias + attention_mask
1347
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
1348
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
1349
+ # it can produce NaNs.
1350
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
1351
+
1352
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1353
+
1354
+ # decoder layers
1355
+ all_hidden_states = []
1356
+
1357
+ # Apply blocks one-by-one.
1358
+ if self.config.block_group_size == 1:
1359
+ for block_idx, block in enumerate(self.transformer.blocks):
1360
+ if output_hidden_states:
1361
+ # add hidden states
1362
+ all_hidden_states.append(x)
1363
+
1364
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
1365
+ if (
1366
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1367
+ or (
1368
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1369
+ and block_idx % 2 == 0
1370
+ )
1371
+ or (
1372
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1373
+ and block_idx % 3 == 0
1374
+ )
1375
+ or (
1376
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1377
+ and block_idx % 4 == 0
1378
+ )
1379
+ ):
1380
+ # shape: (batch_size, seq_len, d_model)
1381
+ x, cache = self._activation_checkpoint_fn(
1382
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1383
+ )
1384
+ else:
1385
+ # shape: (batch_size, seq_len, d_model)
1386
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1387
+ if attn_key_values is not None:
1388
+ assert cache is not None
1389
+ attn_key_values.append(cache)
1390
+ else:
1391
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
1392
+ if output_hidden_states:
1393
+ # add hidden states
1394
+ all_hidden_states.append(x)
1395
+
1396
+ layers_past = (
1397
+ None
1398
+ if past_key_values is None
1399
+ else past_key_values[
1400
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
1401
+ ]
1402
+ )
1403
+ x, cache = block_group(
1404
+ x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
1405
+ )
1406
+ if attn_key_values is not None:
1407
+ assert cache is not None
1408
+ attn_key_values.extend(cache)
1409
+
1410
+ if last_logits_only:
1411
+ # shape: (batch_size, 1, d_model)
1412
+ x = x[:, -1, :].unsqueeze(1)
1413
+
1414
+ # Apply final layer norm.
1415
+ # shape: (batch_size, seq_len or 1, d_model)
1416
+ x = self.transformer.ln_f(x) # type: ignore
1417
+ if output_hidden_states:
1418
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
1419
+ all_hidden_states.append(x)
1420
+
1421
+ # Get logits.
1422
+ # shape: (batch_size, seq_len or 1, vocab_size)
1423
+ if self.config.weight_tying:
1424
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
1425
+ else:
1426
+ logits = self.transformer.ff_out(x) # type: ignore
1427
+ if self.config.scale_logits:
1428
+ logits.mul_(1 / math.sqrt(self.config.d_model))
1429
+
1430
+ return BaseModelOutputWithPast(
1431
+ last_hidden_state=x,
1432
+ past_key_values=tuple(attn_key_values) if attn_key_values is not None else None,
1433
+ hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
1434
+ )
1435
+
1436
+ def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
1437
+ if wrap_strategy is None:
1438
+ return None
1439
+
1440
+ # The 'recurse' mode for the wrap function does not behave like you'd expect.
1441
+ # Even if we return False, it may still recurse because PyTorch does what it wants,
1442
+ # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
1443
+ # but not other linear layers within a block.
1444
+ # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
1445
+ # return True in 'recurse' mode for simplicity.
1446
+ size_based_module_to_wrap = {self.transformer.wte}
1447
+ if hasattr(self.transformer, "ff_out"):
1448
+ size_based_module_to_wrap.add(self.transformer.ff_out)
1449
+
1450
+ if wrap_strategy == FSDPWrapStrategy.by_block:
1451
+
1452
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1453
+ del nonwrapped_numel
1454
+ wrap = isinstance(module, OLMoBlock)
1455
+ if recurse:
1456
+ return True
1457
+ else:
1458
+ return wrap
1459
+
1460
+ return fsdp_wrap_fn
1461
+ elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:
1462
+
1463
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1464
+ del nonwrapped_numel
1465
+ wrap = isinstance(module, (OLMoBlock,)) or module in size_based_module_to_wrap
1466
+ if recurse:
1467
+ return True
1468
+ else:
1469
+ return wrap
1470
+
1471
+ return fsdp_wrap_fn
1472
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group:
1473
+ if self.config.block_group_size <= 1:
1474
+ raise OLMoConfigurationError(
1475
+ "'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
1476
+ )
1477
+
1478
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1479
+ del nonwrapped_numel
1480
+ wrap = isinstance(module, OLMoBlockGroup)
1481
+ if recurse:
1482
+ return True
1483
+ else:
1484
+ return wrap
1485
+
1486
+ return fsdp_wrap_fn
1487
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
1488
+ if self.config.block_group_size <= 1:
1489
+ raise OLMoConfigurationError(
1490
+ "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1"
1491
+ )
1492
+
1493
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1494
+ del nonwrapped_numel
1495
+ wrap = isinstance(module, (OLMoBlockGroup,)) or module in size_based_module_to_wrap
1496
+ if recurse:
1497
+ return True
1498
+ else:
1499
+ return wrap
1500
+
1501
+ return fsdp_wrap_fn
1502
+ elif wrap_strategy == FSDPWrapStrategy.size_based:
1503
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
1504
+
1505
+ return size_based_auto_wrap_policy
1506
+ elif wrap_strategy in {
1507
+ FSDPWrapStrategy.one_in_two,
1508
+ FSDPWrapStrategy.one_in_three,
1509
+ FSDPWrapStrategy.one_in_four,
1510
+ FSDPWrapStrategy.one_in_five,
1511
+ }:
1512
+ c = {
1513
+ FSDPWrapStrategy.one_in_two: 2,
1514
+ FSDPWrapStrategy.one_in_three: 3,
1515
+ FSDPWrapStrategy.one_in_four: 4,
1516
+ FSDPWrapStrategy.one_in_five: 5,
1517
+ }[wrap_strategy]
1518
+
1519
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1520
+ del nonwrapped_numel
1521
+ wrap = isinstance(module, OLMoBlock) and module.layer_id % c == 0
1522
+ if recurse:
1523
+ return True
1524
+ else:
1525
+ return wrap
1526
+
1527
+ return fsdp_wrap_fn
1528
+ else:
1529
+ raise NotImplementedError(wrap_strategy)
1530
+
1531
+ def num_params(self, include_embedding: bool = True) -> int:
1532
+ """
1533
+ Get the total number of parameters.
1534
+ """
1535
+ params = (np for np in self.named_parameters())
1536
+ if not include_embedding:
1537
+ params = filter( # type: ignore
1538
+ lambda np: ".wte." not in np[0] and ".wpe." not in np[0],
1539
+ params,
1540
+ )
1541
+ return sum(p.numel() for _, p in params)
1542
+
1543
+ @property
1544
+ def num_fwd_flops(self):
1545
+ if self.__num_fwd_flops:
1546
+ return self.__num_fwd_flops
1547
+ n_params = self.num_params()
1548
+ # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
1549
+ # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
1550
+ # this gets us FLOPs / token
1551
+ params_flops_per_token = 2 * n_params
1552
+ params_flops_per_seq = params_flops_per_token * self.config.max_sequence_length
1553
+ # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
1554
+ attn_flops_per_seq = (
1555
+ self.config.n_layers * 2 * 2 * (self.config.d_model * (self.config.max_sequence_length**2))
1556
+ )
1557
+ self.__num_fwd_flops = params_flops_per_seq + attn_flops_per_seq
1558
+ return self.__num_fwd_flops
1559
+
1560
+ def generate(
1561
+ self,
1562
+ input_ids: torch.LongTensor,
1563
+ attention_mask: Optional[torch.Tensor] = None,
1564
+ attention_bias: Optional[torch.Tensor] = None,
1565
+ max_steps: int = 10,
1566
+ beam_size: int = 1,
1567
+ per_node_beam_size: Optional[int] = None,
1568
+ sampler: Optional[Sampler] = None,
1569
+ min_steps: Optional[int] = None,
1570
+ final_sequence_scorer: Optional[FinalSequenceScorer] = None,
1571
+ constraints: Optional[List[Constraint]] = None,
1572
+ ) -> OLMoGenerateOutput:
1573
+ """
1574
+ Generate token IDs using beam search.
1575
+
1576
+ Note that by default ``beam_size`` is set to 1, which is greedy decoding.
1577
+
1578
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1579
+ :param attention_mask: A optional tensor of shape `(batch_size, seq_len)`, the same
1580
+ as for the forward method.
1581
+ :param attention_bias: A tensor of shape
1582
+ `(batch_size, 1, seq_len + tokens_to_generate, seq_len + tokens_to_generate)`,
1583
+ the same as for the forward method except only one shape is excepted here.
1584
+
1585
+ For an explanation of the other arguments, see :class:`BeamSearch`.
1586
+ """
1587
+ beam_search = BeamSearch(
1588
+ self.config.eos_token_id,
1589
+ max_steps=max_steps,
1590
+ beam_size=beam_size,
1591
+ per_node_beam_size=per_node_beam_size,
1592
+ sampler=sampler,
1593
+ min_steps=min_steps,
1594
+ final_sequence_scorer=final_sequence_scorer,
1595
+ constraints=constraints,
1596
+ )
1597
+
1598
+ # Validate inputs.
1599
+ batch_size, seq_len = input_ids.shape
1600
+ if attention_mask is not None:
1601
+ assert attention_mask.shape == (batch_size, seq_len)
1602
+ if attention_bias is not None:
1603
+ assert len(attention_bias.shape) == 4
1604
+ assert attention_bias.shape[:2] == (batch_size, 1)
1605
+ assert (
1606
+ seq_len + beam_search.max_steps
1607
+ <= attention_bias.shape[2]
1608
+ == attention_bias.shape[3]
1609
+ <= self.config.max_sequence_length
1610
+ )
1611
+
1612
+ tokens_generated = 0
1613
+
1614
+ def flatten_past_key_values(
1615
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
1616
+ ) -> Dict[str, torch.Tensor]:
1617
+ out = {}
1618
+ for i, (key, value) in enumerate(past_key_values):
1619
+ out[f"past_key_{i}"] = key
1620
+ out[f"past_value_{i}"] = value
1621
+ return out
1622
+
1623
+ def unflatten_past_key_values(
1624
+ past_key_values: Dict[str, torch.Tensor],
1625
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
1626
+ out = []
1627
+ for i in range(self.config.n_layers):
1628
+ past_key = past_key_values[f"past_key_{i}"]
1629
+ past_value = past_key_values[f"past_value_{i}"]
1630
+ out.append((past_key, past_value))
1631
+ return out
1632
+
1633
+ def step(
1634
+ last_predictions: torch.Tensor, state: dict[str, torch.Tensor]
1635
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
1636
+ nonlocal tokens_generated
1637
+
1638
+ attention_mask = state.get("attention_mask")
1639
+ attention_bias = state.get("attention_bias")
1640
+
1641
+ if tokens_generated > 0:
1642
+ past_key_values = unflatten_past_key_values(state)
1643
+ input_ids = last_predictions.unsqueeze(1)
1644
+ if attention_mask is not None:
1645
+ group_size = input_ids.shape[0]
1646
+ attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1)
1647
+ else:
1648
+ past_key_values = None
1649
+ input_ids = state["input_ids"]
1650
+
1651
+ tokens_generated += 1
1652
+
1653
+ # Run forward pass of model to get logits, then normalize to get log probs.
1654
+ output = self(
1655
+ input_ids,
1656
+ attention_mask=attention_mask,
1657
+ attention_bias=attention_bias,
1658
+ past_key_values=past_key_values,
1659
+ use_cache=True,
1660
+ last_logits_only=True,
1661
+ )
1662
+ log_probs = F.log_softmax(output.logits[:, -1, :], dim=-1)
1663
+
1664
+ # Create new state.
1665
+ state = flatten_past_key_values(output.attn_key_values)
1666
+ if attention_mask is not None:
1667
+ state["attention_mask"] = attention_mask
1668
+ if attention_bias is not None:
1669
+ state["attention_bias"] = attention_bias
1670
+
1671
+ return log_probs, state
1672
+
1673
+ initial_preds = input_ids.new_zeros((batch_size,)) # This is arbitrary, we won't use this.
1674
+ state: dict[str, torch.Tensor] = {"input_ids": input_ids}
1675
+ if attention_mask is not None:
1676
+ state["attention_mask"] = attention_mask
1677
+ if attention_bias is not None:
1678
+ state["attention_bias"] = attention_bias
1679
+ with torch.no_grad():
1680
+ token_ids, scores = beam_search.search(initial_preds, state, step)
1681
+
1682
+ return OLMoGenerateOutput(
1683
+ token_ids=token_ids, # type: ignore[arg-type]
1684
+ scores=scores, # type: ignore[arg-type]
1685
+ )
1686
+
1687
+ @classmethod
1688
+ def from_checkpoint(
1689
+ cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: Optional[CheckpointType] = None
1690
+ ) -> OLMo:
1691
+ """
1692
+ Load an OLMo model from a checkpoint.
1693
+ """
1694
+ from .util import resource_path
1695
+
1696
+ # Guess checkpoint type.
1697
+ if checkpoint_type is None:
1698
+ try:
1699
+ if resource_path(checkpoint_dir, "model.pt").is_file():
1700
+ checkpoint_type = CheckpointType.unsharded
1701
+ else:
1702
+ checkpoint_type = CheckpointType.sharded
1703
+ except FileNotFoundError:
1704
+ checkpoint_type = CheckpointType.sharded
1705
+
1706
+ # Load config.
1707
+ config_path = resource_path(checkpoint_dir, "config.yaml")
1708
+ model_config = ModelConfig.load(config_path, key="model", validate_paths=False)
1709
+
1710
+ if checkpoint_type == CheckpointType.unsharded:
1711
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
1712
+ model_config.init_device = "cpu"
1713
+ model = OLMo(model_config)
1714
+
1715
+ # Load state dict directly to target device.
1716
+ state_dict_path = resource_path(checkpoint_dir, "model.pt")
1717
+ state_dict = torch.load(state_dict_path, map_location="cpu")
1718
+ model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
1719
+ model = model.to(torch.device(device))
1720
+ else:
1721
+ from .checkpoint import load_model_state
1722
+
1723
+ # Initialize model on target device. In this case the state dict is loaded in-place
1724
+ # so it's not necessary to start on CPU if the target device is a GPU.
1725
+ model_config.init_device = device
1726
+ model = OLMo(model_config)
1727
+
1728
+ # Load state dict in place.
1729
+ load_model_state(checkpoint_dir, model)
1730
+
1731
+ return model.eval()
1732
+
1733
+ def _make_state_dict_compatible(
1734
+ self, state_dict: Dict[str, torch.Tensor]
1735
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Set[str]]]:
1736
+ """
1737
+ Handles some cases where the state dict is valid yet may need to be transformed in order to
1738
+ be loaded.
1739
+
1740
+ This modifies the state dict in-place and also returns it, along with a mapping of original key
1741
+ names to new key names in cases where the keys were simply renamed. That mapping can be used
1742
+ to make a corresponding optimizer state dict compatible as well.
1743
+ """
1744
+ import re
1745
+ from fnmatch import fnmatch
1746
+
1747
+ new_keys_to_og_keys: Dict[str, str] = {}
1748
+
1749
+ # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
1750
+ # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
1751
+ # fine without the prefixes. This also simplifies the other steps below.
1752
+ for key in list(state_dict.keys()):
1753
+ state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key)
1754
+ new_keys_to_og_keys[new_key] = key
1755
+
1756
+ # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
1757
+ if self.config.block_type == BlockType.sequential:
1758
+ for key in list(state_dict.keys()):
1759
+ if fnmatch(key, "transformer.*.norm.weight"):
1760
+ tensor = state_dict.pop(key)
1761
+ state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor
1762
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1763
+ state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone()
1764
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1765
+ del new_keys_to_og_keys[key]
1766
+ elif fnmatch(key, "transformer.*.norm.bias"):
1767
+ tensor = state_dict.pop(key)
1768
+ state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor
1769
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1770
+ state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone()
1771
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1772
+ del new_keys_to_og_keys[key]
1773
+
1774
+ # For loading a state dict that was saved with a different `block_group_size`.
1775
+ if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
1776
+ state_dict_block_group_size = len(
1777
+ [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
1778
+ )
1779
+ else:
1780
+ state_dict_block_group_size = 1
1781
+ if self.config.block_group_size != state_dict_block_group_size:
1782
+ log.info(
1783
+ f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
1784
+ f"group size {self.config.block_group_size}"
1785
+ )
1786
+ # For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
1787
+ # and then (re-)group them into the right block sizes.
1788
+ if state_dict_block_group_size > 1:
1789
+ for key in list(state_dict.keys()):
1790
+ if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
1791
+ group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
1792
+ block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
1793
+ state_dict[
1794
+ (
1795
+ new_key := key.replace(
1796
+ f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}."
1797
+ )
1798
+ )
1799
+ ] = state_dict.pop(key)
1800
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1801
+
1802
+ if self.config.block_group_size > 1:
1803
+ # Group the state dict blocks into the right block size.
1804
+ for key in list(state_dict.keys()):
1805
+ if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
1806
+ block_idx = int(m.group(1))
1807
+ group_idx, group_block_idx = (
1808
+ block_idx // self.config.block_group_size,
1809
+ block_idx % self.config.block_group_size,
1810
+ )
1811
+ state_dict[
1812
+ (
1813
+ new_key := key.replace(
1814
+ f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}."
1815
+ )
1816
+ )
1817
+ ] = state_dict.pop(key)
1818
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1819
+
1820
+ og_keys_to_new: Dict[str, Set[str]] = defaultdict(set)
1821
+ for new_key, og_key in new_keys_to_og_keys.items():
1822
+ og_keys_to_new[og_key].add(new_key)
1823
+
1824
+ return state_dict, og_keys_to_new