suriyagunasekar
commited on
Commit
•
d655135
1
Parent(s):
07a048e
Upload MixFormerSequentialForCausalLM
Browse files
modeling_mixformer_sequential.py
CHANGED
@@ -1,6 +1,36 @@
|
|
1 |
# Copyright (c) Microsoft Corporation.
|
2 |
# Licensed under the MIT license.
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from __future__ import annotations
|
5 |
|
6 |
import math
|
@@ -21,7 +51,8 @@ from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
|
21 |
@dataclass
|
22 |
class InferenceParams:
|
23 |
"""Inference parameters that are passed to the main model in order
|
24 |
-
to efficienly calculate and store the context during inference.
|
|
|
25 |
max_sequence_len: int
|
26 |
max_batch_size: int
|
27 |
sequence_len_offset: int = 0
|
@@ -50,7 +81,8 @@ class Embedding(nn.Module):
|
|
50 |
return hidden_states
|
51 |
|
52 |
class RotaryEmbedding(nn.Module):
|
53 |
-
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
|
|
|
54 |
|
55 |
def __init__(
|
56 |
self,
|
@@ -187,7 +219,7 @@ class RotaryEmbedding(nn.Module):
|
|
187 |
|
188 |
def _update_kv_cache(kv, inference_params, layer_idx):
|
189 |
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
190 |
-
"""
|
191 |
# Pre-allocate memory for key-values for inference.
|
192 |
num_heads, head_dim = kv.shape[-2:]
|
193 |
if layer_idx not in inference_params.key_value_memory_dict:
|
@@ -281,6 +313,7 @@ class FusedMLP(nn.Module):
|
|
281 |
|
282 |
class SelfAttention(nn.Module):
|
283 |
"""Implement the scaled dot product attention with softmax.
|
|
|
284 |
Arguments
|
285 |
---------
|
286 |
softmax_scale: The temperature to use for the softmax attention.
|
@@ -329,6 +362,7 @@ class SelfAttention(nn.Module):
|
|
329 |
|
330 |
class CrossAttention(nn.Module):
|
331 |
"""Implement the scaled dot product attention with softmax.
|
|
|
332 |
Arguments
|
333 |
---------
|
334 |
softmax_scale: The temperature to use for the softmax attention.
|
@@ -412,7 +446,8 @@ def find_mha_dims(
|
|
412 |
|
413 |
|
414 |
class MHA(nn.Module):
|
415 |
-
"""Multi-head attention layer.
|
|
|
416 |
|
417 |
def __init__(
|
418 |
self,
|
@@ -472,7 +507,8 @@ class MHA(nn.Module):
|
|
472 |
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
473 |
|
474 |
def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
|
475 |
-
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
|
|
476 |
|
477 |
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
478 |
|
|
|
1 |
# Copyright (c) Microsoft Corporation.
|
2 |
# Licensed under the MIT license.
|
3 |
|
4 |
+
# BSD 3-Clause License
|
5 |
+
#
|
6 |
+
# Copyright (c) 2022, Tri Dao, [email protected].
|
7 |
+
# All rights reserved.
|
8 |
+
#
|
9 |
+
# Redistribution and use in source and binary forms, with or without
|
10 |
+
# modification, are permitted provided that the following conditions are met:
|
11 |
+
#
|
12 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
13 |
+
# list of conditions and the following disclaimer.
|
14 |
+
#
|
15 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
16 |
+
# this list of conditions and the following disclaimer in the documentation
|
17 |
+
# and/or other materials provided with the distribution.
|
18 |
+
#
|
19 |
+
# * Neither the name of the copyright holder nor the names of its
|
20 |
+
# contributors may be used to endorse or promote products derived from
|
21 |
+
# this software without specific prior written permission.
|
22 |
+
#
|
23 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
24 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
25 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
26 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
27 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
28 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
29 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
30 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
31 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
32 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
33 |
+
|
34 |
from __future__ import annotations
|
35 |
|
36 |
import math
|
|
|
51 |
@dataclass
|
52 |
class InferenceParams:
|
53 |
"""Inference parameters that are passed to the main model in order
|
54 |
+
to efficienly calculate and store the context during inference.
|
55 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
56 |
max_sequence_len: int
|
57 |
max_batch_size: int
|
58 |
sequence_len_offset: int = 0
|
|
|
81 |
return hidden_states
|
82 |
|
83 |
class RotaryEmbedding(nn.Module):
|
84 |
+
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
|
85 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
86 |
|
87 |
def __init__(
|
88 |
self,
|
|
|
219 |
|
220 |
def _update_kv_cache(kv, inference_params, layer_idx):
|
221 |
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
222 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
223 |
# Pre-allocate memory for key-values for inference.
|
224 |
num_heads, head_dim = kv.shape[-2:]
|
225 |
if layer_idx not in inference_params.key_value_memory_dict:
|
|
|
313 |
|
314 |
class SelfAttention(nn.Module):
|
315 |
"""Implement the scaled dot product attention with softmax.
|
316 |
+
Adapted from https://github.com/Dao-AILab/flash-attention.
|
317 |
Arguments
|
318 |
---------
|
319 |
softmax_scale: The temperature to use for the softmax attention.
|
|
|
362 |
|
363 |
class CrossAttention(nn.Module):
|
364 |
"""Implement the scaled dot product attention with softmax.
|
365 |
+
Adapted from https://github.com/Dao-AILab/flash-attention.
|
366 |
Arguments
|
367 |
---------
|
368 |
softmax_scale: The temperature to use for the softmax attention.
|
|
|
446 |
|
447 |
|
448 |
class MHA(nn.Module):
|
449 |
+
"""Multi-head attention layer.
|
450 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
451 |
|
452 |
def __init__(
|
453 |
self,
|
|
|
507 |
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
508 |
|
509 |
def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
|
510 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
511 |
+
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
512 |
|
513 |
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
514 |
|