Adds support for flash-attn rotary embedding and fused dense layers.
Browse files- modeling_mixformer_sequential.py +59 -19
modeling_mixformer_sequential.py
CHANGED
@@ -32,6 +32,7 @@
|
|
32 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
33 |
|
34 |
from __future__ import annotations
|
|
|
35 |
|
36 |
import math
|
37 |
from typing import Any, Dict, Optional, Tuple, Union
|
@@ -48,6 +49,18 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
48 |
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
@dataclass
|
52 |
class InferenceParams:
|
53 |
"""Inference parameters passed to model to efficiently calculate
|
@@ -213,6 +226,7 @@ class RotaryEmbedding(nn.Module):
|
|
213 |
dim: int,
|
214 |
base: int = 10000,
|
215 |
scale_base: Optional[float] = None,
|
|
|
216 |
device: Optional[str] = None,
|
217 |
**kwargs,
|
218 |
) -> None:
|
@@ -221,15 +235,17 @@ class RotaryEmbedding(nn.Module):
|
|
221 |
if scale_base is not None:
|
222 |
raise NotImplementedError
|
223 |
|
224 |
-
# Generate and save the inverse frequency buffer (non-trainable)
|
225 |
self.dim = dim
|
226 |
-
self.base = base
|
227 |
self.scale_base = scale_base
|
|
|
228 |
self.device = device
|
229 |
|
230 |
-
|
|
|
231 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
232 |
|
|
|
233 |
scale = (
|
234 |
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
235 |
if scale_base is not None
|
@@ -243,23 +259,37 @@ class RotaryEmbedding(nn.Module):
|
|
243 |
self._cos_k_cached = None
|
244 |
self._sin_k_cached = None
|
245 |
|
|
|
|
|
|
|
246 |
def _update_cos_sin_cache(
|
247 |
self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
|
248 |
) -> None:
|
249 |
-
#
|
250 |
-
#
|
251 |
-
if
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
|
|
257 |
self._seq_len_cached = seqlen
|
258 |
-
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
259 |
|
260 |
-
#
|
261 |
-
#
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
if self.scale is None:
|
264 |
self._cos_cached = torch.cos(freqs).to(dtype)
|
265 |
self._sin_cached = torch.sin(freqs).to(dtype)
|
@@ -269,7 +299,7 @@ class RotaryEmbedding(nn.Module):
|
|
269 |
) / self.scale_base
|
270 |
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
271 |
|
272 |
-
#
|
273 |
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
274 |
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
275 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
@@ -520,6 +550,8 @@ class MHA(nn.Module):
|
|
520 |
causal: bool = True,
|
521 |
softmax_scale: Optional[float] = None,
|
522 |
dropout: float = 0.0,
|
|
|
|
|
523 |
layer_idx: Optional[int] = None,
|
524 |
return_residual: bool = False,
|
525 |
checkpointing: bool = False,
|
@@ -532,15 +564,23 @@ class MHA(nn.Module):
|
|
532 |
rotary_kwargs = {"device": device}
|
533 |
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
534 |
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
535 |
-
|
|
|
|
|
|
|
|
|
536 |
|
537 |
# MLP
|
538 |
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
539 |
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
540 |
hidden_size = config.n_embd
|
541 |
|
542 |
-
|
543 |
-
|
|
|
|
|
|
|
|
|
544 |
|
545 |
# Attention
|
546 |
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
|
|
32 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
33 |
|
34 |
from __future__ import annotations
|
35 |
+
import importlib
|
36 |
|
37 |
import math
|
38 |
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
49 |
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
50 |
|
51 |
|
52 |
+
def _is_flash_attn_available() -> bool:
|
53 |
+
return importlib.util.find_spec("flash_attn") is not None
|
54 |
+
|
55 |
+
|
56 |
+
if _is_flash_attn_available():
|
57 |
+
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
58 |
+
from flash_attn.ops.fused_dense import FusedDense
|
59 |
+
else:
|
60 |
+
FlashRotaryEmbedding = None
|
61 |
+
FusedDense = None
|
62 |
+
|
63 |
+
|
64 |
@dataclass
|
65 |
class InferenceParams:
|
66 |
"""Inference parameters passed to model to efficiently calculate
|
|
|
226 |
dim: int,
|
227 |
base: int = 10000,
|
228 |
scale_base: Optional[float] = None,
|
229 |
+
pos_idx_in_fp32: bool = True,
|
230 |
device: Optional[str] = None,
|
231 |
**kwargs,
|
232 |
) -> None:
|
|
|
235 |
if scale_base is not None:
|
236 |
raise NotImplementedError
|
237 |
|
|
|
238 |
self.dim = dim
|
239 |
+
self.base = float(base)
|
240 |
self.scale_base = scale_base
|
241 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
242 |
self.device = device
|
243 |
|
244 |
+
# Generate and save the inverse frequency buffer (non-trainable)
|
245 |
+
inv_freq = self._compute_inv_freq(device)
|
246 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
247 |
|
248 |
+
# Generate and save the scale buffer (non-trainable)
|
249 |
scale = (
|
250 |
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
251 |
if scale_base is not None
|
|
|
259 |
self._cos_k_cached = None
|
260 |
self._sin_k_cached = None
|
261 |
|
262 |
+
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
263 |
+
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
264 |
+
|
265 |
def _update_cos_sin_cache(
|
266 |
self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
|
267 |
) -> None:
|
268 |
+
# Reset the tables if sequence length has been chaned, if we are on a
|
269 |
+
# new device or if we are switching from inference mode to training
|
270 |
+
if (
|
271 |
+
seqlen > self._seq_len_cached
|
272 |
+
or self._cos_cached is None
|
273 |
+
or self._cos_cached.device != device
|
274 |
+
or self._cos_cached.dtype != dtype
|
275 |
+
or (self.training and self._cos_cached.is_inference())
|
276 |
+
):
|
277 |
self._seq_len_cached = seqlen
|
|
|
278 |
|
279 |
+
# fp32 is preferred since the output of `torch.arange` can be quite large
|
280 |
+
# and bf16 would lose a lot of precision
|
281 |
+
if self.pos_idx_in_fp32:
|
282 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
283 |
+
if self.inv_freq.dtype != torch.float32:
|
284 |
+
inv_freq = self._compute_inv_freq(device=device)
|
285 |
+
else:
|
286 |
+
inv_freq = self.inv_freq
|
287 |
+
else:
|
288 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
289 |
+
inv_freq = self.inv_freq
|
290 |
+
|
291 |
+
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
292 |
+
freqs = torch.outer(t, inv_freq)
|
293 |
if self.scale is None:
|
294 |
self._cos_cached = torch.cos(freqs).to(dtype)
|
295 |
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
|
299 |
) / self.scale_base
|
300 |
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
301 |
|
302 |
+
# Force the scale multiplication to happen in fp32
|
303 |
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
304 |
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
305 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
|
|
550 |
causal: bool = True,
|
551 |
softmax_scale: Optional[float] = None,
|
552 |
dropout: float = 0.0,
|
553 |
+
flash_rotary: bool = True,
|
554 |
+
fused_dense: bool = True,
|
555 |
layer_idx: Optional[int] = None,
|
556 |
return_residual: bool = False,
|
557 |
checkpointing: bool = False,
|
|
|
564 |
rotary_kwargs = {"device": device}
|
565 |
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
566 |
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
567 |
+
|
568 |
+
rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
|
569 |
+
if rotary_cls is None:
|
570 |
+
rotary_cls = RotaryEmbedding
|
571 |
+
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
572 |
|
573 |
# MLP
|
574 |
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
575 |
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
576 |
hidden_size = config.n_embd
|
577 |
|
578 |
+
linear_cls = FusedDense if fused_dense else nn.Linear
|
579 |
+
if linear_cls is None:
|
580 |
+
linear_cls = nn.Linear
|
581 |
+
|
582 |
+
self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
|
583 |
+
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
584 |
|
585 |
# Attention
|
586 |
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|