suriyagunasekar commited on
Commit
d655135
1 Parent(s): 07a048e

Upload MixFormerSequentialForCausalLM

Browse files
Files changed (1) hide show
  1. modeling_mixformer_sequential.py +41 -5
modeling_mixformer_sequential.py CHANGED
@@ -1,6 +1,36 @@
1
  # Copyright (c) Microsoft Corporation.
2
  # Licensed under the MIT license.
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from __future__ import annotations
5
 
6
  import math
@@ -21,7 +51,8 @@ from .configuration_mixformer_sequential import MixFormerSequentialConfig
21
  @dataclass
22
  class InferenceParams:
23
  """Inference parameters that are passed to the main model in order
24
- to efficienly calculate and store the context during inference."""
 
25
  max_sequence_len: int
26
  max_batch_size: int
27
  sequence_len_offset: int = 0
@@ -50,7 +81,8 @@ class Embedding(nn.Module):
50
  return hidden_states
51
 
52
  class RotaryEmbedding(nn.Module):
53
- """PyTorch implementation of `flash-attn` RotaryEmbedding layer."""
 
54
 
55
  def __init__(
56
  self,
@@ -187,7 +219,7 @@ class RotaryEmbedding(nn.Module):
187
 
188
  def _update_kv_cache(kv, inference_params, layer_idx):
189
  """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
190
- """
191
  # Pre-allocate memory for key-values for inference.
192
  num_heads, head_dim = kv.shape[-2:]
193
  if layer_idx not in inference_params.key_value_memory_dict:
@@ -281,6 +313,7 @@ class FusedMLP(nn.Module):
281
 
282
  class SelfAttention(nn.Module):
283
  """Implement the scaled dot product attention with softmax.
 
284
  Arguments
285
  ---------
286
  softmax_scale: The temperature to use for the softmax attention.
@@ -329,6 +362,7 @@ class SelfAttention(nn.Module):
329
 
330
  class CrossAttention(nn.Module):
331
  """Implement the scaled dot product attention with softmax.
 
332
  Arguments
333
  ---------
334
  softmax_scale: The temperature to use for the softmax attention.
@@ -412,7 +446,8 @@ def find_mha_dims(
412
 
413
 
414
  class MHA(nn.Module):
415
- """Multi-head attention layer."""
 
416
 
417
  def __init__(
418
  self,
@@ -472,7 +507,8 @@ class MHA(nn.Module):
472
  self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
473
 
474
  def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
475
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
 
476
 
477
  assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
478
 
 
1
  # Copyright (c) Microsoft Corporation.
2
  # Licensed under the MIT license.
3
 
4
+ # BSD 3-Clause License
5
+ #
6
+ # Copyright (c) 2022, Tri Dao, [email protected].
7
+ # All rights reserved.
8
+ #
9
+ # Redistribution and use in source and binary forms, with or without
10
+ # modification, are permitted provided that the following conditions are met:
11
+ #
12
+ # * Redistributions of source code must retain the above copyright notice, this
13
+ # list of conditions and the following disclaimer.
14
+ #
15
+ # * Redistributions in binary form must reproduce the above copyright notice,
16
+ # this list of conditions and the following disclaimer in the documentation
17
+ # and/or other materials provided with the distribution.
18
+ #
19
+ # * Neither the name of the copyright holder nor the names of its
20
+ # contributors may be used to endorse or promote products derived from
21
+ # this software without specific prior written permission.
22
+ #
23
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
24
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
27
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
29
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
31
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
+
34
  from __future__ import annotations
35
 
36
  import math
 
51
  @dataclass
52
  class InferenceParams:
53
  """Inference parameters that are passed to the main model in order
54
+ to efficienly calculate and store the context during inference.
55
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
56
  max_sequence_len: int
57
  max_batch_size: int
58
  sequence_len_offset: int = 0
 
81
  return hidden_states
82
 
83
  class RotaryEmbedding(nn.Module):
84
+ """PyTorch implementation of `flash-attn` RotaryEmbedding layer.
85
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
86
 
87
  def __init__(
88
  self,
 
219
 
220
  def _update_kv_cache(kv, inference_params, layer_idx):
221
  """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
222
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
223
  # Pre-allocate memory for key-values for inference.
224
  num_heads, head_dim = kv.shape[-2:]
225
  if layer_idx not in inference_params.key_value_memory_dict:
 
313
 
314
  class SelfAttention(nn.Module):
315
  """Implement the scaled dot product attention with softmax.
316
+ Adapted from https://github.com/Dao-AILab/flash-attention.
317
  Arguments
318
  ---------
319
  softmax_scale: The temperature to use for the softmax attention.
 
362
 
363
  class CrossAttention(nn.Module):
364
  """Implement the scaled dot product attention with softmax.
365
+ Adapted from https://github.com/Dao-AILab/flash-attention.
366
  Arguments
367
  ---------
368
  softmax_scale: The temperature to use for the softmax attention.
 
446
 
447
 
448
  class MHA(nn.Module):
449
+ """Multi-head attention layer.
450
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
451
 
452
  def __init__(
453
  self,
 
507
  self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
508
 
509
  def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
510
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
511
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
512
 
513
  assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
514