File size: 8,742 Bytes
7a58a7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
import torch
# must replace orginal flash forward method with the following one first, to enbale the window feature.
def flash_attention2_forward_with_window_size(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
window_size=[-1, -1],
return_attn_probs=False,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
window_size ([Int, Int])
The left & right window size for Flash Attention. Default to [-1, -1] which means no window size is used.
return_attn_probs (`bool`, *optional*):
Whether to return the attention softmax logssumexp and probabilities. Default to False.
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad, softmax_lse, S_dmask = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
return_attn_probs=True,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output, softmax_lse, S_dmask = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
return_attn_probs=True,
)
if return_attn_probs:
return attn_output, softmax_lse, S_dmask
else:
return attn_output
def self_extend_flash_forward(
model_self,
query_position,
group_size_2,
neighbor_query_states,
neighbor_key_states,
group_query_states,
group_key_states,
value_states,
attention_mask,
bsz,
q_len,
kv_seq_len,
attn_dropout,
):
if query_position.max() >= group_size_2:
neighbor_attn_output, neighbor_softmax_lse_right_padded, neighbor_prob = model_self._flash_attention_forward(
neighbor_query_states,
neighbor_key_states,
value_states,
attention_mask,
q_len,
dropout=attn_dropout,
window_size=[group_size_2 - 1, 0],
# right dim here does not matter and can be -1, or > 0 due to causal mask
return_attn_probs=True,
)
group_attention_len = (
kv_seq_len - group_size_2
) # here we should use kv_seq_len rather than max_kv_len since we have paddings in qkv and attention_mask
group_attention_mask = attention_mask[:, :group_attention_len] if not attention_mask is None else None
group_attn_output, group_softmax_lse_right_padded, group_prob = model_self._flash_attention_forward(
group_query_states[:, -group_attention_len:, :, :],
group_key_states[:, :group_attention_len, :, :],
value_states[:, :group_attention_len, :, :],
group_attention_mask,
group_query_states[:, -group_attention_len:, :, :].shape[1],
dropout=attn_dropout,
window_size=[-1, -1],
return_attn_probs=True,
) # note that kv and q's indexing are different! also query size could be different from kv length and very small during generation compared to prefilling
# normalize lse first
neighbor_seq_length = torch.Tensor([kv_seq_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask, axis=1, keepdim=True) # [batch_size, 1]
group_seq_length = torch.Tensor([group_attention_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask[:, :group_attention_len], axis=1, keepdim=True) # [batch_size, 1]
# convert align left to align right and convert exp(0) to 0
neighbor_softmax_lse = torch.zeros_like(neighbor_softmax_lse_right_padded)
group_softmax_lse = torch.zeros_like(group_softmax_lse_right_padded)
for idx in range(bsz):
if neighbor_seq_length[idx] > 0:
neighbor_softmax_lse[idx, :, -neighbor_seq_length[idx] :] = neighbor_softmax_lse_right_padded[
idx, :, : neighbor_seq_length[idx]
]
if group_seq_length[idx] > 0:
group_softmax_lse[idx, :, -group_seq_length[idx] :] = group_softmax_lse_right_padded[
idx, :, : group_seq_length[idx]
]
# attn_output size is [batch_size, max_seq_len (not the true one), query_length, dim]
true_neighbor_seq_max_length = neighbor_softmax_lse.shape[
-1
] # it could be smaller than query_length due to the attention_mask
true_group_seq_max_length = group_softmax_lse.shape[
-1
] # it could be smaller than group_query_layer[:, -group_attention_len:, :, :].shape[1] due to the attention_mask[:, :group_attention_len]
neighbor_softmax_lse = neighbor_softmax_lse.transpose(1, 2).unsqueeze(
-1
) # [batch_size, true_neighbor_seq_max_length, self.num_heads, 1]
group_softmax_lse = group_softmax_lse.transpose(1, 2).unsqueeze(
-1
) # [batch_size, true_group_seq_max_length, self.num_heads, 1]
lse_gap = group_softmax_lse - neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :]
#if torch.isinf(neighbor_softmax_lse).any() or torch.isnan(neighbor_softmax_lse).any():
# import pdb; pdb.set_trace()
neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :] = 1 / (1 + torch.exp(lse_gap))
neighbor_softmax_lse[:, :-true_group_seq_max_length, :, :] = 1.
group_softmax_lse = 1 / (1 + torch.exp(-lse_gap))
neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] = (
neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] * neighbor_softmax_lse
)
group_attn_output[:, -true_group_seq_max_length:, ...] = (
group_attn_output[:, -true_group_seq_max_length:, ...] * group_softmax_lse
)
attn_output = torch.empty_like(neighbor_attn_output).copy_(
neighbor_attn_output
) # might be slightly faster than clone
#attn_output[:, group_size_2:, ...] += group_attn_output
attn_output[:, group_size_2-kv_seq_len:, ...] += group_attn_output
attn_output = torch.nan_to_num(attn_output, nan=0)
else:
attn_output = model_self._flash_attention_forward(
neighbor_query_states,
neighbor_key_states,
value_states,
attention_mask,
q_len,
dropout=attn_dropout,
window_size=[-1, -1],
)
return attn_output
|