zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
from xtuner.parallel.sequence import sequence_parallel_wrapper
from .utils import upad_qkv
SUPPORT_FLASH2 = False
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input
SUPPORT_FLASH2 = True
except ImportError:
pass
@sequence_parallel_wrapper
def flash_attn_wo_mask(
query_states,
key_states,
value_states,
dropout_p=0.0,
softmax_scale=None,
causal=True,
window_size=(-1, -1), # -1 means infinite context window
):
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size)
return attn_output
@sequence_parallel_wrapper
def flash_attn_w_mask(
query_states, # bs, q_len, nhead, h_dim
key_states,
value_states,
attention_mask,
causal=True,
dropout_p=0.0,
window_size=(-1, -1), # -1 means infinite context window
):
batch_size, q_len = query_states.shape[:2]
query_states, key_states, value_states, indices_q, \
cu_seq_lens, max_seq_lens = upad_qkv(
query_states, key_states, value_states, attention_mask, q_len)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout_p,
causal=causal,
window_size=window_size)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
return attn_output
@sequence_parallel_wrapper
def varlen_flash_attn(
query_states,
key_states,
value_states,
cumulative_len,
max_seqlen,
dropout_p=0.,
causal=True,
window_size=(-1, -1), # -1 means infinite context window
):
q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten(
0, 1), value_states.flatten(0, 1)
attn_output = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cumulative_len,
cumulative_len,
max_seqlen,
max_seqlen,
dropout_p=dropout_p,
return_attn_probs=False,
causal=causal,
window_size=window_size)
attn_output = attn_output.unsqueeze(0)
return attn_output