Update modeling_phi.py
Browse files- modeling_phi.py +7 -4
modeling_phi.py
CHANGED
@@ -47,10 +47,13 @@ from transformers.utils import (
|
|
47 |
from .configuration_phi import PhiConfig
|
48 |
|
49 |
|
50 |
-
try:
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
54 |
pass
|
55 |
|
56 |
|
|
|
47 |
from .configuration_phi import PhiConfig
|
48 |
|
49 |
|
50 |
+
try: # noqa: SIM105
|
51 |
+
if is_flash_attn_2_available():
|
52 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
53 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
54 |
+
except ImportError:
|
55 |
+
# Workaround for https://github.com/huggingface/transformers/issues/28459,
|
56 |
+
# don't move to contextlib.suppress(ImportError)
|
57 |
pass
|
58 |
|
59 |
|