English
naveensp commited on
Commit
1412507
·
verified ·
1 Parent(s): 17776d2

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -1824
model.py DELETED
@@ -1,1824 +0,0 @@
1
- """
2
- Adapted from
3
- [MosaiclML](https://github.com/mosaicml/examples.git) and
4
- [minGPT](https://github.com/karpathy/minGPT.git)
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import logging
10
- import math
11
- import sys
12
- from abc import abstractmethod
13
- from collections import defaultdict
14
- from functools import partial
15
- from typing import (
16
- Callable,
17
- Dict,
18
- Iterable,
19
- List,
20
- NamedTuple,
21
- Optional,
22
- Sequence,
23
- Set,
24
- Tuple,
25
- cast,
26
- )
27
-
28
- import torch
29
- import torch.backends.cuda
30
- import torch.nn as nn
31
- import torch.nn.functional as F
32
- from torch import einsum
33
-
34
- from transformers.modeling_outputs import BaseModelOutputWithPast
35
-
36
- from .aliases import PathOrStr
37
- from .beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
38
- from .config import (
39
- ActivationCheckpointingStrategy,
40
- ActivationType,
41
- BlockType,
42
- CheckpointType,
43
- FSDPWrapStrategy,
44
- LayerNormType,
45
- ModelConfig,
46
- )
47
- from .exceptions import OLMoConfigurationError
48
- from .initialization import ModuleType, init_weights
49
- from .torch_util import ensure_finite_
50
-
51
- import copy
52
- if sys.version_info.minor > 8:
53
- from collections.abc import MutableMapping
54
- elif sys.version_info.minor == 8:
55
- from typing import MutableMapping
56
- else:
57
- raise SystemExit("This script supports Python 3.8 or higher")
58
-
59
- __all__ = [
60
- "LayerNormBase",
61
- "LayerNorm",
62
- "RMSLayerNorm",
63
- "RotaryEmbedding",
64
- "Activation",
65
- "GELU",
66
- "ReLU",
67
- "SwiGLU",
68
- "BitLinear158",
69
- "OLMoBlock",
70
- "OLMoSequentialBlock",
71
- "OLMoParallelBlock",
72
- "OLMo",
73
- "OLMoOutput",
74
- "OLMoGenerateOutput",
75
- ]
76
-
77
-
78
- log = logging.getLogger(__name__)
79
-
80
-
81
- def activation_checkpoint_function(cfg: ModelConfig):
82
- preserve_rng_state = (
83
- (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
84
- )
85
- from torch.utils.checkpoint import checkpoint
86
-
87
- return partial(
88
- checkpoint,
89
- preserve_rng_state=preserve_rng_state,
90
- use_reentrant=False,
91
- )
92
-
93
-
94
- class BufferCache(dict, MutableMapping[str, torch.Tensor]):
95
- """
96
- Cache for attention biases and other things that would normally be stored as buffers.
97
- We avoid using buffers because we've run into various issues doing so with FSDP.
98
- In general it appears the way FSDP handles buffers is not well-defined.
99
- It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
100
- since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
101
- NaNs when they're synchronized due to casting or some other issue.
102
- """
103
-
104
-
105
- def _non_meta_init_device(config: ModelConfig) -> torch.device:
106
- if config.init_device is not None and config.init_device != "meta":
107
- return torch.device(config.init_device)
108
- else:
109
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
110
-
111
-
112
- class Dropout(nn.Dropout):
113
- def forward(self, input: torch.Tensor) -> torch.Tensor:
114
- if self.p == 0.0:
115
- return input
116
- else:
117
- return F.dropout(input, self.p, self.training, self.inplace)
118
-
119
-
120
- class LayerNormBase(nn.Module):
121
- def __init__(
122
- self,
123
- config: ModelConfig,
124
- *,
125
- size: Optional[int] = None,
126
- elementwise_affine: Optional[bool] = True,
127
- eps: float = 1e-05,
128
- ):
129
- super().__init__()
130
- self.config = config
131
- self.eps = eps
132
- self.normalized_shape = (size or config.d_model,)
133
- if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
134
- self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
135
- use_bias = self.config.bias_for_layer_norm
136
- if use_bias is None:
137
- use_bias = self.config.include_bias
138
- if use_bias:
139
- self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
140
- else:
141
- self.register_parameter("bias", None)
142
- else:
143
- self.register_parameter("bias", None)
144
- self.register_parameter("weight", None)
145
-
146
- @abstractmethod
147
- def forward(self, x: torch.Tensor) -> torch.Tensor:
148
- raise NotImplementedError
149
-
150
- @classmethod
151
- def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
152
- if config.layer_norm_type == LayerNormType.default:
153
- return LayerNorm(config, size=size, low_precision=False, **kwargs)
154
- elif config.layer_norm_type == LayerNormType.low_precision:
155
- return LayerNorm(config, size=size, low_precision=True, **kwargs)
156
- elif config.layer_norm_type == LayerNormType.rms:
157
- return RMSLayerNorm(config, size=size, **kwargs)
158
- else:
159
- raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
160
-
161
- def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
162
- # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
163
- # `is_autocast_cpu_enabled()` for CPU autocast.
164
- # See https://github.com/pytorch/pytorch/issues/110966.
165
- if tensor.device.type == "cuda" and torch.is_autocast_enabled():
166
- return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
167
- elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
168
- return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
169
- else:
170
- return tensor
171
-
172
- def reset_parameters(self):
173
- if self.weight is not None:
174
- torch.nn.init.ones_(self.weight) # type: ignore
175
- if self.bias is not None:
176
- torch.nn.init.zeros_(self.bias) # type: ignore
177
-
178
-
179
- class LayerNorm(LayerNormBase):
180
- """
181
- The default :class:`LayerNorm` implementation which can optionally run in low precision.
182
- """
183
-
184
- def __init__(
185
- self,
186
- config: ModelConfig,
187
- size: Optional[int] = None,
188
- low_precision: bool = False,
189
- elementwise_affine: Optional[bool] = None,
190
- eps: float = 1e-05,
191
- ):
192
- super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
193
- self.low_precision = low_precision
194
-
195
- def forward(self, x: torch.Tensor) -> torch.Tensor:
196
- if self.low_precision:
197
- module_device = x.device
198
- downcast_x = self._cast_if_autocast_enabled(x)
199
- downcast_weight = (
200
- self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
201
- )
202
- downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
203
- with torch.autocast(enabled=False, device_type=module_device.type):
204
- return F.layer_norm(
205
- downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
206
- )
207
- else:
208
- return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
209
-
210
-
211
- class RMSLayerNorm(LayerNormBase):
212
- """
213
- RMS layer norm, a simplified :class:`LayerNorm` implementation
214
- """
215
-
216
- def __init__(
217
- self,
218
- config: ModelConfig,
219
- size: Optional[int] = None,
220
- elementwise_affine: Optional[bool] = None,
221
- eps: float = 1e-5,
222
- ):
223
- super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
224
-
225
- def forward(self, x: torch.Tensor) -> torch.Tensor:
226
- with torch.autocast(enabled=False, device_type=x.device.type):
227
- og_dtype = x.dtype
228
- x = x.to(torch.float32)
229
- variance = x.pow(2).mean(-1, keepdim=True)
230
- x = x * torch.rsqrt(variance + self.eps)
231
- x = x.to(og_dtype)
232
-
233
- if self.weight is not None:
234
- if self.bias is not None:
235
- return self.weight * x + self.bias
236
- else:
237
- return self.weight * x
238
- else:
239
- return x
240
-
241
-
242
- class RotaryEmbedding(nn.Module):
243
- """
244
- [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
245
- """
246
-
247
- def __init__(self, config: ModelConfig, cache: BufferCache):
248
- super().__init__()
249
- self.config = config
250
- self.__cache = cache
251
- # Warm up cache.
252
- self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
253
-
254
- def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
255
- if (
256
- (pos_sin := self.__cache.get("rope_pos_sin")) is not None
257
- and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
258
- and pos_sin.shape[-2] >= seq_len
259
- and pos_cos.shape[-2] >= seq_len
260
- ):
261
- if pos_sin.device != device:
262
- pos_sin = pos_sin.to(device)
263
- self.__cache["rope_pos_sin"] = pos_sin
264
- if pos_cos.device != device:
265
- pos_cos = pos_cos.to(device)
266
- self.__cache["rope_pos_cos"] = pos_cos
267
- return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
268
-
269
- with torch.autocast(device.type, enabled=False):
270
- dim = self.config.d_model // self.config.n_heads
271
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
272
- seq = torch.arange(seq_len, device=device, dtype=torch.float)
273
- freqs = einsum("i , j -> i j", seq, inv_freq)
274
- positions = torch.cat((freqs, freqs), dim=-1)
275
- pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
276
- self.__cache["rope_pos_sin"] = pos_sin
277
- self.__cache["rope_pos_cos"] = pos_cos
278
- return pos_sin, pos_cos
279
-
280
- def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
281
- B, nh, T, hs = x.size()
282
- x = x.view(B, nh, T, 2, hs // 2)
283
- x1, x2 = x.unbind(dim=-2)
284
- return torch.cat((-x2, x1), dim=-1)
285
-
286
- def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
287
- return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
288
-
289
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
290
- if self.config.rope_full_precision:
291
- q_, k_ = q.float(), k.float()
292
- else:
293
- q_, k_ = q, k
294
-
295
- with torch.autocast(q.device.type, enabled=False):
296
- query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
297
- pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
298
- pos_sin = pos_sin.type_as(q_)
299
- pos_cos = pos_cos.type_as(q_)
300
- q_ = self.apply_rotary_pos_emb(
301
- pos_sin[:, :, key_len - query_len : key_len, :],
302
- pos_cos[:, :, key_len - query_len : key_len, :],
303
- q_,
304
- )
305
- k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
306
- return q_.type_as(q), k_.type_as(k)
307
-
308
-
309
- class Activation(nn.Module):
310
- def __init__(self, config: ModelConfig):
311
- super().__init__()
312
- self.config = config
313
-
314
- @abstractmethod
315
- def forward(self, x: torch.Tensor) -> torch.Tensor:
316
- raise NotImplementedError
317
-
318
- @property
319
- @abstractmethod
320
- def output_multiplier(self) -> float:
321
- raise NotImplementedError
322
-
323
- @classmethod
324
- def build(cls, config: ModelConfig) -> Activation:
325
- if config.activation_type == ActivationType.gelu:
326
- return cast(Activation, GELU(approximate="none"))
327
- elif config.activation_type == ActivationType.relu:
328
- return cast(Activation, ReLU(inplace=False))
329
- elif config.activation_type == ActivationType.swiglu:
330
- return SwiGLU(config)
331
- else:
332
- raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
333
-
334
-
335
- class GELU(nn.GELU):
336
- @property
337
- def output_multiplier(self) -> float:
338
- return 1.0
339
-
340
-
341
- class ReLU(nn.ReLU):
342
- @property
343
- def output_multiplier(self) -> float:
344
- return 1.0
345
-
346
-
347
- class SwiGLU(Activation):
348
- def forward(self, x: torch.Tensor) -> torch.Tensor:
349
- x, gate = x.chunk(2, dim=-1)
350
- return F.silu(gate) * x
351
-
352
- @property
353
- def output_multiplier(self) -> float:
354
- return 0.5
355
-
356
-
357
- def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
358
- att_bias = torch.triu(
359
- torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
360
- diagonal=1,
361
- )
362
- att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
363
- return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
364
-
365
-
366
- def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
367
- if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
368
- if causal_bias.device != device:
369
- causal_bias = causal_bias.to(device)
370
- cache["causal_attention_bias"] = causal_bias
371
- return causal_bias
372
- with torch.autocast(device.type, enabled=False):
373
- causal_bias = causal_attention_bias(seq_len, device)
374
- cache["causal_attention_bias"] = causal_bias
375
- return causal_bias
376
-
377
-
378
- def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
379
- alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
380
-
381
- # shape: (1, 1, seq_len, seq_len)
382
- alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
383
- alibi_bias.abs_().mul_(-1)
384
-
385
- # shape: (n_heads,)
386
- m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
387
- m.mul_(config.alibi_bias_max / config.n_heads)
388
-
389
- # shape: (1, n_heads, seq_len, seq_len)
390
- return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
391
-
392
- def activation_quant(x):
393
- """Per−token quantization to 8 bits. No grouping is needed for quantization.
394
- Args:
395
- x: an activation tensor with shape [n, d]
396
- Returns:
397
- y: a quantized activation tensor with shape [n, d]
398
- """
399
- scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
400
- y = (x * scale).round().clamp_(-128, 127) / scale
401
- return y
402
-
403
- def weight_quant(w):
404
- """Per−tensor quantization to 1.58 bits. No grouping is needed for quantization.
405
- Args:
406
- w: a weight tensor with shape [d, k]
407
- Returns:
408
- u: a quantized weight with shape [d, k]
409
- """
410
- scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
411
- u = (w * scale).round().clamp_(-1, 1) / scale
412
- return u
413
-
414
- def activation_norm_quant(x):
415
- """
416
- same as activation_quant definition - but returning y and scale seperately
417
- Args:
418
- x: an activation tensor with shape [n, d]
419
- Returns:
420
- y: a quantized activation tensor with shape [n, d]
421
- scale: a scalar for dequantization with shape [1]
422
- """
423
- scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
424
- y = (x * scale).round().clamp_(-128, 127)
425
- return y, scale
426
-
427
- def gemm_lowbit_kernel(x, w):
428
- y = F.linear(x, w)
429
- return y
430
-
431
- class BitLinear158(nn.Linear):
432
- """
433
- This is only for training, and kernel optimization is needed for efficiency.
434
- """
435
- def __init__(self, in_features: int, out_features: int, bias: bool = True,
436
- device=None, dtype=None, config=None):
437
- super().__init__(in_features, out_features, bias, device, dtype)
438
- self.norm = RMSLayerNorm(config, elementwise_affine=False)
439
-
440
- def forward(self, x):
441
- """
442
- Args:
443
- x: an input tensor with shape [n, d]
444
- Returns:
445
- y: an output tensor with shape [n, d]
446
- """
447
- w = self.weight # a weight tensor with shape [d, k]
448
- x_norm = self.norm(x)
449
- # Atrick for implementing Straight−Through−Estimator (STE) using detach()
450
- x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
451
- w_quant = w + (weight_quant(w) - w).detach()
452
- y = F.linear(x_quant, w_quant)
453
- return y
454
-
455
- class BitLinear158_inference(nn.Linear):
456
- """
457
- Use quantized weights for inference .
458
- """
459
- def __init__(self, in_features: int, out_features: int, bias: bool = True,
460
- device=None, dtype=None, config=None):
461
- super().__init__(in_features, out_features, bias, device, dtype)
462
- self.norm = RMSLayerNorm(config, elementwise_affine=False)
463
- self.weight_scale = nn.Parameter(torch.ones(1))
464
-
465
- def forward(self, x):
466
- """
467
- Args:
468
- x: an input tensor with shape [n, d]
469
- Returns:
470
- y: an output tensor with shape [n, d]
471
- """
472
- w = self.weight # a 1.58−bit weight tensor with shape [d, k]
473
- w_scale = self.weight_scale # a full−precision weight scale tensor with shape [1]
474
- x_norm = self.norm(x)
475
- x_quant, x_scale = activation_norm_quant(x_norm)
476
- y = gemm_lowbit_kernel(x_quant, w) / w_scale / x_scale
477
- return y
478
-
479
-
480
- class OLMoBlock(nn.Module):
481
- """
482
- A base class for transformer block implementations.
483
- """
484
-
485
- def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
486
- super().__init__()
487
- self.layer_id = layer_id
488
- self.config = config
489
- self.hidden_size = (
490
- config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
491
- )
492
- self.__cache = cache
493
- assert config.d_model % config.n_heads == 0
494
-
495
- self._activation_checkpoint_fn = None
496
-
497
- Linear = BitLinear158_inference if config.inference_mode else BitLinear158 if config.ternary else nn.Linear
498
-
499
- # Dropout.
500
- self.dropout = Dropout(config.residual_dropout)
501
-
502
- # Layer norms.
503
- self.k_norm: Optional[LayerNormBase] = None
504
- self.q_norm: Optional[LayerNormBase] = None
505
- if config.attention_layer_norm:
506
- self.k_norm = LayerNormBase.build(
507
- config,
508
- size=config.d_model // config.n_heads if config.multi_query_attention else None,
509
- elementwise_affine=config.attention_layer_norm_with_affine,
510
- )
511
- self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
512
-
513
- # Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
514
- if config.clip_qkv is not None:
515
- assert config.clip_qkv > 0
516
-
517
- # Activation function.
518
- self.act = Activation.build(config)
519
- assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
520
-
521
- # Attention output projection.
522
- self.attn_out = Linear(
523
- config.d_model, config.d_model, bias=config.include_bias, device=config.init_device,
524
- config=config
525
- )
526
-
527
- # Feed-forward output projection.
528
- self.ff_out = Linear(
529
- int(self.act.output_multiplier * self.hidden_size),
530
- config.d_model,
531
- bias=config.include_bias,
532
- device=config.init_device,
533
- config=config,
534
- )
535
- self.ff_out._is_residual = True # type: ignore
536
-
537
- # Rotary embeddings.
538
- if self.config.rope:
539
- self.rotary_emb = RotaryEmbedding(config, self.__cache)
540
-
541
- def reset_parameters(self):
542
- if self.k_norm is not None:
543
- self.k_norm.reset_parameters()
544
- if self.q_norm is not None:
545
- self.q_norm.reset_parameters()
546
- init_weights(
547
- self.config,
548
- self.attn_out,
549
- d=self.config.d_model,
550
- layer_id=self.layer_id,
551
- type_of_module=ModuleType.out_module,
552
- )
553
- init_weights(
554
- self.config,
555
- self.ff_out,
556
- d=self.ff_out.in_features,
557
- layer_id=self.layer_id,
558
- type_of_module=ModuleType.out_module,
559
- )
560
-
561
- def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
562
- if strategy == ActivationCheckpointingStrategy.fine_grained:
563
- self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
564
- else:
565
- self._activation_checkpoint_fn = None
566
-
567
- @classmethod
568
- def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
569
- target_dtype = input_dtype
570
- # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
571
- # `is_autocast_cpu_enabled()` for CPU autocast.
572
- # See https://github.com/pytorch/pytorch/issues/110966.
573
- if bias.device.type == "cuda" and torch.is_autocast_enabled():
574
- target_dtype = torch.get_autocast_gpu_dtype()
575
- elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
576
- target_dtype = torch.get_autocast_cpu_dtype()
577
- if bias.dtype != target_dtype:
578
- bias = bias.to(target_dtype)
579
- ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
580
- return bias
581
-
582
- def _scaled_dot_product_attention(
583
- self,
584
- q: torch.Tensor,
585
- k: torch.Tensor,
586
- v: torch.Tensor,
587
- attn_mask: Optional[torch.Tensor] = None,
588
- dropout_p: float = 0.0,
589
- is_causal: bool = False,
590
- ) -> torch.Tensor:
591
- """
592
- Computes scaled dot product attention on query, key and value tensors, using an optional
593
- attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
594
-
595
- This method is based on PyTorch's `scaled_dot_product_attention`.
596
- """
597
- return F.scaled_dot_product_attention(
598
- q,
599
- k,
600
- v,
601
- attn_mask=attn_mask,
602
- dropout_p=dropout_p,
603
- is_causal=is_causal,
604
- )
605
-
606
- def attention(
607
- self,
608
- q: torch.Tensor,
609
- k: torch.Tensor,
610
- v: torch.Tensor,
611
- attention_bias: Optional[torch.Tensor] = None,
612
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
613
- use_cache: bool = False,
614
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
615
- B, T, C = q.size() # batch size, sequence length, d_model
616
- dtype = k.dtype
617
-
618
- # Optionally apply layer norm to keys and queries.
619
- if self.q_norm is not None and self.k_norm is not None:
620
- q = self.q_norm(q).to(dtype=dtype)
621
- k = self.k_norm(k).to(dtype=dtype)
622
-
623
- # Move head forward to be next to the batch dim.
624
- # shape: (B, nh, T, hs)
625
- q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
626
- if self.config.multi_query_attention:
627
- # shape: (B, 1, T, hs)
628
- k = k.view(B, T, 1, C // self.config.n_heads).transpose(1, 2)
629
- # shape: (B, 1, T, hs)
630
- v = v.view(B, T, 1, C // self.config.n_heads).transpose(1, 2)
631
- else:
632
- # shape: (B, nh, T, hs)
633
- k = k.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
634
- # shape: (B, nh, T, hs)
635
- v = v.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
636
-
637
- if layer_past is not None:
638
- past_key, past_value = layer_past
639
- k = torch.cat((past_key, k), dim=-2)
640
- v = torch.cat((past_value, v), dim=-2)
641
-
642
- present = (k, v) if use_cache else None
643
- query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
644
-
645
- if self.config.rope:
646
- # Apply rotary embeddings.
647
- q, k = self.rotary_emb(q, k)
648
-
649
- if attention_bias is not None:
650
- # Resize and cast attention bias.
651
- # The current dtype of the attention bias might not match the dtype that the SDP attn function will
652
- # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
653
- # as down-casting the attention bias to the autocast precision will result in -infs, which will
654
- # cause the SDP attn function to produce NaNs.
655
- attention_bias = self._cast_attn_bias(
656
- attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
657
- )
658
-
659
- # Get the attention scores.
660
- # shape: (B, nh, T, hs)
661
- att = self._scaled_dot_product_attention(
662
- q,
663
- k,
664
- v,
665
- attn_mask=attention_bias,
666
- dropout_p=0.0 if not self.training else self.config.attention_dropout,
667
- is_causal=attention_bias is None,
668
- )
669
-
670
- # Re-assemble all head outputs side-by-side.
671
- att = att.transpose(1, 2).contiguous().view(B, T, C)
672
-
673
- # Apply output projection.
674
- return self.attn_out(att), present
675
-
676
- @abstractmethod
677
- def forward(
678
- self,
679
- x: torch.Tensor,
680
- attention_bias: Optional[torch.FloatTensor] = None,
681
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
682
- use_cache: bool = False,
683
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
684
- raise NotImplementedError
685
-
686
- @classmethod
687
- def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBlock:
688
- if config.block_type == BlockType.sequential:
689
- return OLMoSequentialBlock(layer_id, config, cache)
690
- elif config.block_type == BlockType.parallel:
691
- return OLMoParallelBlock(layer_id, config, cache)
692
- elif config.block_type == BlockType.llama:
693
- return OLMoLlamaBlock(layer_id, config, cache)
694
- else:
695
- raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
696
-
697
-
698
- class OLMoSequentialBlock(OLMoBlock):
699
- """
700
- This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
701
- (plus another skip connection).
702
- """
703
-
704
- def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
705
- super().__init__(layer_id, config, cache)
706
- # Layer norms.
707
- self.attn_norm = LayerNorm.build(config)
708
- self.ff_norm = LayerNorm.build(config)
709
- Linear = BitLinear158_inference if config.inference_mode else BitLinear158 if config.ternary else nn.Linear
710
- # Attention input projection. Projects x -> (q, k, v)
711
- if config.multi_query_attention:
712
- self.fused_dims = (config.d_model, config.d_model // config.n_heads, config.d_model // config.n_heads)
713
- else:
714
- self.fused_dims = (config.d_model, config.d_model, config.d_model)
715
- self.att_proj = Linear(
716
- config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device,
717
- config=config
718
- )
719
- # Feed-forward input projection.
720
- self.ff_proj = Linear(
721
- config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device,
722
- config=config
723
- )
724
-
725
- def reset_parameters(self):
726
- super().reset_parameters()
727
- self.attn_norm.reset_parameters()
728
- self.ff_norm.reset_parameters()
729
- # NOTE: the standard deviation for these weights does not depend on the layer.
730
- init_weights(
731
- self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
732
- )
733
- init_weights(
734
- self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
735
- )
736
-
737
- def forward(
738
- self,
739
- x: torch.Tensor,
740
- attention_bias: Optional[torch.Tensor] = None,
741
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
742
- use_cache: bool = False,
743
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
744
- # Get query, key, value projections.
745
- # shape:
746
- # - for regular attn q, k, v: (batch_size, seq_len, d_model)
747
- # - for multi-query attn q: (batch_size, seq_len, d_model)
748
- # k, v: (batch_size, seq_len, d_model // n_heads)
749
- if self._activation_checkpoint_fn is not None:
750
- qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x))
751
- else:
752
- qkv = self.att_proj(self.attn_norm(x))
753
-
754
- if self.config.clip_qkv is not None:
755
- qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
756
-
757
- q, k, v = qkv.split(self.fused_dims, dim=-1)
758
-
759
- # Get attention scores.
760
- if self._activation_checkpoint_fn is not None:
761
- att, cache = self._activation_checkpoint_fn( # type: ignore
762
- self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
763
- )
764
- else:
765
- att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
766
-
767
- # Add attention scores.
768
- # shape: (B, T, C)
769
- x = x + self.dropout(att)
770
-
771
- # Add feed-forward projection.
772
- # shape: (batch_size, seq_len, d_model)
773
- og_x = x
774
- if self._activation_checkpoint_fn is not None:
775
- x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
776
- else:
777
- x = self.ff_norm(x)
778
- x = self.ff_proj(x)
779
- if self._activation_checkpoint_fn is not None:
780
- x = self._activation_checkpoint_fn(self.act, x) # type: ignore
781
- else:
782
- x = self.act(x)
783
- x = self.ff_out(x)
784
- x = self.dropout(x)
785
- x = og_x + x
786
-
787
- return x, cache
788
-
789
-
790
- class OLMoParallelBlock(OLMoBlock):
791
- """
792
- This is a transformer block where the output is computed as ``MLP(LN(x)) + Attention(LN(x))``
793
- as in the PaLM architecture, as opposed to the typical ``MLP(LN(x + Attention(LN(x))))``
794
- as in :class:`OLMoSequentialBlock` (ignoring some skip connections).
795
-
796
- The decoupling of the MLP and Attention functions allow us to fuse the separate input projections
797
- into a single linear layer to increase throughput. In this configuration it's also straight-forward
798
- to fuse the output projections, but we found that didn't help.
799
- """
800
-
801
- def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
802
- super().__init__(layer_id, config, cache)
803
- self.norm = LayerNorm.build(config)
804
- Linear = BitLinear158_inference if config.inference_mode else BitLinear158 if config.ternary else nn.Linear
805
-
806
- # Fused attention and feed-forward projection.
807
- # NOTE: we could also fuse the attention and feed-forward output projections but we
808
- # found that didn't help, possibly because of the overhead of joining the `att` and
809
- # `ff` activations together. See https://github.com/allenai/LLM/pull/79 for details.
810
- if config.multi_query_attention:
811
- self.fused_dims = (
812
- config.d_model,
813
- config.d_model // config.n_heads,
814
- config.d_model // config.n_heads,
815
- self.hidden_size,
816
- )
817
- else:
818
- self.fused_dims = (config.d_model, config.d_model, config.d_model, self.hidden_size)
819
- self.fused_attn_ff_proj = Linear(
820
- config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device,
821
- config=config
822
- )
823
-
824
- def reset_parameters(self):
825
- super().reset_parameters()
826
- self.norm.reset_parameters()
827
- # NOTE: the standard deviation for these weights does not depend on the layer.
828
- init_weights(
829
- self.config,
830
- self.fused_attn_ff_proj,
831
- d=self.config.d_model,
832
- layer_id=None,
833
- type_of_module=ModuleType.in_module,
834
- )
835
-
836
- def forward(
837
- self,
838
- x: torch.Tensor,
839
- attention_bias: Optional[torch.Tensor] = None,
840
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
841
- use_cache: bool = False,
842
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
843
- # Get query, key, value, and feed-forward projections.
844
- # shape of q, k, v:
845
- # - for regular attn q, k, v: (batch_size, seq_len, d_model)
846
- # - for multi-query attn q: (batch_size, seq_len, d_model)
847
- # k, v: (batch_size, seq_len, d_model // n_heads)
848
- # shape of ff: (batch_size, seq_len, hidden_size)
849
- if self._activation_checkpoint_fn is not None:
850
- q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split(
851
- self.fused_dims, dim=-1
852
- )
853
- else:
854
- q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1)
855
-
856
- if self.config.clip_qkv is not None:
857
- q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
858
- k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
859
- v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
860
-
861
- # Get attention scores.
862
- # shape: (B, T, C)
863
- if self._activation_checkpoint_fn is not None:
864
- att, cache = self._activation_checkpoint_fn( # type: ignore
865
- self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
866
- )
867
- else:
868
- att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
869
-
870
- # Apply output projections (and activation function) and sum the results.
871
- # We keep these projections separate because we found that we got better throughput this
872
- # way compared to fusing them.
873
- if self._activation_checkpoint_fn is not None:
874
- return (
875
- x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att),
876
- cache,
877
- )
878
- else:
879
- return (
880
- x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att),
881
- cache,
882
- )
883
-
884
-
885
- class OLMoLlamaBlock(OLMoBlock):
886
- """
887
- This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
888
- (plus another skip connection). This block is similar to `OLMoSequentialBlock`
889
- but some operations have slightly different implementations to imitate the
890
- behavior of Llama.
891
- """
892
-
893
- def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
894
- super().__init__(layer_id, config, cache)
895
- # Layer norms.
896
- self.attn_norm = LayerNorm.build(config)
897
- self.ff_norm = LayerNorm.build(config)
898
- self.__cache = cache
899
- Linear = BitLinear158_inference if config.inference_mode else BitLinear158 if config.ternary else nn.Linear
900
-
901
-
902
- # Attention input projection. Projects x -> (q, k, v)
903
- if config.multi_query_attention:
904
- q_proj_out_dim = config.d_model
905
- k_proj_out_dim = config.d_model // config.n_heads
906
- v_proj_out_dim = config.d_model // config.n_heads
907
- else:
908
- q_proj_out_dim = config.d_model
909
- k_proj_out_dim = config.d_model
910
- v_proj_out_dim = config.d_model
911
- self.q_proj = Linear(
912
- config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device,
913
- config=config
914
- )
915
- self.k_proj = Linear(
916
- config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device,
917
- config=config
918
- )
919
- self.v_proj = Linear(
920
- config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device,
921
- config=config
922
- )
923
-
924
- # Feed-forward input projection.
925
- self.ff_proj = Linear(
926
- config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device,
927
- config=config
928
- )
929
-
930
- def reset_parameters(self):
931
- super().reset_parameters()
932
- if self.attn_norm:
933
- self.attn_norm.reset_parameters()
934
- self.ff_norm.reset_parameters()
935
- # NOTE: the standard deviation for these weights does not depend on the layer.
936
- init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
937
- init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
938
- init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
939
- init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
940
-
941
- def _scaled_dot_product_attention(
942
- self,
943
- q: torch.Tensor,
944
- k: torch.Tensor,
945
- v: torch.Tensor,
946
- attn_mask: Optional[torch.Tensor] = None,
947
- dropout_p: float = 0.0,
948
- is_causal: bool = False,
949
- ) -> torch.Tensor:
950
- attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
951
-
952
- if is_causal:
953
- assert attn_mask is None
954
-
955
- query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
956
- attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
957
- elif attn_mask is not None:
958
- attn_bias = attn_mask.to(q.dtype)
959
- else:
960
- attn_bias = torch.zeros_like(attn_weights)
961
-
962
- attn_weights += attn_bias
963
- attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
964
- attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
965
- return torch.matmul(attn_weights, v)
966
-
967
- def forward(
968
- self,
969
- x: torch.Tensor,
970
- attention_bias: Optional[torch.Tensor] = None,
971
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
972
- use_cache: bool = False,
973
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
974
- # Get query, key, value projections.
975
- # shape:
976
- # - for regular attn q, k, v: (batch_size, seq_len, d_model)
977
- # - for multi-query attn q: (batch_size, seq_len, d_model)
978
- # k, v: (batch_size, seq_len, d_model // n_heads)
979
- x_normed = self.attn_norm(x)
980
- q = self.q_proj(x_normed)
981
- k = self.k_proj(x_normed)
982
- v = self.v_proj(x_normed)
983
-
984
- if self.config.clip_qkv is not None:
985
- q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
986
- k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
987
- v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
988
-
989
- # Get attention scores.
990
- att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
991
-
992
- # Add attention scores.
993
- # shape: (B, T, C)
994
- x = x + self.dropout(att)
995
-
996
- # Add feed-forward projection.
997
- # shape: (batch_size, seq_len, d_model)
998
- og_x = x
999
- if self._activation_checkpoint_fn is not None:
1000
- x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
1001
- else:
1002
- x = self.ff_norm(x)
1003
- x = self.ff_proj(x)
1004
- if self._activation_checkpoint_fn is not None:
1005
- x = self._activation_checkpoint_fn(self.act, x) # type: ignore
1006
- else:
1007
- x = self.act(x)
1008
- x = self.ff_out(x)
1009
- x = self.dropout(x)
1010
- x = og_x + x
1011
-
1012
- return x, cache
1013
-
1014
-
1015
- class OLMoOutput(NamedTuple):
1016
- logits: torch.FloatTensor
1017
- """
1018
- A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
1019
- for the next token *before* normalization via (log) softmax.
1020
- """
1021
-
1022
- attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
1023
- """
1024
- Attention keys and values from each block.
1025
- """
1026
-
1027
- hidden_states: Optional[Tuple[torch.Tensor]]
1028
- """
1029
- Hidden states from each block.
1030
- """
1031
-
1032
-
1033
- class OLMoGenerateOutput(NamedTuple):
1034
- token_ids: torch.LongTensor
1035
- """
1036
- The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
1037
- These do *not* include the original input IDs.
1038
- """
1039
-
1040
- scores: torch.FloatTensor
1041
- """
1042
- The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
1043
- """
1044
-
1045
-
1046
- class OLMoBlockGroup(nn.ModuleList):
1047
- def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
1048
- super().__init__(modules)
1049
- self.config = config
1050
- self.layer_offset = layer_offset
1051
- self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1052
- self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
1053
-
1054
- def forward(
1055
- self,
1056
- x: torch.Tensor,
1057
- attention_bias: Optional[torch.FloatTensor] = None,
1058
- layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
1059
- use_cache: bool = False,
1060
- ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
1061
- attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1062
- for block_idx, block in enumerate(self):
1063
- layer_past = None if layers_past is None else layers_past[block_idx]
1064
- block_idx += self.layer_offset
1065
- if (
1066
- (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1067
- or (
1068
- self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1069
- and block_idx % 2 == 0
1070
- )
1071
- or (
1072
- self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1073
- and block_idx % 3 == 0
1074
- )
1075
- or (
1076
- self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1077
- and block_idx % 4 == 0
1078
- )
1079
- ):
1080
- # shape: (batch_size, seq_len, d_model)
1081
- x, cache = self._activation_checkpoint_fn( # type: ignore
1082
- block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1083
- )
1084
- else:
1085
- # shape: (batch_size, seq_len, d_model)
1086
- x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1087
- if attn_key_values is not None:
1088
- assert cache is not None
1089
- attn_key_values.append(cache)
1090
- return x, attn_key_values
1091
-
1092
- def reset_parameters(self):
1093
- for block in self:
1094
- block.reset_parameters()
1095
-
1096
- def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1097
- self.activation_checkpointing_strategy = strategy
1098
- for block in self:
1099
- block.set_activation_checkpointing(strategy)
1100
-
1101
-
1102
- class OLMo(nn.Module):
1103
- def __init__(self, config: ModelConfig, init_params: bool = True):
1104
- super().__init__()
1105
- self.config = config
1106
- self.__cache = BufferCache()
1107
-
1108
- # Validate config.
1109
- if self.config.alibi and self.config.flash_attention:
1110
- raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention")
1111
-
1112
- if self.config.alibi and self.config.rope:
1113
- raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive")
1114
-
1115
- if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
1116
- if self.config.embedding_size < self.config.vocab_size:
1117
- raise OLMoConfigurationError("embedding size should be at least as big as vocab size")
1118
- elif self.config.embedding_size % 128 != 0:
1119
- import warnings
1120
-
1121
- warnings.warn(
1122
- "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1123
- )
1124
-
1125
- self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1126
- self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
1127
-
1128
- if not (
1129
- 0 < self.config.block_group_size <= self.config.n_layers
1130
- and self.config.n_layers % self.config.block_group_size == 0
1131
- ):
1132
- raise OLMoConfigurationError("n layers must be divisible by block group size")
1133
-
1134
- torch.backends.cuda.enable_flash_sdp(self.config.flash_attention)
1135
- torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
1136
-
1137
- self.transformer = nn.ModuleDict(
1138
- dict(
1139
- wte=nn.Embedding(
1140
- config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
1141
- ),
1142
- emb_drop=Dropout(config.embedding_dropout),
1143
- ln_f=LayerNorm.build(config),
1144
- )
1145
- )
1146
-
1147
- blocks = [OLMoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
1148
- if self.config.block_group_size > 1:
1149
- block_groups = [
1150
- OLMoBlockGroup(config, i, blocks[i : i + config.block_group_size])
1151
- for i in range(0, config.n_layers, config.block_group_size)
1152
- ]
1153
- self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
1154
- else:
1155
- self.transformer.update({"blocks": nn.ModuleList(blocks)})
1156
-
1157
- if not (self.config.alibi or self.config.rope):
1158
- self.transformer.update(
1159
- {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1160
- )
1161
- if not config.weight_tying:
1162
- self.transformer.update(
1163
- {
1164
- "ff_out": nn.Linear(
1165
- config.d_model,
1166
- config.embedding_size or config.vocab_size,
1167
- bias=config.include_bias,
1168
- device=config.init_device,
1169
- )
1170
- }
1171
- )
1172
- # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1173
- if init_params and self.config.init_device != "meta":
1174
- self.reset_parameters()
1175
- self.__num_fwd_flops: Optional[int] = None
1176
-
1177
- # Warm up cache.
1178
- if self.config.alibi:
1179
- get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
1180
- self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
1181
-
1182
- def embed_tokens(self, input_ids):
1183
- return self.transformer.wte(input_ids)
1184
-
1185
- def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1186
- self.activation_checkpointing_strategy = strategy
1187
- if self.config.block_group_size != 1:
1188
- for block_group in self.transformer.block_groups:
1189
- block_group.set_activation_checkpointing(strategy)
1190
- else:
1191
- for block in self.transformer.blocks:
1192
- block.set_activation_checkpointing(strategy)
1193
-
1194
- @property
1195
- def device(self) -> torch.device:
1196
- device: torch.device = self.transformer.wte.weight.device # type: ignore
1197
- if device.type == "meta":
1198
- return _non_meta_init_device(self.config)
1199
- else:
1200
- return device
1201
-
1202
- def reset_parameters(self):
1203
- log.info("Initializing model parameters...")
1204
- # Top-level embeddings / linear layers.
1205
- init_weights(
1206
- self.config,
1207
- self.transformer.wte, # type: ignore
1208
- std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
1209
- type_of_module=ModuleType.emb,
1210
- )
1211
- if hasattr(self.transformer, "wpe"):
1212
- init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
1213
-
1214
- # Top-level layer norm.
1215
- self.transformer.ln_f.reset_parameters() # type: ignore
1216
-
1217
- # Output weights.
1218
- if hasattr(self.transformer, "ff_out"):
1219
- init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
1220
-
1221
- # Let the blocks handle themselves.
1222
- if self.config.block_group_size == 1:
1223
- for block in self.transformer.blocks:
1224
- block.reset_parameters()
1225
- else:
1226
- for block_group in self.transformer.block_groups:
1227
- block_group.reset_parameters()
1228
-
1229
- def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1230
- if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
1231
- -1
1232
- ] >= seq_len:
1233
- if alibi_bias.device != device:
1234
- alibi_bias = alibi_bias.to(device)
1235
- self.__cache["alibi_attention_bias"] = alibi_bias
1236
- return alibi_bias
1237
- with torch.autocast(device.type, enabled=False):
1238
- alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1239
- self.__cache["alibi_attention_bias"] = alibi_bias
1240
- return alibi_bias
1241
-
1242
- def forward(
1243
- self,
1244
- input_ids: torch.LongTensor,
1245
- inputs_embeds: Optional[torch.FloatTensor] = None,
1246
- attention_mask: Optional[torch.Tensor] = None,
1247
- attention_bias: Optional[torch.Tensor] = None,
1248
- past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1249
- use_cache: bool = False,
1250
- last_logits_only: bool = False,
1251
- output_hidden_states: Optional[bool] = None,
1252
- ) -> OLMoOutput:
1253
- """
1254
- :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1255
- :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
1256
- embeddings. When provided, it is treated as the output of the input embedding layer.
1257
- :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
1258
- which input IDs are masked. A `1` value in the mask means that
1259
- the corresponding input ID should *not* be ignored. A `0` means
1260
- that the corresponding input ID is masked.
1261
-
1262
- This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1263
- library.
1264
- :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1265
- `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1266
- to introduce causal or other biases.
1267
-
1268
- If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1269
- indicates that the i-th element in the sequence is allowed to attend to the j-th
1270
- element in the sequence.
1271
-
1272
- If the tensor is a float tensor, it will just be added to the attention
1273
- scores before the softmax.
1274
-
1275
- The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1276
- :param past_key_values: Pre-computed keys and values for each attention block.
1277
- Can be used to speed up sequential decoding. The `input_ids` which have
1278
- their past given to this model should not be passed as `input_ids` as they have already been computed.
1279
- :param use_cache: If `True`, return key and value tensors for each block.
1280
- :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
1281
- This can speed up decoding when you only care about the next token.
1282
- """
1283
- output_hidden_states = output_hidden_states if output_hidden_states is not None else False
1284
-
1285
- if past_key_values:
1286
- assert len(past_key_values) == self.config.n_layers
1287
-
1288
- batch_size, seq_len = input_ids.size() if inputs_embeds is None else inputs_embeds.size()[:2]
1289
- if past_key_values is None:
1290
- past_length = 0
1291
- else:
1292
- past_length = past_key_values[0][0].size(-2)
1293
-
1294
- # Get embeddings of input.
1295
- # shape: (batch_size, seq_len, d_model)
1296
- x = self.transformer.wte(input_ids) if inputs_embeds is None else inputs_embeds # type: ignore
1297
-
1298
- if not (self.config.alibi or self.config.rope):
1299
- # Get positional embeddings.
1300
- # shape: (1, seq_len)
1301
- pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
1302
- # shape: (1, seq_len, d_model)
1303
- pos_emb = self.transformer.wpe(pos) # type: ignore
1304
- x = pos_emb + x
1305
-
1306
- # Add input + positional embeddings and apply dropout.
1307
- # shape: (batch_size, seq_len, d_model)
1308
- x = self.transformer.emb_drop(x) # type: ignore
1309
-
1310
- # Transform the attention mask into what the blocks expect.
1311
- if attention_mask is not None:
1312
- # shape: (batch_size, 1, 1, seq_len)
1313
- attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1314
- attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
1315
-
1316
- # Merge attention mask with attention bias.
1317
- if (
1318
- attention_bias is not None
1319
- or attention_mask is not None
1320
- or self.config.alibi
1321
- # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1322
- # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1323
- # scores correctly.
1324
- or past_key_values is not None
1325
- ):
1326
- if attention_bias is None and self.config.alibi:
1327
- attention_bias = get_causal_attention_bias(
1328
- self.__cache, past_length + seq_len, x.device
1329
- ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1330
- elif attention_bias is None:
1331
- attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1332
- elif attention_bias.dtype in (torch.int8, torch.bool):
1333
- attention_bias = attention_bias.to(dtype=torch.float)
1334
- attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
1335
-
1336
- # Transform to the right shape and data type.
1337
- mask_len = seq_len
1338
- if attention_mask is not None:
1339
- mask_len = attention_mask.shape[-1]
1340
- elif past_key_values is not None:
1341
- mask_len = past_key_values[0][0].shape[-2] + seq_len
1342
- attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
1343
-
1344
- # Add in the masking bias.
1345
- if attention_mask is not None:
1346
- attention_bias = attention_bias + attention_mask
1347
- # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
1348
- # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
1349
- # it can produce NaNs.
1350
- ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
1351
-
1352
- attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1353
-
1354
- # decoder layers
1355
- all_hidden_states = []
1356
-
1357
- # Apply blocks one-by-one.
1358
- if self.config.block_group_size == 1:
1359
- for block_idx, block in enumerate(self.transformer.blocks):
1360
- if output_hidden_states:
1361
- # add hidden states
1362
- all_hidden_states.append(x)
1363
-
1364
- layer_past = None if past_key_values is None else past_key_values[block_idx]
1365
- if (
1366
- (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1367
- or (
1368
- self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1369
- and block_idx % 2 == 0
1370
- )
1371
- or (
1372
- self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1373
- and block_idx % 3 == 0
1374
- )
1375
- or (
1376
- self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1377
- and block_idx % 4 == 0
1378
- )
1379
- ):
1380
- # shape: (batch_size, seq_len, d_model)
1381
- x, cache = self._activation_checkpoint_fn(
1382
- block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1383
- )
1384
- else:
1385
- # shape: (batch_size, seq_len, d_model)
1386
- x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1387
- if attn_key_values is not None:
1388
- assert cache is not None
1389
- attn_key_values.append(cache)
1390
- else:
1391
- for group_idx, block_group in enumerate(self.transformer.block_groups):
1392
- if output_hidden_states:
1393
- # add hidden states
1394
- all_hidden_states.append(x)
1395
-
1396
- layers_past = (
1397
- None
1398
- if past_key_values is None
1399
- else past_key_values[
1400
- group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
1401
- ]
1402
- )
1403
- x, cache = block_group(
1404
- x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
1405
- )
1406
- if attn_key_values is not None:
1407
- assert cache is not None
1408
- attn_key_values.extend(cache)
1409
-
1410
- if last_logits_only:
1411
- # shape: (batch_size, 1, d_model)
1412
- x = x[:, -1, :].unsqueeze(1)
1413
-
1414
- # Apply final layer norm.
1415
- # shape: (batch_size, seq_len or 1, d_model)
1416
- x = self.transformer.ln_f(x) # type: ignore
1417
- if output_hidden_states:
1418
- # add final hidden state post-final-layernorm, following HuggingFace's convention
1419
- all_hidden_states.append(x)
1420
-
1421
- # Get logits.
1422
- # shape: (batch_size, seq_len or 1, vocab_size)
1423
- if self.config.weight_tying:
1424
- logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
1425
- else:
1426
- logits = self.transformer.ff_out(x) # type: ignore
1427
- if self.config.scale_logits:
1428
- logits.mul_(1 / math.sqrt(self.config.d_model))
1429
-
1430
- return BaseModelOutputWithPast(
1431
- last_hidden_state=x,
1432
- past_key_values=tuple(attn_key_values) if attn_key_values is not None else None,
1433
- hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
1434
- )
1435
-
1436
- def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
1437
- if wrap_strategy is None:
1438
- return None
1439
-
1440
- # The 'recurse' mode for the wrap function does not behave like you'd expect.
1441
- # Even if we return False, it may still recurse because PyTorch does what it wants,
1442
- # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
1443
- # but not other linear layers within a block.
1444
- # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
1445
- # return True in 'recurse' mode for simplicity.
1446
- size_based_module_to_wrap = {self.transformer.wte}
1447
- if hasattr(self.transformer, "ff_out"):
1448
- size_based_module_to_wrap.add(self.transformer.ff_out)
1449
-
1450
- if wrap_strategy == FSDPWrapStrategy.by_block:
1451
-
1452
- def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1453
- del nonwrapped_numel
1454
- wrap = isinstance(module, OLMoBlock)
1455
- if recurse:
1456
- return True
1457
- else:
1458
- return wrap
1459
-
1460
- return fsdp_wrap_fn
1461
- elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:
1462
-
1463
- def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1464
- del nonwrapped_numel
1465
- wrap = isinstance(module, (OLMoBlock,)) or module in size_based_module_to_wrap
1466
- if recurse:
1467
- return True
1468
- else:
1469
- return wrap
1470
-
1471
- return fsdp_wrap_fn
1472
- elif wrap_strategy == FSDPWrapStrategy.by_block_group:
1473
- if self.config.block_group_size <= 1:
1474
- raise OLMoConfigurationError(
1475
- "'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
1476
- )
1477
-
1478
- def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1479
- del nonwrapped_numel
1480
- wrap = isinstance(module, OLMoBlockGroup)
1481
- if recurse:
1482
- return True
1483
- else:
1484
- return wrap
1485
-
1486
- return fsdp_wrap_fn
1487
- elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
1488
- if self.config.block_group_size <= 1:
1489
- raise OLMoConfigurationError(
1490
- "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1"
1491
- )
1492
-
1493
- def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1494
- del nonwrapped_numel
1495
- wrap = isinstance(module, (OLMoBlockGroup,)) or module in size_based_module_to_wrap
1496
- if recurse:
1497
- return True
1498
- else:
1499
- return wrap
1500
-
1501
- return fsdp_wrap_fn
1502
- elif wrap_strategy == FSDPWrapStrategy.size_based:
1503
- from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
1504
-
1505
- return size_based_auto_wrap_policy
1506
- elif wrap_strategy in {
1507
- FSDPWrapStrategy.one_in_two,
1508
- FSDPWrapStrategy.one_in_three,
1509
- FSDPWrapStrategy.one_in_four,
1510
- FSDPWrapStrategy.one_in_five,
1511
- }:
1512
- c = {
1513
- FSDPWrapStrategy.one_in_two: 2,
1514
- FSDPWrapStrategy.one_in_three: 3,
1515
- FSDPWrapStrategy.one_in_four: 4,
1516
- FSDPWrapStrategy.one_in_five: 5,
1517
- }[wrap_strategy]
1518
-
1519
- def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1520
- del nonwrapped_numel
1521
- wrap = isinstance(module, OLMoBlock) and module.layer_id % c == 0
1522
- if recurse:
1523
- return True
1524
- else:
1525
- return wrap
1526
-
1527
- return fsdp_wrap_fn
1528
- else:
1529
- raise NotImplementedError(wrap_strategy)
1530
-
1531
- def num_params(self, include_embedding: bool = True) -> int:
1532
- """
1533
- Get the total number of parameters.
1534
- """
1535
- params = (np for np in self.named_parameters())
1536
- if not include_embedding:
1537
- params = filter( # type: ignore
1538
- lambda np: ".wte." not in np[0] and ".wpe." not in np[0],
1539
- params,
1540
- )
1541
- return sum(p.numel() for _, p in params)
1542
-
1543
- @property
1544
- def num_fwd_flops(self):
1545
- if self.__num_fwd_flops:
1546
- return self.__num_fwd_flops
1547
- n_params = self.num_params()
1548
- # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
1549
- # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
1550
- # this gets us FLOPs / token
1551
- params_flops_per_token = 2 * n_params
1552
- params_flops_per_seq = params_flops_per_token * self.config.max_sequence_length
1553
- # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
1554
- attn_flops_per_seq = (
1555
- self.config.n_layers * 2 * 2 * (self.config.d_model * (self.config.max_sequence_length**2))
1556
- )
1557
- self.__num_fwd_flops = params_flops_per_seq + attn_flops_per_seq
1558
- return self.__num_fwd_flops
1559
-
1560
- def generate(
1561
- self,
1562
- input_ids: torch.LongTensor,
1563
- attention_mask: Optional[torch.Tensor] = None,
1564
- attention_bias: Optional[torch.Tensor] = None,
1565
- max_steps: int = 10,
1566
- beam_size: int = 1,
1567
- per_node_beam_size: Optional[int] = None,
1568
- sampler: Optional[Sampler] = None,
1569
- min_steps: Optional[int] = None,
1570
- final_sequence_scorer: Optional[FinalSequenceScorer] = None,
1571
- constraints: Optional[List[Constraint]] = None,
1572
- ) -> OLMoGenerateOutput:
1573
- """
1574
- Generate token IDs using beam search.
1575
-
1576
- Note that by default ``beam_size`` is set to 1, which is greedy decoding.
1577
-
1578
- :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1579
- :param attention_mask: A optional tensor of shape `(batch_size, seq_len)`, the same
1580
- as for the forward method.
1581
- :param attention_bias: A tensor of shape
1582
- `(batch_size, 1, seq_len + tokens_to_generate, seq_len + tokens_to_generate)`,
1583
- the same as for the forward method except only one shape is excepted here.
1584
-
1585
- For an explanation of the other arguments, see :class:`BeamSearch`.
1586
- """
1587
- beam_search = BeamSearch(
1588
- self.config.eos_token_id,
1589
- max_steps=max_steps,
1590
- beam_size=beam_size,
1591
- per_node_beam_size=per_node_beam_size,
1592
- sampler=sampler,
1593
- min_steps=min_steps,
1594
- final_sequence_scorer=final_sequence_scorer,
1595
- constraints=constraints,
1596
- )
1597
-
1598
- # Validate inputs.
1599
- batch_size, seq_len = input_ids.shape
1600
- if attention_mask is not None:
1601
- assert attention_mask.shape == (batch_size, seq_len)
1602
- if attention_bias is not None:
1603
- assert len(attention_bias.shape) == 4
1604
- assert attention_bias.shape[:2] == (batch_size, 1)
1605
- assert (
1606
- seq_len + beam_search.max_steps
1607
- <= attention_bias.shape[2]
1608
- == attention_bias.shape[3]
1609
- <= self.config.max_sequence_length
1610
- )
1611
-
1612
- tokens_generated = 0
1613
-
1614
- def flatten_past_key_values(
1615
- past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
1616
- ) -> Dict[str, torch.Tensor]:
1617
- out = {}
1618
- for i, (key, value) in enumerate(past_key_values):
1619
- out[f"past_key_{i}"] = key
1620
- out[f"past_value_{i}"] = value
1621
- return out
1622
-
1623
- def unflatten_past_key_values(
1624
- past_key_values: Dict[str, torch.Tensor],
1625
- ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
1626
- out = []
1627
- for i in range(self.config.n_layers):
1628
- past_key = past_key_values[f"past_key_{i}"]
1629
- past_value = past_key_values[f"past_value_{i}"]
1630
- out.append((past_key, past_value))
1631
- return out
1632
-
1633
- def step(
1634
- last_predictions: torch.Tensor, state: dict[str, torch.Tensor]
1635
- ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
1636
- nonlocal tokens_generated
1637
-
1638
- attention_mask = state.get("attention_mask")
1639
- attention_bias = state.get("attention_bias")
1640
-
1641
- if tokens_generated > 0:
1642
- past_key_values = unflatten_past_key_values(state)
1643
- input_ids = last_predictions.unsqueeze(1)
1644
- if attention_mask is not None:
1645
- group_size = input_ids.shape[0]
1646
- attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1)
1647
- else:
1648
- past_key_values = None
1649
- input_ids = state["input_ids"]
1650
-
1651
- tokens_generated += 1
1652
-
1653
- # Run forward pass of model to get logits, then normalize to get log probs.
1654
- output = self(
1655
- input_ids,
1656
- attention_mask=attention_mask,
1657
- attention_bias=attention_bias,
1658
- past_key_values=past_key_values,
1659
- use_cache=True,
1660
- last_logits_only=True,
1661
- )
1662
- log_probs = F.log_softmax(output.logits[:, -1, :], dim=-1)
1663
-
1664
- # Create new state.
1665
- state = flatten_past_key_values(output.attn_key_values)
1666
- if attention_mask is not None:
1667
- state["attention_mask"] = attention_mask
1668
- if attention_bias is not None:
1669
- state["attention_bias"] = attention_bias
1670
-
1671
- return log_probs, state
1672
-
1673
- initial_preds = input_ids.new_zeros((batch_size,)) # This is arbitrary, we won't use this.
1674
- state: dict[str, torch.Tensor] = {"input_ids": input_ids}
1675
- if attention_mask is not None:
1676
- state["attention_mask"] = attention_mask
1677
- if attention_bias is not None:
1678
- state["attention_bias"] = attention_bias
1679
- with torch.no_grad():
1680
- token_ids, scores = beam_search.search(initial_preds, state, step)
1681
-
1682
- return OLMoGenerateOutput(
1683
- token_ids=token_ids, # type: ignore[arg-type]
1684
- scores=scores, # type: ignore[arg-type]
1685
- )
1686
-
1687
- @classmethod
1688
- def from_checkpoint(
1689
- cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: Optional[CheckpointType] = None
1690
- ) -> OLMo:
1691
- """
1692
- Load an OLMo model from a checkpoint.
1693
- """
1694
- from .util import resource_path
1695
-
1696
- # Guess checkpoint type.
1697
- if checkpoint_type is None:
1698
- try:
1699
- if resource_path(checkpoint_dir, "model.pt").is_file():
1700
- checkpoint_type = CheckpointType.unsharded
1701
- else:
1702
- checkpoint_type = CheckpointType.sharded
1703
- except FileNotFoundError:
1704
- checkpoint_type = CheckpointType.sharded
1705
-
1706
- # Load config.
1707
- config_path = resource_path(checkpoint_dir, "config.yaml")
1708
- model_config = ModelConfig.load(config_path, key="model", validate_paths=False)
1709
-
1710
- if checkpoint_type == CheckpointType.unsharded:
1711
- # Initialize model (always on CPU to start with so we don't run out of GPU memory).
1712
- model_config.init_device = "cpu"
1713
- model = OLMo(model_config)
1714
-
1715
- # Load state dict directly to target device.
1716
- state_dict_path = resource_path(checkpoint_dir, "model.pt")
1717
- state_dict = torch.load(state_dict_path, map_location="cpu")
1718
- model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
1719
- model = model.to(torch.device(device))
1720
- else:
1721
- from .checkpoint import load_model_state
1722
-
1723
- # Initialize model on target device. In this case the state dict is loaded in-place
1724
- # so it's not necessary to start on CPU if the target device is a GPU.
1725
- model_config.init_device = device
1726
- model = OLMo(model_config)
1727
-
1728
- # Load state dict in place.
1729
- load_model_state(checkpoint_dir, model)
1730
-
1731
- return model.eval()
1732
-
1733
- def _make_state_dict_compatible(
1734
- self, state_dict: Dict[str, torch.Tensor]
1735
- ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Set[str]]]:
1736
- """
1737
- Handles some cases where the state dict is valid yet may need to be transformed in order to
1738
- be loaded.
1739
-
1740
- This modifies the state dict in-place and also returns it, along with a mapping of original key
1741
- names to new key names in cases where the keys were simply renamed. That mapping can be used
1742
- to make a corresponding optimizer state dict compatible as well.
1743
- """
1744
- import re
1745
- from fnmatch import fnmatch
1746
-
1747
- new_keys_to_og_keys: Dict[str, str] = {}
1748
-
1749
- # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
1750
- # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
1751
- # fine without the prefixes. This also simplifies the other steps below.
1752
- for key in list(state_dict.keys()):
1753
- state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key)
1754
- new_keys_to_og_keys[new_key] = key
1755
-
1756
- # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
1757
- if self.config.block_type == BlockType.sequential:
1758
- for key in list(state_dict.keys()):
1759
- if fnmatch(key, "transformer.*.norm.weight"):
1760
- tensor = state_dict.pop(key)
1761
- state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor
1762
- new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1763
- state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone()
1764
- new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1765
- del new_keys_to_og_keys[key]
1766
- elif fnmatch(key, "transformer.*.norm.bias"):
1767
- tensor = state_dict.pop(key)
1768
- state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor
1769
- new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1770
- state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone()
1771
- new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1772
- del new_keys_to_og_keys[key]
1773
-
1774
- # For loading a state dict that was saved with a different `block_group_size`.
1775
- if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
1776
- state_dict_block_group_size = len(
1777
- [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
1778
- )
1779
- else:
1780
- state_dict_block_group_size = 1
1781
- if self.config.block_group_size != state_dict_block_group_size:
1782
- log.info(
1783
- f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
1784
- f"group size {self.config.block_group_size}"
1785
- )
1786
- # For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
1787
- # and then (re-)group them into the right block sizes.
1788
- if state_dict_block_group_size > 1:
1789
- for key in list(state_dict.keys()):
1790
- if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
1791
- group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
1792
- block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
1793
- state_dict[
1794
- (
1795
- new_key := key.replace(
1796
- f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}."
1797
- )
1798
- )
1799
- ] = state_dict.pop(key)
1800
- new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1801
-
1802
- if self.config.block_group_size > 1:
1803
- # Group the state dict blocks into the right block size.
1804
- for key in list(state_dict.keys()):
1805
- if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
1806
- block_idx = int(m.group(1))
1807
- group_idx, group_block_idx = (
1808
- block_idx // self.config.block_group_size,
1809
- block_idx % self.config.block_group_size,
1810
- )
1811
- state_dict[
1812
- (
1813
- new_key := key.replace(
1814
- f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}."
1815
- )
1816
- )
1817
- ] = state_dict.pop(key)
1818
- new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1819
-
1820
- og_keys_to_new: Dict[str, Set[str]] = defaultdict(set)
1821
- for new_key, og_key in new_keys_to_og_keys.items():
1822
- og_keys_to_new[og_key].add(new_key)
1823
-
1824
- return state_dict, og_keys_to_new