|
|
|
"""Defines utilities for interacting with scaled_dot_product_attention""" |
|
import math |
|
from typing import List, Optional |
|
|
|
import torch |
|
|
|
__all__: List[str] = [] |
|
|
|
|
|
def _input_requires_grad(*tensors: torch.Tensor) -> bool: |
|
"""Returns True if any of the tensors requires grad""" |
|
return any(t.requires_grad for t in tensors) |
|
|
|
|
|
def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor: |
|
"""Handles the unpad of the last dimension""" |
|
if inpt_tensor.size(-1) != og_size: |
|
return inpt_tensor[..., :og_size] |
|
return inpt_tensor |
|
|
|
|
|
def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float: |
|
""" |
|
For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output |
|
by the original head size and not the padded. |
|
""" |
|
if scale is not None: |
|
return scale |
|
return 1.0 / math.sqrt(head_dim_size) |
|
|
|
|
|
def _validate_sdpa_input( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p=0.0, |
|
is_causal=False, |
|
scale=None, |
|
): |
|
if query.dtype != key.dtype or query.dtype != value.dtype: |
|
raise ValueError( |
|
f"Expected query, key, and value to have the same dtype, " |
|
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " |
|
f"and value.dtype: {value.dtype} instead." |
|
) |
|
if query.device != key.device or query.device != value.device: |
|
raise ValueError( |
|
f"Expected query, key, and value to have the same device type, " |
|
f"but got query.device: {query.device}, key.device: {key.device}, " |
|
f"and value.device: {value.device} instead." |
|
) |
|
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2: |
|
raise ValueError( |
|
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: " |
|
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." |
|
) |
|
|