Update modeling_llama.py
Browse files- modeling_llama.py +6 -6
modeling_llama.py
CHANGED
@@ -60,14 +60,10 @@ def is_flash_attn_available():
|
|
60 |
return False
|
61 |
|
62 |
# Let's add an extra check to see if cuda is available
|
63 |
-
import torch
|
64 |
|
65 |
return _is_package_available("flash_attn") and torch.cuda.is_available()
|
66 |
|
67 |
-
|
68 |
-
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
|
69 |
-
# from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
70 |
-
from flash_attn.bert_padding import unpad_input, pad_input
|
71 |
|
72 |
|
73 |
|
@@ -232,7 +228,10 @@ class LlamaAttention(nn.Module):
|
|
232 |
|
233 |
attention_mask: [bsz, q_len]
|
234 |
"""
|
235 |
-
|
|
|
|
|
|
|
236 |
bsz, q_len, *_ = qkv.size()
|
237 |
|
238 |
if key_padding_mask is None:
|
@@ -342,6 +341,7 @@ class LlamaAttention(nn.Module):
|
|
342 |
return attn_output, attn_weights, past_key_value
|
343 |
# use flash attention
|
344 |
elif past_key_value is not None:
|
|
|
345 |
output = flash_attn_with_kvcache(
|
346 |
query_states.transpose(1, 2),
|
347 |
key_states.transpose(1, 2),
|
|
|
60 |
return False
|
61 |
|
62 |
# Let's add an extra check to see if cuda is available
|
|
|
63 |
|
64 |
return _is_package_available("flash_attn") and torch.cuda.is_available()
|
65 |
|
66 |
+
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
|
|
|
228 |
|
229 |
attention_mask: [bsz, q_len]
|
230 |
"""
|
231 |
+
if is_flash_attn_available():
|
232 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
|
233 |
+
# from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
234 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
235 |
bsz, q_len, *_ = qkv.size()
|
236 |
|
237 |
if key_padding_mask is None:
|
|
|
341 |
return attn_output, attn_weights, past_key_value
|
342 |
# use flash attention
|
343 |
elif past_key_value is not None:
|
344 |
+
from flash_attn.flash_attn_interface import flash_attn_with_kvcache
|
345 |
output = flash_attn_with_kvcache(
|
346 |
query_states.transpose(1, 2),
|
347 |
key_states.transpose(1, 2),
|