|
|
|
import logging |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import torch.nn |
|
import torch.nn.functional as F |
|
from torch.backends.cuda import ( |
|
can_use_efficient_attention, |
|
can_use_flash_attention, |
|
flash_sdp_enabled, |
|
math_sdp_enabled, |
|
mem_efficient_sdp_enabled, |
|
SDPAParams, |
|
) |
|
|
|
from torch.nn.attention import SDPBackend |
|
from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
def _validate_sdpa_input( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p=0.0, |
|
is_causal=False, |
|
scale=None, |
|
): |
|
if ( |
|
not isinstance(query, NestedTensor) |
|
or not isinstance(key, NestedTensor) |
|
or not isinstance(value, NestedTensor) |
|
): |
|
raise ValueError( |
|
f"Expected query, key, and value to be nested tensors, " |
|
f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, " |
|
f"and value.is_nested: {value.is_nested} instead." |
|
) |
|
if query.dtype != key.dtype or query.dtype != value.dtype: |
|
raise ValueError( |
|
f"Expected query, key, and value to have the same dtype, " |
|
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " |
|
f"and value.dtype: {value.dtype} instead." |
|
) |
|
if query.device != key.device or query.device != value.device: |
|
raise ValueError( |
|
f"Expected query, key, and value to have the same device type, " |
|
f"but got query.device: {query.device}, key.device: {key.device}, " |
|
f"and value.device: {value.device} instead." |
|
) |
|
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2: |
|
raise ValueError( |
|
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: " |
|
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." |
|
) |
|
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx: |
|
raise ValueError( |
|
f"Expected query, key, and value to all be ragged on the same dimension, but got ragged " |
|
f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively." |
|
) |
|
if attn_mask is not None: |
|
|
|
raise ValueError("Masks are not yet supported!") |
|
if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype: |
|
raise ValueError( |
|
f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: " |
|
f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead." |
|
) |
|
|
|
|
|
def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool: |
|
|
|
|
|
q_batch_size = params.query.size(0) |
|
k_batch_size = params.key.size(0) |
|
v_batch_size = params.value.size(0) |
|
|
|
|
|
|
|
|
|
return q_batch_size == k_batch_size and q_batch_size == v_batch_size |
|
|
|
|
|
def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool: |
|
max_size = 256 |
|
query_size_last = params.query.size(-1) |
|
key_size_last = params.key.size(-1) |
|
value_size_last = params.value.size(-1) |
|
same_head_dim_size = ( |
|
query_size_last == key_size_last and query_size_last == value_size_last |
|
) |
|
if not ( |
|
same_head_dim_size |
|
and (query_size_last % 8 == 0) |
|
and (query_size_last <= max_size) |
|
): |
|
if debug: |
|
log.warning( |
|
"For NestedTensor inputs, Flash attention requires q,k,v to have the same " |
|
"last dimension and to be a multiple of 8 and less than or equal to 256. " |
|
"Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.", |
|
query_size_last, |
|
key_size_last, |
|
value_size_last, |
|
) |
|
return False |
|
return True |
|
|
|
|
|
def _check_for_seq_len_0_and_consistent_head_dim_nested_helper( |
|
param: torch.Tensor, param_name: str, debug=False |
|
) -> bool: |
|
assert isinstance(param, NestedTensor), "param should be a jagged NT" |
|
|
|
if param._ragged_idx == 1: |
|
|
|
if debug: |
|
log.warning( |
|
"Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.", |
|
param_name, |
|
) |
|
return False |
|
|
|
|
|
if param._min_seqlen == 0: |
|
if debug: |
|
log.warning( |
|
"Fused kernels do not support seq_len == 0, %s has a seq len of 0.", |
|
param_name, |
|
) |
|
return False |
|
|
|
return True |
|
|
|
|
|
def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool: |
|
max_size = max(q_size, k_size, v_size) |
|
if ( |
|
(q_size != max_size and q_size != 1) |
|
or (k_size != max_size and k_size != 1) |
|
or (v_size != max_size and v_size != 1) |
|
): |
|
if debug: |
|
log.warning( |
|
"Both fused kernels require query, key and value to have broadcastable %s, " |
|
"got Query %s %d, Key %s %d, Value %s %d instead.", |
|
param_name, |
|
param_name, |
|
q_size, |
|
param_name, |
|
k_size, |
|
param_name, |
|
v_size, |
|
) |
|
return False |
|
return True |
|
|
|
|
|
def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool: |
|
|
|
q_is_safe = ( |
|
_check_for_seq_len_0_and_consistent_head_dim_nested_helper( |
|
params.query, "query", debug |
|
) |
|
if params.query.is_nested |
|
else True |
|
) |
|
|
|
if not q_is_safe: |
|
return False |
|
|
|
k_is_safe = ( |
|
_check_for_seq_len_0_and_consistent_head_dim_nested_helper( |
|
params.key, "key", debug |
|
) |
|
if params.key.is_nested |
|
else True |
|
) |
|
|
|
if not k_is_safe: |
|
return False |
|
|
|
v_is_safe = ( |
|
_check_for_seq_len_0_and_consistent_head_dim_nested_helper( |
|
params.value, "value", debug |
|
) |
|
if params.value.is_nested |
|
else True |
|
) |
|
|
|
if not v_is_safe: |
|
return False |
|
|
|
|
|
|
|
q_num_heads = params.query.size(1) |
|
k_num_heads = params.key.size(1) |
|
v_num_heads = params.value.size(1) |
|
same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads |
|
|
|
if not same_num_heads: |
|
if ( |
|
params.query.requires_grad |
|
or params.key.requires_grad |
|
or params.value.requires_grad |
|
): |
|
if debug: |
|
log.warning( |
|
"Both fused kernels do not support training with broadcasted NT inputs." |
|
) |
|
return False |
|
return _try_broadcast_param_size( |
|
q_num_heads, k_num_heads, v_num_heads, "num heads", debug |
|
) |
|
return True |
|
|
|
|
|
def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool: |
|
constraints = ( |
|
_check_batch_size_nested, |
|
_check_head_dim_size_flash_nested, |
|
_check_for_seq_len_0_nested, |
|
) |
|
for constraint in constraints: |
|
if not constraint(params, debug): |
|
return False |
|
return True |
|
|
|
|
|
def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool: |
|
constraints = ( |
|
_check_batch_size_nested, |
|
_check_for_seq_len_0_nested, |
|
) |
|
for constraint in constraints: |
|
if not constraint(params, debug): |
|
return False |
|
return True |
|
|
|
|
|
def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool: |
|
if ( |
|
not params.query.transpose(1, 2).is_contiguous() |
|
or not params.key.transpose(1, 2).is_contiguous() |
|
or not params.value.transpose(1, 2).is_contiguous() |
|
): |
|
if debug: |
|
log.warning( |
|
"If inputs are nested tensors they must be contiguous after transposing." |
|
) |
|
return False |
|
if params.is_causal: |
|
if debug: |
|
log.warning( |
|
"Nested tensors for query / key are not supported when is_causal=True." |
|
) |
|
return False |
|
return True |
|
|
|
|
|
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal): |
|
if ( |
|
not flash_sdp_enabled() |
|
and not mem_efficient_sdp_enabled() |
|
and not math_sdp_enabled() |
|
): |
|
return SDPBackend.ERROR |
|
|
|
ordering = ( |
|
SDPBackend.FLASH_ATTENTION, |
|
SDPBackend.EFFICIENT_ATTENTION, |
|
SDPBackend.MATH, |
|
) |
|
|
|
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal) |
|
|
|
for backend in ordering: |
|
if backend == SDPBackend.FLASH_ATTENTION: |
|
if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params): |
|
return SDPBackend.FLASH_ATTENTION |
|
if backend == SDPBackend.EFFICIENT_ATTENTION: |
|
if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged( |
|
params |
|
): |
|
return SDPBackend.EFFICIENT_ATTENTION |
|
if backend == SDPBackend.MATH: |
|
if math_sdp_enabled() and _can_use_math_sdpa_jagged(params): |
|
return SDPBackend.MATH |
|
|
|
log.warning("Memory efficient kernel not used because:") |
|
can_use_efficient_attention(params, debug=True) |
|
_can_use_efficient_sdpa_jagged(params, debug=True) |
|
log.warning("Flash attention kernel not used because:") |
|
can_use_flash_attention(params, debug=True) |
|
_can_use_flash_sdpa_jagged(params, debug=True) |
|
log.warning("Math attention kernel not used because:") |
|
_can_use_math_sdpa_jagged(params, debug=True) |
|
return SDPBackend.ERROR |
|
|
|
|
|
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(qkv, NestedTensor): |
|
raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.") |
|
|
|
if qkv.lengths() is None: |
|
|
|
cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device) |
|
max_seqlen = qkv._max_seqlen |
|
n_elem = qkv.values().shape[0] |
|
else: |
|
|
|
cumulative_seqlen = ( |
|
qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device) |
|
) |
|
batch_size = qkv.size(0) |
|
max_seqlen = qkv._max_seqlen |
|
|
|
n_elem = int(cumulative_seqlen[-1].item()) |
|
return cumulative_seqlen, max_seqlen, n_elem |
|
|
|
|
|
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(tensor, NestedTensor) |
|
offsets = tensor.offsets() |
|
strides = tensor._strides |
|
|
|
n_tensors = offsets.size(0) - 1 |
|
if n_tensors <= 1: |
|
return True |
|
|
|
|
|
prev_stride = strides[1] |
|
for stride in strides[2:]: |
|
if prev_stride <= stride: |
|
|
|
|
|
return False |
|
prev_stride = stride |
|
|
|
|
|
return True |
|
|
|
|
|
def _view_as_dense( |
|
tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int |
|
) -> torch.Tensor: |
|
if tensor.is_nested: |
|
return buffer_from_jagged(tensor) |
|
return tensor.view(Nnz, num_heads, head_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sdpa_nested_preprocessing(query, key, value): |
|
|
|
|
|
|
|
q_batch_size = query.size(0) |
|
k_batch_size = key.size(0) |
|
v_batch_size = value.size(0) |
|
|
|
q_num_heads = query.size(1) |
|
k_num_heads = key.size(1) |
|
v_num_heads = value.size(1) |
|
|
|
if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not ( |
|
q_num_heads == k_num_heads and k_num_heads == v_num_heads |
|
): |
|
raise RuntimeError( |
|
"This path is currently not implemented for jagged layout NT." |
|
) |
|
|
|
|
|
num_heads = query.size(1) |
|
head_dim_qk = query.size(3) |
|
head_dim_v = value.size(3) |
|
q_t = query.transpose(1, 2) |
|
k_t = key.transpose(1, 2) |
|
v_t = value.transpose(1, 2) |
|
|
|
( |
|
cumulative_sequence_length_q, |
|
max_seqlen_batch_q, |
|
Nnz_q, |
|
) = _cumulative_and_max_seq_len_nnz(q_t) |
|
( |
|
cumulative_sequence_length_kv, |
|
max_seqlen_batch_kv, |
|
Nnz_kv, |
|
) = _cumulative_and_max_seq_len_nnz(k_t) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t): |
|
q_t = q_t.contiguous() |
|
if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t): |
|
k_t = k_t.contiguous() |
|
if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t): |
|
v_t = v_t.contiguous() |
|
|
|
query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk) |
|
key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk) |
|
value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v) |
|
|
|
output_nt_info = { |
|
"offsets": q_t.offsets(), |
|
"_max_seqlen": q_t._max_seqlen, |
|
"_min_seqlen": q_t._min_seqlen, |
|
} |
|
|
|
return ( |
|
query_buffer_reshaped, |
|
key_buffer_reshaped, |
|
value_buffer_reshaped, |
|
cumulative_sequence_length_q, |
|
cumulative_sequence_length_kv, |
|
max_seqlen_batch_q, |
|
max_seqlen_batch_kv, |
|
output_nt_info, |
|
) |
|
|
|
|
|
def _pad_last_dim( |
|
tensor: torch.Tensor, alignment_size: int, slice: bool |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
last_dim_size = tensor.size(-1) |
|
if last_dim_size % alignment_size == 0: |
|
return tensor |
|
pad_count = alignment_size - (last_dim_size % alignment_size) |
|
tensor = torch.nn.functional.pad(tensor, [0, pad_count]) |
|
if slice: |
|
return tensor[..., 0:last_dim_size] |
|
return tensor |
|
|
|
|
|
|
|
def _calculate_scale(query, scale): |
|
|
|
softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1)) |
|
return softmax_scale |
|
|
|
|
|
def _post_process_flash_output(out: torch.Tensor, og_size): |
|
if not out.is_nested and out.size(-1) != og_size: |
|
out = out[..., 0:og_size] |
|
return out |
|
|
|
|
|
def jagged_scaled_dot_product_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p=0.0, |
|
is_causal=False, |
|
scale=None, |
|
): |
|
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale) |
|
|
|
assert ( |
|
isinstance(query, NestedTensor) |
|
and isinstance(key, NestedTensor) |
|
and isinstance(value, NestedTensor) |
|
) |
|
|
|
|
|
|
|
|
|
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1: |
|
from torch.nested._internal.ops import extract_kwargs |
|
|
|
output = F.scaled_dot_product_attention( |
|
query._values, |
|
key._values, |
|
value._values, |
|
attn_mask=( |
|
attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask |
|
), |
|
dropout_p=dropout_p, |
|
is_causal=is_causal, |
|
scale=scale, |
|
) |
|
|
|
return NestedTensor(output, **extract_kwargs(query)) |
|
|
|
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad |
|
|
|
backend_choice = _select_sdp_backend( |
|
query, key, value, attn_mask, dropout_p, is_causal |
|
) |
|
|
|
if backend_choice == SDPBackend.FLASH_ATTENTION: |
|
og_size = query.size(-1) |
|
query_padded = _pad_last_dim(query, 8, False) |
|
key_padded = _pad_last_dim(key, 8, False) |
|
value_padded = _pad_last_dim(value, 8, False) |
|
|
|
og_scale = _calculate_scale(query, scale) |
|
( |
|
query_buffer_reshaped, |
|
key_buffer_reshaped, |
|
value_buffer_reshaped, |
|
cumulative_sequence_length_q, |
|
cumulative_sequence_length_kv, |
|
max_seqlen_batch_q, |
|
max_seqlen_batch_kv, |
|
output_nt_info, |
|
) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded) |
|
|
|
( |
|
attention, |
|
logsumexp, |
|
philox_seed, |
|
philox_offset, |
|
debug_attn_mask, |
|
) = torch.ops.aten._flash_attention_forward( |
|
query_buffer_reshaped, |
|
key_buffer_reshaped, |
|
value_buffer_reshaped, |
|
cumulative_sequence_length_q, |
|
cumulative_sequence_length_kv, |
|
max_seqlen_batch_q, |
|
max_seqlen_batch_kv, |
|
dropout_p, |
|
is_causal, |
|
False, |
|
scale=og_scale, |
|
) |
|
|
|
attention = ViewNestedFromBuffer.apply( |
|
attention.squeeze(0), output_nt_info["offsets"] |
|
).transpose(1, 2) |
|
return _post_process_flash_output(attention, og_size) |
|
elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: |
|
( |
|
query_reshaped, |
|
key_reshaped, |
|
value_reshaped, |
|
cumulative_sequence_length_q, |
|
cumulative_sequence_length_kv, |
|
max_seqlen_batch_q, |
|
max_seqlen_batch_kv, |
|
output_nt_info, |
|
) = _sdpa_nested_preprocessing(query, key, value) |
|
( |
|
attention, |
|
log_sumexp, |
|
seed, |
|
offset, |
|
max_seqlen_q, |
|
max_seqlen_batch_kv, |
|
) = torch.ops.aten._efficient_attention_forward( |
|
query_reshaped.unsqueeze(0), |
|
key_reshaped.unsqueeze(0), |
|
value_reshaped.unsqueeze(0), |
|
None, |
|
cumulative_sequence_length_q, |
|
cumulative_sequence_length_kv, |
|
max_seqlen_batch_q, |
|
max_seqlen_batch_kv, |
|
dropout_p, |
|
int(is_causal), |
|
compute_logsumexp, |
|
scale=scale, |
|
) |
|
|
|
|
|
return ViewNestedFromBuffer.apply( |
|
attention.squeeze(0), output_nt_info["offsets"] |
|
).transpose(1, 2) |
|
elif backend_choice == SDPBackend.MATH: |
|
|
|
|
|
|
|
offsets = query.offsets() |
|
d1 = query._size[1] |
|
d2 = value._size[-1] |
|
|
|
|
|
|
|
def get_strided_layout_nested_tensor(jagged_layout_nt): |
|
lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1] |
|
transpose = torch.transpose(jagged_layout_nt, 1, 2) |
|
tensor_list = buffer_from_jagged(transpose).split(list(lengths), dim=0) |
|
strided_nt = torch.nested.as_nested_tensor(list(tensor_list)) |
|
strided_nt = strided_nt.transpose(1, 2).contiguous() |
|
return strided_nt |
|
|
|
query = get_strided_layout_nested_tensor(query) |
|
key = get_strided_layout_nested_tensor(key) |
|
value = get_strided_layout_nested_tensor(value) |
|
|
|
attn_out = torch._scaled_dot_product_attention_math( |
|
query, key, value, attn_mask, dropout_p, is_causal, scale=scale |
|
)[0] |
|
|
|
|
|
attn_out = attn_out.transpose(1, 2).contiguous().values() |
|
attn_out = attn_out.view(-1, d1, d2) |
|
attn_out = ViewNestedFromBuffer.apply(attn_out, offsets) |
|
attn_out = attn_out.transpose(1, 2) |
|
|
|
return attn_out |
|
else: |
|
raise RuntimeError( |
|
"No viable backend for scaled_dot_product_attention was found." |
|
) |
|
|