zR commited on
Commit
bcf026a
1 Parent(s): a7cafb0

fix import error of flash attn

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +9 -5
modeling_chatglm.py CHANGED
@@ -21,16 +21,20 @@ from transformers.modeling_outputs import (
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
- from transformers.utils import logging, is_torch_npu_available, is_flash_attn_greater_or_equal_2_10, \
25
- is_flash_attn_2_available
26
  from transformers.generation.logits_process import LogitsProcessor
27
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
28
 
29
  from .configuration_chatglm import ChatGLMConfig
30
 
31
- if is_flash_attn_2_available():
32
- from flash_attn import flash_attn_func, flash_attn_varlen_func
33
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
 
 
 
34
 
35
  # flags required to enable jit fusion kernels
36
 
 
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging, is_torch_npu_available
 
25
  from transformers.generation.logits_process import LogitsProcessor
26
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
 
28
  from .configuration_chatglm import ChatGLMConfig
29
 
30
+ try:
31
+ from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
32
+ if is_flash_attn_2_available():
33
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
34
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
35
+ except:
36
+ pass
37
+
38
 
39
  # flags required to enable jit fusion kernels
40