feat: check flash-attn version if installed
#14
by
reedcli
- opened
- modeling_yi.py +8 -1
modeling_yi.py
CHANGED
@@ -29,6 +29,11 @@ try:
|
|
29 |
except Exception:
|
30 |
is_flash_attn_available = False
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
logger = logging.get_logger(__name__)
|
33 |
|
34 |
_CONFIG_FOR_DOC = "YiConfig"
|
@@ -539,7 +544,9 @@ class YiModel(YiPreTrainedModel):
|
|
539 |
def _prepare_decoder_attention_mask(
|
540 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
541 |
):
|
542 |
-
input_shape =
|
|
|
|
|
543 |
# create causal mask
|
544 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
545 |
combined_attention_mask = None
|
|
|
29 |
except Exception:
|
30 |
is_flash_attn_available = False
|
31 |
|
32 |
+
if is_flash_attn_available:
|
33 |
+
from flash_attn import __version__
|
34 |
+
|
35 |
+
assert __version__ >= "2.3.0", "please update your flash_attn version"
|
36 |
+
|
37 |
logger = logging.get_logger(__name__)
|
38 |
|
39 |
_CONFIG_FOR_DOC = "YiConfig"
|
|
|
544 |
def _prepare_decoder_attention_mask(
|
545 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
546 |
):
|
547 |
+
input_shape = (
|
548 |
+
input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
|
549 |
+
)
|
550 |
# create causal mask
|
551 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
552 |
combined_attention_mask = None
|