|
from typing import List, Optional, Tuple, Union |
|
import torch |
|
from packaging import version |
|
import importlib.metadata |
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
|
|
from transformers.utils.import_utils import _is_package_available |
|
|
|
def is_transformers_attn_greater_or_equal_4_39(): |
|
if not _is_package_available("transformers"): |
|
return False |
|
|
|
return version.parse(importlib.metadata.version("transformers")) >= version.parse( |
|
"4.39.0" |
|
) |
|
|
|
def _prepare_4d_attention_mask_for_sdpa( |
|
attention_mask: Optional[torch.Tensor], |
|
input_shape: Union[torch.Size, Tuple, List], |
|
inputs_embeds: torch.Tensor, |
|
past_key_values_length: int, |
|
sliding_window: Optional[int] = None, |
|
): |
|
attn_mask_converter = AttentionMaskConverter(is_causal=False, sliding_window=sliding_window) |
|
|
|
key_value_length = input_shape[-1] + past_key_values_length |
|
batch_size, query_length = input_shape |
|
|
|
|
|
|
|
|
|
is_tracing = torch.jit.is_tracing() |
|
|
|
if attention_mask is not None: |
|
if torch.all(attention_mask == 1): |
|
if is_tracing: |
|
pass |
|
elif query_length == 1: |
|
|
|
attention_mask = None |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
pass |
|
elif query_length > 1 and key_value_length != query_length: |
|
|
|
|
|
attention_mask = True |
|
elif is_tracing: |
|
raise ValueError( |
|
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' |
|
) |
|
|
|
if attention_mask is None: |
|
expanded_4d_mask = None |
|
elif attention_mask is True: |
|
expanded_4d_mask = attn_mask_converter.to_causal_4d( |
|
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
|
) |
|
else: |
|
expanded_4d_mask = attn_mask_converter.to_4d( |
|
attention_mask, |
|
input_shape[-1], |
|
dtype=inputs_embeds.dtype, |
|
key_value_length=key_value_length, |
|
) |
|
|
|
|
|
|
|
if query_length > 1: |
|
if is_transformers_attn_greater_or_equal_4_39(): |
|
expanded_4d_mask = AttentionMaskConverter._unmask_unattended( |
|
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min |
|
) |
|
else: |
|
expanded_4d_mask = AttentionMaskConverter._unmask_unattended( |
|
expanded_4d_mask, attention_mask, unmasked_value=0.0 |
|
) |
|
|
|
return expanded_4d_mask |
|
|
|
|
|
def _prepare_4d_attention_mask( |
|
attention_mask: Optional[torch.Tensor], |
|
input_shape: Union[torch.Size, Tuple, List], |
|
inputs_embeds: torch.Tensor, |
|
past_key_values_length: int, |
|
sliding_window: Optional[int] = None, |
|
): |
|
attn_mask_converter = AttentionMaskConverter(is_causal=False, sliding_window=sliding_window) |
|
|
|
key_value_length = input_shape[-1] + past_key_values_length |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = attn_mask_converter.to_4d( |
|
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype |
|
) |
|
else: |
|
attention_mask = attn_mask_converter.to_causal_4d( |
|
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
|
) |
|
|
|
return attention_mask |
|
|
|
|
|
def _prepare_4d_causal_attention_mask( |
|
attention_mask: Optional[torch.Tensor], |
|
input_shape: Union[torch.Size, Tuple, List], |
|
inputs_embeds: torch.Tensor, |
|
past_key_values_length: int, |
|
sliding_window: Optional[int] = None, |
|
): |
|
attn_mask_converter = AttentionMaskConverter(is_causal=False, sliding_window=sliding_window) |
|
|
|
key_value_length = input_shape[-1] + past_key_values_length |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = attn_mask_converter.to_4d( |
|
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype |
|
) |
|
else: |
|
attention_mask = attn_mask_converter.to_causal_4d( |
|
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
|
) |
|
|
|
return attention_mask |
|
|
|
|
|
def _prepare_4d_causal_attention_mask_for_sdpa( |
|
attention_mask: Optional[torch.Tensor], |
|
input_shape: Union[torch.Size, Tuple, List], |
|
inputs_embeds: torch.Tensor, |
|
past_key_values_length: int, |
|
sliding_window: Optional[int] = None, |
|
): |
|
""" |
|
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. |
|
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and |
|
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, |
|
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). |
|
""" |
|
attn_mask_converter = AttentionMaskConverter(is_causal=False, sliding_window=sliding_window) |
|
|
|
key_value_length = input_shape[-1] + past_key_values_length |
|
batch_size, query_length = input_shape |
|
|
|
|
|
|
|
|
|
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) |
|
|
|
if attention_mask is not None: |
|
|
|
if len(attention_mask.shape) == 4: |
|
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) |
|
if tuple(attention_mask.shape) != expected_shape: |
|
raise ValueError( |
|
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." |
|
) |
|
else: |
|
|
|
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) |
|
attention_mask = inverted_mask.masked_fill( |
|
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min |
|
) |
|
return attention_mask |
|
|
|
elif not is_tracing and torch.all(attention_mask == 1): |
|
if query_length == 1: |
|
|
|
attention_mask = None |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
pass |
|
elif query_length > 1 and key_value_length != query_length: |
|
|
|
|
|
attention_mask = True |
|
elif is_tracing: |
|
raise ValueError( |
|
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.' |
|
) |
|
|
|
if attention_mask is None: |
|
expanded_4d_mask = None |
|
elif attention_mask is True: |
|
expanded_4d_mask = attn_mask_converter.to_causal_4d( |
|
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
|
) |
|
else: |
|
expanded_4d_mask = attn_mask_converter.to_4d( |
|
attention_mask, |
|
input_shape[-1], |
|
dtype=inputs_embeds.dtype, |
|
key_value_length=key_value_length, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if query_length > 1 and not is_tracing: |
|
if is_transformers_attn_greater_or_equal_4_39(): |
|
expanded_4d_mask = AttentionMaskConverter._unmask_unattended( |
|
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min |
|
) |
|
else: |
|
expanded_4d_mask = AttentionMaskConverter._unmask_unattended( |
|
expanded_4d_mask, attention_mask, unmasked_value=0.0 |
|
) |
|
|
|
return expanded_4d_mask |
|
|