daking commited on
Commit
38fcf74
1 Parent(s): 0261af7

LLM-foundry update October 17, 2023 23:04:30

Browse files
Files changed (12) hide show
  1. adapt_tokenizer.py +4 -5
  2. attention.py +131 -93
  3. blocks.py +20 -20
  4. configuration_mpt.py +32 -10
  5. custom_embedding.py +0 -1
  6. fc.py +7 -0
  7. ffn.py +39 -0
  8. hf_prefixlm_converter.py +16 -11
  9. meta_init_context.py +17 -12
  10. modeling_mpt.py +55 -51
  11. norm.py +10 -9
  12. param_init_fns.py +49 -51
adapt_tokenizer.py CHANGED
@@ -1,9 +1,8 @@
1
- from typing import Union
2
- from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3
- Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
4
  NUM_SENTINEL_TOKENS: int = 100
5
 
6
- def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
7
  """Adds sentinel tokens and padding token (if missing).
8
 
9
  Expands the tokenizer vocabulary to include sentinel tokens
@@ -34,7 +33,7 @@ class AutoTokenizerForMOD(AutoTokenizer):
34
  """
35
 
36
  @classmethod
37
- def from_pretrained(cls, *args, **kwargs):
38
  """See `AutoTokenizer.from_pretrained` docstring."""
39
  tokenizer = super().from_pretrained(*args, **kwargs)
40
  adapt_tokenizer_for_denoising(tokenizer)
 
1
+ from typing import Any
2
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase
 
3
  NUM_SENTINEL_TOKENS: int = 100
4
 
5
+ def adapt_tokenizer_for_denoising(tokenizer: PreTrainedTokenizerBase) -> None:
6
  """Adds sentinel tokens and padding token (if missing).
7
 
8
  Expands the tokenizer vocabulary to include sentinel tokens
 
33
  """
34
 
35
  @classmethod
36
+ def from_pretrained(cls, *args: Any, **kwargs: Any) -> PreTrainedTokenizerBase:
37
  """See `AutoTokenizer.from_pretrained` docstring."""
38
  tokenizer = super().from_pretrained(*args, **kwargs)
39
  adapt_tokenizer_for_denoising(tokenizer)
attention.py CHANGED
@@ -1,15 +1,30 @@
1
  """Attention layers."""
2
  import math
3
  import warnings
4
- from typing import Optional
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
  from packaging import version
9
  from torch import nn
10
- from .norm import LPLayerNorm
 
11
 
12
- def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  if original_is_causal and num_query_tokens != num_key_tokens:
14
  if num_query_tokens != 1:
15
  raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
@@ -17,9 +32,27 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau
17
  return False
18
  return original_is_causal
19
 
20
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
- kv_n_heads = 1 if multiquery else n_heads
23
  k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
24
  v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
25
  if past_key_value is not None:
@@ -29,6 +62,9 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_
29
  past_key_value = (k, v)
30
  (b, _, s_q, d) = q.shape
31
  s_k = k.size(-1)
 
 
 
32
  if softmax_scale is None:
33
  softmax_scale = 1 / math.sqrt(d)
34
  attn_weight = q.matmul(k) * softmax_scale
@@ -42,11 +78,11 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_
42
  min_val = torch.finfo(q.dtype).min
43
  if key_padding_mask is not None:
44
  if attn_bias is not None:
45
- warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
46
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
47
  if is_causal and (not q.size(2) == 1):
48
  s = max(s_q, s_k)
49
- causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
50
  causal_mask = causal_mask.tril()
51
  causal_mask = causal_mask.to(torch.bool)
52
  causal_mask = ~causal_mask
@@ -61,19 +97,27 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_
61
  return (out, attn_weight, past_key_value)
62
  return (out, None, past_key_value)
63
 
64
- def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
 
 
65
  for tensor in tensors:
66
  if tensor.dtype not in valid_dtypes:
67
  raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
68
  if not tensor.is_cuda:
69
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
70
 
71
- def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
72
  try:
73
  from flash_attn import bert_padding, flash_attn_interface
74
  except:
75
- raise RuntimeError('Please install flash-attn==1.0.3.post0')
76
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
77
  if past_key_value is not None:
78
  if len(past_key_value) != 0:
79
  key = torch.cat([past_key_value[0], key], dim=1)
@@ -92,19 +136,27 @@ def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale
92
  (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
93
  query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
94
  (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
95
- key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
96
  (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
97
- value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
98
- if multiquery:
99
  key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
100
  value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
 
 
 
101
  dropout_p = dropout_p if training else 0.0
102
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
103
- output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
 
 
 
 
 
104
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
105
  return (output, None, past_key_value)
106
 
107
- def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
108
  try:
109
  from .flash_attn_triton import flash_attn_func
110
  except:
@@ -116,8 +168,14 @@ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softma
116
  except:
117
  _installed = False
118
  if not _installed:
119
- raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
120
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
121
  if past_key_value is not None:
122
  if len(past_key_value) != 0:
123
  key = torch.cat([past_key_value[0], key], dim=1)
@@ -129,6 +187,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softma
129
  attn_bias = attn_bias[:, :, _s_q:, _s_k:]
130
  if dropout_p:
131
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
 
132
  if needs_weights:
133
  raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
134
  if key_padding_mask is not None:
@@ -138,124 +197,103 @@ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softma
138
  attn_bias = query.new_zeros(b_size, 1, 1, s_k)
139
  attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
140
  query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
141
- key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
142
- value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
143
- if multiquery:
144
- key = key.expand(*key.shape[:2], n_heads, key.size(-1))
145
- value = value.expand(*value.shape[:2], n_heads, value.size(-1))
 
 
 
146
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
147
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
148
  output = attn_output.view(*attn_output.shape[:2], -1)
149
  return (output, None, past_key_value)
150
 
151
- class MultiheadAttention(nn.Module):
152
- """Multi-head self attention.
153
 
154
- Using torch or triton attention implemetation enables user to also use
155
- additive bias.
 
 
 
156
  """
157
 
158
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
159
  super().__init__()
160
  self.attn_impl = attn_impl
161
  self.clip_qkv = clip_qkv
162
  self.qk_ln = qk_ln
163
  self.d_model = d_model
164
  self.n_heads = n_heads
 
 
 
 
 
 
 
 
165
  self.softmax_scale = softmax_scale
166
  if self.softmax_scale is None:
167
  self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
168
  self.attn_dropout_p = attn_pdrop
169
- self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
170
- fuse_splits = (d_model, 2 * d_model)
 
 
 
171
  self.Wqkv._fused = (0, fuse_splits)
172
  if self.qk_ln:
173
- layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
174
- self.q_ln = layernorm_class(self.d_model, device=device)
175
- self.k_ln = layernorm_class(self.d_model, device=device)
176
  if self.attn_impl == 'flash':
177
  self.attn_fn = flash_attn_fn
178
  elif self.attn_impl == 'triton':
179
  self.attn_fn = triton_flash_attn_fn
180
- if verbose:
181
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
182
  elif self.attn_impl == 'torch':
183
  self.attn_fn = scaled_multihead_dot_product_attention
184
- if torch.cuda.is_available() and verbose:
185
- warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
186
  else:
187
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
188
- self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
189
  self.out_proj._is_residual = True
190
 
191
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
192
  qkv = self.Wqkv(x)
193
  if self.clip_qkv:
194
- qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
195
- (query, key, value) = qkv.chunk(3, dim=2)
196
  key_padding_mask = attention_mask
197
  if self.qk_ln:
198
  dtype = query.dtype
199
  query = self.q_ln(query).to(dtype)
200
  key = self.k_ln(key).to(dtype)
201
- (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
202
  return (self.out_proj(context), attn_weights, past_key_value)
203
 
204
- class MultiQueryAttention(nn.Module):
205
- """Multi-Query self attention.
206
 
207
- Using torch or triton attention implemetation enables user to also use
208
  additive bias.
209
  """
210
 
211
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
212
- super().__init__()
213
- self.attn_impl = attn_impl
214
- self.clip_qkv = clip_qkv
215
- self.qk_ln = qk_ln
216
- self.d_model = d_model
217
- self.n_heads = n_heads
218
- self.head_dim = d_model // n_heads
219
- self.softmax_scale = softmax_scale
220
- if self.softmax_scale is None:
221
- self.softmax_scale = 1 / math.sqrt(self.head_dim)
222
- self.attn_dropout_p = attn_pdrop
223
- self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
224
- fuse_splits = (d_model, d_model + self.head_dim)
225
- self.Wqkv._fused = (0, fuse_splits)
226
- if self.qk_ln:
227
- layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
228
- self.q_ln = layernorm_class(d_model, device=device)
229
- self.k_ln = layernorm_class(self.head_dim, device=device)
230
- if self.attn_impl == 'flash':
231
- self.attn_fn = flash_attn_fn
232
- elif self.attn_impl == 'triton':
233
- self.attn_fn = triton_flash_attn_fn
234
- if verbose:
235
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
236
- elif self.attn_impl == 'torch':
237
- self.attn_fn = scaled_multihead_dot_product_attention
238
- if torch.cuda.is_available() and verbose:
239
- warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
240
- else:
241
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
242
- self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
243
- self.out_proj._is_residual = True
244
 
245
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
246
- qkv = self.Wqkv(x)
247
- if self.clip_qkv:
248
- qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
249
- (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
250
- key_padding_mask = attention_mask
251
- if self.qk_ln:
252
- dtype = query.dtype
253
- query = self.q_ln(query).to(dtype)
254
- key = self.k_ln(key).to(dtype)
255
- (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
256
- return (self.out_proj(context), attn_weights, past_key_value)
257
 
258
- def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
259
  if attn_impl == 'flash':
260
  return None
261
  elif attn_impl in ['torch', 'triton']:
@@ -269,7 +307,7 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_s
269
  else:
270
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
271
 
272
- def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
273
  if attn_impl == 'flash':
274
  return None
275
  elif attn_impl in ['torch', 'triton']:
@@ -280,7 +318,7 @@ def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=
280
  else:
281
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
282
 
283
- def gen_slopes(n_heads, alibi_bias_max=8, device=None):
284
  _n_heads = 2 ** math.ceil(math.log2(n_heads))
285
  m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
286
  m = m.mul(alibi_bias_max / _n_heads)
@@ -289,7 +327,7 @@ def gen_slopes(n_heads, alibi_bias_max=8, device=None):
289
  slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
290
  return slopes.view(1, n_heads, 1, 1)
291
 
292
- def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
293
  alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
294
  if full:
295
  alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
@@ -297,4 +335,4 @@ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None
297
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
298
  alibi_bias = alibi_bias * slopes
299
  return alibi_bias.to(dtype=dtype)
300
- ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
 
1
  """Attention layers."""
2
  import math
3
  import warnings
4
+ from typing import Any, List, Optional, Tuple
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
  from packaging import version
9
  from torch import nn
10
+ from .fc import FC_CLASS_REGISTRY
11
+ from .norm import NORM_CLASS_REGISTRY
12
 
13
+ def is_flash_v2_installed():
14
+ try:
15
+ import flash_attn as flash_attn
16
+ except:
17
+ return False
18
+ return version.parse(flash_attn.__version__) >= version.parse('2.0.0')
19
+
20
+ def is_flash_v1_installed():
21
+ try:
22
+ import flash_attn as flash_attn
23
+ except:
24
+ return False
25
+ return version.parse(flash_attn.__version__) < version.parse('2.0.0')
26
+
27
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
28
  if original_is_causal and num_query_tokens != num_key_tokens:
29
  if num_query_tokens != 1:
30
  raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
 
32
  return False
33
  return original_is_causal
34
 
35
+ def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
36
+ """Perform repeat of kv heads along a particular dimension.
37
+
38
+ hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
39
+ n_rep: amount of repetitions of kv_n_heads
40
+ Unlike torch.repeat_interleave, this function avoids allocating new memory.
41
+ """
42
+ if n_rep == 1:
43
+ return hidden
44
+ (b, s, kv_n_heads, d) = hidden.shape
45
+ hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
46
+ return hidden.reshape(b, s, kv_n_heads * n_rep, d)
47
+
48
+ def scaled_multihead_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
49
+ if multiquery:
50
+ warnings.warn(DeprecationWarning('The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'))
51
+ kv_n_heads = 1
52
+ elif kv_n_heads is None:
53
+ warnings.warn(DeprecationWarning('Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'))
54
+ kv_n_heads = n_heads
55
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
 
56
  k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
57
  v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
58
  if past_key_value is not None:
 
62
  past_key_value = (k, v)
63
  (b, _, s_q, d) = q.shape
64
  s_k = k.size(-1)
65
+ if kv_n_heads > 1 and kv_n_heads < n_heads:
66
+ k = repeat_kv_for_gqa(k.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
67
+ v = repeat_kv_for_gqa(v.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
68
  if softmax_scale is None:
69
  softmax_scale = 1 / math.sqrt(d)
70
  attn_weight = q.matmul(k) * softmax_scale
 
78
  min_val = torch.finfo(q.dtype).min
79
  if key_padding_mask is not None:
80
  if attn_bias is not None:
81
+ warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
82
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
83
  if is_causal and (not q.size(2) == 1):
84
  s = max(s_q, s_k)
85
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
86
  causal_mask = causal_mask.tril()
87
  causal_mask = causal_mask.to(torch.bool)
88
  causal_mask = ~causal_mask
 
97
  return (out, attn_weight, past_key_value)
98
  return (out, None, past_key_value)
99
 
100
+ def check_valid_inputs(*tensors: torch.Tensor, valid_dtypes: Optional[List[torch.dtype]]=None):
101
+ if valid_dtypes is None:
102
+ valid_dtypes = [torch.float16, torch.bfloat16]
103
  for tensor in tensors:
104
  if tensor.dtype not in valid_dtypes:
105
  raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
106
  if not tensor.is_cuda:
107
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
108
 
109
+ def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
110
  try:
111
  from flash_attn import bert_padding, flash_attn_interface
112
  except:
113
+ raise RuntimeError('Please install flash-attn==1.0.9 or flash-attn==2.3.2')
114
  check_valid_inputs(query, key, value)
115
+ if multiquery:
116
+ warnings.warn(DeprecationWarning('The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'))
117
+ kv_n_heads = 1
118
+ elif kv_n_heads is None:
119
+ warnings.warn(DeprecationWarning('Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'))
120
+ kv_n_heads = n_heads
121
  if past_key_value is not None:
122
  if len(past_key_value) != 0:
123
  key = torch.cat([past_key_value[0], key], dim=1)
 
136
  (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
137
  query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
138
  (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
139
+ key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
140
  (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
141
+ value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
142
+ if kv_n_heads == 1:
143
  key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
144
  value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
145
+ elif kv_n_heads < n_heads:
146
+ key_unpad = repeat_kv_for_gqa(key_unpad.view(batch_size, seqlen, kv_n_heads, -1), n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
147
+ value_unpad = repeat_kv_for_gqa(value_unpad.view(batch_size, seqlen, kv_n_heads, -1), n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
148
  dropout_p = dropout_p if training else 0.0
149
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
150
+ if is_flash_v1_installed():
151
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
152
+ elif is_flash_v2_installed():
153
+ output_unpad = flash_attn_interface.flash_attn_varlen_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
154
+ else:
155
+ raise RuntimeError('flash-attn==1.0.9 or flash-attn==2.3.2 is required.')
156
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
157
  return (output, None, past_key_value)
158
 
159
+ def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
160
  try:
161
  from .flash_attn_triton import flash_attn_func
162
  except:
 
168
  except:
169
  _installed = False
170
  if not _installed:
171
+ raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU ' + 'and `pip install .[gpu]` if installing from llm-foundry source or ' + '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` ' + 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). ' + 'Note: (1) requires you have CMake and PyTorch already installed.')
172
  check_valid_inputs(query, key, value)
173
+ if multiquery:
174
+ warnings.warn(DeprecationWarning('The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'))
175
+ kv_n_heads = 1
176
+ elif kv_n_heads is None:
177
+ warnings.warn(DeprecationWarning('Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'))
178
+ kv_n_heads = n_heads
179
  if past_key_value is not None:
180
  if len(past_key_value) != 0:
181
  key = torch.cat([past_key_value[0], key], dim=1)
 
187
  attn_bias = attn_bias[:, :, _s_q:, _s_k:]
188
  if dropout_p:
189
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
190
+ dropout_p = dropout_p if training else 0.0
191
  if needs_weights:
192
  raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
193
  if key_padding_mask is not None:
 
197
  attn_bias = query.new_zeros(b_size, 1, 1, s_k)
198
  attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
199
  query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
200
+ key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
201
+ value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
202
+ if kv_n_heads == 1:
203
+ key = key.repeat(1, 1, n_heads, 1)
204
+ value = value.repeat(1, 1, n_heads, 1)
205
+ elif kv_n_heads < n_heads:
206
+ key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
207
+ value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
208
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
209
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
210
  output = attn_output.view(*attn_output.shape[:2], -1)
211
  return (output, None, past_key_value)
212
 
213
+ class GroupedQueryAttention(nn.Module):
214
+ """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
215
 
216
+ and Multi-query attention (MQA).
217
+
218
+ This allows the user to set a variable of number of kv_n_heads, rather than
219
+ just n_heads or 1, as in MHA and MQA. Using torch or triton attention
220
+ implementation enables user to also use additive bias.
221
  """
222
 
223
+ def __init__(self, d_model: int, n_heads: int, kv_n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True):
224
  super().__init__()
225
  self.attn_impl = attn_impl
226
  self.clip_qkv = clip_qkv
227
  self.qk_ln = qk_ln
228
  self.d_model = d_model
229
  self.n_heads = n_heads
230
+ self.kv_n_heads = kv_n_heads
231
+ self.head_dim = d_model // n_heads
232
+ if self.kv_n_heads <= 0:
233
+ raise ValueError('kv_n_heads should be greater than zero.')
234
+ if self.kv_n_heads > self.n_heads:
235
+ raise ValueError('The number of KV heads should be less than or equal to Q heads.')
236
+ if self.n_heads % self.kv_n_heads != 0:
237
+ raise ValueError('Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.')
238
  self.softmax_scale = softmax_scale
239
  if self.softmax_scale is None:
240
  self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
241
  self.attn_dropout_p = attn_pdrop
242
+ fc_kwargs: dict[str, Any] = {'bias': bias}
243
+ if fc_type != 'te':
244
+ fc_kwargs['device'] = device
245
+ self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs)
246
+ fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
247
  self.Wqkv._fused = (0, fuse_splits)
248
  if self.qk_ln:
249
+ norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
250
+ self.q_ln = norm_class(self.d_model, device=device)
251
+ self.k_ln = norm_class(self.kv_n_heads * self.head_dim, device=device)
252
  if self.attn_impl == 'flash':
253
  self.attn_fn = flash_attn_fn
254
  elif self.attn_impl == 'triton':
255
  self.attn_fn = triton_flash_attn_fn
 
 
256
  elif self.attn_impl == 'torch':
257
  self.attn_fn = scaled_multihead_dot_product_attention
 
 
258
  else:
259
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
260
+ self.out_proj = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model, **fc_kwargs)
261
  self.out_proj._is_residual = True
262
 
263
+ def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
264
  qkv = self.Wqkv(x)
265
  if self.clip_qkv:
266
+ qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
267
+ (query, key, value) = qkv.split([self.d_model, self.kv_n_heads * self.head_dim, self.kv_n_heads * self.head_dim], dim=2)
268
  key_padding_mask = attention_mask
269
  if self.qk_ln:
270
  dtype = query.dtype
271
  query = self.q_ln(query).to(dtype)
272
  key = self.k_ln(key).to(dtype)
273
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, self.kv_n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
274
  return (self.out_proj(context), attn_weights, past_key_value)
275
 
276
+ class MultiheadAttention(GroupedQueryAttention):
277
+ """Multi-head self attention.
278
 
279
+ Using torch or triton attention implementation enables user to also use
280
  additive bias.
281
  """
282
 
283
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True):
284
+ super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ class MultiQueryAttention(GroupedQueryAttention):
287
+ """Multi-Query self attention.
288
+
289
+ Using torch or triton attention implementation enables user to also use
290
+ additive bias.
291
+ """
292
+
293
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True):
294
+ super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias)
 
 
 
295
 
296
+ def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
297
  if attn_impl == 'flash':
298
  return None
299
  elif attn_impl in ['torch', 'triton']:
 
307
  else:
308
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
309
 
310
+ def build_attn_bias(attn_impl: str, attn_bias: torch.Tensor, n_heads: int, seq_len: int, causal: bool=False, alibi: bool=False, alibi_bias_max: int=8) -> Optional[torch.Tensor]:
311
  if attn_impl == 'flash':
312
  return None
313
  elif attn_impl in ['torch', 'triton']:
 
318
  else:
319
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
320
 
321
+ def gen_slopes(n_heads: int, alibi_bias_max: int=8, device: Optional[torch.device]=None) -> torch.Tensor:
322
  _n_heads = 2 ** math.ceil(math.log2(n_heads))
323
  m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
324
  m = m.mul(alibi_bias_max / _n_heads)
 
327
  slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
328
  return slopes.view(1, n_heads, 1, 1)
329
 
330
+ def build_alibi_bias(n_heads: int, seq_len: int, full: bool=False, alibi_bias_max: int=8, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None) -> torch.Tensor:
331
  alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
332
  if full:
333
  alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
 
335
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
336
  alibi_bias = alibi_bias * slopes
337
  return alibi_bias.to(dtype=dtype)
338
+ ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention, 'grouped_query_attention': GroupedQueryAttention}
blocks.py CHANGED
@@ -1,41 +1,41 @@
1
  """GPT Blocks used for the GPT Model."""
2
- from typing import Dict, Optional, Tuple
3
  import torch
4
  import torch.nn as nn
5
  from .attention import ATTN_CLASS_REGISTRY
 
6
  from .norm import NORM_CLASS_REGISTRY
7
 
8
- class MPTMLP(nn.Module):
9
-
10
- def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
11
- super().__init__()
12
- self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13
- self.act = nn.GELU(approximate='none')
14
- self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15
- self.down_proj._is_residual = True
16
-
17
- def forward(self, x):
18
- return self.down_proj(self.act(self.up_proj(x)))
19
-
20
  class MPTBlock(nn.Module):
21
 
22
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
 
 
 
 
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
 
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
 
 
27
  self.norm_1 = norm_class(d_model, device=device)
28
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29
- self.norm_2 = norm_class(d_model, device=device)
30
- self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
 
 
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33
 
34
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
- (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
- m = self.norm_2(x)
 
 
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
  return (x, attn_weights, past_key_value)
 
1
  """GPT Blocks used for the GPT Model."""
2
+ from typing import Any, Dict, Optional, Tuple
3
  import torch
4
  import torch.nn as nn
5
  from .attention import ATTN_CLASS_REGISTRY
6
+ from .ffn import FFN_CLASS_REGISTRY, build_ffn
7
  from .norm import NORM_CLASS_REGISTRY
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class MPTBlock(nn.Module):
10
 
11
+ def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict]=None, ffn_config: Optional[Dict]=None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, no_bias: bool=False, **kwargs: Any):
12
+ if attn_config is None:
13
+ attn_config = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
14
+ if ffn_config is None:
15
+ ffn_config = {'ffn_type': 'mptmlp'}
16
  del kwargs
17
  super().__init__()
18
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
19
+ assert isinstance(attn_config['attn_type'], str)
20
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
21
+ args_to_exclude_in_attn_class = {'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max'}
22
+ attn_config_subset_for_attn_class = {k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class}
23
  self.norm_1 = norm_class(d_model, device=device)
24
+ self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class, bias=not no_bias)
25
+ self.norm_2 = None
26
+ if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False):
27
+ self.norm_2 = norm_class(d_model, device=device)
28
+ self.ffn = build_ffn(d_model=d_model, expansion_ratio=expansion_ratio, device=device, bias=not no_bias, **ffn_config)
29
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
30
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
31
 
32
+ def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
33
  a = self.norm_1(x)
34
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions)
35
  x = x + self.resid_attn_dropout(b)
36
+ m = x
37
+ if self.norm_2 is not None:
38
+ m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
  return (x, attn_weights, past_key_value)
configuration_mpt.py CHANGED
@@ -1,27 +1,29 @@
1
  """A HuggingFace-style model configuration."""
2
- from typing import Dict, Optional, Union
 
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
 
5
  init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
9
 
10
- def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
11
  """The MPT configuration class.
12
 
13
  Args:
14
  d_model (int): The size of the embedding dimension of the model.
15
  n_heads (int): The number of attention heads.
16
  n_layers (int): The number of layers in the model.
17
- expansion_ratio (int): The ratio of the up/down scale in the MLP.
18
  max_seq_len (int): The maximum sequence length of the model.
19
  vocab_size (int): The size of the vocabulary.
20
  resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
21
  emb_pdrop (float): The dropout probability for the embedding layer.
22
  learned_pos_emb (bool): Whether to use learned positional embeddings
23
- attn_config (Dict): A dictionary used to configure the model's attention module:
24
- attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
25
  attn_pdrop (float): The dropout probability for the attention layers.
26
  attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
27
  qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
@@ -38,13 +40,15 @@ class MPTConfig(PretrainedConfig):
38
  Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
39
  alibi (bool): Whether to use the alibi bias instead of position embeddings.
40
  alibi_bias_max (int): The maximum value of the alibi bias.
 
 
 
41
  init_device (str): The device to use for parameter initialization.
42
  logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
43
  no_bias (bool): Whether to use bias in all layers.
44
  verbose (int): The verbosity level. 0 is silent.
45
  embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
46
  norm_type (str): choose type of norm to use
47
- multiquery_attention (bool): Whether to use multiquery attention implementation.
48
  use_cache (bool): Whether or not the model should return the last key/values attentions
49
  init_config (Dict): A dictionary used to configure the model initialization:
50
  init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
@@ -61,6 +65,7 @@ class MPTConfig(PretrainedConfig):
61
  init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
62
  ---
63
  See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
 
64
  """
65
  self.d_model = d_model
66
  self.n_heads = n_heads
@@ -72,29 +77,36 @@ class MPTConfig(PretrainedConfig):
72
  self.emb_pdrop = emb_pdrop
73
  self.learned_pos_emb = learned_pos_emb
74
  self.attn_config = attn_config
 
75
  self.init_device = init_device
76
  self.logit_scale = logit_scale
77
  self.no_bias = no_bias
78
- self.verbose = verbose
79
  self.embedding_fraction = embedding_fraction
80
  self.norm_type = norm_type
81
  self.use_cache = use_cache
82
  self.init_config = init_config
 
 
 
83
  if 'name' in kwargs:
84
  del kwargs['name']
85
  if 'loss_fn' in kwargs:
86
  del kwargs['loss_fn']
 
 
 
87
  super().__init__(**kwargs)
88
  self._validate_config()
89
 
90
- def _set_config_defaults(self, config, config_defaults):
91
  for (k, v) in config_defaults.items():
92
  if k not in config:
93
  config[k] = v
94
  return config
95
 
96
- def _validate_config(self):
97
  self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
 
98
  self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
99
  if self.d_model % self.n_heads != 0:
100
  raise ValueError('d_model must be divisible by n_heads')
@@ -115,4 +127,14 @@ class MPTConfig(PretrainedConfig):
115
  if self.init_config.get('name', None) is None:
116
  raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
117
  if not self.learned_pos_emb and (not self.attn_config['alibi']):
118
- raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')
 
 
 
 
 
 
 
 
 
 
 
1
  """A HuggingFace-style model configuration."""
2
+ import warnings
3
+ from typing import Any, Dict, Optional, Union
4
  from transformers import PretrainedConfig
5
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
6
+ ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
7
  init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
8
 
9
  class MPTConfig(PretrainedConfig):
10
  model_type = 'mpt'
11
 
12
+ def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, ffn_config: Dict=ffn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, fc_type: str='torch', verbose: Optional[int]=None, **kwargs: Any):
13
  """The MPT configuration class.
14
 
15
  Args:
16
  d_model (int): The size of the embedding dimension of the model.
17
  n_heads (int): The number of attention heads.
18
  n_layers (int): The number of layers in the model.
19
+ expansion_ratio (int): The ratio of the up/down scale in the ffn.
20
  max_seq_len (int): The maximum sequence length of the model.
21
  vocab_size (int): The size of the vocabulary.
22
  resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
23
  emb_pdrop (float): The dropout probability for the embedding layer.
24
  learned_pos_emb (bool): Whether to use learned positional embeddings
25
+ attn_config (Dict): A dictionary used to configure the model's attention module:
26
+ attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
27
  attn_pdrop (float): The dropout probability for the attention layers.
28
  attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
29
  qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
 
40
  Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
41
  alibi (bool): Whether to use the alibi bias instead of position embeddings.
42
  alibi_bias_max (int): The maximum value of the alibi bias.
43
+ kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
44
+ ffn_config (Dict): A dictionary used to configure the model's ffn module:
45
+ ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp
46
  init_device (str): The device to use for parameter initialization.
47
  logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
48
  no_bias (bool): Whether to use bias in all layers.
49
  verbose (int): The verbosity level. 0 is silent.
50
  embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
51
  norm_type (str): choose type of norm to use
 
52
  use_cache (bool): Whether or not the model should return the last key/values attentions
53
  init_config (Dict): A dictionary used to configure the model initialization:
54
  init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
 
65
  init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
66
  ---
67
  See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
68
+ fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
69
  """
70
  self.d_model = d_model
71
  self.n_heads = n_heads
 
77
  self.emb_pdrop = emb_pdrop
78
  self.learned_pos_emb = learned_pos_emb
79
  self.attn_config = attn_config
80
+ self.ffn_config = ffn_config
81
  self.init_device = init_device
82
  self.logit_scale = logit_scale
83
  self.no_bias = no_bias
 
84
  self.embedding_fraction = embedding_fraction
85
  self.norm_type = norm_type
86
  self.use_cache = use_cache
87
  self.init_config = init_config
88
+ self.fc_type = fc_type
89
+ if verbose is not None:
90
+ warnings.warn(DeprecationWarning('verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'))
91
  if 'name' in kwargs:
92
  del kwargs['name']
93
  if 'loss_fn' in kwargs:
94
  del kwargs['loss_fn']
95
+ if self.attn_config.get('alibi', False):
96
+ self.learned_pos_emb = False
97
+ warnings.warn(f'alibi is turned on, setting `learned_pos_emb` to `False.`')
98
  super().__init__(**kwargs)
99
  self._validate_config()
100
 
101
+ def _set_config_defaults(self, config: Dict[str, Any], config_defaults: Dict[str, Any]) -> Dict[str, Any]:
102
  for (k, v) in config_defaults.items():
103
  if k not in config:
104
  config[k] = v
105
  return config
106
 
107
+ def _validate_config(self) -> None:
108
  self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
109
+ self.ffn_config = self._set_config_defaults(self.ffn_config, ffn_config_defaults)
110
  self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
111
  if self.d_model % self.n_heads != 0:
112
  raise ValueError('d_model must be divisible by n_heads')
 
127
  if self.init_config.get('name', None) is None:
128
  raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
129
  if not self.learned_pos_emb and (not self.attn_config['alibi']):
130
+ warnings.warn(f'Positional information not being provided to the model using either learned_pos_emb or alibi.')
131
+ if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp':
132
+ try:
133
+ import transformer_engine.pytorch as te
134
+ del te
135
+ except:
136
+ raise ImportError('TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. ' + 'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n' + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156')
137
+ if self.ffn_config['ffn_type'] == 'mptmlp':
138
+ self.ffn_config['fc_type'] = self.fc_type
139
+ elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
140
+ self.ffn_config['bias'] = not self.no_bias
custom_embedding.py CHANGED
@@ -1,4 +1,3 @@
1
- import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from torch import Tensor
 
 
1
  import torch.nn as nn
2
  import torch.nn.functional as F
3
  from torch import Tensor
fc.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ FC_CLASS_REGISTRY = {'torch': nn.Linear}
3
+ try:
4
+ import transformer_engine.pytorch as te
5
+ FC_CLASS_REGISTRY['te'] = te.Linear
6
+ except:
7
+ pass
ffn.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPT Blocks used for the GPT Model."""
2
+ from typing import Any, Optional
3
+ import torch
4
+ import torch.nn as nn
5
+ from .fc import FC_CLASS_REGISTRY
6
+ try:
7
+ import transformer_engine.pytorch as te
8
+ except:
9
+ te = None
10
+
11
+ class MPTMLP(nn.Module):
12
+
13
+ def __init__(self, d_model: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None, bias: bool=True):
14
+ super().__init__()
15
+ fc_kwargs: dict[str, Any] = {'bias': bias}
16
+ if fc_type != 'te':
17
+ fc_kwargs['device'] = device
18
+ self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, expansion_ratio * d_model, **fc_kwargs)
19
+ self.act = nn.GELU(approximate='none')
20
+ self.down_proj = FC_CLASS_REGISTRY[fc_type](expansion_ratio * d_model, d_model, **fc_kwargs)
21
+ self.down_proj._is_residual = True
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ return self.down_proj(self.act(self.up_proj(x)))
25
+ FFN_CLASS_REGISTRY = {'mptmlp': MPTMLP}
26
+ if te is not None:
27
+ te.LayerNormMLP._has_norm = True
28
+ FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
29
+
30
+ def build_ffn(d_model: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None, bias: bool=True, **kwargs: Any) -> nn.Module:
31
+ ffn_type = kwargs.pop('ffn_type')
32
+ if ffn_type == 'mptmlp':
33
+ if len(kwargs) > 0:
34
+ raise ValueError(f'MPTMLP got an unexpected keyword argument: {kwargs}')
35
+ return MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, device=device, bias=bias)
36
+ elif ffn_type == 'te_ln_mlp':
37
+ assert te is not None
38
+ return te.LayerNormMLP(hidden_size=d_model, ffn_hidden_size=d_model * expansion_ratio, bias=bias, **kwargs)
39
+ raise ValueError(f'ffn_type={ffn_type!r} not recognized.')
hf_prefixlm_converter.py CHANGED
@@ -9,7 +9,7 @@ and treat the input prompt as the prefix in `generate`.
9
  import math
10
  import warnings
11
  from types import MethodType
12
- from typing import Any, Dict, List, Optional, Tuple, Union
13
  import torch
14
  from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
15
  from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
@@ -90,13 +90,14 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
90
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
91
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
92
  for attn_module in attn_modules:
 
93
  attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
94
  output = call_og_forward()
95
  for attn_module in attn_modules:
96
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
97
  return output
98
 
99
- def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
100
  """Wraps original generate to enable PrefixLM attention."""
101
  attn_modules = _get_attn_modules(model)
102
  for attn_module in attn_modules:
@@ -157,7 +158,7 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
157
  return alibi.to(dtype)
158
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
159
 
160
- def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
161
  if deprecated_arguments.pop('position_ids', False) is not False:
162
  warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
163
  if len(deprecated_arguments) > 0:
@@ -204,9 +205,9 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
204
  logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
205
  use_cache = False
206
 
207
- def create_custom_forward(module):
208
 
209
- def custom_forward(*inputs):
210
  return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
211
  return custom_forward
212
  outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
@@ -227,10 +228,10 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
227
  return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
228
  setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
229
  setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
230
- setattr(model.transformer, 'forward', MethodType(forward, model.transformer))
231
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
232
 
233
- def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
234
  """Replacement forward method for BloomCausalLM."""
235
  if deprecated_arguments.pop('position_ids', False) is not False:
236
  warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
@@ -252,7 +253,8 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
252
  return (loss,) + output if loss is not None else output
253
  return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
254
 
255
- def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict:
 
256
  if past:
257
  input_ids = input_ids[:, -1].unsqueeze(-1)
258
  bidirectional_mask = None
@@ -282,19 +284,22 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
282
  setattr(model, '_original_generate', getattr(model, 'generate'))
283
  model.model.decoder.bidirectional_mask = None
284
 
285
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
286
  combined_attention_mask = None
287
  if input_shape[-1] > 1:
 
288
  if self.bidirectional_mask == 'g':
289
  (bsz, src_length) = input_shape
290
  combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
291
  else:
292
  combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
293
  if self.bidirectional_mask is not None:
 
294
  assert attention_mask.shape == self.bidirectional_mask.shape
295
  expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
296
  combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
297
  if attention_mask is not None:
 
298
  expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
299
  combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
300
  return combined_attention_mask
@@ -315,7 +320,7 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
315
  self.model.decoder.bidirectional_mask = None
316
  return outputs
317
 
318
- def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
319
  """Wraps original generate to enable PrefixLM-style attention."""
320
  self.model.decoder.bidirectional_mask = 'g'
321
  try:
@@ -398,7 +403,7 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
398
  else:
399
  raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
400
 
401
- def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
402
  """Attempts to add bidirectional_mask to batch if missing.
403
 
404
  Raises:
 
9
  import math
10
  import warnings
11
  from types import MethodType
12
+ from typing import Any, List, MutableMapping, Optional, Tuple, Union
13
  import torch
14
  from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
15
  from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
 
90
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
91
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
92
  for attn_module in attn_modules:
93
+ assert isinstance(attn_module.bias, torch.Tensor)
94
  attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
95
  output = call_og_forward()
96
  for attn_module in attn_modules:
97
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
98
  return output
99
 
100
+ def generate(self: CAUSAL_GPT_TYPES, *args: Any, **kwargs: Any):
101
  """Wraps original generate to enable PrefixLM attention."""
102
  attn_modules = _get_attn_modules(model)
103
  for attn_module in attn_modules:
 
158
  return alibi.to(dtype)
159
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
160
 
161
+ def transformer_forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments: Any) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
162
  if deprecated_arguments.pop('position_ids', False) is not False:
163
  warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
164
  if len(deprecated_arguments) > 0:
 
205
  logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
206
  use_cache = False
207
 
208
+ def create_custom_forward(module: torch.nn.Module):
209
 
210
+ def custom_forward(*inputs: Any):
211
  return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
212
  return custom_forward
213
  outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
 
228
  return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
229
  setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
230
  setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
231
+ setattr(model.transformer, 'forward', MethodType(transformer_forward, model.transformer))
232
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
233
 
234
+ def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments: Any) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
235
  """Replacement forward method for BloomCausalLM."""
236
  if deprecated_arguments.pop('position_ids', False) is not False:
237
  warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
 
253
  return (loss,) + output if loss is not None else output
254
  return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
255
 
256
+ def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs: Any) -> dict:
257
+ del kwargs
258
  if past:
259
  input_ids = input_ids[:, -1].unsqueeze(-1)
260
  bidirectional_mask = None
 
284
  setattr(model, '_original_generate', getattr(model, 'generate'))
285
  model.model.decoder.bidirectional_mask = None
286
 
287
+ def _prepare_decoder_attention_mask(self: torch.nn.Module, attention_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], inputs_embeds: Optional[torch.Tensor], past_key_values_length: int):
288
  combined_attention_mask = None
289
  if input_shape[-1] > 1:
290
+ assert inputs_embeds is not None
291
  if self.bidirectional_mask == 'g':
292
  (bsz, src_length) = input_shape
293
  combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
294
  else:
295
  combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
296
  if self.bidirectional_mask is not None:
297
+ assert attention_mask is not None
298
  assert attention_mask.shape == self.bidirectional_mask.shape
299
  expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
300
  combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
301
  if attention_mask is not None:
302
+ assert inputs_embeds is not None
303
  expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
304
  combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
305
  return combined_attention_mask
 
320
  self.model.decoder.bidirectional_mask = None
321
  return outputs
322
 
323
+ def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Any):
324
  """Wraps original generate to enable PrefixLM-style attention."""
325
  self.model.decoder.bidirectional_mask = 'g'
326
  try:
 
403
  else:
404
  raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
405
 
406
+ def add_bidirectional_mask_if_missing(batch: MutableMapping):
407
  """Attempts to add bidirectional_mask to batch if missing.
408
 
409
  Raises:
meta_init_context.py CHANGED
@@ -1,4 +1,5 @@
1
  from contextlib import contextmanager
 
2
  import torch
3
  import torch.nn as nn
4
 
@@ -57,25 +58,29 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
57
  if include_buffers:
58
  old_register_buffer = nn.Module.register_buffer
59
 
60
- def register_empty_parameter(module, name, param):
61
- old_register_parameter(module, name, param)
62
  if param is not None:
63
- param_cls = type(module._parameters[name])
64
- kwargs = module._parameters[name].__dict__
65
- module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
66
-
67
- def register_empty_buffer(module, name, buffer):
68
- old_register_buffer(module, name, buffer)
69
- if buffer is not None:
70
- module._buffers[name] = module._buffers[name].to(device)
 
 
 
 
71
  if include_buffers:
72
  tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
73
  else:
74
  tensor_constructors_to_patch = {}
75
 
76
- def patch_tensor_constructor(fn):
77
 
78
- def wrapper(*args, **kwargs):
79
  kwargs['device'] = device
80
  return fn(*args, **kwargs)
81
  return wrapper
 
1
  from contextlib import contextmanager
2
+ from typing import Any, Callable, Optional
3
  import torch
4
  import torch.nn as nn
5
 
 
58
  if include_buffers:
59
  old_register_buffer = nn.Module.register_buffer
60
 
61
+ def register_empty_parameter(self: torch.nn.Module, name: str, param: Optional[torch.nn.Parameter]):
62
+ old_register_parameter(self, name, param)
63
  if param is not None:
64
+ parameter = self._parameters[name]
65
+ assert parameter is not None
66
+ param_cls = type(parameter)
67
+ kwargs = parameter.__dict__
68
+ self._parameters[name] = param_cls(parameter.to(device), **kwargs)
69
+
70
+ def register_empty_buffer(self: torch.nn.Module, name: str, tensor: Optional[torch.Tensor], persistent: bool=True):
71
+ old_register_buffer(self, name, tensor, persistent=persistent)
72
+ if tensor is not None:
73
+ named_buffer = self._buffers[name]
74
+ assert named_buffer is not None
75
+ self._buffers[name] = named_buffer.to(device)
76
  if include_buffers:
77
  tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
78
  else:
79
  tensor_constructors_to_patch = {}
80
 
81
+ def patch_tensor_constructor(fn: Callable):
82
 
83
+ def wrapper(*args: Any, **kwargs: Any):
84
  kwargs['device'] = device
85
  return fn(*args, **kwargs)
86
  return wrapper
modeling_mpt.py CHANGED
@@ -4,26 +4,31 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
  """
5
  import math
6
  import warnings
7
- from typing import List, Optional, Tuple, Union
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
  from .custom_embedding import SharedEmbedding
 
 
 
 
16
  from .norm import NORM_CLASS_REGISTRY
17
  from .configuration_mpt import MPTConfig
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
19
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
20
  from .meta_init_context import init_empty_weights
21
- from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
22
  try:
23
- from .flash_attn_triton import flash_attn_func
24
  except:
25
  pass
26
- Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
 
27
 
28
  class MPTPreTrainedModel(PreTrainedModel):
29
  config_class = MPTConfig
@@ -40,6 +45,7 @@ class MPTModel(MPTPreTrainedModel):
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
41
  self.alibi = config.attn_config['alibi']
42
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
 
43
  if config.init_device == 'mixed':
44
  if dist.get_local_rank() == 0:
45
  config.init_device = 'cpu'
@@ -51,13 +57,13 @@ class MPTModel(MPTPreTrainedModel):
51
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
52
  self.embedding_fraction = config.embedding_fraction
53
  self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
54
- if not self.alibi:
55
  self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
56
  self.emb_drop = nn.Dropout(config.emb_pdrop)
57
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
58
  self.norm_f = norm_class(config.d_model, device=config.init_device)
59
  if config.init_device != 'meta':
60
- print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
61
  self.apply(self.param_init_fn)
62
  self.is_causal = not self.prefix_lm
63
  self._attn_bias_initialized = False
@@ -66,25 +72,22 @@ class MPTModel(MPTPreTrainedModel):
66
  if config.no_bias:
67
  for module in self.modules():
68
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
69
- if config.verbose:
70
- warnings.warn(f'Removing bias ({module.bias}) from {module}.')
71
  module.register_parameter('bias', None)
72
- if config.verbose and config.verbose > 2:
73
- print(self)
74
- if 'verbose' not in self.config.init_config:
75
- self.config.init_config['verbose'] = self.config.verbose
76
- if self.config.init_config['verbose'] > 1:
77
- init_fn_name = self.config.init_config['name']
78
- warnings.warn(f'Using {init_fn_name} initialization.')
79
 
80
- def get_input_embeddings(self):
81
  return self.wte
82
 
83
- def set_input_embeddings(self, value):
84
  self.wte = value
85
 
86
  @torch.no_grad()
87
- def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):
88
  if not self._attn_bias_initialized:
89
  if self.attn_bias_shape:
90
  self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
@@ -115,7 +118,7 @@ class MPTModel(MPTPreTrainedModel):
115
  attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
116
  return (attn_bias, None)
117
 
118
- def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
119
  (s_k, s_q) = attn_bias.shape[-2:]
120
  if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
121
  raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
@@ -130,7 +133,7 @@ class MPTModel(MPTPreTrainedModel):
130
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
131
  return attn_bias
132
 
133
- def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
134
  seq_len = sequence_id.shape[-1]
135
  if seq_len > self.config.max_seq_len:
136
  raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
@@ -140,7 +143,7 @@ class MPTModel(MPTPreTrainedModel):
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
146
  if attention_mask is not None:
@@ -152,7 +155,7 @@ class MPTModel(MPTPreTrainedModel):
152
  if output_attentions:
153
  if self.attn_impl != 'torch':
154
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
155
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
@@ -166,9 +169,7 @@ class MPTModel(MPTPreTrainedModel):
166
  S = input_ids.size(1)
167
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
168
  tok_emb = self.wte(input_ids)
169
- if self.alibi:
170
- x = tok_emb
171
- else:
172
  past_position = 0
173
  if past_key_values is not None:
174
  if len(past_key_values) != self.config.n_layers:
@@ -177,12 +178,14 @@ class MPTModel(MPTPreTrainedModel):
177
  if self.attn_impl == 'torch':
178
  past_position = past_key_values[0][0].size(3)
179
  if S + past_position > self.config.max_seq_len:
180
- raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
181
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
182
  if attention_mask is not None:
183
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
184
  pos_emb = self.wpe(pos)
185
  x = tok_emb + pos_emb
 
 
186
  if self.embedding_fraction == 1:
187
  x = self.emb_drop(x)
188
  else:
@@ -190,6 +193,7 @@ class MPTModel(MPTPreTrainedModel):
190
  assert isinstance(self.emb_drop, nn.Module)
191
  x = self.emb_drop(x_shrunk)
192
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
 
193
  if use_cache and past_key_values is None:
194
  past_key_values = [() for _ in range(self.config.n_layers)]
195
  all_hidden_states = () if output_hidden_states else None
@@ -199,9 +203,9 @@ class MPTModel(MPTPreTrainedModel):
199
  assert all_hidden_states is not None
200
  all_hidden_states = all_hidden_states + (x,)
201
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
202
- (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
203
- if past_key_values is not None:
204
- past_key_values[b_idx] = past_key_value
205
  if output_attentions:
206
  assert all_self_attns is not None
207
  all_self_attns = all_self_attns + (attn_weights,)
@@ -209,16 +213,16 @@ class MPTModel(MPTPreTrainedModel):
209
  if output_hidden_states:
210
  assert all_hidden_states is not None
211
  all_hidden_states = all_hidden_states + (x,)
212
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
213
 
214
- def param_init_fn(self, module):
215
  init_fn_name = self.config.init_config['name']
216
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
217
 
218
- def fsdp_wrap_fn(self, module):
219
  return isinstance(module, MPTBlock)
220
 
221
- def activation_checkpointing_fn(self, module):
222
  return isinstance(module, MPTBlock)
223
 
224
  class MPTForCausalLM(MPTPreTrainedModel):
@@ -227,8 +231,8 @@ class MPTForCausalLM(MPTPreTrainedModel):
227
  super().__init__(config)
228
  if not config.tie_word_embeddings:
229
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
230
- print(f'Instantiating an MPTForCausalLM model from {__file__}')
231
- self.transformer = MPTModel(config)
232
  for child in self.transformer.children():
233
  if isinstance(child, torch.nn.ModuleList):
234
  continue
@@ -244,25 +248,25 @@ class MPTForCausalLM(MPTPreTrainedModel):
244
  raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
245
  self.logit_scale = logit_scale
246
 
247
- def get_input_embeddings(self):
248
  return self.transformer.wte
249
 
250
- def set_input_embeddings(self, value):
251
  self.transformer.wte = value
252
 
253
- def get_output_embeddings(self):
254
  return self.transformer.wte
255
 
256
- def set_output_embeddings(self, new_embeddings):
257
  self.transformer.wte = new_embeddings
258
 
259
- def set_decoder(self, decoder):
260
  self.transformer = decoder
261
 
262
- def get_decoder(self):
263
  return self.transformer
264
 
265
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None):
266
  return_dict = return_dict if return_dict is not None else self.config.return_dict
267
  use_cache = use_cache if use_cache is not None else self.config.use_cache
268
  if inputs_embeds is not None:
@@ -275,22 +279,22 @@ class MPTForCausalLM(MPTPreTrainedModel):
275
  logits *= self.logit_scale
276
  loss = None
277
  if labels is not None:
278
- labels = torch.roll(labels, shifts=-1)
279
- labels[:, -1] = -100
280
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
281
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
282
 
283
- def param_init_fn(self, module):
284
  init_fn_name = self.config.init_config['name']
285
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
286
 
287
- def fsdp_wrap_fn(self, module):
288
  return isinstance(module, MPTBlock)
289
 
290
- def activation_checkpointing_fn(self, module):
291
  return isinstance(module, MPTBlock)
292
 
293
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
294
  if inputs_embeds is not None:
295
  raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
296
  attention_mask = kwargs['attention_mask'].bool()
@@ -311,7 +315,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
311
  return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
312
 
313
  @staticmethod
314
- def _reorder_cache(past_key_values, beam_idx):
315
  """Used by HuggingFace generate when using beam search with kv-caching.
316
 
317
  See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
 
4
  """
5
  import math
6
  import warnings
7
+ from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
+ from transformers import PreTrainedModel, PreTrainedTokenizerBase
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
  from .custom_embedding import SharedEmbedding
16
+ from .fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
17
+ from .ffn import FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY
18
+ from .ffn import MPTMLP as MPTMLP
19
+ from .ffn import build_ffn as build_ffn
20
  from .norm import NORM_CLASS_REGISTRY
21
  from .configuration_mpt import MPTConfig
22
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
23
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
24
  from .meta_init_context import init_empty_weights
25
+ from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
26
  try:
27
+ from .flash_attn_triton import flash_attn_func as flash_attn_func
28
  except:
29
  pass
30
+ import logging
31
+ log = logging.getLogger(__name__)
32
 
33
  class MPTPreTrainedModel(PreTrainedModel):
34
  config_class = MPTConfig
 
45
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
46
  self.alibi = config.attn_config['alibi']
47
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
48
+ self.learned_pos_emb = config.learned_pos_emb
49
  if config.init_device == 'mixed':
50
  if dist.get_local_rank() == 0:
51
  config.init_device = 'cpu'
 
57
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
58
  self.embedding_fraction = config.embedding_fraction
59
  self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
60
+ if self.learned_pos_emb:
61
  self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
62
  self.emb_drop = nn.Dropout(config.emb_pdrop)
63
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
64
  self.norm_f = norm_class(config.d_model, device=config.init_device)
65
  if config.init_device != 'meta':
66
+ log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
67
  self.apply(self.param_init_fn)
68
  self.is_causal = not self.prefix_lm
69
  self._attn_bias_initialized = False
 
72
  if config.no_bias:
73
  for module in self.modules():
74
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
75
+ log.info(f'Removing bias ({module.bias}) from {module}.')
 
76
  module.register_parameter('bias', None)
77
+ if hasattr(module, 'use_bias'):
78
+ log.info(f'Setting use_bias=False for {module}.')
79
+ module.use_bias = False
80
+ log.debug(self)
81
+ log.debug(f"Using {self.config.init_config['name']} initialization.")
 
 
82
 
83
+ def get_input_embeddings(self) -> nn.Embedding:
84
  return self.wte
85
 
86
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
87
  self.wte = value
88
 
89
  @torch.no_grad()
90
+ def _attn_bias(self, device: torch.device, dtype: torch.dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
91
  if not self._attn_bias_initialized:
92
  if self.attn_bias_shape:
93
  self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
 
118
  attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
119
  return (attn_bias, None)
120
 
121
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor:
122
  (s_k, s_q) = attn_bias.shape[-2:]
123
  if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
124
  raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
 
133
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
134
  return attn_bias
135
 
136
+ def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor) -> torch.Tensor:
137
  seq_len = sequence_id.shape[-1]
138
  if seq_len > self.config.max_seq_len:
139
  raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
 
143
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
144
  return attn_bias
145
 
146
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
147
  return_dict = return_dict if return_dict is not None else self.config.return_dict
148
  use_cache = use_cache if use_cache is not None else self.config.use_cache
149
  if attention_mask is not None:
 
155
  if output_attentions:
156
  if self.attn_impl != 'torch':
157
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
158
+ if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
159
  raise NotImplementedError('MPT does not support training with left padding.')
160
  if self.prefix_lm and prefix_mask is None:
161
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
169
  S = input_ids.size(1)
170
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
171
  tok_emb = self.wte(input_ids)
172
+ if self.learned_pos_emb:
 
 
173
  past_position = 0
174
  if past_key_values is not None:
175
  if len(past_key_values) != self.config.n_layers:
 
178
  if self.attn_impl == 'torch':
179
  past_position = past_key_values[0][0].size(3)
180
  if S + past_position > self.config.max_seq_len:
181
+ raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
182
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
183
  if attention_mask is not None:
184
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
185
  pos_emb = self.wpe(pos)
186
  x = tok_emb + pos_emb
187
+ else:
188
+ x = tok_emb
189
  if self.embedding_fraction == 1:
190
  x = self.emb_drop(x)
191
  else:
 
193
  assert isinstance(self.emb_drop, nn.Module)
194
  x = self.emb_drop(x_shrunk)
195
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
196
+ presents = () if use_cache else None
197
  if use_cache and past_key_values is None:
198
  past_key_values = [() for _ in range(self.config.n_layers)]
199
  all_hidden_states = () if output_hidden_states else None
 
203
  assert all_hidden_states is not None
204
  all_hidden_states = all_hidden_states + (x,)
205
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
206
+ (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
207
+ if presents is not None:
208
+ presents += (present,)
209
  if output_attentions:
210
  assert all_self_attns is not None
211
  all_self_attns = all_self_attns + (attn_weights,)
 
213
  if output_hidden_states:
214
  assert all_hidden_states is not None
215
  all_hidden_states = all_hidden_states + (x,)
216
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attns)
217
 
218
+ def param_init_fn(self, module: nn.Module) -> None:
219
  init_fn_name = self.config.init_config['name']
220
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
221
 
222
+ def fsdp_wrap_fn(self, module: nn.Module) -> bool:
223
  return isinstance(module, MPTBlock)
224
 
225
+ def activation_checkpointing_fn(self, module: nn.Module) -> bool:
226
  return isinstance(module, MPTBlock)
227
 
228
  class MPTForCausalLM(MPTPreTrainedModel):
 
231
  super().__init__(config)
232
  if not config.tie_word_embeddings:
233
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
234
+ log.info(f'Instantiating an MPTForCausalLM model from {__file__}')
235
+ self.transformer: MPTModel = MPTModel(config)
236
  for child in self.transformer.children():
237
  if isinstance(child, torch.nn.ModuleList):
238
  continue
 
248
  raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
249
  self.logit_scale = logit_scale
250
 
251
+ def get_input_embeddings(self) -> nn.Embedding:
252
  return self.transformer.wte
253
 
254
+ def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
255
  self.transformer.wte = value
256
 
257
+ def get_output_embeddings(self) -> nn.Embedding:
258
  return self.transformer.wte
259
 
260
+ def set_output_embeddings(self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None:
261
  self.transformer.wte = new_embeddings
262
 
263
+ def set_decoder(self, decoder: MPTModel) -> None:
264
  self.transformer = decoder
265
 
266
+ def get_decoder(self) -> MPTModel:
267
  return self.transformer
268
 
269
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None) -> CausalLMOutputWithPast:
270
  return_dict = return_dict if return_dict is not None else self.config.return_dict
271
  use_cache = use_cache if use_cache is not None else self.config.use_cache
272
  if inputs_embeds is not None:
 
279
  logits *= self.logit_scale
280
  loss = None
281
  if labels is not None:
282
+ _labels = torch.roll(labels, shifts=-1)
283
+ _labels[:, -1] = -100
284
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
285
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
286
 
287
+ def param_init_fn(self, module: nn.Module) -> None:
288
  init_fn_name = self.config.init_config['name']
289
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
290
 
291
+ def fsdp_wrap_fn(self, module: nn.Module) -> bool:
292
  return isinstance(module, MPTBlock)
293
 
294
+ def activation_checkpointing_fn(self, module: nn.Module) -> bool:
295
  return isinstance(module, MPTBlock)
296
 
297
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]=None, inputs_embeds: Optional[torch.Tensor]=None, **kwargs: Any) -> Dict[str, Any]:
298
  if inputs_embeds is not None:
299
  raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
300
  attention_mask = kwargs['attention_mask'].bool()
 
315
  return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
316
 
317
  @staticmethod
318
+ def _reorder_cache(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], beam_idx: torch.LongTensor) -> List[Tuple[torch.Tensor, ...]]:
319
  """Used by HuggingFace generate when using beam search with kv-caching.
320
 
321
  See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
norm.py CHANGED
@@ -1,6 +1,7 @@
 
1
  import torch
2
 
3
- def _cast_if_autocast_enabled(tensor):
4
  if torch.is_autocast_enabled():
5
  if tensor.device.type == 'cuda':
6
  dtype = torch.get_autocast_gpu_dtype()
@@ -13,10 +14,10 @@ def _cast_if_autocast_enabled(tensor):
13
 
14
  class LPLayerNorm(torch.nn.LayerNorm):
15
 
16
- def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17
  super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
18
 
19
- def forward(self, x):
20
  module_device = x.device
21
  downcast_x = _cast_if_autocast_enabled(x)
22
  downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
@@ -24,7 +25,7 @@ class LPLayerNorm(torch.nn.LayerNorm):
24
  with torch.autocast(enabled=False, device_type=module_device.type):
25
  return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26
 
27
- def rms_norm(x, weight=None, eps=1e-05):
28
  output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29
  if weight is not None:
30
  return output * weight
@@ -32,7 +33,7 @@ def rms_norm(x, weight=None, eps=1e-05):
32
 
33
  class RMSNorm(torch.nn.Module):
34
 
35
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
36
  super().__init__()
37
  self.eps = eps
38
  if weight:
@@ -40,17 +41,17 @@ class RMSNorm(torch.nn.Module):
40
  else:
41
  self.register_parameter('weight', None)
42
 
43
- def forward(self, x):
44
  return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
45
 
46
  class LPRMSNorm(RMSNorm):
47
 
48
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
49
  super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
50
 
51
- def forward(self, x):
52
  downcast_x = _cast_if_autocast_enabled(x)
53
  downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
54
  with torch.autocast(enabled=False, device_type=x.device.type):
55
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56
- NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
 
1
+ from typing import Dict, List, Optional, Type, Union
2
  import torch
3
 
4
+ def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
5
  if torch.is_autocast_enabled():
6
  if tensor.device.type == 'cuda':
7
  dtype = torch.get_autocast_gpu_dtype()
 
14
 
15
  class LPLayerNorm(torch.nn.LayerNorm):
16
 
17
+ def __init__(self, normalized_shape: Union[int, List[int], torch.Size], eps: float=1e-05, elementwise_affine: bool=True, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
18
  super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
19
 
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
  module_device = x.device
22
  downcast_x = _cast_if_autocast_enabled(x)
23
  downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
 
25
  with torch.autocast(enabled=False, device_type=module_device.type):
26
  return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
27
 
28
+ def rms_norm(x: torch.Tensor, weight: Optional[torch.Tensor]=None, eps: float=1e-05) -> torch.Tensor:
29
  output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
30
  if weight is not None:
31
  return output * weight
 
33
 
34
  class RMSNorm(torch.nn.Module):
35
 
36
+ def __init__(self, normalized_shape: Union[int, List[int], torch.Size], eps: float=1e-05, weight: bool=True, dtype: Optional[torch.dtype]=None, device: Optional[torch.device]=None):
37
  super().__init__()
38
  self.eps = eps
39
  if weight:
 
41
  else:
42
  self.register_parameter('weight', None)
43
 
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
  return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
46
 
47
  class LPRMSNorm(RMSNorm):
48
 
49
+ def __init__(self, normalized_shape: Union[int, List[int], torch.Size], eps: float=1e-05, weight: bool=True, dtype: Optional[torch.dtype]=None, device: Optional[torch.device]=None):
50
  super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
51
 
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
  downcast_x = _cast_if_autocast_enabled(x)
54
  downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
55
  with torch.autocast(enabled=False, device_type=x.device.type):
56
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
57
+ NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
param_init_fns.py CHANGED
@@ -2,22 +2,26 @@ import math
2
  import warnings
3
  from collections.abc import Sequence
4
  from functools import partial
5
- from typing import Optional, Tuple, Union
6
  import torch
7
  from torch import nn
 
8
  from .norm import NORM_CLASS_REGISTRY
 
 
 
 
9
 
10
- def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
11
  del kwargs
12
- if verbose > 1:
13
- warnings.warn(f"Initializing network using module's reset_parameters attribute")
14
- if hasattr(module, 'reset_parameters'):
15
  module.reset_parameters()
16
 
17
- def fused_init_helper_(module: nn.Module, init_fn_):
18
  _fused = getattr(module, '_fused', None)
19
  if _fused is None:
20
  raise RuntimeError(f'Internal logic error')
 
21
  (dim, splits) = _fused
22
  splits = (0, *splits, module.weight.size(dim))
23
  for (s, e) in zip(splits[:-1], splits[1:]):
@@ -25,10 +29,8 @@ def fused_init_helper_(module: nn.Module, init_fn_):
25
  slice_indices[dim] = slice(s, e)
26
  init_fn_(module.weight[slice_indices])
27
 
28
- def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
29
  del kwargs
30
- if verbose > 1:
31
- warnings.warn(f'If model has bias parameters they are initialized to 0.')
32
  init_div_is_residual = init_div_is_residual
33
  if init_div_is_residual is False:
34
  div_is_residual = 1.0
@@ -36,20 +38,18 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
36
  div_is_residual = math.sqrt(2 * n_layers)
37
  elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
38
  div_is_residual = init_div_is_residual
39
- elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
40
  div_is_residual = float(init_div_is_residual)
41
  else:
42
  div_is_residual = 1.0
43
  raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
44
- if init_div_is_residual is not False:
45
- if verbose > 1:
46
- warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
47
- if isinstance(module, nn.Linear):
48
  if hasattr(module, '_fused'):
49
  fused_init_helper_(module, init_fn_)
50
  else:
51
  init_fn_(module.weight)
52
  if module.bias is not None:
 
53
  torch.nn.init.zeros_(module.bias)
54
  if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55
  with torch.no_grad():
@@ -60,8 +60,6 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
60
  if std == 0:
61
  warnings.warn(f'Embedding layer initialized to 0.')
62
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
63
- if verbose > 1:
64
- warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
65
  elif emb_init_uniform_lim is not None:
66
  lim = emb_init_uniform_lim
67
  if isinstance(lim, Sequence):
@@ -75,17 +73,13 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
75
  lim = [-lim, lim]
76
  (a, b) = lim
77
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
78
- if verbose > 1:
79
- warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
80
  else:
81
  emb_init_fn_ = init_fn_
82
  emb_init_fn_(module.weight)
83
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
84
- if verbose > 1:
85
- warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
86
- if hasattr(module, 'weight') and module.weight is not None:
87
  torch.nn.init.ones_(module.weight)
88
- if hasattr(module, 'bias') and module.bias is not None:
89
  torch.nn.init.zeros_(module.bias)
90
  elif isinstance(module, nn.MultiheadAttention):
91
  if module._qkv_same_embed_dim:
@@ -114,32 +108,45 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
114
  module.out_proj.weight.div_(div_is_residual)
115
  if module.out_proj.bias is not None:
116
  torch.nn.init.zeros_(module.out_proj.bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  else:
118
  for _ in module.parameters(recurse=False):
119
  raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
120
 
121
- def _normal_init_(std, mean=0.0):
122
  return partial(torch.nn.init.normal_, mean=mean, std=std)
123
 
124
- def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
125
  del kwargs
126
  init_fn_ = _normal_init_(std=std)
127
- if verbose > 1:
128
- warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
129
- generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
130
 
131
- def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
132
  del kwargs
133
  if init_std is None:
134
  raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
135
- _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
136
 
137
- def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
138
  del kwargs
139
  std = math.sqrt(2 / (5 * d_model))
140
- _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
141
 
142
- def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
143
  """From section 2.3.1 of GPT-NeoX-20B:
144
 
145
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
@@ -148,34 +155,25 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init
148
  """
149
  del kwargs
150
  residual_div = n_layers / math.sqrt(10)
151
- if verbose > 1:
152
- warnings.warn(f'setting init_div_is_residual to {residual_div}')
153
- small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
154
 
155
- def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
156
  del kwargs
157
- if verbose > 1:
158
- warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
159
  kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
160
- generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
161
 
162
- def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
163
  del kwargs
164
- if verbose > 1:
165
- warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
166
  kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
167
- generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
168
 
169
- def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
170
  del kwargs
171
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
172
- if verbose > 1:
173
- warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
174
- generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
175
 
176
- def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
 
177
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
- if verbose > 1:
179
- warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
180
- generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
181
  MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
 
2
  import warnings
3
  from collections.abc import Sequence
4
  from functools import partial
5
+ from typing import Any, Callable, Optional, Tuple, Union
6
  import torch
7
  from torch import nn
8
+ from .fc import FC_CLASS_REGISTRY
9
  from .norm import NORM_CLASS_REGISTRY
10
+ try:
11
+ import transformer_engine.pytorch as te
12
+ except:
13
+ te = None
14
 
15
+ def torch_default_param_init_fn_(module: nn.Module, **kwargs: Any) -> None:
16
  del kwargs
17
+ if hasattr(module, 'reset_parameters') and isinstance(module.reset_parameters, Callable):
 
 
18
  module.reset_parameters()
19
 
20
+ def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
21
  _fused = getattr(module, '_fused', None)
22
  if _fused is None:
23
  raise RuntimeError(f'Internal logic error')
24
+ assert isinstance(module.weight, torch.Tensor)
25
  (dim, splits) = _fused
26
  splits = (0, *splits, module.weight.size(dim))
27
  for (s, e) in zip(splits[:-1], splits[1:]):
 
29
  slice_indices[dim] = slice(s, e)
30
  init_fn_(module.weight[slice_indices])
31
 
32
+ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
33
  del kwargs
 
 
34
  init_div_is_residual = init_div_is_residual
35
  if init_div_is_residual is False:
36
  div_is_residual = 1.0
 
38
  div_is_residual = math.sqrt(2 * n_layers)
39
  elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
40
  div_is_residual = init_div_is_residual
41
+ elif init_div_is_residual.isnumeric():
42
  div_is_residual = float(init_div_is_residual)
43
  else:
44
  div_is_residual = 1.0
45
  raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
46
+ if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
 
 
 
47
  if hasattr(module, '_fused'):
48
  fused_init_helper_(module, init_fn_)
49
  else:
50
  init_fn_(module.weight)
51
  if module.bias is not None:
52
+ assert isinstance(module.bias, torch.Tensor)
53
  torch.nn.init.zeros_(module.bias)
54
  if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55
  with torch.no_grad():
 
60
  if std == 0:
61
  warnings.warn(f'Embedding layer initialized to 0.')
62
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
 
 
63
  elif emb_init_uniform_lim is not None:
64
  lim = emb_init_uniform_lim
65
  if isinstance(lim, Sequence):
 
73
  lim = [-lim, lim]
74
  (a, b) = lim
75
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
 
 
76
  else:
77
  emb_init_fn_ = init_fn_
78
  emb_init_fn_(module.weight)
79
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
80
+ if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
 
 
81
  torch.nn.init.ones_(module.weight)
82
+ if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor):
83
  torch.nn.init.zeros_(module.bias)
84
  elif isinstance(module, nn.MultiheadAttention):
85
  if module._qkv_same_embed_dim:
 
108
  module.out_proj.weight.div_(div_is_residual)
109
  if module.out_proj.bias is not None:
110
  torch.nn.init.zeros_(module.out_proj.bias)
111
+ elif te is not None and isinstance(module, te.LayerNormMLP):
112
+ if isinstance(module.layer_norm_weight, torch.Tensor):
113
+ torch.nn.init.ones_(module.layer_norm_weight)
114
+ if isinstance(module.layer_norm_bias, torch.Tensor):
115
+ torch.nn.init.zeros_(module.layer_norm_bias)
116
+ init_fn_(module.fc1_weight)
117
+ if module.fc1_bias is not None:
118
+ assert isinstance(module.fc1_bias, torch.Tensor)
119
+ torch.nn.init.zeros_(module.fc1_bias)
120
+ init_fn_(module.fc2_weight)
121
+ if module.fc2_bias is not None:
122
+ assert isinstance(module.fc2_bias, torch.Tensor)
123
+ torch.nn.init.zeros_(module.fc2_bias)
124
+ with torch.no_grad():
125
+ module.fc2_weight.div_(div_is_residual)
126
  else:
127
  for _ in module.parameters(recurse=False):
128
  raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
129
 
130
+ def _normal_init_(std: float, mean: float=0.0) -> Callable:
131
  return partial(torch.nn.init.normal_, mean=mean, std=std)
132
 
133
+ def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
134
  del kwargs
135
  init_fn_ = _normal_init_(std=std)
136
+ generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
137
 
138
+ def baseline_param_init_fn_(module: nn.Module, init_std: Optional[float], n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
139
  del kwargs
140
  if init_std is None:
141
  raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
142
+ _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
143
 
144
+ def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
145
  del kwargs
146
  std = math.sqrt(2 / (5 * d_model))
147
+ _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
148
 
149
+ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
150
  """From section 2.3.1 of GPT-NeoX-20B:
151
 
152
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
 
155
  """
156
  del kwargs
157
  residual_div = n_layers / math.sqrt(10)
158
+ small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
159
 
160
+ def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
161
  del kwargs
 
 
162
  kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
163
+ generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
164
 
165
+ def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
166
  del kwargs
 
 
167
  kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
168
+ generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
169
 
170
+ def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
171
  del kwargs
172
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
173
+ generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
174
 
175
+ def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
176
+ del kwargs
177
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
+ generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
179
  MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}