feat: check flash-attn version if installed

#14
by reedcli - opened
Files changed (1) hide show
  1. modeling_yi.py +8 -1
modeling_yi.py CHANGED
@@ -29,6 +29,11 @@ try:
29
  except Exception:
30
  is_flash_attn_available = False
31
 
 
 
 
 
 
32
  logger = logging.get_logger(__name__)
33
 
34
  _CONFIG_FOR_DOC = "YiConfig"
@@ -539,7 +544,9 @@ class YiModel(YiPreTrainedModel):
539
  def _prepare_decoder_attention_mask(
540
  self, attention_mask, input_ids, inputs_embeds, past_key_values_length
541
  ):
542
- input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
 
 
543
  # create causal mask
544
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
545
  combined_attention_mask = None
 
29
  except Exception:
30
  is_flash_attn_available = False
31
 
32
+ if is_flash_attn_available:
33
+ from flash_attn import __version__
34
+
35
+ assert __version__ >= "2.3.0", "please update your flash_attn version"
36
+
37
  logger = logging.get_logger(__name__)
38
 
39
  _CONFIG_FOR_DOC = "YiConfig"
 
544
  def _prepare_decoder_attention_mask(
545
  self, attention_mask, input_ids, inputs_embeds, past_key_values_length
546
  ):
547
+ input_shape = (
548
+ input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
549
+ )
550
  # create causal mask
551
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
552
  combined_attention_mask = None