Spaces:
Running
Running
import logging | |
from typing import Optional, Tuple | |
import torch | |
import torch.nn | |
import torch.nn.functional as F | |
from torch.backends.cuda import ( | |
can_use_efficient_attention, | |
can_use_flash_attention, | |
flash_sdp_enabled, | |
math_sdp_enabled, | |
mem_efficient_sdp_enabled, | |
SDPAParams, | |
) | |
from torch.nn.attention import SDPBackend | |
from .nested_tensor import NestedTensor | |
log = logging.getLogger(__name__) | |
def _validate_sdpa_input( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p=0.0, | |
is_causal=False, | |
scale=None, | |
): | |
if ( | |
not isinstance(query, NestedTensor) | |
or not isinstance(key, NestedTensor) | |
or not isinstance(value, NestedTensor) | |
): | |
raise ValueError( | |
f"Expected query, key, and value to be nested tensors, " | |
f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, " | |
f"and value.is_nested: {value.is_nested} instead." | |
) | |
if query.dtype != key.dtype or query.dtype != value.dtype: | |
raise ValueError( | |
f"Expected query, key, and value to have the same dtype, " | |
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " | |
f"and value.dtype: {value.dtype} instead." | |
) | |
if query.device != key.device or query.device != value.device: | |
raise ValueError( | |
f"Expected query, key, and value to have the same device type, " | |
f"but got query.device: {query.device}, key.device: {key.device}, " | |
f"and value.device: {value.device} instead." | |
) | |
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2: | |
raise ValueError( | |
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: " | |
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." | |
) | |
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx: | |
raise ValueError( | |
f"Expected query, key, and value to all be ragged on the same dimension, but got ragged " | |
f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively." | |
) | |
if attn_mask is not None: | |
# TODO: Figure out whether masks are actually supported for this layout or not | |
raise ValueError("Masks are not yet supported!") | |
if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype: | |
raise ValueError( | |
f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: " | |
f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead." | |
) | |
def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool: | |
# This is expected to be called after check_tensor_shapes ensuring that the | |
# size() calls won't error since the inputs are all 4 dimensional | |
q_batch_size = params.query.size(0) | |
k_batch_size = params.key.size(0) | |
v_batch_size = params.value.size(0) | |
# num_heads logic for nested input is checked in | |
# check_for_seq_len_0_nested_tensor as there is handling there to make sure | |
# num_heads is not ragged | |
return q_batch_size == k_batch_size and q_batch_size == v_batch_size | |
def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool: | |
max_size = 256 | |
query_size_last = params.query.size(-1) | |
key_size_last = params.key.size(-1) | |
value_size_last = params.value.size(-1) | |
same_head_dim_size = ( | |
query_size_last == key_size_last and query_size_last == value_size_last | |
) | |
if not ( | |
same_head_dim_size | |
and (query_size_last % 8 == 0) | |
and (query_size_last <= max_size) | |
): | |
if debug: | |
log.warning( | |
"For NestedTensor inputs, Flash attention requires q,k,v to have the same " | |
"last dimension and to be a multiple of 8 and less than or equal to 256. " | |
"Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.", | |
query_size_last, | |
key_size_last, | |
value_size_last, | |
) | |
return False | |
return True | |
def _check_for_seq_len_0_and_consistent_head_dim_nested_helper( | |
param: torch.Tensor, param_name: str, debug=False | |
) -> bool: | |
assert isinstance(param, NestedTensor), "param should be a jagged NT" | |
if param._ragged_idx == 1: | |
# num_head_dims is ragged | |
if debug: | |
log.warning( | |
"Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.", | |
param_name, | |
) | |
return False | |
# This is being called inside sdp with shape [batch, heads, {seq_len}, dim] | |
if param._min_seqlen == 0: | |
if debug: | |
log.warning( | |
"Fused kernels do not support seq_len == 0, %s has a seq len of 0.", | |
param_name, | |
) | |
return False | |
return True | |
def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool: | |
max_size = max(q_size, k_size, v_size) | |
if ( | |
(q_size != max_size and q_size != 1) | |
or (k_size != max_size and k_size != 1) | |
or (v_size != max_size and v_size != 1) | |
): | |
if debug: | |
log.warning( | |
"Both fused kernels require query, key and value to have broadcastable %s, " | |
"got Query %s %d, Key %s %d, Value %s %d instead.", | |
param_name, | |
param_name, | |
q_size, | |
param_name, | |
k_size, | |
param_name, | |
v_size, | |
) | |
return False | |
return True | |
def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool: | |
# When this function is called we are assured that the nt is dim==4 | |
q_is_safe = ( | |
_check_for_seq_len_0_and_consistent_head_dim_nested_helper( | |
params.query, "query", debug | |
) | |
if params.query.is_nested | |
else True | |
) | |
# short circuit if any is unsafe | |
if not q_is_safe: | |
return False | |
k_is_safe = ( | |
_check_for_seq_len_0_and_consistent_head_dim_nested_helper( | |
params.key, "key", debug | |
) | |
if params.key.is_nested | |
else True | |
) | |
# short circuit if any is unsafe | |
if not k_is_safe: | |
return False | |
v_is_safe = ( | |
_check_for_seq_len_0_and_consistent_head_dim_nested_helper( | |
params.value, "value", debug | |
) | |
if params.value.is_nested | |
else True | |
) | |
# short circuit if any is unsafe | |
if not v_is_safe: | |
return False | |
# We now know none of the inputs have ragged num_heads, so we can safely | |
# access .size(1) | |
q_num_heads = params.query.size(1) | |
k_num_heads = params.key.size(1) | |
v_num_heads = params.value.size(1) | |
same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads | |
if not same_num_heads: | |
if ( | |
params.query.requires_grad | |
or params.key.requires_grad | |
or params.value.requires_grad | |
): | |
if debug: | |
log.warning( | |
"Both fused kernels do not support training with broadcasted NT inputs." | |
) | |
return False | |
return _try_broadcast_param_size( | |
q_num_heads, k_num_heads, v_num_heads, "num heads", debug | |
) | |
return True | |
def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool: | |
constraints = ( | |
_check_batch_size_nested, | |
_check_head_dim_size_flash_nested, | |
_check_for_seq_len_0_nested, | |
) | |
for constraint in constraints: | |
if not constraint(params, debug): | |
return False | |
return True | |
def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool: | |
constraints = ( | |
_check_batch_size_nested, | |
_check_for_seq_len_0_nested, | |
) | |
for constraint in constraints: | |
if not constraint(params, debug): | |
return False | |
return True | |
def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool: | |
if ( | |
not params.query.transpose(1, 2).is_contiguous() | |
or not params.key.transpose(1, 2).is_contiguous() | |
or not params.value.transpose(1, 2).is_contiguous() | |
): | |
if debug: | |
log.warning( | |
"If inputs are nested tensors they must be contiguous after transposing." | |
) | |
return False | |
if params.is_causal: | |
if debug: | |
log.warning( | |
"Nested tensors for query / key are not supported when is_causal=True." | |
) | |
return False | |
return True | |
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal): | |
if ( | |
not flash_sdp_enabled() | |
and not mem_efficient_sdp_enabled() | |
and not math_sdp_enabled() | |
): | |
return SDPBackend.ERROR | |
ordering = ( | |
SDPBackend.FLASH_ATTENTION, | |
SDPBackend.EFFICIENT_ATTENTION, | |
SDPBackend.MATH, | |
) | |
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal) | |
for backend in ordering: | |
if backend == SDPBackend.FLASH_ATTENTION: | |
if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params): | |
return SDPBackend.FLASH_ATTENTION | |
if backend == SDPBackend.EFFICIENT_ATTENTION: | |
if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged( | |
params | |
): | |
return SDPBackend.EFFICIENT_ATTENTION | |
if backend == SDPBackend.MATH: | |
if math_sdp_enabled() and _can_use_math_sdpa_jagged(params): | |
return SDPBackend.MATH | |
log.warning("Memory efficient kernel not used because:") | |
can_use_efficient_attention(params, debug=True) | |
_can_use_efficient_sdpa_jagged(params, debug=True) | |
log.warning("Flash attention kernel not used because:") | |
can_use_flash_attention(params, debug=True) | |
_can_use_flash_sdpa_jagged(params, debug=True) | |
log.warning("Math attention kernel not used because:") | |
_can_use_math_sdpa_jagged(params, debug=True) | |
return SDPBackend.ERROR | |
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]: | |
# This function is used to calculate two pieces of metadata that are needed | |
# for use with flash-attention and efficient_attention kernels. They are the | |
# cumulative sequence_length over a batch of sequences and the maximum | |
# sequence length. | |
# It returns a tuple of cumulative sequence lengths and the maximum sequence | |
# length, and the last element in the cumulative_sequence_lengths | |
if not isinstance(qkv, NestedTensor): | |
raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.") | |
if qkv.lengths() is None: | |
# TODO: Explore performance impact of copying | |
cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device) | |
max_seqlen = qkv._max_seqlen | |
n_elem = qkv.values().shape[0] | |
else: | |
# TODO: Explore performance impact of copying | |
cumulative_seqlen = ( | |
qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device) | |
) | |
batch_size = qkv.size(0) | |
max_seqlen = qkv._max_seqlen | |
# TODO: Explore performance impact when compiling | |
n_elem = int(cumulative_seqlen[-1].item()) | |
return cumulative_seqlen, max_seqlen, n_elem | |
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor): | |
# This function checks if a nested tensor is valid for | |
# use with the flash-attention and efficient_attention kernels without | |
# needing to call contiguous on the nested tensor input. | |
# It checks that the storage offsets' adjacent_differences are a constant | |
# mutiple of the previous tensor in the nested tensor and that the strides | |
# are monitonically decreasing. This check is done after calling transpose on | |
# the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim] | |
# Returns a boolean indicating if contiguous needs to be called for input | |
assert isinstance(tensor, NestedTensor) | |
offsets = tensor.offsets() | |
strides = tensor._strides | |
n_tensors = offsets.size(0) - 1 | |
if n_tensors <= 1: | |
return True | |
# Check initially that the tensor strides are in strictly descending order | |
prev_stride = strides[1] | |
for stride in strides[2:]: | |
if prev_stride <= stride: | |
# This would mean that the last stride is greater than the seq_len | |
# stride | |
return False | |
prev_stride = stride | |
# Congrats you made it! | |
return True | |
def _view_as_dense( | |
tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int | |
) -> torch.Tensor: | |
if tensor.is_nested: | |
return tensor.values() | |
return tensor.view(Nnz, num_heads, head_dim) | |
# TODO: Next iteration should add test cases and check it works | |
# def _sdpa_nested_preprocessing_with_broadcast(query, key, value): | |
# # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) | |
# # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) | |
# # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) | |
# q_batch_size = query.size(0) | |
# k_batch_size = key.size(0) | |
# v_batch_size = value.size(0) | |
# output_batch_size = max(q_batch_size, k_batch_size, v_batch_size) | |
# q_num_heads = query.size(1) | |
# k_num_heads = key.size(1) | |
# v_num_heads = value.size(1) | |
# output_num_heads = max(q_num_heads, k_num_heads, v_num_heads) | |
# head_dim_qk = query.size(3) | |
# head_dim_v = value.size(3) | |
# q_t = query.transpose(1, 2) | |
# k_t = key.transpose(1, 2) | |
# v_t = value.transpose(1, 2) | |
# # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads != | |
# # output_batch_size/num_heads then they are 1 | |
# q_batch_size_needs_broadcast = q_batch_size != output_batch_size | |
# k_batch_size_needs_broadcast = k_batch_size != output_batch_size | |
# v_batch_size_needs_broadcast = v_batch_size != output_batch_size | |
# # If {*}_batch_size_needs_broadcast, then | |
# # (1) max_seqlen_batch_{*} is given by {*}_t.size(1) | |
# # this is because needs_broadcast indicates that the batch_size is 1 | |
# # and hence there is only 1 value for seq_len | |
# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1), | |
# # ..., outut_batch_size * {*}_t.size(1)] | |
# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1) | |
# if q_batch_size_needs_broadcast or not q_t.is_nested: | |
# max_seqlen_batch_q = q_t.size(1) | |
# cumulative_sequence_length_q = torch.arange( | |
# 0, | |
# (output_batch_size + 1) * max_seqlen_batch_q, | |
# max_seqlen_batch_q, | |
# device=q_t.device, | |
# dtype=torch.int32, | |
# ) | |
# Nnz_q = output_batch_size * max_seqlen_batch_q | |
# else: | |
# ( | |
# cumulative_sequence_length_q, | |
# max_seqlen_batch_q, | |
# Nnz_q, | |
# ) = _cumulative_and_max_seq_len_nnz(q_t) | |
# if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast: | |
# assert k_t.size(1) == v_t.size(1) | |
# max_seqlen_batch_kv = k_t.size(1) | |
# cumulative_sequence_length_kv = torch.arange( | |
# 0, | |
# (output_batch_size + 1) * max_seqlen_batch_kv, | |
# max_seqlen_batch_kv, | |
# device=k_t.device, | |
# dtype=torch.int32, | |
# ) | |
# Nnz_kv = output_batch_size * max_seqlen_batch_kv | |
# else: | |
# cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = ( | |
# _cumulative_and_max_seq_len_nnz(v_t) | |
# if k_batch_size_needs_broadcast | |
# else _cumulative_and_max_seq_len_nnz(k_t) | |
# ) | |
# q_num_heads_needs_broadcast = q_num_heads != output_num_heads | |
# k_num_heads_needs_broadcast = k_num_heads != output_num_heads | |
# v_num_heads_needs_broadcast = v_num_heads != output_num_heads | |
# if not q_t.is_nested: | |
# query_buffer_reshaped = q_t.expand( | |
# output_batch_size, q_t.size(1), output_num_heads, head_dim_qk | |
# ) | |
# query_buffer_reshaped = query_buffer_reshaped.reshape( | |
# Nnz_q, output_num_heads, head_dim_qk | |
# ) | |
# else: | |
# if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t): | |
# q_t = q_t.contiguous() | |
# # If we are broadcasting then Nnz_q will be the output_batch_size since | |
# # seq_len is 1 | |
# effective_batch_size_q = ( | |
# output_batch_size if q_batch_size_needs_broadcast else Nnz_q | |
# ) | |
# query_buffer_reshaped = _view_as_dense( | |
# q_t, effective_batch_size_q, output_num_heads, head_dim_qk | |
# ) | |
# # If the physical layout of the NestedTensor's storage | |
# # is not: batch, {seq_len}, num_heads, head_dim then we need | |
# # to call contiguous | |
# if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t): | |
# k_t = k_t.contiguous() | |
# if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t): | |
# v_t = v_t.contiguous() | |
# effective_batch_size_k = ( | |
# output_batch_size if k_batch_size_needs_broadcast else Nnz_kv | |
# ) | |
# key_buffer_reshaped = _view_as_dense( | |
# k_t, effective_batch_size_k, output_num_heads, head_dim_qk | |
# ) | |
# effective_batch_size_v = ( | |
# output_batch_size if v_batch_size_needs_broadcast else Nnz_kv | |
# ) | |
# value_buffer_reshaped = _view_as_dense( | |
# v_t, effective_batch_size_v, output_num_heads, head_dim_v | |
# ) | |
# if not q_batch_size_needs_broadcast: | |
# output_shape = q_t._size | |
# if head_dim_v != head_dim_qk: | |
# output_shape[-1] = head_dim_v | |
# if q_num_heads_needs_broadcast: | |
# output_shape[1] = output_num_heads | |
# else: | |
# output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu")) | |
# output_shape[0] = q_t.size(1) | |
# output_shape[1] = output_num_heads | |
# output_shape[2] = head_dim_v | |
# return ( | |
# query_buffer_reshaped, | |
# key_buffer_reshaped, | |
# value_buffer_reshaped, | |
# cumulative_sequence_length_q, | |
# cumulative_sequence_length_kv, | |
# max_seqlen_batch_q, | |
# max_seqlen_batch_kv, | |
# output_shape, | |
# ) | |
def _sdpa_nested_preprocessing(query, key, value): | |
# Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) | |
# Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) | |
# Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) | |
q_batch_size = query.size(0) | |
k_batch_size = key.size(0) | |
v_batch_size = value.size(0) | |
q_num_heads = query.size(1) | |
k_num_heads = key.size(1) | |
v_num_heads = value.size(1) | |
if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not ( | |
q_num_heads == k_num_heads and k_num_heads == v_num_heads | |
): | |
raise RuntimeError( | |
"This path is currently not implemented for jagged layout NT." | |
) | |
# return _sdpa_nested_preprocessing_with_broadcast(query, key, value) | |
num_heads = query.size(1) | |
head_dim_qk = query.size(3) | |
head_dim_v = value.size(3) | |
q_t = query.transpose(1, 2) | |
k_t = key.transpose(1, 2) | |
v_t = value.transpose(1, 2) | |
( | |
cumulative_sequence_length_q, | |
max_seqlen_batch_q, | |
Nnz_q, | |
) = _cumulative_and_max_seq_len_nnz(q_t) | |
( | |
cumulative_sequence_length_kv, | |
max_seqlen_batch_kv, | |
Nnz_kv, | |
) = _cumulative_and_max_seq_len_nnz(k_t) | |
# [TODO] K and V have to have the same Nnz, should probably torch_check | |
# assume in order to not iterate over v | |
# If the physical layout of the NestedTensor's storage | |
# is not: batch, {seq_len}, num_heads, head_dim then we need | |
# to call contiguous | |
if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t): | |
q_t = q_t.contiguous() | |
if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t): | |
k_t = k_t.contiguous() | |
if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t): | |
v_t = v_t.contiguous() | |
query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk) | |
key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk) | |
value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v) | |
output_nt_info = { | |
"offsets": q_t.offsets(), | |
"_max_seqlen": q_t._max_seqlen, | |
"_min_seqlen": q_t._min_seqlen, | |
} | |
return ( | |
query_buffer_reshaped, | |
key_buffer_reshaped, | |
value_buffer_reshaped, | |
cumulative_sequence_length_q, | |
cumulative_sequence_length_kv, | |
max_seqlen_batch_q, | |
max_seqlen_batch_kv, | |
output_nt_info, | |
) | |
def _pad_last_dim( | |
tensor: torch.Tensor, alignment_size: int, slice: bool | |
) -> torch.Tensor: | |
# FlashAttentionV2 requires that head dimension be a multiple of 8 | |
# This was previously done within the kernel, however | |
# This causes the kernel to maybe alias query, key, value | |
# So instead we pad the head_dimensions to be a multiple of 8 | |
# in the composite region | |
last_dim_size = tensor.size(-1) | |
if last_dim_size % alignment_size == 0: | |
return tensor | |
pad_count = alignment_size - (last_dim_size % alignment_size) | |
tensor = torch.nn.functional.pad(tensor, [0, pad_count]) | |
if slice: | |
return tensor[..., 0:last_dim_size] | |
return tensor | |
# TODO: coalesce with torch/nn/utils/attention.py | |
def _calculate_scale(query, scale): | |
# TODO: Investigate why math.sqrt() isn't properly handled by Dynamo? | |
softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1)) | |
return softmax_scale | |
def _post_process_flash_output(out: torch.Tensor, og_size): | |
if not out.is_nested and out.size(-1) != og_size: | |
out = out[..., 0:og_size] | |
return out | |
def jagged_scaled_dot_product_attention( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
dropout_p=0.0, | |
is_causal=False, | |
scale=None, | |
): | |
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale) | |
# for mypy, ugh | |
assert ( | |
isinstance(query, NestedTensor) | |
and isinstance(key, NestedTensor) | |
and isinstance(value, NestedTensor) | |
) | |
# Special path for non-ragged sequence length (e.g. for SAM where we have a ragged | |
# second batch dim instead). For this case, we can just send the dense buffers through | |
# vanilla SDPA. | |
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1: | |
from torch.nested._internal.ops import extract_kwargs | |
output = F.scaled_dot_product_attention( | |
query._values, | |
key._values, | |
value._values, | |
attn_mask=( | |
attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask | |
), | |
dropout_p=dropout_p, | |
is_causal=is_causal, | |
scale=scale, | |
) | |
return NestedTensor(output, **extract_kwargs(query)) | |
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad | |
backend_choice = _select_sdp_backend( | |
query, key, value, attn_mask, dropout_p, is_causal | |
) | |
if backend_choice == SDPBackend.FLASH_ATTENTION: | |
og_size = query.size(-1) | |
query_padded = _pad_last_dim(query, 8, False) | |
key_padded = _pad_last_dim(key, 8, False) | |
value_padded = _pad_last_dim(value, 8, False) | |
# We need to calculate the scale based off the OG head dim size | |
og_scale = _calculate_scale(query, scale) | |
( | |
query_buffer_reshaped, | |
key_buffer_reshaped, | |
value_buffer_reshaped, | |
cumulative_sequence_length_q, | |
cumulative_sequence_length_kv, | |
max_seqlen_batch_q, | |
max_seqlen_batch_kv, | |
output_nt_info, | |
) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded) | |
( | |
attention, | |
logsumexp, | |
philox_seed, | |
philox_offset, | |
debug_attn_mask, | |
) = torch.ops.aten._flash_attention_forward( | |
query_buffer_reshaped, | |
key_buffer_reshaped, | |
value_buffer_reshaped, | |
cumulative_sequence_length_q, | |
cumulative_sequence_length_kv, | |
max_seqlen_batch_q, | |
max_seqlen_batch_kv, | |
dropout_p, | |
is_causal, | |
False, | |
scale=og_scale, | |
) | |
# Reshape output to convert nnz to batch_size and seq_len | |
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets | |
attention = nested_view_from_values_offsets( | |
attention.squeeze(0), output_nt_info["offsets"] | |
).transpose(1, 2) | |
return _post_process_flash_output(attention, og_size) | |
elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: | |
( | |
query_reshaped, | |
key_reshaped, | |
value_reshaped, | |
cumulative_sequence_length_q, | |
cumulative_sequence_length_kv, | |
max_seqlen_batch_q, | |
max_seqlen_batch_kv, | |
output_nt_info, | |
) = _sdpa_nested_preprocessing(query, key, value) | |
( | |
attention, | |
log_sumexp, | |
seed, | |
offset, | |
max_seqlen_q, | |
max_seqlen_batch_kv, | |
) = torch.ops.aten._efficient_attention_forward( | |
query_reshaped.unsqueeze(0), | |
key_reshaped.unsqueeze(0), | |
value_reshaped.unsqueeze(0), | |
None, | |
cumulative_sequence_length_q, | |
cumulative_sequence_length_kv, | |
max_seqlen_batch_q, | |
max_seqlen_batch_kv, | |
dropout_p, | |
int(is_causal), | |
compute_logsumexp, | |
scale=scale, | |
) | |
# Reshape output to convert nnz to batch_size and seq_len | |
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets | |
return nested_view_from_values_offsets( | |
attention.squeeze(0), output_nt_info["offsets"] | |
).transpose(1, 2) | |
elif backend_choice == SDPBackend.MATH: | |
# save the offsets and shape of the inputs, so we can reshape the final output | |
# query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1] | |
# attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2] | |
offsets = query.offsets() | |
d1 = query._size[1] | |
d2 = value._size[-1] | |
# convert jagged layout Nested Tensor to strided layout Nested Tensor | |
# which support the math implementation of SDPA | |
def get_strided_layout_nested_tensor(jagged_layout_nt): | |
lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1] | |
transpose = torch.transpose(jagged_layout_nt, 1, 2) | |
tensor_list = transpose.values().split(list(lengths), dim=0) | |
strided_nt = torch.nested.as_nested_tensor(list(tensor_list)) | |
strided_nt = strided_nt.transpose(1, 2).contiguous() | |
return strided_nt | |
query = get_strided_layout_nested_tensor(query) | |
key = get_strided_layout_nested_tensor(key) | |
value = get_strided_layout_nested_tensor(value) | |
attn_out = torch._scaled_dot_product_attention_math( | |
query, key, value, attn_mask, dropout_p, is_causal, scale=scale | |
)[0] | |
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets | |
# convert strided layout Nested Tensor back to jagged layout Nested Tensor | |
attn_out = attn_out.transpose(1, 2).contiguous().values() | |
attn_out = attn_out.view(-1, d1, d2) | |
attn_out = nested_view_from_values_offsets(attn_out, offsets) | |
attn_out = attn_out.transpose(1, 2) | |
return attn_out | |
else: | |
raise RuntimeError( | |
"No viable backend for scaled_dot_product_attention was found." | |
) | |