gugarosa commited on
Commit
3128bb6
1 Parent(s): 4a426d8

Support for `attention_mask` in forward pass.

Browse files

This commit implements the following:

- Cleans up unused arguments and definitions.
- Adds support for `attention_mask`.
- Adds support for cached inference.

README.md CHANGED
@@ -118,7 +118,7 @@ text = tokenizer.batch_decode(outputs)[0]
118
  print(text)
119
  ```
120
 
121
- **Remark.** In the generation function, our model currently does not support beam search (`num_beams` >1) and `attention_mask' parameters.
122
  Furthermore, in the forward pass of the model, we currently do not support outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
123
 
124
  ### Citation
 
118
  print(text)
119
  ```
120
 
121
+ **Remark.** In the generation function, our model currently does not support beam search (`num_beams` >1).
122
  Furthermore, in the forward pass of the model, we currently do not support outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
123
 
124
  ### Citation
config.json CHANGED
@@ -1,13 +1,6 @@
1
  {
2
  "_name_or_path": "phi-1.5-half",
3
  "activation_function": "gelu_new",
4
- "architecture": {
5
- "block_cls": "parallel",
6
- "mixer": {},
7
- "mlp": {
8
- "mlp_cls": "mlp"
9
- }
10
- },
11
  "architectures": [
12
  "MixFormerSequentialForCausalLM"
13
  ],
@@ -15,7 +8,6 @@
15
  "AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig",
16
  "AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
17
  },
18
- "embd_layer": "default",
19
  "embd_pdrop": 0.0,
20
  "initializer_range": 0.02,
21
  "layer_norm_epsilon": 1e-05,
@@ -25,7 +17,6 @@
25
  "n_inner": null,
26
  "n_layer": 24,
27
  "n_positions": 2048,
28
- "phyagi_version": "0.0.4.dev",
29
  "resid_pdrop": 0.0,
30
  "rotary_dim": 32,
31
  "tie_word_embeddings": false,
 
1
  {
2
  "_name_or_path": "phi-1.5-half",
3
  "activation_function": "gelu_new",
 
 
 
 
 
 
 
4
  "architectures": [
5
  "MixFormerSequentialForCausalLM"
6
  ],
 
8
  "AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig",
9
  "AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
10
  },
 
11
  "embd_pdrop": 0.0,
12
  "initializer_range": 0.02,
13
  "layer_norm_epsilon": 1e-05,
 
17
  "n_inner": null,
18
  "n_layer": 24,
19
  "n_positions": 2048,
 
20
  "resid_pdrop": 0.0,
21
  "rotary_dim": 32,
22
  "tie_word_embeddings": false,
configuration_mixformer_sequential.py CHANGED
@@ -17,8 +17,6 @@ class MixFormerSequentialConfig(PretrainedConfig):
17
  "hidden_size": "n_embd",
18
  "num_attention_heads": "n_head",
19
  "num_hidden_layers": "n_layer",
20
- "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
21
- "blocks": "architecture", # `blocks` key is for backward compatibility
22
  }
23
 
24
  def __init__(
@@ -31,8 +29,6 @@ class MixFormerSequentialConfig(PretrainedConfig):
31
  n_head: Optional[int] = 16,
32
  rotary_dim: Optional[int] = 32,
33
  activation_function: Optional[str] = "gelu_new",
34
- embd_layer: Optional[str] = "default",
35
- architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
36
  embd_pdrop: Optional[float] = 0.0,
37
  resid_pdrop: Optional[float] = 0.0,
38
  layer_norm_epsilon: Optional[float] = 1e-5,
@@ -49,8 +45,6 @@ class MixFormerSequentialConfig(PretrainedConfig):
49
  self.n_head = n_head
50
  self.rotary_dim = min(rotary_dim, n_embd // n_head)
51
  self.activation_function = activation_function
52
- self.embd_layer = embd_layer
53
- self.architecture = architecture
54
  self.embd_pdrop = embd_pdrop
55
  self.resid_pdrop = resid_pdrop
56
  self.layer_norm_epsilon = layer_norm_epsilon
 
17
  "hidden_size": "n_embd",
18
  "num_attention_heads": "n_head",
19
  "num_hidden_layers": "n_layer",
 
 
20
  }
21
 
22
  def __init__(
 
29
  n_head: Optional[int] = 16,
30
  rotary_dim: Optional[int] = 32,
31
  activation_function: Optional[str] = "gelu_new",
 
 
32
  embd_pdrop: Optional[float] = 0.0,
33
  resid_pdrop: Optional[float] = 0.0,
34
  layer_norm_epsilon: Optional[float] = 1e-5,
 
45
  self.n_head = n_head
46
  self.rotary_dim = min(rotary_dim, n_embd // n_head)
47
  self.activation_function = activation_function
 
 
48
  self.embd_pdrop = embd_pdrop
49
  self.resid_pdrop = resid_pdrop
50
  self.layer_norm_epsilon = layer_norm_epsilon
modeling_mixformer_sequential.py CHANGED
@@ -1,6 +1,6 @@
1
  # Copyright (c) Microsoft Corporation.
2
  # Licensed under the MIT license.
3
-
4
  # BSD 3-Clause License
5
  #
6
  # Copyright (c) 2022, Tri Dao, [email protected].
@@ -50,16 +50,38 @@ from .configuration_mixformer_sequential import MixFormerSequentialConfig
50
 
51
  @dataclass
52
  class InferenceParams:
53
- """Inference parameters that are passed to the main model in order
54
- to efficienly calculate and store the context during inference.
55
- Adapted from https://github.com/Dao-AILab/flash-attention."""
56
- max_sequence_len: int
57
- max_batch_size: int
58
- sequence_len_offset: int = 0
59
- batch_size_offset: int = 0
60
- key_value_memory_dict: dict = field(default_factory=dict)
61
- fused_ft_kernel: bool = False
62
- lengths_per_sample: Optional[torch.Tensor] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  class Embedding(nn.Module):
@@ -80,14 +102,19 @@ class Embedding(nn.Module):
80
 
81
  return hidden_states
82
 
 
83
  class RotaryEmbedding(nn.Module):
84
- """PyTorch implementation of `flash-attn` RotaryEmbedding layer.
85
- Adapted from https://github.com/Dao-AILab/flash-attention."""
 
 
 
 
86
 
87
  def __init__(
88
  self,
89
  dim: int,
90
- base: Optional[int] = 10000,
91
  scale_base: Optional[float] = None,
92
  device: Optional[str] = None,
93
  **kwargs,
@@ -119,7 +146,7 @@ class RotaryEmbedding(nn.Module):
119
  self._cos_k_cached = None
120
  self._sin_k_cached = None
121
 
122
- def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0) -> None:
123
  # Reset the tables if the sequence length has changed,
124
  # or if we're on a new device (possibly due to tracing for instance)
125
  seqlen = x.shape[1] + seqlen_offset
@@ -153,7 +180,7 @@ class RotaryEmbedding(nn.Module):
153
  self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
154
  self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
155
 
156
- def apply_rotary_emb_qkv(
157
  self,
158
  qkv: torch.FloatTensor,
159
  sin: torch.FloatTensor,
@@ -189,7 +216,6 @@ class RotaryEmbedding(nn.Module):
189
 
190
  # Computes the new keys and queries, recasting to original dtype
191
  q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
192
-
193
  k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
194
 
195
  return torch.cat(
@@ -202,47 +228,9 @@ class RotaryEmbedding(nn.Module):
202
  )
203
 
204
  def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
205
- """Perform the forward pass.
206
-
207
- Args:
208
- qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
209
- seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
210
-
211
- Returns:
212
- New `qkv` and the cached sinusoids.
213
-
214
- """
215
-
216
  self._update_cos_sin_cache(qkv, seqlen_offset)
217
-
218
- return self.apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
219
-
220
- def _update_kv_cache(kv, inference_params, layer_idx):
221
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
222
- Adapted from https://github.com/Dao-AILab/flash-attention."""
223
- # Pre-allocate memory for key-values for inference.
224
- num_heads, head_dim = kv.shape[-2:]
225
- if layer_idx not in inference_params.key_value_memory_dict:
226
- kv_cache = torch.empty(
227
- inference_params.max_batch_size, inference_params.max_sequence_len, 2,
228
- num_heads, head_dim, dtype=kv.dtype, device=kv.device
229
- )
230
- inference_params.key_value_memory_dict[layer_idx] = kv_cache
231
- else:
232
- kv_cache = inference_params.key_value_memory_dict[layer_idx]
233
-
234
- # Adjust key and value for inference
235
- batch_start = inference_params.batch_size_offset
236
- batch_end = batch_start + kv.shape[0]
237
- sequence_start = inference_params.sequence_len_offset
238
- sequence_end = sequence_start + kv.shape[1]
239
- assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
240
- assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
241
-
242
- assert kv_cache is not None
243
- kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
244
- kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
245
- return kv
246
 
247
 
248
  class MLP(nn.Module):
@@ -267,17 +255,6 @@ class MLP(nn.Module):
267
  self.fc2 = nn.Linear(n_inner, config.n_embd)
268
  self.act = ACT2FN[act_fn]
269
 
270
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
271
- old_keys = [prefix + "fc_in.weight", prefix + "fc_out.weight", prefix + "fc_in.bias", prefix + "fc_out.bias"]
272
- new_keys = [prefix + "fc1.weight", prefix + "fc2.weight", prefix + "fc1.bias", prefix + "fc2.bias"]
273
-
274
- if all(k in state_dict for k in old_keys) and not all(k in state_dict for k in new_keys):
275
- # Older version of `MLP` saved with different key names.
276
- for old_key, new_key in zip(old_keys, new_keys):
277
- state_dict[new_key] = state_dict.pop(old_key)
278
-
279
- return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
280
-
281
  def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
282
  hidden_states = self.fc1(hidden_states)
283
  hidden_states = self.act(hidden_states)
@@ -286,132 +263,114 @@ class MLP(nn.Module):
286
  return hidden_states
287
 
288
 
289
- class FusedMLP(nn.Module):
290
- """Fused Multi-Layer Perceptron from `flash-attn`.
291
-
292
  Reference:
293
- https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
294
 
295
  """
296
- def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None,
297
- raise_on_missing: bool = False) -> None:
298
- super().__init__()
299
-
300
- act_fn = config.activation_function if act_fn is None else act_fn
301
- assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
302
-
303
- n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
304
- n_inner = n_inner if n_inner is not None else 4 * config.n_embd
305
-
306
- gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"]
307
- activation = "gelu_approx" if act_fn in gelu_activations else "relu"
308
-
309
- self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
310
-
311
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
312
- return self.mlp(hidden_states)
313
 
314
- class SelfAttention(nn.Module):
315
- """Implement the scaled dot product attention with softmax.
316
- Adapted from https://github.com/Dao-AILab/flash-attention.
317
- Arguments
318
- ---------
319
- softmax_scale: The temperature to use for the softmax attention.
320
- (default: 1/sqrt(d_keys) where d_keys is computed at
321
- runtime)
322
- attention_dropout: The dropout rate to apply to the attention
323
- (default: 0.0)
324
- """
325
- def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
326
  super().__init__()
 
327
  self.causal = causal
328
  self.softmax_scale = softmax_scale
329
  self.drop = nn.Dropout(attention_dropout)
330
 
331
- def forward(self, qkv, causal=None, key_padding_mask=None):
332
- """Implements the multihead softmax attention.
333
- Arguments
334
- ---------
335
- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
336
- causal: if passed, will override self.causal
337
- key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
338
- False means to mask out. (B, S)
339
- """
340
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
341
  causal = self.causal if causal is None else causal
 
342
  q, k, v = qkv.unbind(dim=2)
 
343
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
344
- scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
345
- if key_padding_mask is not None:
346
- padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
347
- device=scores.device)
348
- padding_mask.masked_fill_(key_padding_mask, 0.0)
349
- # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
350
- scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
 
351
  if causal:
352
- # "triu_tril_cuda_template" not implemented for 'BFloat16'
353
- # So we have to construct the mask in float
354
- causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
355
- # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
356
  scores = scores + causal_mask.to(dtype=scores.dtype)
 
357
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
358
- attention_drop = self.drop(attention)
359
- output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
 
 
360
  return output
361
 
362
 
363
  class CrossAttention(nn.Module):
364
- """Implement the scaled dot product attention with softmax.
365
- Adapted from https://github.com/Dao-AILab/flash-attention.
366
- Arguments
367
- ---------
368
- softmax_scale: The temperature to use for the softmax attention.
369
- (default: 1/sqrt(d_keys) where d_keys is computed at
370
- runtime)
371
- attention_dropout: The dropout rate to apply to the attention
372
- (default: 0.0)
373
  """
374
- def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
 
 
 
 
 
 
375
  super().__init__()
 
376
  self.causal = causal
377
  self.softmax_scale = softmax_scale
378
  self.drop = nn.Dropout(attention_dropout)
379
 
380
- def forward(self, q, kv, causal=None, key_padding_mask=None):
381
- """Implements the multihead softmax attention.
382
- Arguments
383
- ---------
384
- q: The tensor containing the query. (B, Sq, H, D)
385
- kv: The tensor containing the key and value. (B, Sk, 2, H, D)
386
- causal: if passed, will override self.causal
387
- key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
388
- False means to mask out. (B, Sk)
389
- """
390
- batch_size, seqlen_q = q.shape[0], q.shape[1]
391
  causal = self.causal if causal is None else causal
392
- seqlen_k = kv.shape[1]
393
  assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
 
 
394
  k, v = kv.unbind(dim=2)
 
395
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
396
- scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
397
- if key_padding_mask is not None:
398
- padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
399
- device=scores.device)
400
- padding_mask.masked_fill_(key_padding_mask, 0.0)
401
- # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
402
- scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
 
403
  if causal:
404
- # "triu_tril_cuda_template" not implemented for 'BFloat16'
405
- # So we have to construct the mask in float
406
- causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
407
- device=scores.device), 1)
408
- # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
409
  scores = scores + causal_mask.to(dtype=scores.dtype)
 
410
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
411
- attention_drop = self.drop(attention)
412
- output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
 
 
413
  return output
414
 
 
415
  def find_mha_dims(
416
  config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
417
  ) -> Tuple[int, int]:
@@ -445,152 +404,163 @@ def find_mha_dims(
445
  return n_head, head_dim
446
 
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  class MHA(nn.Module):
449
- """Multi-head attention layer.
450
- Adapted from https://github.com/Dao-AILab/flash-attention."""
451
 
452
  def __init__(
453
  self,
454
  config: PretrainedConfig,
 
 
455
  rotary_dim: Optional[int] = None,
 
456
  n_head: Optional[int] = None,
457
  head_dim: Optional[int] = None,
458
- bias: Optional[bool] = True,
459
- dropout: Optional[float] = 0.0,
460
  softmax_scale: Optional[float] = None,
461
- causal: Optional[bool] = True,
462
  layer_idx: Optional[int] = None,
463
- rotary_emb_scale_base: Optional[float] = None,
464
- return_residual: Optional[bool] = False,
465
- checkpointing: Optional[bool] = False,
466
- device: Optional[str] = None,
467
- dtype: Optional[torch.dtype] = None,
468
- fused_dense: Optional[bool] = True,
469
- flash_attn: Optional[bool] = True,
470
- cutlass_attn: Optional[bool] = False,
471
- flash_rotary: Optional[bool] = True,
472
- raise_on_missing: Optional[bool] = False
473
  ) -> None:
474
  super().__init__()
475
 
476
- factory_kwargs = {"device": device, "dtype": dtype}
477
- n_head, head_dim = find_mha_dims(config, n_head, head_dim)
478
-
479
- self.hidden_size = config.n_embd
480
- self.n_head = n_head
481
- self.head_dim = head_dim
482
- self.op_size = n_head * head_dim
483
-
484
- self.causal = causal
485
- self.layer_idx = layer_idx
486
  self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
487
- self.fused_dense = fused_dense
488
- self.flash_attn = flash_attn
489
- self.cutlass_attn = cutlass_attn
490
- self.flash_rotary = flash_rotary
491
- self.return_residual = return_residual
492
- self.checkpointing = checkpointing
493
-
494
  if self.rotary_emb_dim > 0:
495
  rotary_kwargs = {"device": device}
496
  if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
497
  rotary_kwargs["scale_base"] = rotary_emb_scale_base
498
-
499
  self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
500
- else:
501
- pass
 
 
 
502
 
503
- self.Wqkv = nn.Linear(self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs)
504
- self.out_proj = nn.Linear(self.op_size, self.hidden_size, bias=bias, **factory_kwargs)
505
 
 
506
  self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
507
  self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
508
 
509
- def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
510
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
511
- Adapted from https://github.com/Dao-AILab/flash-attention."""
512
-
513
- assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
514
-
515
- return _update_kv_cache(kv, inference_params, self.layer_idx)
516
 
517
  def forward(
518
  self,
519
  x: torch.FloatTensor,
520
- x_kv: Optional[torch.FloatTensor] = None,
521
- key_padding_mask: Optional[torch.BoolTensor] = None,
522
  cu_seqlens: Optional[torch.LongTensor] = None,
523
  max_seqlen: Optional[int] = None,
524
- mixer_subset: Optional[torch.LongTensor] = None,
525
- past_cache: Optional[InferenceParams] = None,
526
- **kwargs
527
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
528
- """Perform the forward pass.
529
-
530
- Args:
531
- x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
532
- cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
533
- is the is the sum of the sequence lengths in the batch.
534
- x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
535
- key_padding_mask: boolean mask, True means to keep, False means to mask out.
536
- (batch, seqlen). Only applicable when not using FlashAttention.
537
- cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
538
- of the sequences in the batch, used to index into x. Only applicable when using
539
- FlashAttention.
540
- max_seqlen: int. Maximum sequence length in the batch.
541
- mixer_subset: for cross-attention only. If not None, will take a subset of x
542
- before applying the query projection. Useful for e.g., ViT where we only care
543
- about the CLS token in the last layer.
544
- past_cache: For generation only.
545
-
546
- Returns:
547
- (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
548
- else (total, hidden_dim) where total is the is the sum of the sequence lengths
549
- in the batch.
550
-
551
- """
552
-
553
- if cu_seqlens is not None:
554
- assert max_seqlen is not None
555
- assert key_padding_mask is None
556
- assert self.flash_attn
557
- assert self.rotary_emb_dim == 0
558
-
559
- if key_padding_mask is not None:
560
- assert cu_seqlens is None
561
- assert max_seqlen is None
562
- assert not self.flash_attn
563
-
564
- if past_cache is not None:
565
- assert key_padding_mask is None
566
- assert cu_seqlens is None and max_seqlen is None
567
-
568
- attn_kwargs = {"key_padding_mask": key_padding_mask}
569
-
570
- assert x_kv is None and mixer_subset is None
571
-
572
  qkv = self.Wqkv(x)
573
  qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
574
 
575
- if past_cache is None:
576
- if self.rotary_emb_dim > 0:
577
- qkv = self.rotary_emb(qkv)
578
- context = self.inner_attn(qkv, **attn_kwargs)
 
 
 
 
 
 
579
 
 
 
 
 
 
 
 
580
  else:
581
- if self.rotary_emb_dim > 0:
582
- qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
583
  q = qkv[:, :, 0]
584
- kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
585
- # If we're processing the prompt, causal=None (use self.causal).
586
- # If we're decoding, then causal=False.
587
- causal = None if past_cache.sequence_len_offset == 0 else False
588
- context = self.inner_cross_attn(q, kv, causal=causal)
589
 
590
- out = rearrange(context, "... h d -> ... (h d)")
591
- out = self.out_proj(out)
592
 
593
- return out if not self.return_residual else (out, x)
594
 
595
  class ParallelBlock(nn.Module):
596
  """Parallel block.
@@ -602,8 +572,6 @@ class ParallelBlock(nn.Module):
602
  def __init__(
603
  self,
604
  config: PretrainedConfig,
605
- mixer: Optional[Dict[str, Any]] = None,
606
- mlp: Optional[Dict[str, Any]] = None,
607
  block_idx: Optional[int] = None,
608
  ) -> None:
609
  super().__init__()
@@ -612,19 +580,20 @@ class ParallelBlock(nn.Module):
612
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
613
  self.block_idx = block_idx
614
 
615
- self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
616
- mlp_cls = mlp.pop('mlp_cls')
617
- if mlp_cls == 'fused_mlp':
618
- self.mlp = FusedMLP(config=config, **mlp)
619
- else:
620
- self.mlp = MLP(config=config, **mlp)
621
 
622
- def forward(self, hidden_states: torch.FloatTensor,
623
- past_cache: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
 
 
 
 
 
624
  residual = hidden_states
625
  hidden_states = self.ln(hidden_states)
626
 
627
- attn_outputs = self.mixer(hidden_states, past_cache=past_cache)
628
  if isinstance(attn_outputs, tuple):
629
  attn_outputs = attn_outputs[0]
630
 
@@ -635,6 +604,7 @@ class ParallelBlock(nn.Module):
635
 
636
  return hidden_states
637
 
 
638
  class CausalLMHead(nn.Module):
639
  """Causal Language Modeling head.
640
 
@@ -666,7 +636,7 @@ class CausalLMLoss(nn.Module):
666
 
667
  """
668
 
669
- def __init__(self, shift_labels: Optional[bool] = True) -> None:
670
  super().__init__()
671
 
672
  self.shift_labels = shift_labels
@@ -681,6 +651,7 @@ class CausalLMLoss(nn.Module):
681
 
682
  return loss
683
 
 
684
  class MixFormerSequentialPreTrainedModel(PreTrainedModel):
685
  """MixFormer (sequential for DeepSpeed) pre-trained model."""
686
 
@@ -691,9 +662,35 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
691
  def __init__(self, *inputs, **kwargs) -> None:
692
  super().__init__(*inputs, **kwargs)
693
 
694
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs) -> Dict[str, Any]:
695
- if "use_cache" in kwargs and not kwargs["use_cache"]:
696
- return {"input_ids": input_ids}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
 
698
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
699
  past_key_values = InferenceParams(
@@ -705,11 +702,15 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
705
  key_value_memory_dict={},
706
  )
707
  else:
708
- # assume past_key_values has cached all but last token in input_ids
709
  past_key_values.sequence_len_offset = len(input_ids[0]) - 1
710
  input_ids = input_ids[:, -1].unsqueeze(-1)
711
 
712
- return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
 
 
 
 
713
 
714
 
715
  class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
@@ -723,23 +724,7 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
723
  super().__init__(config)
724
 
725
  modules = [Embedding(config)]
726
- block_config = config.architecture
727
-
728
- if not isinstance(block_config, list):
729
- block_config = [block_config for _ in range(config.n_layer)]
730
-
731
- if config.n_layer != len(block_config):
732
- config.n_layer = len(block_config)
733
-
734
- for block_idx, block in enumerate(block_config):
735
- # `block_cls` with `legacy` value is for backward compatibility
736
- # `path` key is for backward compatibility
737
- block = copy.deepcopy(block) or {"block_cls": "parallel"}
738
- block_cls = block.pop("path", None) or block.pop("block_cls", None)
739
-
740
- block["block_idx"] = block_idx
741
- modules.append(ParallelBlock(config, **block))
742
-
743
  modules.append(CausalLMHead(config))
744
 
745
  self.layers = nn.Sequential(*modules)
@@ -760,20 +745,26 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
760
  self.layers[-1].linear = new_embeddings
761
 
762
  def forward(
763
- self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None,
764
- past_key_values: Optional[torch.FloatTensor] = None, **kwargs
 
 
 
 
765
  ) -> CausalLMOutputWithPast:
 
 
766
 
767
- if not past_key_values:
768
  lm_logits = self.layers(input_ids)
769
  else:
770
  hidden_layer = self.layers[0](input_ids)
771
  for module in self.layers[1:-1]:
772
- hidden_layer = module(hidden_layer, past_cache=past_key_values)
773
  lm_logits = self.layers[-1](hidden_layer)
774
 
775
  loss = None
776
  if labels is not None:
777
  loss = self.loss(lm_logits, labels)
778
-
779
  return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
 
1
  # Copyright (c) Microsoft Corporation.
2
  # Licensed under the MIT license.
3
+ #
4
  # BSD 3-Clause License
5
  #
6
  # Copyright (c) 2022, Tri Dao, [email protected].
 
50
 
51
  @dataclass
52
  class InferenceParams:
53
+ """Inference parameters passed to model to efficiently calculate
54
+ and store context during inference.
55
+
56
+ Reference:
57
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
58
+
59
+ Args:
60
+ max_sequence_len: Maximum sequence length.
61
+ max_batch_size: Maximum batch size.
62
+ sequence_len_offset: Sequence length offset.
63
+ batch_size_offset: Batch size offset.
64
+ key_value_memory_dict: Key value memory dictionary.
65
+ fused_ft_kernel: Whether to use fused kernel for fast inference.
66
+ lengths_per_sample: Lengths per sample.
67
+
68
+ """
69
+
70
+ max_sequence_len: int = field(metadata={"help": "Maximum sequence length."})
71
+
72
+ max_batch_size: int = field(metadata={"help": "Maximum batch size."})
73
+
74
+ sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
75
+
76
+ batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
77
+
78
+ key_value_memory_dict: Dict[str, Any] = field(
79
+ default_factory=dict, metadata={"help": "Key value memory dictionary."}
80
+ )
81
+
82
+ fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."})
83
+
84
+ lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
85
 
86
 
87
  class Embedding(nn.Module):
 
102
 
103
  return hidden_states
104
 
105
+
106
  class RotaryEmbedding(nn.Module):
107
+ """Rotary embeddings.
108
+
109
+ Reference:
110
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
111
+
112
+ """
113
 
114
  def __init__(
115
  self,
116
  dim: int,
117
+ base: int = 10000,
118
  scale_base: Optional[float] = None,
119
  device: Optional[str] = None,
120
  **kwargs,
 
146
  self._cos_k_cached = None
147
  self._sin_k_cached = None
148
 
149
+ def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None:
150
  # Reset the tables if the sequence length has changed,
151
  # or if we're on a new device (possibly due to tracing for instance)
152
  seqlen = x.shape[1] + seqlen_offset
 
180
  self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
181
  self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
182
 
183
+ def _apply_rotary_emb_qkv(
184
  self,
185
  qkv: torch.FloatTensor,
186
  sin: torch.FloatTensor,
 
216
 
217
  # Computes the new keys and queries, recasting to original dtype
218
  q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
 
219
  k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
220
 
221
  return torch.cat(
 
228
  )
229
 
230
  def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
231
+ # `qkv` is of shape (batch, seqlen, 3, nheads, headdim)
 
 
 
 
 
 
 
 
 
 
232
  self._update_cos_sin_cache(qkv, seqlen_offset)
233
+ return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
 
236
  class MLP(nn.Module):
 
255
  self.fc2 = nn.Linear(n_inner, config.n_embd)
256
  self.act = ACT2FN[act_fn]
257
 
 
 
 
 
 
 
 
 
 
 
 
258
  def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
259
  hidden_states = self.fc1(hidden_states)
260
  hidden_states = self.act(hidden_states)
 
263
  return hidden_states
264
 
265
 
266
+ class SelfAttention(nn.Module):
267
+ """Self-attention layer (compatible with PyTorch).
268
+
269
  Reference:
270
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
271
 
272
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ def __init__(
275
+ self,
276
+ causal: bool = True,
277
+ softmax_scale: Optional[float] = None,
278
+ attention_dropout: float = 0.0,
279
+ ) -> None:
 
 
 
 
 
 
280
  super().__init__()
281
+
282
  self.causal = causal
283
  self.softmax_scale = softmax_scale
284
  self.drop = nn.Dropout(attention_dropout)
285
 
286
+ def forward(
287
+ self,
288
+ qkv: torch.FloatTensor,
289
+ causal: bool = None,
290
+ attention_mask: Optional[torch.BoolTensor] = None,
291
+ **kwargs,
292
+ ) -> torch.FloatTensor:
 
 
 
293
  causal = self.causal if causal is None else causal
294
+ batch_size, seq_len = qkv.shape[0], qkv.shape[1]
295
  q, k, v = qkv.unbind(dim=2)
296
+
297
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
298
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
299
+
300
+ if attention_mask is not None:
301
+ padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device)
302
+ padding_mask.masked_fill_(attention_mask, 0.0)
303
+
304
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
305
+
306
  if causal:
307
+ causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1)
 
 
 
308
  scores = scores + causal_mask.to(dtype=scores.dtype)
309
+
310
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
311
+ attention = self.drop(attention)
312
+
313
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
314
+
315
  return output
316
 
317
 
318
  class CrossAttention(nn.Module):
319
+ """Cross-attention layer (compatible with PyTorch).
320
+
321
+ Reference:
322
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
323
+
 
 
 
 
324
  """
325
+
326
+ def __init__(
327
+ self,
328
+ causal: bool = True,
329
+ softmax_scale: Optional[float] = None,
330
+ attention_dropout: float = 0.0,
331
+ ) -> None:
332
  super().__init__()
333
+
334
  self.causal = causal
335
  self.softmax_scale = softmax_scale
336
  self.drop = nn.Dropout(attention_dropout)
337
 
338
+ def forward(
339
+ self,
340
+ q: torch.FloatTensor,
341
+ kv: torch.FloatTensor,
342
+ causal: bool = None,
343
+ attention_mask: Optional[torch.BoolTensor] = None,
344
+ **kwargs,
345
+ ) -> torch.FloatTensor:
 
 
 
346
  causal = self.causal if causal is None else causal
347
+ batch_size, seq_len_q = q.shape[0], q.shape[1]
348
  assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
349
+
350
+ seq_len_k = kv.shape[1]
351
  k, v = kv.unbind(dim=2)
352
+
353
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
354
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
355
+
356
+ if attention_mask is not None:
357
+ padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device)
358
+ padding_mask.masked_fill_(attention_mask, 0.0)
359
+
360
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
361
+
362
  if causal:
363
+ causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1)
 
 
 
 
364
  scores = scores + causal_mask.to(dtype=scores.dtype)
365
+
366
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
367
+ attention = self.drop(attention)
368
+
369
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
370
+
371
  return output
372
 
373
+
374
  def find_mha_dims(
375
  config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
376
  ) -> Tuple[int, int]:
 
404
  return n_head, head_dim
405
 
406
 
407
+ def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
408
+ """Update the key-value cache for inference.
409
+
410
+ Reference:
411
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
412
+
413
+ Args:
414
+ kv: Key-value tensor.
415
+ inference_params: Inference parameters.
416
+ layer_idx: Layer index.
417
+
418
+ Returns:
419
+ Updated key-value tensor.
420
+
421
+ """
422
+
423
+ num_heads, head_dim = kv.shape[-2:]
424
+
425
+ if layer_idx not in inference_params.key_value_memory_dict:
426
+ kv_cache = torch.empty(
427
+ inference_params.max_batch_size,
428
+ inference_params.max_sequence_len,
429
+ 2,
430
+ num_heads,
431
+ head_dim,
432
+ dtype=kv.dtype,
433
+ device=kv.device,
434
+ )
435
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
436
+ else:
437
+ if not inference_params.fused_ft_kernel:
438
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
439
+ else:
440
+ k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
441
+ kv_cache = None
442
+
443
+ batch_start = inference_params.batch_size_offset
444
+ batch_end = batch_start + kv.shape[0]
445
+ assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
446
+
447
+ sequence_start = inference_params.sequence_len_offset
448
+ sequence_end = sequence_start + kv.shape[1]
449
+ assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
450
+
451
+ if not inference_params.fused_ft_kernel:
452
+ assert kv_cache is not None
453
+
454
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
455
+ kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
456
+
457
+ return kv
458
+
459
+ assert inference_params.sequence_len_offset == 0
460
+ assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
461
+
462
+ packsize = 4 if kv.dtype == torch.float32 else 8
463
+
464
+ if kv_cache is not None:
465
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
466
+ k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous()
467
+ v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
468
+ inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
469
+ else:
470
+ k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
471
+ kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
472
+ )
473
+ v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
474
+
475
+ return kv
476
+
477
+
478
  class MHA(nn.Module):
479
+ """Multi-head attention layer."""
 
480
 
481
  def __init__(
482
  self,
483
  config: PretrainedConfig,
484
+ dtype: Optional[torch.dtype] = None,
485
+ device: Optional[str] = None,
486
  rotary_dim: Optional[int] = None,
487
+ rotary_emb_scale_base: Optional[float] = None,
488
  n_head: Optional[int] = None,
489
  head_dim: Optional[int] = None,
490
+ bias: bool = True,
491
+ causal: bool = True,
492
  softmax_scale: Optional[float] = None,
493
+ dropout: float = 0.0,
494
  layer_idx: Optional[int] = None,
495
+ return_residual: bool = False,
496
+ checkpointing: bool = False,
 
 
 
 
 
 
 
 
497
  ) -> None:
498
  super().__init__()
499
 
500
+ # Rotary embedding
 
 
 
 
 
 
 
 
 
501
  self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
 
 
 
 
 
 
 
502
  if self.rotary_emb_dim > 0:
503
  rotary_kwargs = {"device": device}
504
  if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
505
  rotary_kwargs["scale_base"] = rotary_emb_scale_base
 
506
  self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
507
+
508
+ # MLP
509
+ self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim)
510
+ op_size = self.n_head * self.head_dim
511
+ hidden_size = config.n_embd
512
 
513
+ self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype)
514
+ self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype)
515
 
516
+ # Attention
517
  self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
518
  self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
519
 
520
+ self.layer_idx = layer_idx
521
+ self.return_residual = return_residual
522
+ self.checkpointing = checkpointing
 
 
 
 
523
 
524
  def forward(
525
  self,
526
  x: torch.FloatTensor,
527
+ past_key_values: Optional[InferenceParams] = None,
528
+ attention_mask: Optional[torch.BoolTensor] = None,
529
  cu_seqlens: Optional[torch.LongTensor] = None,
530
  max_seqlen: Optional[int] = None,
531
+ **kwargs,
 
 
532
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  qkv = self.Wqkv(x)
534
  qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
535
 
536
+ seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0
537
+ if self.rotary_emb_dim > 0:
538
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
539
+
540
+ if past_key_values is not None:
541
+ kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
542
+
543
+ if attention_mask is not None:
544
+ attention_mask, cu_seqlens, max_seqlen = attention_mask
545
+ attention_mask = attention_mask.to(qkv.device)
546
 
547
+ attention_kwargs = {"attention_mask": attention_mask}
548
+
549
+ if past_key_values is None or seqlen_offset == 0:
550
+ if self.checkpointing:
551
+ attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs)
552
+ else:
553
+ attn_output = self.inner_attn(qkv, **attention_kwargs)
554
  else:
 
 
555
  q = qkv[:, :, 0]
556
+ causal = None if past_key_values.sequence_len_offset == 0 else False
557
+ attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs)
558
+
559
+ output = rearrange(attn_output, "... h d -> ... (h d)")
560
+ output = self.out_proj(output)
561
 
562
+ return output if not self.return_residual else (output, x)
 
563
 
 
564
 
565
  class ParallelBlock(nn.Module):
566
  """Parallel block.
 
572
  def __init__(
573
  self,
574
  config: PretrainedConfig,
 
 
575
  block_idx: Optional[int] = None,
576
  ) -> None:
577
  super().__init__()
 
580
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
581
  self.block_idx = block_idx
582
 
583
+ self.mixer = MHA(config, layer_idx=block_idx)
584
+ self.mlp = MLP(config)
 
 
 
 
585
 
586
+ def forward(
587
+ self,
588
+ hidden_states: torch.FloatTensor,
589
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
590
+ attention_mask: Optional[torch.BoolTensor] = None,
591
+ **kwargs,
592
+ ) -> torch.FloatTensor:
593
  residual = hidden_states
594
  hidden_states = self.ln(hidden_states)
595
 
596
+ attn_outputs = self.mixer(hidden_states, past_key_values=past_key_values, attention_mask=attention_mask)
597
  if isinstance(attn_outputs, tuple):
598
  attn_outputs = attn_outputs[0]
599
 
 
604
 
605
  return hidden_states
606
 
607
+
608
  class CausalLMHead(nn.Module):
609
  """Causal Language Modeling head.
610
 
 
636
 
637
  """
638
 
639
+ def __init__(self, shift_labels: bool = True) -> None:
640
  super().__init__()
641
 
642
  self.shift_labels = shift_labels
 
651
 
652
  return loss
653
 
654
+
655
  class MixFormerSequentialPreTrainedModel(PreTrainedModel):
656
  """MixFormer (sequential for DeepSpeed) pre-trained model."""
657
 
 
662
  def __init__(self, *inputs, **kwargs) -> None:
663
  super().__init__(*inputs, **kwargs)
664
 
665
+ def _init_weights(self, module: nn.Module) -> None:
666
+ if isinstance(module, (nn.Linear,)):
667
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
668
+ if module.bias is not None:
669
+ module.bias.data.zero_()
670
+ elif isinstance(module, nn.Embedding):
671
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
672
+ if module.padding_idx is not None:
673
+ module.weight.data[module.padding_idx].zero_()
674
+ elif isinstance(module, nn.LayerNorm):
675
+ module.bias.data.zero_()
676
+ module.weight.data.fill_(1.0)
677
+
678
+ def prepare_inputs_for_generation(
679
+ self,
680
+ input_ids: torch.LongTensor,
681
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
682
+ attention_mask: Optional[torch.BoolTensor] = None,
683
+ **kwargs,
684
+ ) -> Dict[str, Any]:
685
+ if attention_mask is not None and torch.any(~attention_mask.bool()):
686
+ total_seq_len = torch.sum(attention_mask, dim=1)
687
+ max_seq_len = torch.max(total_seq_len)
688
+
689
+ total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1)
690
+ cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32)
691
+ attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item())
692
+ else:
693
+ attention_mask = None
694
 
695
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
696
  past_key_values = InferenceParams(
 
702
  key_value_memory_dict={},
703
  )
704
  else:
705
+ # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
706
  past_key_values.sequence_len_offset = len(input_ids[0]) - 1
707
  input_ids = input_ids[:, -1].unsqueeze(-1)
708
 
709
+ return {
710
+ "input_ids": input_ids,
711
+ "past_key_values": past_key_values,
712
+ "attention_mask": attention_mask,
713
+ }
714
 
715
 
716
  class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
 
724
  super().__init__(config)
725
 
726
  modules = [Embedding(config)]
727
+ modules += [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  modules.append(CausalLMHead(config))
729
 
730
  self.layers = nn.Sequential(*modules)
 
745
  self.layers[-1].linear = new_embeddings
746
 
747
  def forward(
748
+ self,
749
+ input_ids: torch.LongTensor,
750
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
751
+ attention_mask: Optional[torch.BoolTensor] = None,
752
+ labels: Optional[torch.LongTensor] = None,
753
+ **kwargs,
754
  ) -> CausalLMOutputWithPast:
755
+ if attention_mask is not None and self.training:
756
+ raise ValueError("`attention_mask` is not supported during training.")
757
 
758
+ if past_key_values is None and attention_mask is None:
759
  lm_logits = self.layers(input_ids)
760
  else:
761
  hidden_layer = self.layers[0](input_ids)
762
  for module in self.layers[1:-1]:
763
+ hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
764
  lm_logits = self.layers[-1](hidden_layer)
765
 
766
  loss = None
767
  if labels is not None:
768
  loss = self.loss(lm_logits, labels)
769
+
770
  return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)