Spaces:
Runtime error
Runtime error
from flash_attn import flash_attn_func, flash_attn_varlen_func | |
import torch | |
# must replace orginal flash forward method with the following one first, to enbale the window feature. | |
def flash_attention2_forward_with_window_size( | |
self, | |
query_states, | |
key_states, | |
value_states, | |
attention_mask, | |
query_length, | |
dropout=0.0, | |
softmax_scale=None, | |
window_size=[-1, -1], | |
return_attn_probs=False, | |
): | |
""" | |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | |
first unpad the input, then computes the attention scores and pad the final attention scores. | |
Args: | |
query_states (`torch.Tensor`): | |
Input query states to be passed to Flash Attention API | |
key_states (`torch.Tensor`): | |
Input key states to be passed to Flash Attention API | |
value_states (`torch.Tensor`): | |
Input value states to be passed to Flash Attention API | |
attention_mask (`torch.Tensor`): | |
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the | |
position of padding tokens and 1 for the position of non-padding tokens. | |
dropout (`int`, *optional*): | |
Attention dropout | |
softmax_scale (`float`, *optional*): | |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) | |
window_size ([Int, Int]) | |
The left & right window size for Flash Attention. Default to [-1, -1] which means no window size is used. | |
return_attn_probs (`bool`, *optional*): | |
Whether to return the attention softmax logssumexp and probabilities. Default to False. | |
""" | |
if not self._flash_attn_uses_top_left_mask: | |
causal = self.is_causal | |
else: | |
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. | |
causal = self.is_causal and query_length != 1 | |
# Contains at least one padding token in the sequence | |
if attention_mask is not None: | |
batch_size = query_states.shape[0] | |
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | |
query_states, key_states, value_states, attention_mask, query_length | |
) | |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
attn_output_unpad, softmax_lse, S_dmask = flash_attn_varlen_func( | |
query_states, | |
key_states, | |
value_states, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max_seqlen_in_batch_q, | |
max_seqlen_k=max_seqlen_in_batch_k, | |
dropout_p=dropout, | |
softmax_scale=softmax_scale, | |
causal=causal, | |
window_size=window_size, | |
return_attn_probs=True, | |
) | |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | |
else: | |
attn_output, softmax_lse, S_dmask = flash_attn_func( | |
query_states, | |
key_states, | |
value_states, | |
dropout, | |
softmax_scale=softmax_scale, | |
causal=causal, | |
window_size=window_size, | |
return_attn_probs=True, | |
) | |
if return_attn_probs: | |
return attn_output, softmax_lse, S_dmask | |
else: | |
return attn_output | |
def self_extend_flash_forward( | |
model_self, | |
query_position, | |
group_size_2, | |
neighbor_query_states, | |
neighbor_key_states, | |
group_query_states, | |
group_key_states, | |
value_states, | |
attention_mask, | |
bsz, | |
q_len, | |
kv_seq_len, | |
attn_dropout, | |
): | |
if query_position.max() >= group_size_2: | |
neighbor_attn_output, neighbor_softmax_lse_right_padded, neighbor_prob = model_self._flash_attention_forward( | |
neighbor_query_states, | |
neighbor_key_states, | |
value_states, | |
attention_mask, | |
q_len, | |
dropout=attn_dropout, | |
window_size=[group_size_2 - 1, 0], | |
# right dim here does not matter and can be -1, or > 0 due to causal mask | |
return_attn_probs=True, | |
) | |
group_attention_len = ( | |
kv_seq_len - group_size_2 | |
) # here we should use kv_seq_len rather than max_kv_len since we have paddings in qkv and attention_mask | |
group_attention_mask = attention_mask[:, :group_attention_len] if not attention_mask is None else None | |
group_attn_output, group_softmax_lse_right_padded, group_prob = model_self._flash_attention_forward( | |
group_query_states[:, -group_attention_len:, :, :], | |
group_key_states[:, :group_attention_len, :, :], | |
value_states[:, :group_attention_len, :, :], | |
group_attention_mask, | |
group_query_states[:, -group_attention_len:, :, :].shape[1], | |
dropout=attn_dropout, | |
window_size=[-1, -1], | |
return_attn_probs=True, | |
) # note that kv and q's indexing are different! also query size could be different from kv length and very small during generation compared to prefilling | |
# normalize lse first | |
neighbor_seq_length = torch.Tensor([kv_seq_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask, axis=1, keepdim=True) # [batch_size, 1] | |
group_seq_length = torch.Tensor([group_attention_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask[:, :group_attention_len], axis=1, keepdim=True) # [batch_size, 1] | |
# convert align left to align right and convert exp(0) to 0 | |
neighbor_softmax_lse = torch.zeros_like(neighbor_softmax_lse_right_padded) | |
group_softmax_lse = torch.zeros_like(group_softmax_lse_right_padded) | |
for idx in range(bsz): | |
if neighbor_seq_length[idx] > 0: | |
neighbor_softmax_lse[idx, :, -neighbor_seq_length[idx] :] = neighbor_softmax_lse_right_padded[ | |
idx, :, : neighbor_seq_length[idx] | |
] | |
if group_seq_length[idx] > 0: | |
group_softmax_lse[idx, :, -group_seq_length[idx] :] = group_softmax_lse_right_padded[ | |
idx, :, : group_seq_length[idx] | |
] | |
# attn_output size is [batch_size, max_seq_len (not the true one), query_length, dim] | |
true_neighbor_seq_max_length = neighbor_softmax_lse.shape[ | |
-1 | |
] # it could be smaller than query_length due to the attention_mask | |
true_group_seq_max_length = group_softmax_lse.shape[ | |
-1 | |
] # it could be smaller than group_query_layer[:, -group_attention_len:, :, :].shape[1] due to the attention_mask[:, :group_attention_len] | |
neighbor_softmax_lse = neighbor_softmax_lse.transpose(1, 2).unsqueeze( | |
-1 | |
) # [batch_size, true_neighbor_seq_max_length, self.num_heads, 1] | |
group_softmax_lse = group_softmax_lse.transpose(1, 2).unsqueeze( | |
-1 | |
) # [batch_size, true_group_seq_max_length, self.num_heads, 1] | |
lse_gap = group_softmax_lse - neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :] | |
#if torch.isinf(neighbor_softmax_lse).any() or torch.isnan(neighbor_softmax_lse).any(): | |
# import pdb; pdb.set_trace() | |
neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :] = 1 / (1 + torch.exp(lse_gap)) | |
neighbor_softmax_lse[:, :-true_group_seq_max_length, :, :] = 1. | |
group_softmax_lse = 1 / (1 + torch.exp(-lse_gap)) | |
neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] = ( | |
neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] * neighbor_softmax_lse | |
) | |
group_attn_output[:, -true_group_seq_max_length:, ...] = ( | |
group_attn_output[:, -true_group_seq_max_length:, ...] * group_softmax_lse | |
) | |
attn_output = torch.empty_like(neighbor_attn_output).copy_( | |
neighbor_attn_output | |
) # might be slightly faster than clone | |
#attn_output[:, group_size_2:, ...] += group_attn_output | |
attn_output[:, group_size_2-kv_seq_len:, ...] += group_attn_output | |
attn_output = torch.nan_to_num(attn_output, nan=0) | |
else: | |
attn_output = model_self._flash_attention_forward( | |
neighbor_query_states, | |
neighbor_key_states, | |
value_states, | |
attention_mask, | |
q_len, | |
dropout=attn_dropout, | |
window_size=[-1, -1], | |
) | |
return attn_output | |