gugarosa commited on
Commit
0254d42
1 Parent(s): 0bbd68a

Fixes flash-attn import with a try/except statement

Browse files
Files changed (1) hide show
  1. modeling_mixformer_sequential.py +6 -14
modeling_mixformer_sequential.py CHANGED
@@ -32,7 +32,6 @@
32
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
 
34
  from __future__ import annotations
35
- import importlib
36
 
37
  import math
38
  from typing import Any, Dict, Optional, Tuple, Union
@@ -49,14 +48,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
49
  from .configuration_mixformer_sequential import MixFormerSequentialConfig
50
 
51
 
52
- def _is_flash_attn_available() -> bool:
53
- return importlib.util.find_spec("flash_attn") is not None
54
-
55
-
56
- if _is_flash_attn_available():
57
  from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
58
  from flash_attn.ops.fused_dense import FusedDense
59
- else:
60
  FlashRotaryEmbedding = None
61
  FusedDense = None
62
 
@@ -549,9 +544,6 @@ class MHA(nn.Module):
549
  bias: bool = True,
550
  causal: bool = True,
551
  softmax_scale: Optional[float] = None,
552
- dropout: float = 0.0,
553
- flash_rotary: bool = True,
554
- fused_dense: bool = True,
555
  layer_idx: Optional[int] = None,
556
  return_residual: bool = False,
557
  checkpointing: bool = False,
@@ -565,7 +557,7 @@ class MHA(nn.Module):
565
  if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
566
  rotary_kwargs["scale_base"] = rotary_emb_scale_base
567
 
568
- rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
569
  if rotary_cls is None:
570
  rotary_cls = RotaryEmbedding
571
  self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
@@ -575,7 +567,7 @@ class MHA(nn.Module):
575
  op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
576
  hidden_size = config.n_embd
577
 
578
- linear_cls = FusedDense if fused_dense else nn.Linear
579
  if linear_cls is None:
580
  linear_cls = nn.Linear
581
 
@@ -583,8 +575,8 @@ class MHA(nn.Module):
583
  self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
584
 
585
  # Attention
586
- self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
587
- self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
588
 
589
  self.layer_idx = layer_idx
590
  self.return_residual = return_residual
 
32
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
 
34
  from __future__ import annotations
 
35
 
36
  import math
37
  from typing import Any, Dict, Optional, Tuple, Union
 
48
  from .configuration_mixformer_sequential import MixFormerSequentialConfig
49
 
50
 
51
+ try:
 
 
 
 
52
  from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
53
  from flash_attn.ops.fused_dense import FusedDense
54
+ except:
55
  FlashRotaryEmbedding = None
56
  FusedDense = None
57
 
 
544
  bias: bool = True,
545
  causal: bool = True,
546
  softmax_scale: Optional[float] = None,
 
 
 
547
  layer_idx: Optional[int] = None,
548
  return_residual: bool = False,
549
  checkpointing: bool = False,
 
557
  if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
558
  rotary_kwargs["scale_base"] = rotary_emb_scale_base
559
 
560
+ rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
561
  if rotary_cls is None:
562
  rotary_cls = RotaryEmbedding
563
  self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
 
567
  op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
568
  hidden_size = config.n_embd
569
 
570
+ linear_cls = FusedDense if config.fused_dense else nn.Linear
571
  if linear_cls is None:
572
  linear_cls = nn.Linear
573
 
 
575
  self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
576
 
577
  # Attention
578
+ self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
579
+ self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
580
 
581
  self.layer_idx = layer_idx
582
  self.return_residual = return_residual