Fixes flash-attn import with a try/except statement
Browse files
modeling_mixformer_sequential.py
CHANGED
@@ -32,7 +32,6 @@
|
|
32 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
33 |
|
34 |
from __future__ import annotations
|
35 |
-
import importlib
|
36 |
|
37 |
import math
|
38 |
from typing import Any, Dict, Optional, Tuple, Union
|
@@ -49,14 +48,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
49 |
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
50 |
|
51 |
|
52 |
-
|
53 |
-
return importlib.util.find_spec("flash_attn") is not None
|
54 |
-
|
55 |
-
|
56 |
-
if _is_flash_attn_available():
|
57 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
58 |
from flash_attn.ops.fused_dense import FusedDense
|
59 |
-
|
60 |
FlashRotaryEmbedding = None
|
61 |
FusedDense = None
|
62 |
|
@@ -549,9 +544,6 @@ class MHA(nn.Module):
|
|
549 |
bias: bool = True,
|
550 |
causal: bool = True,
|
551 |
softmax_scale: Optional[float] = None,
|
552 |
-
dropout: float = 0.0,
|
553 |
-
flash_rotary: bool = True,
|
554 |
-
fused_dense: bool = True,
|
555 |
layer_idx: Optional[int] = None,
|
556 |
return_residual: bool = False,
|
557 |
checkpointing: bool = False,
|
@@ -565,7 +557,7 @@ class MHA(nn.Module):
|
|
565 |
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
566 |
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
567 |
|
568 |
-
rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
|
569 |
if rotary_cls is None:
|
570 |
rotary_cls = RotaryEmbedding
|
571 |
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
@@ -575,7 +567,7 @@ class MHA(nn.Module):
|
|
575 |
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
576 |
hidden_size = config.n_embd
|
577 |
|
578 |
-
linear_cls = FusedDense if fused_dense else nn.Linear
|
579 |
if linear_cls is None:
|
580 |
linear_cls = nn.Linear
|
581 |
|
@@ -583,8 +575,8 @@ class MHA(nn.Module):
|
|
583 |
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
584 |
|
585 |
# Attention
|
586 |
-
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=
|
587 |
-
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=
|
588 |
|
589 |
self.layer_idx = layer_idx
|
590 |
self.return_residual = return_residual
|
|
|
32 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
33 |
|
34 |
from __future__ import annotations
|
|
|
35 |
|
36 |
import math
|
37 |
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
48 |
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
49 |
|
50 |
|
51 |
+
try:
|
|
|
|
|
|
|
|
|
52 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
53 |
from flash_attn.ops.fused_dense import FusedDense
|
54 |
+
except:
|
55 |
FlashRotaryEmbedding = None
|
56 |
FusedDense = None
|
57 |
|
|
|
544 |
bias: bool = True,
|
545 |
causal: bool = True,
|
546 |
softmax_scale: Optional[float] = None,
|
|
|
|
|
|
|
547 |
layer_idx: Optional[int] = None,
|
548 |
return_residual: bool = False,
|
549 |
checkpointing: bool = False,
|
|
|
557 |
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
558 |
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
559 |
|
560 |
+
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
561 |
if rotary_cls is None:
|
562 |
rotary_cls = RotaryEmbedding
|
563 |
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
|
|
567 |
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
568 |
hidden_size = config.n_embd
|
569 |
|
570 |
+
linear_cls = FusedDense if config.fused_dense else nn.Linear
|
571 |
if linear_cls is None:
|
572 |
linear_cls = nn.Linear
|
573 |
|
|
|
575 |
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
576 |
|
577 |
# Attention
|
578 |
+
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
|
579 |
+
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
|
580 |
|
581 |
self.layer_idx = layer_idx
|
582 |
self.return_residual = return_residual
|