x54-729 commited on
Commit
1249bed
·
1 Parent(s): e6e6069

fix import error

Browse files
Files changed (1) hide show
  1. modeling_internlm.py +19 -5
modeling_internlm.py CHANGED
@@ -48,6 +48,20 @@ logger = logging.get_logger(__name__)
48
 
49
  _CONFIG_FOR_DOC = "InternLMConfig"
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def _get_unpad_data(attention_mask):
52
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
53
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
@@ -438,13 +452,11 @@ class InternLMFlashAttention2(InternLMAttention):
438
  softmax_scale (`float`, *optional*):
439
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
440
  """
441
- from flash_attn import flash_attn_func, flash_attn_varlen_func
442
- from flash_attn.bert_padding import pad_input
443
  # Contains at least one padding token in the sequence
444
  causal = self.is_causal and query_length != 1
445
  if attention_mask is not None:
446
  batch_size = query_states.shape[0]
447
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
448
  query_states, key_states, value_states, attention_mask, query_length
449
  )
450
 
@@ -472,8 +484,7 @@ class InternLMFlashAttention2(InternLMAttention):
472
 
473
  return attn_output
474
 
475
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
476
- from flash_attn.bert_padding import index_first_axis, unpad_input
477
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
478
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
479
 
@@ -762,6 +773,9 @@ class InternLMModel(InternLMPreTrainedModel):
762
 
763
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
764
 
 
 
 
765
  # retrieve input_ids and inputs_embeds
766
  if input_ids is not None and inputs_embeds is not None:
767
  raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 
48
 
49
  _CONFIG_FOR_DOC = "InternLMConfig"
50
 
51
+ flash_attn_func, flash_attn_varlen_func = None, None
52
+ pad_input, index_first_axis, unpad_input = None, None, None
53
+ def _import_flash_attn():
54
+ global flash_attn_func, flash_attn_varlen_func
55
+ global pad_input, index_first_axis, unpad_input
56
+ try:
57
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
58
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
59
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
60
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
61
+ except ImportError:
62
+ raise ImportError("flash_attn is not installed.")
63
+
64
+
65
  def _get_unpad_data(attention_mask):
66
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
67
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
 
452
  softmax_scale (`float`, *optional*):
453
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
454
  """
 
 
455
  # Contains at least one padding token in the sequence
456
  causal = self.is_causal and query_length != 1
457
  if attention_mask is not None:
458
  batch_size = query_states.shape[0]
459
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
460
  query_states, key_states, value_states, attention_mask, query_length
461
  )
462
 
 
484
 
485
  return attn_output
486
 
487
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
 
488
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
489
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
490
 
 
773
 
774
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
775
 
776
+ if self.config.attn_implementation == "flash_attention_2":
777
+ _import_flash_attn()
778
+
779
  # retrieve input_ids and inputs_embeds
780
  if input_ids is not None and inputs_embeds is not None:
781
  raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")