Upload DogeForCausalLM
Browse files- config.json +1 -0
- configuration_doge.py +8 -0
- model.safetensors +2 -2
- modeling_doge.py +126 -29
config.json
CHANGED
@@ -9,6 +9,7 @@
|
|
9 |
"AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
|
10 |
},
|
11 |
"bos_token_id": 0,
|
|
|
12 |
"eos_token_id": 1,
|
13 |
"expert_retrieval_size": 256,
|
14 |
"hidden_act": "silu",
|
|
|
9 |
"AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
|
10 |
},
|
11 |
"bos_token_id": 0,
|
12 |
+
"dynamic_mask_ratio": 0.0,
|
13 |
"eos_token_id": 1,
|
14 |
"expert_retrieval_size": 256,
|
15 |
"hidden_act": "silu",
|
configuration_doge.py
CHANGED
@@ -111,6 +111,8 @@ class DogeConfig(PretrainedConfig):
|
|
111 |
If it is not specified, will default to `num_attention_heads`.
|
112 |
attention_dropout (`float`, *optional*, defaults to 0.0):
|
113 |
The dropout ratio for the attention probabilities.
|
|
|
|
|
114 |
is_moe (`bool`, *optional*, defaults to `False`):
|
115 |
Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
|
116 |
num_cdmmoe_experts (`int`, *optional*, defaults to 2048):
|
@@ -154,6 +156,7 @@ class DogeConfig(PretrainedConfig):
|
|
154 |
num_attention_heads=8,
|
155 |
num_key_value_heads=None,
|
156 |
attention_dropout=0.0,
|
|
|
157 |
is_moe=False,
|
158 |
num_cdmmoe_experts=2048,
|
159 |
num_cdmmoe_heads=4,
|
@@ -183,6 +186,7 @@ class DogeConfig(PretrainedConfig):
|
|
183 |
self.num_attention_heads = num_attention_heads
|
184 |
self.num_key_value_heads = num_key_value_heads
|
185 |
self.attention_dropout = attention_dropout
|
|
|
186 |
self.is_moe = is_moe
|
187 |
self.num_cdmmoe_experts = num_cdmmoe_experts
|
188 |
self.num_cdmmoe_heads = num_cdmmoe_heads
|
@@ -195,6 +199,10 @@ class DogeConfig(PretrainedConfig):
|
|
195 |
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
196 |
rope_config_validation(self)
|
197 |
|
|
|
|
|
|
|
|
|
198 |
super().__init__(
|
199 |
bos_token_id=bos_token_id,
|
200 |
eos_token_id=eos_token_id,
|
|
|
111 |
If it is not specified, will default to `num_attention_heads`.
|
112 |
attention_dropout (`float`, *optional*, defaults to 0.0):
|
113 |
The dropout ratio for the attention probabilities.
|
114 |
+
dynamic_mask_ratio (`float`, *optional*, defaults to 0.0, range [0, 1]):
|
115 |
+
The ratio to control the proportion of the dynamic mask filled with the minimum value.
|
116 |
is_moe (`bool`, *optional*, defaults to `False`):
|
117 |
Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
|
118 |
num_cdmmoe_experts (`int`, *optional*, defaults to 2048):
|
|
|
156 |
num_attention_heads=8,
|
157 |
num_key_value_heads=None,
|
158 |
attention_dropout=0.0,
|
159 |
+
dynamic_mask_ratio=0.0,
|
160 |
is_moe=False,
|
161 |
num_cdmmoe_experts=2048,
|
162 |
num_cdmmoe_heads=4,
|
|
|
186 |
self.num_attention_heads = num_attention_heads
|
187 |
self.num_key_value_heads = num_key_value_heads
|
188 |
self.attention_dropout = attention_dropout
|
189 |
+
self.dynamic_mask_ratio = dynamic_mask_ratio
|
190 |
self.is_moe = is_moe
|
191 |
self.num_cdmmoe_experts = num_cdmmoe_experts
|
192 |
self.num_cdmmoe_heads = num_cdmmoe_heads
|
|
|
199 |
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
200 |
rope_config_validation(self)
|
201 |
|
202 |
+
# for backward compatibility
|
203 |
+
if num_key_value_heads is None:
|
204 |
+
self.num_key_value_heads = num_attention_heads
|
205 |
+
|
206 |
super().__init__(
|
207 |
bos_token_id=bos_token_id,
|
208 |
eos_token_id=eos_token_id,
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3632a5c94bc7d3cf66602318b168603ec19f1025e0aef01c286d65e30ed55e8b
|
3 |
+
size 52482152
|
modeling_doge.py
CHANGED
@@ -39,6 +39,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|
39 |
from transformers.utils import (
|
40 |
add_start_docstrings,
|
41 |
add_start_docstrings_to_model_forward,
|
|
|
42 |
logging,
|
43 |
replace_return_docstrings,
|
44 |
)
|
@@ -49,6 +50,9 @@ try:
|
|
49 |
except ImportError:
|
50 |
einx_add = None
|
51 |
|
|
|
|
|
|
|
52 |
|
53 |
logger = logging.get_logger(__name__)
|
54 |
|
@@ -216,14 +220,15 @@ class DogeDynamicMaskAttention(nn.Module):
|
|
216 |
self.num_key_value_heads = config.num_key_value_heads
|
217 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
218 |
self.attention_dropout = config.attention_dropout
|
|
|
219 |
|
220 |
# Q K V O projections
|
221 |
self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=config.hidden_bias)
|
222 |
self.k_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
|
|
|
223 |
# dynamic mask for the QK^T attention score matrix
|
224 |
self.A = nn.Parameter(torch.ones(self.num_heads))
|
225 |
-
self.dt_proj = nn.Linear(self.
|
226 |
-
self.v_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
|
227 |
self.o_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
|
228 |
|
229 |
def forward(
|
@@ -254,6 +259,10 @@ class DogeDynamicMaskAttention(nn.Module):
|
|
254 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
255 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
256 |
|
|
|
|
|
|
|
|
|
257 |
# repeat key and value states
|
258 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
259 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
@@ -262,12 +271,13 @@ class DogeDynamicMaskAttention(nn.Module):
|
|
262 |
attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
263 |
|
264 |
# add mask to attention scores
|
265 |
-
|
266 |
-
|
267 |
-
dynamic_mask
|
268 |
-
|
269 |
-
|
270 |
-
|
|
|
271 |
|
272 |
# upcast attention scores to fp32
|
273 |
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
@@ -282,8 +292,35 @@ class DogeDynamicMaskAttention(nn.Module):
|
|
282 |
|
283 |
return attn_output, past_key_value
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
-
|
|
|
287 |
|
288 |
def forward(
|
289 |
self,
|
@@ -312,34 +349,31 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
|
|
312 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
313 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
314 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
|
|
|
315 |
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
|
323 |
-
dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
|
324 |
-
dynamic_mask = dynamic_mask < 1.0
|
325 |
-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
|
326 |
|
327 |
query_states = query_states.contiguous()
|
328 |
key_states = key_states.contiguous()
|
329 |
value_states = value_states.contiguous()
|
330 |
|
331 |
-
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
332 |
-
is_causal = True if causal_mask is None and q_len > 1 else False
|
333 |
-
|
334 |
# NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
|
335 |
torch.backends.cuda.enable_cudnn_sdp(False)
|
336 |
attn_output = F.scaled_dot_product_attention(
|
337 |
query_states,
|
338 |
key_states,
|
339 |
value_states,
|
340 |
-
attn_mask=
|
341 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
342 |
-
|
343 |
)
|
344 |
|
345 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
@@ -349,9 +383,70 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
|
|
349 |
return attn_output, past_key_value
|
350 |
|
351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
DOGE_ATTENTION_CLASSES = {
|
|
|
353 |
"eager": DogeDynamicMaskAttention,
|
354 |
-
"sdpa":
|
355 |
}
|
356 |
|
357 |
|
@@ -519,6 +614,7 @@ class DogePreTrainedModel(PreTrainedModel):
|
|
519 |
supports_gradient_checkpointing = True
|
520 |
_no_split_modules = ["DogeDecoderLayer"]
|
521 |
_skip_keys_device_placement = ["past_key_values"]
|
|
|
522 |
_supports_sdpa = True
|
523 |
_supports_cache_class = True
|
524 |
_supports_quantized_cache = True
|
@@ -693,7 +789,7 @@ class DogeModel(DogePreTrainedModel):
|
|
693 |
all_self_attns = () if output_attentions else None
|
694 |
next_decoder_cache = None
|
695 |
|
696 |
-
for decoder_layer in self.layers:
|
697 |
if output_hidden_states:
|
698 |
all_hidden_states += (hidden_states,)
|
699 |
|
@@ -877,7 +973,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
|
|
877 |
input_ids: torch.LongTensor = None,
|
878 |
attention_mask: Optional[torch.Tensor] = None,
|
879 |
position_ids: Optional[torch.LongTensor] = None,
|
880 |
-
past_key_values: Optional[torch.
|
881 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
882 |
labels: Optional[torch.LongTensor] = None,
|
883 |
use_cache: Optional[bool] = None,
|
@@ -886,7 +982,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
|
|
886 |
return_dict: Optional[bool] = None,
|
887 |
cache_position: Optional[torch.LongTensor] = None,
|
888 |
num_logits_to_keep: int = 0,
|
889 |
-
**
|
890 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
891 |
r"""
|
892 |
Args:
|
@@ -920,6 +1016,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
|
|
920 |
output_hidden_states=output_hidden_states,
|
921 |
return_dict=return_dict,
|
922 |
cache_position=cache_position,
|
|
|
923 |
)
|
924 |
|
925 |
hidden_states = outputs[0]
|
@@ -929,7 +1026,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
|
|
929 |
|
930 |
loss = None
|
931 |
if labels is not None:
|
932 |
-
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **
|
933 |
|
934 |
if not return_dict:
|
935 |
output = (logits,) + outputs[1:]
|
|
|
39 |
from transformers.utils import (
|
40 |
add_start_docstrings,
|
41 |
add_start_docstrings_to_model_forward,
|
42 |
+
is_torch_greater_or_equal,
|
43 |
logging,
|
44 |
replace_return_docstrings,
|
45 |
)
|
|
|
50 |
except ImportError:
|
51 |
einx_add = None
|
52 |
|
53 |
+
if is_torch_greater_or_equal("2.5"):
|
54 |
+
from torch.nn.attention.flex_attention import flex_attention
|
55 |
+
|
56 |
|
57 |
logger = logging.get_logger(__name__)
|
58 |
|
|
|
220 |
self.num_key_value_heads = config.num_key_value_heads
|
221 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
222 |
self.attention_dropout = config.attention_dropout
|
223 |
+
self.dynamic_mask_ratio = config.dynamic_mask_ratio
|
224 |
|
225 |
# Q K V O projections
|
226 |
self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=config.hidden_bias)
|
227 |
self.k_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
|
228 |
+
self.v_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
|
229 |
# dynamic mask for the QK^T attention score matrix
|
230 |
self.A = nn.Parameter(torch.ones(self.num_heads))
|
231 |
+
self.dt_proj = nn.Linear(self.num_key_value_heads * self.head_dim, self.num_heads, bias=config.hidden_bias)
|
|
|
232 |
self.o_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
|
233 |
|
234 |
def forward(
|
|
|
259 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
260 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
261 |
|
262 |
+
# calculate dynamic mask from value_states
|
263 |
+
dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
|
264 |
+
dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
|
265 |
+
|
266 |
# repeat key and value states
|
267 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
268 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
271 |
attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
272 |
|
273 |
# add mask to attention scores
|
274 |
+
attn_mask = self.prepare_dynamic_mask(
|
275 |
+
hidden_states=hidden_states,
|
276 |
+
dynamic_mask=dynamic_mask,
|
277 |
+
dynamic_mask_ratio=0.1,
|
278 |
+
attention_mask=attention_mask,
|
279 |
+
)
|
280 |
+
attn_weights = attn_weights + attn_mask
|
281 |
|
282 |
# upcast attention scores to fp32
|
283 |
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
292 |
|
293 |
return attn_output, past_key_value
|
294 |
|
295 |
+
def prepare_dynamic_mask(
|
296 |
+
self,
|
297 |
+
hidden_states: torch.Tensor,
|
298 |
+
dynamic_mask: torch.Tensor,
|
299 |
+
dynamic_mask_ratio: float = 0.0,
|
300 |
+
attention_mask: Optional[torch.Tensor] = None,
|
301 |
+
):
|
302 |
+
"""
|
303 |
+
Combine `dynamic_mask` with `attention_mask` to generate the final `attn_mask`.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
|
307 |
+
dynamic_mask (`torch.Tensor`): dynamic mask of shape `(batch_size, num_heads, key_sequence_length)`.
|
308 |
+
dynamic_mask_ratio (`float`, *optional*): Ratio from 0.0 to 1.0 used to control the proportion of the dynamic mask filled with the minimum value.
|
309 |
+
attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
|
310 |
+
"""
|
311 |
+
min_type = torch.finfo(hidden_states.dtype).min
|
312 |
+
attn_mask = dynamic_mask[:, :, None, :]
|
313 |
+
if 0.0 < dynamic_mask_ratio < 1.0:
|
314 |
+
num_dynamic_mask = int(attn_mask.shape[-1] * dynamic_mask_ratio)
|
315 |
+
if num_dynamic_mask > 0:
|
316 |
+
rate_value = torch.kthvalue(attn_mask, num_dynamic_mask, dim=-1, keepdim=True).values
|
317 |
+
attn_mask = attn_mask.masked_fill(attn_mask < rate_value, min_type)
|
318 |
+
if attention_mask is not None:
|
319 |
+
attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : hidden_states.shape[-2]] == min_type, min_type)
|
320 |
+
return attn_mask
|
321 |
|
322 |
+
|
323 |
+
class DogeSdpaDynamicMaskAttention(DogeDynamicMaskAttention):
|
324 |
|
325 |
def forward(
|
326 |
self,
|
|
|
349 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
350 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
351 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
352 |
+
|
353 |
+
# calculate dynamic mask from value_states
|
354 |
+
dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
|
355 |
+
dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
|
356 |
|
357 |
+
attn_mask = self.prepare_dynamic_mask(
|
358 |
+
hidden_states=hidden_states,
|
359 |
+
dynamic_mask=dynamic_mask,
|
360 |
+
dynamic_mask_ratio=self.dynamic_mask_ratio,
|
361 |
+
attention_mask=attention_mask,
|
362 |
+
)
|
|
|
|
|
|
|
|
|
363 |
|
364 |
query_states = query_states.contiguous()
|
365 |
key_states = key_states.contiguous()
|
366 |
value_states = value_states.contiguous()
|
367 |
|
|
|
|
|
|
|
368 |
# NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
|
369 |
torch.backends.cuda.enable_cudnn_sdp(False)
|
370 |
attn_output = F.scaled_dot_product_attention(
|
371 |
query_states,
|
372 |
key_states,
|
373 |
value_states,
|
374 |
+
attn_mask=attn_mask,
|
375 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
376 |
+
enable_gqa=True,
|
377 |
)
|
378 |
|
379 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
383 |
return attn_output, past_key_value
|
384 |
|
385 |
|
386 |
+
class DogeFlexDynamicMaskAttention(DogeDynamicMaskAttention):
|
387 |
+
|
388 |
+
def forward(
|
389 |
+
self,
|
390 |
+
hidden_states: torch.Tensor,
|
391 |
+
attention_mask: Optional[torch.Tensor] = None,
|
392 |
+
position_ids: Optional[torch.LongTensor] = None,
|
393 |
+
past_key_value: Optional[Cache] = None,
|
394 |
+
cache_position: Optional[torch.LongTensor] = None,
|
395 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
396 |
+
**kwargs,
|
397 |
+
) -> Tuple[torch.Tensor, Optional[Cache]]:
|
398 |
+
bsz, q_len, _ = hidden_states.shape
|
399 |
+
|
400 |
+
query_states = self.q_proj(hidden_states)
|
401 |
+
key_states = self.k_proj(hidden_states)
|
402 |
+
value_states = self.v_proj(hidden_states)
|
403 |
+
|
404 |
+
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
405 |
+
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
406 |
+
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
407 |
+
|
408 |
+
cos, sin = position_embeddings
|
409 |
+
query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
|
410 |
+
|
411 |
+
if past_key_value is not None:
|
412 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
413 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
414 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
415 |
+
|
416 |
+
dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
|
417 |
+
dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
|
418 |
+
|
419 |
+
attn_mask = self.prepare_dynamic_mask(
|
420 |
+
hidden_states=hidden_states,
|
421 |
+
dynamic_mask=dynamic_mask,
|
422 |
+
dynamic_mask_ratio=self.dynamic_mask_ratio,
|
423 |
+
attention_mask=attention_mask,
|
424 |
+
)
|
425 |
+
# TODO: flex_attention: Captured buffers that require grad are not yet supported.
|
426 |
+
# NOTE: So we only use flex_attention in inference mode.
|
427 |
+
def dynamic_mask_mod(score, batch, head, q_idx, kv_idx):
|
428 |
+
score = score + attn_mask[batch][head][q_idx][kv_idx]
|
429 |
+
return score
|
430 |
+
|
431 |
+
attn_output = flex_attention(
|
432 |
+
query_states,
|
433 |
+
key_states,
|
434 |
+
value_states,
|
435 |
+
score_mod=dynamic_mask_mod,
|
436 |
+
enable_gqa=True,
|
437 |
+
)
|
438 |
+
|
439 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
440 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
441 |
+
attn_output = self.o_proj(attn_output)
|
442 |
+
|
443 |
+
return attn_output, past_key_value
|
444 |
+
|
445 |
+
|
446 |
DOGE_ATTENTION_CLASSES = {
|
447 |
+
"flex_attention": DogeFlexDynamicMaskAttention,
|
448 |
"eager": DogeDynamicMaskAttention,
|
449 |
+
"sdpa": DogeSdpaDynamicMaskAttention,
|
450 |
}
|
451 |
|
452 |
|
|
|
614 |
supports_gradient_checkpointing = True
|
615 |
_no_split_modules = ["DogeDecoderLayer"]
|
616 |
_skip_keys_device_placement = ["past_key_values"]
|
617 |
+
_supports_flex_attn = True
|
618 |
_supports_sdpa = True
|
619 |
_supports_cache_class = True
|
620 |
_supports_quantized_cache = True
|
|
|
789 |
all_self_attns = () if output_attentions else None
|
790 |
next_decoder_cache = None
|
791 |
|
792 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
793 |
if output_hidden_states:
|
794 |
all_hidden_states += (hidden_states,)
|
795 |
|
|
|
973 |
input_ids: torch.LongTensor = None,
|
974 |
attention_mask: Optional[torch.Tensor] = None,
|
975 |
position_ids: Optional[torch.LongTensor] = None,
|
976 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
977 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
978 |
labels: Optional[torch.LongTensor] = None,
|
979 |
use_cache: Optional[bool] = None,
|
|
|
982 |
return_dict: Optional[bool] = None,
|
983 |
cache_position: Optional[torch.LongTensor] = None,
|
984 |
num_logits_to_keep: int = 0,
|
985 |
+
**kwargs,
|
986 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
987 |
r"""
|
988 |
Args:
|
|
|
1016 |
output_hidden_states=output_hidden_states,
|
1017 |
return_dict=return_dict,
|
1018 |
cache_position=cache_position,
|
1019 |
+
**kwargs,
|
1020 |
)
|
1021 |
|
1022 |
hidden_states = outputs[0]
|
|
|
1026 |
|
1027 |
loss = None
|
1028 |
if labels is not None:
|
1029 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
|
1030 |
|
1031 |
if not return_dict:
|
1032 |
output = (logits,) + outputs[1:]
|