Upload modeling_hunyuan.py
Browse files- modeling_hunyuan.py +303 -144
modeling_hunyuan.py
CHANGED
@@ -1,5 +1,16 @@
|
|
1 |
-
#
|
2 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
#
|
4 |
""" PyTorch HunYuan model."""
|
5 |
|
@@ -8,7 +19,6 @@ import warnings
|
|
8 |
from typing import List, Optional, Tuple, Union
|
9 |
|
10 |
import torch
|
11 |
-
torch.set_default_dtype(torch.float32)
|
12 |
from torch import Tensor
|
13 |
import torch.nn.functional as F
|
14 |
import torch.utils.checkpoint
|
@@ -64,7 +74,8 @@ _CONFIG_FOR_DOC = "HunYuanConfig"
|
|
64 |
def topkgating(logits: Tensor, topk: int):
|
65 |
logits = logits.float()
|
66 |
gates = F.softmax(logits, dim=1)
|
67 |
-
expert_capacity = topk * gates.shape[0]
|
|
|
68 |
num_experts = int(gates.shape[1])
|
69 |
# Top-k router probability and corresponding expert indices for each token.
|
70 |
# Shape: [tokens_per_group, num_selected_experts].
|
@@ -254,7 +265,7 @@ class HunYuanRotaryEmbedding(nn.Module):
|
|
254 |
self.max_position_embeddings = max_position_embeddings
|
255 |
self.base = base
|
256 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
257 |
-
inv_freq = inv_freq.bfloat16()
|
258 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
259 |
|
260 |
# Build here to make `torch.jit.trace` work.
|
@@ -266,6 +277,7 @@ class HunYuanRotaryEmbedding(nn.Module):
|
|
266 |
self.max_seq_len_cached = seq_len
|
267 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
|
268 |
|
|
|
269 |
freqs = torch.outer(t, self.inv_freq)
|
270 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
271 |
emb = torch.cat((freqs, freqs), dim=-1).float()
|
@@ -274,7 +286,7 @@ class HunYuanRotaryEmbedding(nn.Module):
|
|
274 |
|
275 |
def forward(self, x, seq_len=None):
|
276 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
277 |
-
if seq_len > self.max_seq_len_cached:
|
278 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
279 |
|
280 |
return (
|
@@ -399,7 +411,7 @@ class HunYuanMLP(nn.Module):
|
|
399 |
self.layer_idx = layer_idx
|
400 |
self.hidden_size = config.hidden_size
|
401 |
if is_shared_mlp:
|
402 |
-
self.intermediate_size = config.intermediate_size * config.num_shared_expert
|
403 |
else:
|
404 |
self.intermediate_size = config.intermediate_size
|
405 |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
@@ -450,142 +462,66 @@ class HunYuanTopKGate(nn.Module):
|
|
450 |
if self.moe_topk == 1:
|
451 |
gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
|
452 |
else:
|
453 |
-
gate_output = topkgating(logits, self.moe_topk)
|
454 |
|
455 |
return gate_output
|
456 |
|
457 |
|
458 |
class HunYuanMoE(nn.Module):
|
459 |
-
"""Mixture-of-Experts block with vectorized expert execution and Straight‑Through Estimator (STE) utilities.
|
460 |
-
|
461 |
-
This implementation removes all Python‑side loops over experts. Expert parameters are **stacked** on‑the‑fly and
|
462 |
-
all experts are executed in a single batched matmul sequence, which allows efficient tensor‑parallel execution
|
463 |
-
(e.g. with DeepSpeed ZeRO‑3) while keeping the state‑dict format unchanged (each expert remains an individual
|
464 |
-
sub‑module for full compatibility with existing checkpoints)."""
|
465 |
-
|
466 |
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
|
467 |
super().__init__()
|
468 |
self.config = config
|
469 |
self.layer_idx = layer_idx
|
470 |
self.moe_topk = config.moe_topk
|
471 |
self.num_experts = config.num_experts
|
472 |
-
|
473 |
-
# Optional shared MLP branch (mixed MoE + dense)
|
474 |
if config.use_mixed_mlp_moe:
|
475 |
self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
|
476 |
-
|
477 |
-
# Router
|
478 |
self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
|
479 |
-
|
480 |
-
# Experts kept as individual sub‑modules so that load_state_dict / save_pretrained stay identical.
|
481 |
self.experts = nn.ModuleList(
|
482 |
[HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
|
483 |
)
|
484 |
|
485 |
-
|
486 |
-
# Internal helpers
|
487 |
-
# ---------------------------------------------------------------------
|
488 |
-
def _stack_weights(self):
|
489 |
-
"""Return stacked (batched) expert weights for gate/up/down projections.
|
490 |
-
|
491 |
-
Shapes:
|
492 |
-
Wg : (E, I, H)
|
493 |
-
Wu : (E, I, H)
|
494 |
-
Wd : (E, H, I) – note transposed compared to nn.Linear's weight so that
|
495 |
-
torch.matmul(act, Wd.transpose(-2,‑1)) produces (E, C, H)
|
496 |
-
"""
|
497 |
-
Wg = torch.stack([exp.gate_proj.weight for exp in self.experts], dim=0)
|
498 |
-
Wu = torch.stack([exp.up_proj.weight for exp in self.experts], dim=0)
|
499 |
-
Wd = torch.stack([exp.down_proj.weight for exp in self.experts], dim=0)
|
500 |
-
return Wg, Wu, Wd
|
501 |
-
|
502 |
-
# ---------------------------------------------------------------------
|
503 |
-
# Public API
|
504 |
-
# ---------------------------------------------------------------------
|
505 |
-
def forward(
|
506 |
-
self,
|
507 |
-
hidden_states: torch.Tensor,
|
508 |
-
*,
|
509 |
-
return_router_logits: bool = False,
|
510 |
-
):
|
511 |
-
"""Sparse Top‑k MoE forward pass ("y_sparse") with optional router logits.
|
512 |
-
|
513 |
-
This is the route used during both training and inference for the *sparse*
|
514 |
-
branch. No Python loops over experts are used."
|
515 |
-
"""
|
516 |
bsz, seq_len, hidden_size = hidden_states.shape
|
517 |
|
518 |
-
# Optional dense branch when using mixed MoE
|
519 |
if self.config.use_mixed_mlp_moe:
|
520 |
hidden_states_mlp = self.shared_mlp(hidden_states)
|
521 |
|
522 |
-
|
523 |
-
l_moe, combine_weights, dispatch_mask, _ = self.gate(hidden_states)
|
524 |
-
|
525 |
-
flat_input = hidden_states.reshape(-1, hidden_size) # (S, H) where S = B*T
|
526 |
-
# dispatch tokens → experts
|
527 |
-
dispatched_input = torch.einsum("sec,sm->ecm", # (E, C, H)
|
528 |
-
dispatch_mask.to(hidden_states.dtype),
|
529 |
-
flat_input)
|
530 |
-
|
531 |
-
# ---------------- Expert computation ----------------
|
532 |
-
Wg, Wu, Wd = self._stack_weights()
|
533 |
-
Wg = Wg.to(hidden_states.dtype)
|
534 |
-
Wu = Wu.to(hidden_states.dtype)
|
535 |
-
Wd = Wd.to(hidden_states.dtype)
|
536 |
-
|
537 |
-
gate_out = torch.einsum("ech,eih->eci", dispatched_input, Wg) # (E, C, I)
|
538 |
-
up_out = torch.einsum("ech,eih->eci", dispatched_input, Wu) # (E, C, I)
|
539 |
-
act_fn = self.experts[0].act_fn
|
540 |
-
interm = act_fn(gate_out) * up_out # (E, C, I)
|
541 |
-
expert_output = torch.matmul(interm, Wd.transpose(-2, -1)) # (E, C, H)
|
542 |
-
|
543 |
-
# ---------------- Combine ----------------
|
544 |
-
combined_output = torch.einsum("sec,ecm->sm", # (S, H)
|
545 |
-
combine_weights.to(hidden_states.dtype),
|
546 |
-
expert_output)
|
547 |
-
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
548 |
|
549 |
-
|
550 |
-
combined_output = hidden_states_mlp + combined_output
|
551 |
|
552 |
-
|
553 |
-
router_logits = self.gate.wg(flat_input).view(bsz, seq_len, self.num_experts)
|
554 |
-
return combined_output, router_logits
|
555 |
-
return combined_output
|
556 |
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
def forward_all(self, hidden_states: torch.Tensor):
|
562 |
-
"""Compute every expert on every token ("dense" MoE path).
|
563 |
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
bsz, seq_len, hidden_size = hidden_states.shape
|
568 |
-
flat_input = hidden_states.reshape(-1, hidden_size) # (S, H)
|
569 |
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
|
575 |
-
|
576 |
-
up_out = torch.einsum("sh,eih->esi", flat_input, Wu) # (E, S, I)
|
577 |
-
act_fn = self.experts[0].act_fn
|
578 |
-
interm = act_fn(gate_out) * up_out # (E, S, I)
|
579 |
-
dense_out = torch.matmul(interm, Wd.transpose(-2, -1)) # (E, S, H)
|
580 |
|
581 |
-
dense_out = dense_out.permute(1, 0, 2).contiguous() # (S, E, H)
|
582 |
-
dense_out = dense_out.view(bsz, seq_len, self.num_experts, hidden_size) # (B, T, E, H)
|
583 |
|
584 |
-
|
585 |
-
|
586 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
587 |
|
588 |
-
return dense_out
|
589 |
class HunYuanAttention(nn.Module):
|
590 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
591 |
|
@@ -1128,18 +1064,33 @@ class HunYuanDecoderLayer(nn.Module):
|
|
1128 |
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
1129 |
**kwargs,
|
1130 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
1131 |
-
"""
|
1132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1133 |
if "padding_mask" in kwargs:
|
1134 |
warnings.warn(
|
1135 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
|
1136 |
"`attention_mask` instead.`"
|
1137 |
)
|
1138 |
-
|
1139 |
-
# ---------------- Self‑Attention ----------------
|
1140 |
residual = hidden_states
|
|
|
1141 |
hidden_states = self.input_layernorm(hidden_states)
|
1142 |
-
|
|
|
1143 |
hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn(
|
1144 |
hidden_states=hidden_states,
|
1145 |
attention_mask=attention_mask,
|
@@ -1151,40 +1102,47 @@ class HunYuanDecoderLayer(nn.Module):
|
|
1151 |
**kwargs,
|
1152 |
)
|
1153 |
hidden_states = residual + hidden_states
|
1154 |
-
|
1155 |
-
#
|
1156 |
residual = hidden_states
|
1157 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
1158 |
-
|
1159 |
-
|
1160 |
-
|
1161 |
-
y_sparse, router_logits = self.mlp(hidden_states, return_router_logits=True)
|
1162 |
-
|
1163 |
-
# Dense (all‑experts) path – memory‑efficient, no grad
|
1164 |
-
with torch.no_grad():
|
1165 |
-
y_all = self.mlp.forward_all(hidden_states) # (B, T, E, H)
|
1166 |
-
|
1167 |
-
gate = router_logits.softmax(-1).unsqueeze(-1) # (B, T, E, 1)
|
1168 |
-
y_dense = (gate * y_all).sum(-2) # (B, T, H)
|
1169 |
-
|
1170 |
-
mlp_out = y_dense + (y_sparse - y_dense).detach()
|
1171 |
-
else:
|
1172 |
-
mlp_out = self.mlp(hidden_states)
|
1173 |
-
|
1174 |
-
hidden_states = residual + mlp_out
|
1175 |
-
|
1176 |
-
# ---------------- Outputs ----------------
|
1177 |
outputs = (hidden_states,)
|
1178 |
-
|
1179 |
if output_attentions:
|
1180 |
outputs += (self_attn_weights,)
|
1181 |
-
|
1182 |
if use_cache:
|
1183 |
outputs += (present_key_value,)
|
1184 |
-
|
1185 |
outputs += (kv_states,)
|
1186 |
-
|
1187 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1188 |
class HunYuanPreTrainedModel(PreTrainedModel):
|
1189 |
config_class = HunYuanConfig
|
1190 |
base_model_prefix = "model"
|
@@ -1277,6 +1235,10 @@ HUNYUAN_INPUTS_DOCSTRING = r"""
|
|
1277 |
"""
|
1278 |
|
1279 |
|
|
|
|
|
|
|
|
|
1280 |
class HunYuanModel(HunYuanPreTrainedModel):
|
1281 |
"""
|
1282 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`]
|
@@ -1455,7 +1417,7 @@ class HunYuanModel(HunYuanPreTrainedModel):
|
|
1455 |
)
|
1456 |
|
1457 |
|
1458 |
-
class
|
1459 |
_tied_weights_keys = ["lm_head.weight"]
|
1460 |
|
1461 |
def __init__(self, config: HunYuanConfig):
|
@@ -1585,7 +1547,7 @@ class HunYuanForCausalLM(HunYuanPreTrainedModel):
|
|
1585 |
if isinstance(past_key_values, Cache):
|
1586 |
cache_length = past_key_values.get_seq_length()
|
1587 |
past_length = past_key_values.seen_tokens
|
1588 |
-
max_cache_length = past_key_values.
|
1589 |
else:
|
1590 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1591 |
max_cache_length = None
|
@@ -1644,6 +1606,21 @@ class HunYuanForCausalLM(HunYuanPreTrainedModel):
|
|
1644 |
return reordered_past
|
1645 |
|
1646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1647 |
class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
|
1648 |
def __init__(self, config):
|
1649 |
super().__init__(config)
|
@@ -1748,4 +1725,186 @@ class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
|
|
1748 |
past_key_values=transformer_outputs.past_key_values,
|
1749 |
hidden_states=transformer_outputs.hidden_states,
|
1750 |
attentions=transformer_outputs.attentions,
|
1751 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
#
|
15 |
""" PyTorch HunYuan model."""
|
16 |
|
|
|
19 |
from typing import List, Optional, Tuple, Union
|
20 |
|
21 |
import torch
|
|
|
22 |
from torch import Tensor
|
23 |
import torch.nn.functional as F
|
24 |
import torch.utils.checkpoint
|
|
|
74 |
def topkgating(logits: Tensor, topk: int):
|
75 |
logits = logits.float()
|
76 |
gates = F.softmax(logits, dim=1)
|
77 |
+
# expert_capacity = topk * gates.shape[0]
|
78 |
+
expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])
|
79 |
num_experts = int(gates.shape[1])
|
80 |
# Top-k router probability and corresponding expert indices for each token.
|
81 |
# Shape: [tokens_per_group, num_selected_experts].
|
|
|
265 |
self.max_position_embeddings = max_position_embeddings
|
266 |
self.base = base
|
267 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
268 |
+
# inv_freq = inv_freq.bfloat16()
|
269 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
270 |
|
271 |
# Build here to make `torch.jit.trace` work.
|
|
|
277 |
self.max_seq_len_cached = seq_len
|
278 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
|
279 |
|
280 |
+
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
281 |
freqs = torch.outer(t, self.inv_freq)
|
282 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
283 |
emb = torch.cat((freqs, freqs), dim=-1).float()
|
|
|
286 |
|
287 |
def forward(self, x, seq_len=None):
|
288 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
289 |
+
if seq_len > self.max_seq_len_cached or self.inv_freq.dtype != torch.float32:
|
290 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
291 |
|
292 |
return (
|
|
|
411 |
self.layer_idx = layer_idx
|
412 |
self.hidden_size = config.hidden_size
|
413 |
if is_shared_mlp:
|
414 |
+
self.intermediate_size = config.intermediate_size * config.num_shared_expert[0]
|
415 |
else:
|
416 |
self.intermediate_size = config.intermediate_size
|
417 |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
462 |
if self.moe_topk == 1:
|
463 |
gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
|
464 |
else:
|
465 |
+
gate_output = topkgating(logits, self.moe_topk[0])
|
466 |
|
467 |
return gate_output
|
468 |
|
469 |
|
470 |
class HunYuanMoE(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
|
472 |
super().__init__()
|
473 |
self.config = config
|
474 |
self.layer_idx = layer_idx
|
475 |
self.moe_topk = config.moe_topk
|
476 |
self.num_experts = config.num_experts
|
|
|
|
|
477 |
if config.use_mixed_mlp_moe:
|
478 |
self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
|
|
|
|
|
479 |
self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
|
|
|
|
|
480 |
self.experts = nn.ModuleList(
|
481 |
[HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
|
482 |
)
|
483 |
|
484 |
+
def forward(self, hidden_states):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
bsz, seq_len, hidden_size = hidden_states.shape
|
486 |
|
|
|
487 |
if self.config.use_mixed_mlp_moe:
|
488 |
hidden_states_mlp = self.shared_mlp(hidden_states)
|
489 |
|
490 |
+
l_moe, combine_weights, dispatch_mask, exp_counts = self.gate(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
+
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
|
|
493 |
|
494 |
+
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
|
|
|
|
|
|
|
495 |
|
496 |
+
chunks = dispatched_input.chunk(self.num_experts, dim=0)
|
497 |
+
expert_outputs = []
|
498 |
+
for chunk, expert in zip(chunks, self.experts):
|
499 |
+
expert_outputs.append(expert(chunk))
|
|
|
|
|
500 |
|
501 |
+
expert_output = torch.cat(expert_outputs, dim=0)
|
502 |
+
combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
|
503 |
+
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
|
|
|
|
504 |
|
505 |
+
if self.config.use_mixed_mlp_moe:
|
506 |
+
output = hidden_states_mlp + combined_output
|
507 |
+
else:
|
508 |
+
output = combined_output
|
509 |
|
510 |
+
return output
|
|
|
|
|
|
|
|
|
511 |
|
|
|
|
|
512 |
|
513 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
514 |
+
"""
|
515 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
516 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
517 |
+
"""
|
518 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
519 |
+
if n_rep == 1:
|
520 |
+
return hidden_states
|
521 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
522 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
523 |
+
|
524 |
|
|
|
525 |
class HunYuanAttention(nn.Module):
|
526 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
527 |
|
|
|
1064 |
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
1065 |
**kwargs,
|
1066 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
1067 |
+
"""
|
1068 |
+
Args:
|
1069 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
1070 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
1071 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
1072 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
1073 |
+
output_attentions (`bool`, *optional*):
|
1074 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
1075 |
+
returned tensors for more detail.
|
1076 |
+
use_cache (`bool`, *optional*):
|
1077 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
1078 |
+
(see `past_key_values`).
|
1079 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
1080 |
+
kv_states (`Tuple(torch.FloatTensor)`, *optional*): Used when CLA is enabled,
|
1081 |
+
key and value states from past attention blocks
|
1082 |
+
"""
|
1083 |
if "padding_mask" in kwargs:
|
1084 |
warnings.warn(
|
1085 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
|
1086 |
"`attention_mask` instead.`"
|
1087 |
)
|
1088 |
+
|
|
|
1089 |
residual = hidden_states
|
1090 |
+
|
1091 |
hidden_states = self.input_layernorm(hidden_states)
|
1092 |
+
|
1093 |
+
# Self Attention
|
1094 |
hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn(
|
1095 |
hidden_states=hidden_states,
|
1096 |
attention_mask=attention_mask,
|
|
|
1102 |
**kwargs,
|
1103 |
)
|
1104 |
hidden_states = residual + hidden_states
|
1105 |
+
|
1106 |
+
# Fully Connected
|
1107 |
residual = hidden_states
|
1108 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
1109 |
+
hidden_states = self.mlp(hidden_states)
|
1110 |
+
hidden_states = residual + hidden_states
|
1111 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1112 |
outputs = (hidden_states,)
|
1113 |
+
|
1114 |
if output_attentions:
|
1115 |
outputs += (self_attn_weights,)
|
1116 |
+
|
1117 |
if use_cache:
|
1118 |
outputs += (present_key_value,)
|
1119 |
+
|
1120 |
outputs += (kv_states,)
|
1121 |
+
|
1122 |
return outputs
|
1123 |
+
|
1124 |
+
|
1125 |
+
HUNYUAN_START_DOCSTRING = r"""
|
1126 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
1127 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
1128 |
+
etc.)
|
1129 |
+
|
1130 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
1131 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
1132 |
+
and behavior.
|
1133 |
+
|
1134 |
+
Parameters:
|
1135 |
+
config ([`HunYuanConfig`]):
|
1136 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1137 |
+
load the weights associated with the model, only the configuration. Check out the
|
1138 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
1139 |
+
"""
|
1140 |
+
|
1141 |
+
|
1142 |
+
@add_start_docstrings(
|
1143 |
+
"The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
|
1144 |
+
HUNYUAN_START_DOCSTRING,
|
1145 |
+
)
|
1146 |
class HunYuanPreTrainedModel(PreTrainedModel):
|
1147 |
config_class = HunYuanConfig
|
1148 |
base_model_prefix = "model"
|
|
|
1235 |
"""
|
1236 |
|
1237 |
|
1238 |
+
@add_start_docstrings(
|
1239 |
+
"The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
|
1240 |
+
HUNYUAN_START_DOCSTRING,
|
1241 |
+
)
|
1242 |
class HunYuanModel(HunYuanPreTrainedModel):
|
1243 |
"""
|
1244 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`]
|
|
|
1417 |
)
|
1418 |
|
1419 |
|
1420 |
+
class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
|
1421 |
_tied_weights_keys = ["lm_head.weight"]
|
1422 |
|
1423 |
def __init__(self, config: HunYuanConfig):
|
|
|
1547 |
if isinstance(past_key_values, Cache):
|
1548 |
cache_length = past_key_values.get_seq_length()
|
1549 |
past_length = past_key_values.seen_tokens
|
1550 |
+
max_cache_length = past_key_values.get_max_cache_shape()
|
1551 |
else:
|
1552 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1553 |
max_cache_length = None
|
|
|
1606 |
return reordered_past
|
1607 |
|
1608 |
|
1609 |
+
@add_start_docstrings(
|
1610 |
+
"""
|
1611 |
+
The HunYuan Model transformer with a sequence classification head on top (linear layer).
|
1612 |
+
|
1613 |
+
[`HunYuanForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1614 |
+
(e.g. GPT-2) do.
|
1615 |
+
|
1616 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1617 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1618 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1619 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1620 |
+
each row of the batch).
|
1621 |
+
""",
|
1622 |
+
HUNYUAN_START_DOCSTRING,
|
1623 |
+
)
|
1624 |
class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
|
1625 |
def __init__(self, config):
|
1626 |
super().__init__(config)
|
|
|
1725 |
past_key_values=transformer_outputs.past_key_values,
|
1726 |
hidden_states=transformer_outputs.hidden_states,
|
1727 |
attentions=transformer_outputs.attentions,
|
1728 |
+
)
|
1729 |
+
|
1730 |
+
|
1731 |
+
# ================================================================
|
1732 |
+
# Dense/Sparse MoE utilities
|
1733 |
+
# These enable dense training (all experts active) with sparse inference,
|
1734 |
+
# by fusing the per‑expert parameters into single linear layers and
|
1735 |
+
# applying a straight‑through estimator (STE) – gradients flow through the
|
1736 |
+
# dense path, while the forward pass matches sparse routing behaviour.
|
1737 |
+
# ================================================================
|
1738 |
+
|
1739 |
+
import types
|
1740 |
+
|
1741 |
+
class HunYuanDenseMoE(nn.Module):
|
1742 |
+
"""Dense counterpart of :class:`HunYuanMoE`.
|
1743 |
+
|
1744 |
+
* The per‑expert linear layers (``gate_proj``, ``up_proj``, ``down_proj``)
|
1745 |
+
are concatenated to form three *fused* linear layers.
|
1746 |
+
* Forward pass:
|
1747 |
+
1. **Dense path** – every expert contributes,
|
1748 |
+
weighted by the softmax router probabilities.
|
1749 |
+
2. **Sparse path** – only the Top‑K experts (identical to the
|
1750 |
+
original sparse MoE) are evaluated.
|
1751 |
+
3. **STE** – ``output = dense + (sparse – dense).detach()`` –
|
1752 |
+
identical values to sparse inference, dense gradients.
|
1753 |
+
"""
|
1754 |
+
|
1755 |
+
def __init__(self, moe: "HunYuanMoE"):
|
1756 |
+
super().__init__()
|
1757 |
+
self.config = moe.config
|
1758 |
+
self.layer_idx = moe.layer_idx
|
1759 |
+
self.num_experts = moe.num_experts
|
1760 |
+
# All experts share the same hidden/intermediate sizes
|
1761 |
+
self.hidden_size = moe.experts[0].hidden_size
|
1762 |
+
self.intermediate_size = moe.experts[0].intermediate_size
|
1763 |
+
|
1764 |
+
# Router is reused directly
|
1765 |
+
self.gate = moe.gate
|
1766 |
+
|
1767 |
+
# ------------------------------------------------------------------
|
1768 |
+
# Fuse per‑expert parameters
|
1769 |
+
# ------------------------------------------------------------------
|
1770 |
+
with torch.no_grad():
|
1771 |
+
fused_gate_w = torch.cat([exp.gate_proj.weight for exp in moe.experts], dim=0).clone()
|
1772 |
+
fused_up_w = torch.cat([exp.up_proj.weight for exp in moe.experts], dim=0).clone()
|
1773 |
+
# down_proj weights are shaped (hidden, intermediate)
|
1774 |
+
fused_down_w = torch.cat([exp.down_proj.weight for exp in moe.experts], dim=1).clone()
|
1775 |
+
|
1776 |
+
self.fused_gate_proj = nn.Linear(self.hidden_size,
|
1777 |
+
self.intermediate_size * self.num_experts,
|
1778 |
+
bias=False)
|
1779 |
+
self.fused_up_proj = nn.Linear(self.hidden_size,
|
1780 |
+
self.intermediate_size * self.num_experts,
|
1781 |
+
bias=False)
|
1782 |
+
self.fused_down_proj = nn.Linear(self.intermediate_size * self.num_experts,
|
1783 |
+
self.hidden_size,
|
1784 |
+
bias=False)
|
1785 |
+
|
1786 |
+
# Load weights
|
1787 |
+
self.fused_gate_proj.weight.data.copy_(fused_gate_w)
|
1788 |
+
self.fused_up_proj.weight.data.copy_(fused_up_w)
|
1789 |
+
self.fused_down_proj.weight.data.copy_(fused_down_w)
|
1790 |
+
|
1791 |
+
self.act_fn = moe.experts[0].act_fn
|
1792 |
+
self.topk = self.gate.moe_topk[0] if isinstance(self.gate.moe_topk, (list, tuple)) else self.gate.moe_topk
|
1793 |
+
|
1794 |
+
def _dense_path(self, x, probs):
|
1795 |
+
"""Compute dense mixture – every expert active."""
|
1796 |
+
gate_out = self.fused_gate_proj(x) # (T, I*E)
|
1797 |
+
up_out = self.fused_up_proj(x) # (T, I*E)
|
1798 |
+
interm = self.act_fn(gate_out) * up_out # (T, I*E)
|
1799 |
+
|
1800 |
+
# Reshape to (T, E, I)
|
1801 |
+
interm = interm.view(-1, self.num_experts, self.intermediate_size)
|
1802 |
+
|
1803 |
+
# Weight by softmax router probabilities
|
1804 |
+
interm_weighted = interm * probs.unsqueeze(-1) # (T, E, I)
|
1805 |
+
|
1806 |
+
# Collapse experts back to vector and project down
|
1807 |
+
dense_flat = interm_weighted.reshape(-1, self.intermediate_size * self.num_experts)
|
1808 |
+
return self.fused_down_proj(dense_flat) # (T, H)
|
1809 |
+
|
1810 |
+
def _sparse_path(self, x, probs):
|
1811 |
+
"""Compute sparse Top‑K mixture (matches original inference)."""
|
1812 |
+
# Pre‑compute per‑expert activations to avoid repeated forward calls
|
1813 |
+
gate_out = self.fused_gate_proj(x) # (T, I*E)
|
1814 |
+
up_out = self.fused_up_proj(x) # (T, I*E)
|
1815 |
+
interm = self.act_fn(gate_out) * up_out # (T, I*E)
|
1816 |
+
interm = interm.view(-1, self.num_experts, self.intermediate_size) # (T,E,I)
|
1817 |
+
|
1818 |
+
# Top‑K experts per token
|
1819 |
+
values, indices = torch.topk(probs, self.topk, dim=1) # (T,K)
|
1820 |
+
|
1821 |
+
# Gather corresponding intermediate activations
|
1822 |
+
gathered = interm.gather(1, indices.unsqueeze(-1).expand(-1, -1, self.intermediate_size)) # (T,K,I)
|
1823 |
+
gathered = gathered * values.unsqueeze(-1) # weight by router prob
|
1824 |
+
|
1825 |
+
# Gather matching down-proj weights
|
1826 |
+
# fused_down_proj.weight: (H, I*E) -> reshape to (E,I,H)
|
1827 |
+
down_w = self.fused_down_proj.weight.view(self.hidden_size,
|
1828 |
+
self.num_experts,
|
1829 |
+
self.intermediate_size).permute(1,2,0).contiguous() # (E,I,H)
|
1830 |
+
selected_w = down_w.index_select(0, indices.reshape(-1)).view(indices.size(0), self.topk,
|
1831 |
+
self.intermediate_size, self.hidden_size)
|
1832 |
+
# (T,K,I,H)
|
1833 |
+
|
1834 |
+
# Compute output: batch matmul over I
|
1835 |
+
sparse_out = torch.einsum('t k i, t k i h -> t h', gathered, selected_w)
|
1836 |
+
return sparse_out # (T, H)
|
1837 |
+
|
1838 |
+
def forward(self, hidden_states):
|
1839 |
+
bsz, seq_len, _ = hidden_states.shape
|
1840 |
+
x = hidden_states.reshape(-1, self.hidden_size) # T x H
|
1841 |
+
logits = self.gate.wg(x) # (T,E)
|
1842 |
+
probs = torch.softmax(logits, dim=1) # (T,E)
|
1843 |
+
|
1844 |
+
dense_out = self._dense_path(x, probs)
|
1845 |
+
sparse_out = self._sparse_path(x, probs)
|
1846 |
+
|
1847 |
+
out = dense_out + (sparse_out - dense_out).detach() # STE
|
1848 |
+
return out.view(bsz, seq_len, self.hidden_size)
|
1849 |
+
|
1850 |
+
# -----------------------------------------------------------------------
|
1851 |
+
# Helper for module replacement
|
1852 |
+
# -----------------------------------------------------------------------
|
1853 |
+
def _replace_submodule(root: nn.Module, target: str, new_module: nn.Module):
|
1854 |
+
"""Replace a (possibly nested) sub‑module.
|
1855 |
+
|
1856 |
+
``target`` is the dotted path returned by ``model.named_modules()``.
|
1857 |
+
"""
|
1858 |
+
parts = target.split('.')
|
1859 |
+
parent = root
|
1860 |
+
for p in parts[:-1]:
|
1861 |
+
parent = getattr(parent, p)
|
1862 |
+
setattr(parent, parts[-1], new_module)
|
1863 |
+
|
1864 |
+
# -----------------------------------------------------------------------
|
1865 |
+
# Public APIs
|
1866 |
+
# -----------------------------------------------------------------------
|
1867 |
+
def densify(model: nn.Module):
|
1868 |
+
"""Convert all :class:`HunYuanMoE` modules under *model* to
|
1869 |
+
:class:`HunYuanDenseMoE`. Operates **in‑place**."""
|
1870 |
+
replacements = []
|
1871 |
+
for name, module in model.named_modules():
|
1872 |
+
if isinstance(module, HunYuanMoE):
|
1873 |
+
replacements.append((name, module))
|
1874 |
+
for name, sparse_moe in replacements:
|
1875 |
+
dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
|
1876 |
+
_replace_submodule(model, name, dense_moe)
|
1877 |
+
return model
|
1878 |
+
|
1879 |
+
|
1880 |
+
def sparsify(model: nn.Module):
|
1881 |
+
"""Rebuild standard sparse :class:`HunYuanMoE` modules from their
|
1882 |
+
fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
|
1883 |
+
replacements = []
|
1884 |
+
for name, module in model.named_modules():
|
1885 |
+
if isinstance(module, HunYuanDenseMoE):
|
1886 |
+
replacements.append((name, module))
|
1887 |
+
for name, dense_moe in replacements:
|
1888 |
+
cfg = dense_moe.config
|
1889 |
+
sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
|
1890 |
+
|
1891 |
+
# Copy router
|
1892 |
+
sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
|
1893 |
+
|
1894 |
+
# Slice fused weights back to per‑expert
|
1895 |
+
for idx, expert in enumerate(sparse_moe.experts):
|
1896 |
+
start = idx * dense_moe.intermediate_size
|
1897 |
+
end = (idx + 1) * dense_moe.intermediate_size
|
1898 |
+
|
1899 |
+
expert.gate_proj.weight.data.copy_(
|
1900 |
+
dense_moe.fused_gate_proj.weight.data[start:end]
|
1901 |
+
)
|
1902 |
+
expert.up_proj.weight.data.copy_(
|
1903 |
+
dense_moe.fused_up_proj.weight.data[start:end]
|
1904 |
+
)
|
1905 |
+
expert.down_proj.weight.data.copy_(
|
1906 |
+
dense_moe.fused_down_proj.weight.data[:, start:end]
|
1907 |
+
)
|
1908 |
+
|
1909 |
+
_replace_submodule(model, name, sparse_moe)
|
1910 |
+
return model
|