Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
try: | |
from flash_attn.bert_padding import index_first_axis, unpad_input | |
except ImportError: | |
pass | |
def _get_unpad_data(attention_mask): | |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
max_seqlen_in_batch = seqlens_in_batch.max().item() | |
cu_seqlens = F.pad( | |
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) | |
return ( | |
indices, | |
cu_seqlens, | |
max_seqlen_in_batch, | |
) | |
def upad_qkv(query_layer, key_layer, value_layer, attention_mask, | |
query_length): | |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( | |
attention_mask) | |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | |
key_layer = index_first_axis( | |
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, | |
head_dim), indices_k) | |
value_layer = index_first_axis( | |
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, | |
head_dim), indices_k) | |
if query_length == kv_seq_len: | |
# Different from the origin version as sequence parallel change | |
# the number of attention heads. | |
query_layer = index_first_axis( | |
query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), | |
indices_k) | |
cu_seqlens_q = cu_seqlens_k | |
max_seqlen_in_batch_q = max_seqlen_in_batch_k | |
indices_q = indices_k | |
elif query_length == 1: | |
max_seqlen_in_batch_q = 1 | |
cu_seqlens_q = torch.arange( | |
batch_size + 1, dtype=torch.int32, device=query_layer.device | |
) # There is a memcpy here, that is very bad. | |
indices_q = cu_seqlens_q[:-1] | |
query_layer = query_layer.squeeze(1) | |
else: | |
# The -q_len: slice assumes left padding. | |
attention_mask = attention_mask[:, -query_length:] | |
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = \ | |
unpad_input(query_layer, attention_mask) | |
return ( | |
query_layer, | |
key_layer, | |
value_layer, | |
indices_q, | |
(cu_seqlens_q, cu_seqlens_k), | |
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | |
) | |