Spaces:
Sleeping
Sleeping
from dataclasses import dataclass, field | |
from typing import Optional, Any | |
import math | |
from typing import List, Optional, Tuple, Union | |
import torch | |
from torch import nn | |
from transformers.activations import ACT2FN | |
from transformers.cache_utils import Cache, DynamicCache, StaticCache | |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter | |
from transformers.modeling_outputs import ( | |
BaseModelOutputWithPast, | |
CausalLMOutputWithPast, | |
) | |
from transformers.modeling_utils import PreTrainedModel | |
class PhariaConfig: | |
pad_token_id: Optional[int] = None | |
bos_token_id: int = 1 | |
eos_token_id: int = 2 | |
hidden_act: str = "gelu" | |
hidden_size: int = 512 | |
initializer_range: float = 0.02 | |
intermediate_size: int = 2048 | |
max_position_embeddings: int = 8192 | |
num_attention_heads: int = 4 | |
num_hidden_layers: int = 4 | |
num_key_value_heads: int = 2 | |
torch_dtype: str = "bfloat16" | |
transformers_version: str = "4.31.0.dev0" | |
use_cache: bool = True | |
vocab_size: int = -1 | |
mlp_bias: bool = True | |
attention_bias: bool = True | |
tie_word_embeddings: bool = False | |
attention_dropout: float = 0.0 | |
rope_theta: int = 1000000 | |
rope_scaling: Optional[Any] = None | |
class PhariaRotaryEmbedding(nn.Module): | |
def __init__( | |
self, | |
dim, | |
max_position_embeddings=2048, | |
base=10000, | |
device=None, | |
scaling_factor=1.0, | |
): | |
super().__init__() | |
self.scaling_factor = scaling_factor | |
self.dim = dim | |
self.max_position_embeddings = max_position_embeddings | |
self.base = base | |
inv_freq = 1.0 / ( | |
self.base | |
** ( | |
torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) | |
/ self.dim | |
) | |
) | |
self.register_buffer("inv_freq", inv_freq, persistent=False) | |
# For BC we register cos and sin cached | |
self.max_seq_len_cached = max_position_embeddings | |
def forward(self, x, position_ids): | |
# x: [bs, num_attention_heads, seq_len, head_size] | |
inv_freq_expanded = ( | |
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | |
) | |
position_ids_expanded = position_ids[:, None, :].float() | |
# Force float32 since bfloat16 loses precision on long contexts | |
# See https://github.com/huggingface/transformers/pull/29285 | |
device_type = x.device.type | |
device_type = ( | |
device_type | |
if isinstance(device_type, str) and device_type != "mps" | |
else "cpu" | |
) | |
with torch.autocast(device_type=device_type, enabled=False): | |
freqs = ( | |
inv_freq_expanded.float() @ position_ids_expanded.float() | |
).transpose(1, 2) | |
emb = freqs.repeat_interleave(2, dim=-1, output_size=self.dim) | |
cos = emb.cos() | |
sin = emb.sin() | |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | |
class PhariaLinearScalingRotaryEmbedding(PhariaRotaryEmbedding): | |
"""PhariaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" | |
def forward(self, x, position_ids): | |
# difference to the original RoPE: a scaling factor is aplied to the position ids | |
position_ids = position_ids.float() / self.scaling_factor | |
cos, sin = super().forward(x, position_ids) | |
return cos, sin | |
class PhariaDynamicNTKScalingRotaryEmbedding(PhariaRotaryEmbedding): | |
"""PhariaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" | |
def forward(self, x, position_ids): | |
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length | |
seq_len = torch.max(position_ids) + 1 | |
if seq_len > self.max_position_embeddings: | |
base = self.base * ( | |
(self.scaling_factor * seq_len / self.max_position_embeddings) | |
- (self.scaling_factor - 1) | |
) ** (self.dim / (self.dim - 2)) | |
inv_freq = 1.0 / ( | |
base | |
** ( | |
torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) | |
/ self.dim | |
) | |
) | |
self.register_buffer( | |
"inv_freq", inv_freq, persistent=False | |
) # TODO joao: this may break with compilation | |
cos, sin = super().forward(x, position_ids) | |
return cos, sin | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input (interleaved).""" | |
y = torch.empty_like(x) | |
y[..., ::2] = -x[..., 1::2] | |
y[..., 1::2] = x[..., ::2] | |
return y | |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | |
"""Applies Rotary Position Embedding to the query and key tensors. | |
Args: | |
q (`torch.Tensor`): The query tensor. | |
k (`torch.Tensor`): The key tensor. | |
cos (`torch.Tensor`): The cosine part of the rotary embedding. | |
sin (`torch.Tensor`): The sine part of the rotary embedding. | |
position_ids (`torch.Tensor`, *optional*): | |
Deprecated and unused. | |
unsqueeze_dim (`int`, *optional*, defaults to 1): | |
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | |
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | |
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | |
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | |
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | |
Returns: | |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | |
""" | |
cos = cos.unsqueeze(unsqueeze_dim) | |
sin = sin.unsqueeze(unsqueeze_dim) | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed | |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
""" | |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
""" | |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states[:, :, None, :, :].expand( | |
batch, num_key_value_heads, n_rep, slen, head_dim | |
) | |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
class LlamaAttention(nn.Module): | |
"""Multi-headed attention from 'Attention Is All You Need' paper""" | |
def __init__(self, config: PhariaConfig, layer_idx: Optional[int] = None): | |
super().__init__() | |
self.config = config | |
self.layer_idx = layer_idx | |
# if layer_idx is None: | |
# logger.warning_once( | |
# f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " | |
# "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " | |
# "when creating this class." | |
# ) | |
self.attention_dropout = config.attention_dropout | |
self.hidden_size = config.hidden_size | |
self.num_heads = config.num_attention_heads | |
self.head_dim = self.hidden_size // self.num_heads | |
self.num_key_value_heads = config.num_key_value_heads | |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
self.max_position_embeddings = config.max_position_embeddings | |
self.rope_theta = config.rope_theta | |
self.is_causal = True | |
if (self.head_dim * self.num_heads) != self.hidden_size: | |
raise ValueError( | |
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" | |
f" and `num_heads`: {self.num_heads})." | |
) | |
self.q_proj = nn.Linear( | |
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias | |
) | |
self.k_proj = nn.Linear( | |
self.hidden_size, | |
self.num_key_value_heads * self.head_dim, | |
bias=config.attention_bias, | |
) | |
self.v_proj = nn.Linear( | |
self.hidden_size, | |
self.num_key_value_heads * self.head_dim, | |
bias=config.attention_bias, | |
) | |
self.o_proj = nn.Linear( | |
self.hidden_size, self.hidden_size, bias=config.attention_bias | |
) | |
self._init_rope() | |
def _init_rope(self): | |
if self.config.rope_scaling is None: | |
self.rotary_emb = PhariaRotaryEmbedding( | |
self.head_dim, | |
max_position_embeddings=self.max_position_embeddings, | |
base=self.rope_theta, | |
) | |
else: | |
scaling_type = self.config.rope_scaling["type"] | |
scaling_factor = self.config.rope_scaling["factor"] | |
if scaling_type == "linear": | |
self.rotary_emb = PhariaLinearScalingRotaryEmbedding( | |
self.head_dim, | |
max_position_embeddings=self.max_position_embeddings, | |
scaling_factor=scaling_factor, | |
base=self.rope_theta, | |
) | |
elif scaling_type == "dynamic": | |
self.rotary_emb = PhariaDynamicNTKScalingRotaryEmbedding( | |
self.head_dim, | |
max_position_embeddings=self.max_position_embeddings, | |
scaling_factor=scaling_factor, | |
base=self.rope_theta, | |
) | |
else: | |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
cache_position: Optional[torch.LongTensor] = None, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
query_states = query_states.view( | |
bsz, q_len, self.num_heads, self.head_dim | |
).transpose(1, 2) | |
key_states = key_states.view( | |
bsz, q_len, self.num_key_value_heads, self.head_dim | |
).transpose(1, 2) | |
value_states = value_states.view( | |
bsz, q_len, self.num_key_value_heads, self.head_dim | |
).transpose(1, 2) | |
cos, sin = self.rotary_emb(value_states, position_ids) | |
query_states, key_states = apply_rotary_pos_emb( | |
query_states, key_states, cos, sin | |
) | |
if past_key_value is not None: | |
# cache_position needed for the static cache | |
cache_kwargs = {"cache_position": cache_position} | |
key_states, value_states = past_key_value.update( | |
key_states, value_states, self.layer_idx, cache_kwargs | |
) | |
key_states = repeat_kv(key_states, self.num_key_value_groups) | |
value_states = repeat_kv(value_states, self.num_key_value_groups) | |
attn_weights = torch.matmul( | |
query_states, key_states.transpose(2, 3) | |
) / math.sqrt(self.head_dim) | |
if attention_mask is not None: # no matter the length, we just slice it | |
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | |
attn_weights = attn_weights + causal_mask | |
# upcast attention to fp32 | |
attn_weights = nn.functional.softmax( | |
attn_weights, dim=-1, dtype=torch.float32 | |
).to(query_states.dtype) | |
attn_weights = nn.functional.dropout( | |
attn_weights, p=self.attention_dropout, training=self.training | |
) | |
attn_output = torch.matmul(attn_weights, value_states) | |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | |
raise ValueError( | |
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | |
f" {attn_output.size()}" | |
) | |
attn_output: Optional[torch.Tensor] = attn_output.transpose(1, 2).contiguous() | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, past_key_value | |
class PhariaMLP(nn.Module): | |
def __init__(self, config, layer_idx: int): | |
super().__init__() | |
self.layer_idx = layer_idx | |
self.config = config | |
self.hidden_size = config.hidden_size | |
self.intermediate_size = config.intermediate_size | |
self.up_proj = nn.Linear( | |
self.hidden_size, self.intermediate_size, bias=config.mlp_bias | |
) | |
self.down_proj = nn.Linear( | |
self.intermediate_size, self.hidden_size, bias=config.mlp_bias | |
) | |
self.act_fn = ACT2FN[config.hidden_act] | |
def forward(self, x): | |
o = self.down_proj(self.act_fn(self.up_proj(x))) | |
return o | |
class PhariaDecoderLayer(nn.Module): | |
def __init__(self, config: PhariaConfig, layer_idx: int): | |
super().__init__() | |
self.hidden_size = config.hidden_size | |
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) | |
self.mlp = PhariaMLP(config, layer_idx=layer_idx) | |
self.input_layernorm = nn.LayerNorm(config.hidden_size) | |
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) | |
self.layer_idx = layer_idx | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: Optional[bool] = False, | |
use_cache: Optional[bool] = False, | |
cache_position: Optional[torch.LongTensor] = None, | |
) -> Tuple[ | |
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] | |
]: | |
residual = hidden_states | |
hidden_states = self.input_layernorm(hidden_states) | |
hidden_states, self_attn_weights, present_key_value = self.self_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
cache_position=cache_position, | |
) | |
hidden_states = residual + hidden_states | |
residual = hidden_states | |
hidden_states = self.post_attention_layernorm(hidden_states) | |
if self.layer_idx == -1: | |
print("Layer 0 huggingface") | |
print(hidden_states) | |
print(hidden_states.shape) | |
hidden_states = self.mlp(hidden_states) | |
hidden_states = residual + hidden_states | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (self_attn_weights,) | |
if use_cache: | |
outputs += (present_key_value,) | |
return outputs | |
class PhariaPreTrainedModel(nn.Module): | |
config_class = PhariaConfig | |
base_model_prefix = "model" | |
supports_gradient_checkpointing = True | |
_no_split_modules = ["PhariaDecoderLayer"] | |
_skip_keys_device_placement = ["past_key_values"] | |
_supports_flash_attn_2 = False | |
_supports_sdpa = False | |
_supports_cache_class = True | |
_supports_static_cache = True | |
def _init_weights(self, module): | |
std = self.config.initializer_range | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
class PhariaModel(nn.Module): | |
config_class = PhariaConfig | |
def __init__(self, config: PhariaConfig): | |
#super().__init__(config) | |
super(PhariaModel, self).__init__() | |
self.config = config | |
self.padding_idx = config.pad_token_id | |
self.vocab_size = config.vocab_size | |
print(config.vocab_size, config.hidden_size, self.padding_idx) | |
self.embed_tokens = nn.Embedding( | |
config.vocab_size, config.hidden_size, self.padding_idx | |
) | |
self.layers = nn.ModuleList( | |
[ | |
PhariaDecoderLayer(config, layer_idx) | |
for layer_idx in range(config.num_hidden_layers) | |
] | |
) | |
self.norm = nn.LayerNorm(config.hidden_size) | |
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = False, | |
output_hidden_states: Optional[bool] = False, | |
return_dict: Optional[bool] = False, | |
cache_position: Optional[torch.LongTensor] = None, | |
) -> Union[Tuple, BaseModelOutputWithPast]: | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
if (input_ids is None) ^ (inputs_embeds is not None): | |
raise ValueError( | |
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" | |
) | |
# if self.gradient_checkpointing and self.training and use_cache: | |
# # logger.warning_once( | |
# # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." | |
# # ) | |
# use_cache = False | |
if inputs_embeds is None: | |
inputs_embeds = self.embed_tokens(input_ids) | |
return_legacy_cache = False | |
if use_cache and not isinstance( | |
past_key_values, Cache | |
): # kept for BC (non `Cache` `past_key_values` inputs) | |
return_legacy_cache = True | |
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
if cache_position is None: | |
past_seen_tokens = ( | |
past_key_values.get_seq_length() if past_key_values is not None else 0 | |
) | |
cache_position = torch.arange( | |
past_seen_tokens, | |
past_seen_tokens + inputs_embeds.shape[1], | |
device=inputs_embeds.device, | |
) | |
if position_ids is None: | |
position_ids = cache_position.unsqueeze(0) | |
causal_mask = self._update_causal_mask( | |
attention_mask, | |
inputs_embeds, | |
cache_position, | |
past_key_values, | |
output_attentions, | |
) | |
# embed positions | |
hidden_states = inputs_embeds | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
next_decoder_cache = None | |
for decoder_layer in self.layers: | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
# if self.gradient_checkpointing and self.training: | |
# layer_outputs = self._gradient_checkpointing_func( | |
# decoder_layer.__call__, | |
# hidden_states, | |
# causal_mask, | |
# position_ids, | |
# past_key_values, | |
# output_attentions, | |
# use_cache, | |
# cache_position, | |
# ) | |
# else: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=causal_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_values, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
cache_position=cache_position, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache = layer_outputs[2 if output_attentions else 1] | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
hidden_states = self.norm(hidden_states) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
next_cache = next_decoder_cache if use_cache else None | |
if return_legacy_cache: | |
next_cache = next_cache.to_legacy_cache() | |
hidden_states = self.head(hidden_states) | |
return hidden_states | |
if not return_dict: | |
return tuple( | |
v | |
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] | |
if v is not None | |
) | |
#return BaseModelOutputWithPast( | |
# last_hidden_state=hidden_states, | |
# past_key_values=next_cache, | |
# hidden_states=all_hidden_states, | |
# attentions=all_self_attns, | |
#) | |
def _update_causal_mask( | |
self, | |
attention_mask: torch.Tensor, | |
input_tensor: torch.Tensor, | |
cache_position: torch.Tensor, | |
past_key_values: Cache, | |
output_attentions: bool, | |
): | |
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static | |
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. | |
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using | |
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 | |
# Removed by Tristan. | |
#if self.config._attn_implementation == "flash_attention_2": | |
# if attention_mask is not None and 0.0 in attention_mask: | |
# return attention_mask | |
# return None | |
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in | |
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail | |
# to infer the attention mask. | |
past_seen_tokens = ( | |
past_key_values.get_seq_length() if past_key_values is not None else 0 | |
) | |
using_static_cache = isinstance(past_key_values, StaticCache) | |
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward | |
#if ( | |
# self.config._attn_implementation == "sdpa" | |
# and not using_static_cache | |
# and not output_attentions | |
#): | |
# if AttentionMaskConverter._ignore_causal_mask_sdpa( | |
# attention_mask, | |
# inputs_embeds=input_tensor, | |
# past_key_values_length=past_seen_tokens, | |
# is_training=self.training, | |
# ): | |
# return None | |
dtype, device = input_tensor.dtype, input_tensor.device | |
min_dtype = torch.finfo(dtype).min | |
sequence_length = input_tensor.shape[1] | |
if using_static_cache: | |
target_length = past_key_values.get_max_length() | |
else: | |
target_length = ( | |
attention_mask.shape[-1] | |
if isinstance(attention_mask, torch.Tensor) | |
else past_seen_tokens + sequence_length + 1 | |
) | |
if attention_mask is not None and attention_mask.dim() == 4: | |
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing | |
if attention_mask.max() != 0: | |
raise ValueError( | |
"Custom 4D attention mask should be passed in inverted form with max==0`" | |
) | |
causal_mask = attention_mask | |
else: | |
causal_mask = torch.full( | |
(sequence_length, target_length), | |
fill_value=min_dtype, | |
dtype=dtype, | |
device=device, | |
) | |
if sequence_length != 1: | |
causal_mask = torch.triu(causal_mask, diagonal=1) | |
causal_mask *= torch.arange( | |
target_length, device=device | |
) > cache_position.reshape(-1, 1) | |
causal_mask = causal_mask[None, None, :, :].expand( | |
input_tensor.shape[0], 1, -1, -1 | |
) | |
if attention_mask is not None: | |
causal_mask = ( | |
causal_mask.clone() | |
) # copy to contiguous memory for in-place edit | |
mask_length = attention_mask.shape[-1] | |
padding_mask = ( | |
causal_mask[:, :, :, :mask_length] | |
+ attention_mask[:, None, None, :] | |
) | |
padding_mask = padding_mask == 0 | |
causal_mask[:, :, :, :mask_length] = causal_mask[ | |
:, :, :, :mask_length | |
].masked_fill(padding_mask, min_dtype) | |
#if ( | |
# self.config._attn_implementation == "sdpa" | |
# and attention_mask is not None | |
# and attention_mask.device.type == "cuda" | |
# and not output_attentions | |
#): | |
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when | |
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. | |
# Details: https://github.com/pytorch/pytorch/issues/110213 | |
# causal_mask = AttentionMaskConverter._unmask_unattended( | |
# causal_mask, min_dtype | |
# ) | |
return causal_mask | |