Spaces:
Runtime error
Runtime error
from xtuner.parallel.sequence import sequence_parallel_wrapper | |
from .utils import upad_qkv | |
SUPPORT_FLASH2 = False | |
try: | |
from flash_attn import flash_attn_func, flash_attn_varlen_func | |
from flash_attn.bert_padding import pad_input | |
SUPPORT_FLASH2 = True | |
except ImportError: | |
pass | |
def flash_attn_wo_mask( | |
query_states, | |
key_states, | |
value_states, | |
dropout_p=0.0, | |
softmax_scale=None, | |
causal=True, | |
window_size=(-1, -1), # -1 means infinite context window | |
): | |
attn_output = flash_attn_func( | |
query_states, | |
key_states, | |
value_states, | |
dropout_p=dropout_p, | |
softmax_scale=softmax_scale, | |
causal=causal, | |
window_size=window_size) | |
return attn_output | |
def flash_attn_w_mask( | |
query_states, # bs, q_len, nhead, h_dim | |
key_states, | |
value_states, | |
attention_mask, | |
causal=True, | |
dropout_p=0.0, | |
window_size=(-1, -1), # -1 means infinite context window | |
): | |
batch_size, q_len = query_states.shape[:2] | |
query_states, key_states, value_states, indices_q, \ | |
cu_seq_lens, max_seq_lens = upad_qkv( | |
query_states, key_states, value_states, attention_mask, q_len) | |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
attn_output_unpad = flash_attn_varlen_func( | |
query_states, | |
key_states, | |
value_states, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max_seqlen_in_batch_q, | |
max_seqlen_k=max_seqlen_in_batch_k, | |
dropout_p=dropout_p, | |
causal=causal, | |
window_size=window_size) | |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) | |
return attn_output | |
def varlen_flash_attn( | |
query_states, | |
key_states, | |
value_states, | |
cumulative_len, | |
max_seqlen, | |
dropout_p=0., | |
causal=True, | |
window_size=(-1, -1), # -1 means infinite context window | |
): | |
q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( | |
0, 1), value_states.flatten(0, 1) | |
attn_output = flash_attn_varlen_func( | |
q_unpad, | |
k_unpad, | |
v_unpad, | |
cumulative_len, | |
cumulative_len, | |
max_seqlen, | |
max_seqlen, | |
dropout_p=dropout_p, | |
return_attn_probs=False, | |
causal=causal, | |
window_size=window_size) | |
attn_output = attn_output.unsqueeze(0) | |
return attn_output | |