Text Generation
Transformers
Safetensors
Czech
mpt
custom_code
text-generation-inference
Inference Endpoints
mfajcik commited on
Commit
cfa6500
1 Parent(s): a888a41

Update modeling_mpt.py

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +26 -3
modeling_mpt.py CHANGED
@@ -10,10 +10,28 @@ import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  from .attention import is_flash_v1_installed, is_flash_v2_installed
 
 
 
 
13
  if is_flash_v2_installed():
14
  try:
15
  from flash_attn import bert_padding
16
  from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  except Exception as e:
18
  raise e
19
  if is_flash_v1_installed():
@@ -140,9 +158,14 @@ def gen_flash_attn_padding_info(bsz: int, S: int, past_key_len: int, device: tor
140
  key_padding_mask = attention_mask_in_length
141
  query_padding_mask = attention_mask_in_length
142
  unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
143
- (_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
144
- (_, indices_k, cu_seqlens_k, max_seqlen_k) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
145
- (_, indices_v, _, _) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
 
 
 
 
 
146
  flash_attn_padding_info['indices_q'] = indices_q
147
  flash_attn_padding_info['indices_k'] = indices_k
148
  flash_attn_padding_info['indices_v'] = indices_v
 
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  from .attention import is_flash_v1_installed, is_flash_v2_installed
13
+
14
+
15
+ # Global variable to store the result
16
+ is_flash_attn_ge_2_7_0 = None
17
  if is_flash_v2_installed():
18
  try:
19
  from flash_attn import bert_padding
20
  from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
21
+
22
+ import flash_attn
23
+ from packaging import version
24
+
25
+
26
+ # Function to check the version and set the global variable
27
+ def check_flash_attn_version_gte270():
28
+ global is_flash_attn_ge_2_7_0
29
+ installed_version = flash_attn.__version__
30
+ is_flash_attn_ge_2_7_0 = version.parse(installed_version) >= version.parse("2.7.0")
31
+
32
+ # Call the function to set the global variable
33
+ check_flash_attn_version_gte270()
34
+
35
  except Exception as e:
36
  raise e
37
  if is_flash_v1_installed():
 
158
  key_padding_mask = attention_mask_in_length
159
  query_padding_mask = attention_mask_in_length
160
  unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
161
+ if is_flash_attn_ge_2_7_0:
162
+ (_, indices_q, cu_seqlens_q, max_seqlen_q, _) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
163
+ (_, indices_k, cu_seqlens_k, max_seqlen_k, _) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
164
+ (_, indices_v, _, _, _) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
165
+ else:
166
+ (_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
167
+ (_, indices_k, cu_seqlens_k, max_seqlen_k) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
168
+ (_, indices_v, _, _) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
169
  flash_attn_padding_info['indices_q'] = indices_q
170
  flash_attn_padding_info['indices_k'] = indices_k
171
  flash_attn_padding_info['indices_v'] = indices_v