codys12 commited on
Commit
b161686
·
verified ·
1 Parent(s): 860686d

Upload modeling_hunyuan.py

Browse files
Files changed (1) hide show
  1. modeling_hunyuan.py +303 -144
modeling_hunyuan.py CHANGED
@@ -1,5 +1,16 @@
1
- # coding=utf-8
2
- # Copyright 2024 Tencent Inc. All Rights Reserved.
 
 
 
 
 
 
 
 
 
 
 
3
  #
4
  """ PyTorch HunYuan model."""
5
 
@@ -8,7 +19,6 @@ import warnings
8
  from typing import List, Optional, Tuple, Union
9
 
10
  import torch
11
- torch.set_default_dtype(torch.float32)
12
  from torch import Tensor
13
  import torch.nn.functional as F
14
  import torch.utils.checkpoint
@@ -64,7 +74,8 @@ _CONFIG_FOR_DOC = "HunYuanConfig"
64
  def topkgating(logits: Tensor, topk: int):
65
  logits = logits.float()
66
  gates = F.softmax(logits, dim=1)
67
- expert_capacity = topk * gates.shape[0]
 
68
  num_experts = int(gates.shape[1])
69
  # Top-k router probability and corresponding expert indices for each token.
70
  # Shape: [tokens_per_group, num_selected_experts].
@@ -254,7 +265,7 @@ class HunYuanRotaryEmbedding(nn.Module):
254
  self.max_position_embeddings = max_position_embeddings
255
  self.base = base
256
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
257
- inv_freq = inv_freq.bfloat16()
258
  self.register_buffer("inv_freq", inv_freq, persistent=False)
259
 
260
  # Build here to make `torch.jit.trace` work.
@@ -266,6 +277,7 @@ class HunYuanRotaryEmbedding(nn.Module):
266
  self.max_seq_len_cached = seq_len
267
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
268
 
 
269
  freqs = torch.outer(t, self.inv_freq)
270
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
271
  emb = torch.cat((freqs, freqs), dim=-1).float()
@@ -274,7 +286,7 @@ class HunYuanRotaryEmbedding(nn.Module):
274
 
275
  def forward(self, x, seq_len=None):
276
  # x: [bs, num_attention_heads, seq_len, head_size]
277
- if seq_len > self.max_seq_len_cached:
278
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
279
 
280
  return (
@@ -399,7 +411,7 @@ class HunYuanMLP(nn.Module):
399
  self.layer_idx = layer_idx
400
  self.hidden_size = config.hidden_size
401
  if is_shared_mlp:
402
- self.intermediate_size = config.intermediate_size * config.num_shared_expert
403
  else:
404
  self.intermediate_size = config.intermediate_size
405
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@@ -450,142 +462,66 @@ class HunYuanTopKGate(nn.Module):
450
  if self.moe_topk == 1:
451
  gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
452
  else:
453
- gate_output = topkgating(logits, self.moe_topk)
454
 
455
  return gate_output
456
 
457
 
458
  class HunYuanMoE(nn.Module):
459
- """Mixture-of-Experts block with vectorized expert execution and Straight‑Through Estimator (STE) utilities.
460
-
461
- This implementation removes all Python‑side loops over experts. Expert parameters are **stacked** on‑the‑fly and
462
- all experts are executed in a single batched matmul sequence, which allows efficient tensor‑parallel execution
463
- (e.g. with DeepSpeed ZeRO‑3) while keeping the state‑dict format unchanged (each expert remains an individual
464
- sub‑module for full compatibility with existing checkpoints)."""
465
-
466
  def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
467
  super().__init__()
468
  self.config = config
469
  self.layer_idx = layer_idx
470
  self.moe_topk = config.moe_topk
471
  self.num_experts = config.num_experts
472
-
473
- # Optional shared MLP branch (mixed MoE + dense)
474
  if config.use_mixed_mlp_moe:
475
  self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
476
-
477
- # Router
478
  self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
479
-
480
- # Experts kept as individual sub‑modules so that load_state_dict / save_pretrained stay identical.
481
  self.experts = nn.ModuleList(
482
  [HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
483
  )
484
 
485
- # ---------------------------------------------------------------------
486
- # Internal helpers
487
- # ---------------------------------------------------------------------
488
- def _stack_weights(self):
489
- """Return stacked (batched) expert weights for gate/up/down projections.
490
-
491
- Shapes:
492
- Wg : (E, I, H)
493
- Wu : (E, I, H)
494
- Wd : (E, H, I) – note transposed compared to nn.Linear's weight so that
495
- torch.matmul(act, Wd.transpose(-2,‑1)) produces (E, C, H)
496
- """
497
- Wg = torch.stack([exp.gate_proj.weight for exp in self.experts], dim=0)
498
- Wu = torch.stack([exp.up_proj.weight for exp in self.experts], dim=0)
499
- Wd = torch.stack([exp.down_proj.weight for exp in self.experts], dim=0)
500
- return Wg, Wu, Wd
501
-
502
- # ---------------------------------------------------------------------
503
- # Public API
504
- # ---------------------------------------------------------------------
505
- def forward(
506
- self,
507
- hidden_states: torch.Tensor,
508
- *,
509
- return_router_logits: bool = False,
510
- ):
511
- """Sparse Top‑k MoE forward pass ("y_sparse") with optional router logits.
512
-
513
- This is the route used during both training and inference for the *sparse*
514
- branch. No Python loops over experts are used."
515
- """
516
  bsz, seq_len, hidden_size = hidden_states.shape
517
 
518
- # Optional dense branch when using mixed MoE
519
  if self.config.use_mixed_mlp_moe:
520
  hidden_states_mlp = self.shared_mlp(hidden_states)
521
 
522
- # ---------------- Routing ----------------
523
- l_moe, combine_weights, dispatch_mask, _ = self.gate(hidden_states)
524
-
525
- flat_input = hidden_states.reshape(-1, hidden_size) # (S, H) where S = B*T
526
- # dispatch tokens → experts
527
- dispatched_input = torch.einsum("sec,sm->ecm", # (E, C, H)
528
- dispatch_mask.to(hidden_states.dtype),
529
- flat_input)
530
-
531
- # ---------------- Expert computation ----------------
532
- Wg, Wu, Wd = self._stack_weights()
533
- Wg = Wg.to(hidden_states.dtype)
534
- Wu = Wu.to(hidden_states.dtype)
535
- Wd = Wd.to(hidden_states.dtype)
536
-
537
- gate_out = torch.einsum("ech,eih->eci", dispatched_input, Wg) # (E, C, I)
538
- up_out = torch.einsum("ech,eih->eci", dispatched_input, Wu) # (E, C, I)
539
- act_fn = self.experts[0].act_fn
540
- interm = act_fn(gate_out) * up_out # (E, C, I)
541
- expert_output = torch.matmul(interm, Wd.transpose(-2, -1)) # (E, C, H)
542
-
543
- # ---------------- Combine ----------------
544
- combined_output = torch.einsum("sec,ecm->sm", # (S, H)
545
- combine_weights.to(hidden_states.dtype),
546
- expert_output)
547
- combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
548
 
549
- if self.config.use_mixed_mlp_moe:
550
- combined_output = hidden_states_mlp + combined_output
551
 
552
- if return_router_logits:
553
- router_logits = self.gate.wg(flat_input).view(bsz, seq_len, self.num_experts)
554
- return combined_output, router_logits
555
- return combined_output
556
 
557
- # ---------------------------------------------------------------------
558
- # Dense branch – outputs *all* expert activations (no routing)
559
- # ---------------------------------------------------------------------
560
- @torch.no_grad()
561
- def forward_all(self, hidden_states: torch.Tensor):
562
- """Compute every expert on every token ("dense" MoE path).
563
 
564
- Returns:
565
- Tensor of shape (B, T, E, H) – per‑expert hidden states.
566
- """
567
- bsz, seq_len, hidden_size = hidden_states.shape
568
- flat_input = hidden_states.reshape(-1, hidden_size) # (S, H)
569
 
570
- Wg, Wu, Wd = self._stack_weights()
571
- Wg = Wg.to(hidden_states.dtype)
572
- Wu = Wu.to(hidden_states.dtype)
573
- Wd = Wd.to(hidden_states.dtype)
574
 
575
- gate_out = torch.einsum("sh,eih->esi", flat_input, Wg) # (E, S, I)
576
- up_out = torch.einsum("sh,eih->esi", flat_input, Wu) # (E, S, I)
577
- act_fn = self.experts[0].act_fn
578
- interm = act_fn(gate_out) * up_out # (E, S, I)
579
- dense_out = torch.matmul(interm, Wd.transpose(-2, -1)) # (E, S, H)
580
 
581
- dense_out = dense_out.permute(1, 0, 2).contiguous() # (S, E, H)
582
- dense_out = dense_out.view(bsz, seq_len, self.num_experts, hidden_size) # (B, T, E, H)
583
 
584
- if self.config.use_mixed_mlp_moe:
585
- hidden_states_mlp = self.shared_mlp(hidden_states) # (B, T, H)
586
- dense_out = dense_out + hidden_states_mlp.unsqueeze(2)
 
 
 
 
 
 
 
 
587
 
588
- return dense_out
589
  class HunYuanAttention(nn.Module):
590
  """Multi-headed attention from 'Attention Is All You Need' paper"""
591
 
@@ -1128,18 +1064,33 @@ class HunYuanDecoderLayer(nn.Module):
1128
  kv_states: Optional[Tuple[torch.Tensor]] = None,
1129
  **kwargs,
1130
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1131
- """Modified to include Straight‑Through Estimator (STE) training logic."""
1132
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1133
  if "padding_mask" in kwargs:
1134
  warnings.warn(
1135
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
1136
  "`attention_mask` instead.`"
1137
  )
1138
-
1139
- # ---------------- Self‑Attention ----------------
1140
  residual = hidden_states
 
1141
  hidden_states = self.input_layernorm(hidden_states)
1142
-
 
1143
  hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn(
1144
  hidden_states=hidden_states,
1145
  attention_mask=attention_mask,
@@ -1151,40 +1102,47 @@ class HunYuanDecoderLayer(nn.Module):
1151
  **kwargs,
1152
  )
1153
  hidden_states = residual + hidden_states
1154
-
1155
- # ---------------- MLP / MoE (+STE) ----------------
1156
  residual = hidden_states
1157
  hidden_states = self.post_attention_layernorm(hidden_states)
1158
-
1159
- if self.training and isinstance(self.mlp, HunYuanMoE):
1160
- # Sparse path + router logits
1161
- y_sparse, router_logits = self.mlp(hidden_states, return_router_logits=True)
1162
-
1163
- # Dense (all‑experts) path – memory‑efficient, no grad
1164
- with torch.no_grad():
1165
- y_all = self.mlp.forward_all(hidden_states) # (B, T, E, H)
1166
-
1167
- gate = router_logits.softmax(-1).unsqueeze(-1) # (B, T, E, 1)
1168
- y_dense = (gate * y_all).sum(-2) # (B, T, H)
1169
-
1170
- mlp_out = y_dense + (y_sparse - y_dense).detach()
1171
- else:
1172
- mlp_out = self.mlp(hidden_states)
1173
-
1174
- hidden_states = residual + mlp_out
1175
-
1176
- # ---------------- Outputs ----------------
1177
  outputs = (hidden_states,)
1178
-
1179
  if output_attentions:
1180
  outputs += (self_attn_weights,)
1181
-
1182
  if use_cache:
1183
  outputs += (present_key_value,)
1184
-
1185
  outputs += (kv_states,)
1186
-
1187
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1188
  class HunYuanPreTrainedModel(PreTrainedModel):
1189
  config_class = HunYuanConfig
1190
  base_model_prefix = "model"
@@ -1277,6 +1235,10 @@ HUNYUAN_INPUTS_DOCSTRING = r"""
1277
  """
1278
 
1279
 
 
 
 
 
1280
  class HunYuanModel(HunYuanPreTrainedModel):
1281
  """
1282
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`]
@@ -1455,7 +1417,7 @@ class HunYuanModel(HunYuanPreTrainedModel):
1455
  )
1456
 
1457
 
1458
- class HunYuanForCausalLM(HunYuanPreTrainedModel):
1459
  _tied_weights_keys = ["lm_head.weight"]
1460
 
1461
  def __init__(self, config: HunYuanConfig):
@@ -1585,7 +1547,7 @@ class HunYuanForCausalLM(HunYuanPreTrainedModel):
1585
  if isinstance(past_key_values, Cache):
1586
  cache_length = past_key_values.get_seq_length()
1587
  past_length = past_key_values.seen_tokens
1588
- max_cache_length = past_key_values.get_max_length()
1589
  else:
1590
  cache_length = past_length = past_key_values[0][0].shape[2]
1591
  max_cache_length = None
@@ -1644,6 +1606,21 @@ class HunYuanForCausalLM(HunYuanPreTrainedModel):
1644
  return reordered_past
1645
 
1646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1647
  class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
1648
  def __init__(self, config):
1649
  super().__init__(config)
@@ -1748,4 +1725,186 @@ class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
1748
  past_key_values=transformer_outputs.past_key_values,
1749
  hidden_states=transformer_outputs.hidden_states,
1750
  attentions=transformer_outputs.attentions,
1751
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
2
+ #
3
+ # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
  #
15
  """ PyTorch HunYuan model."""
16
 
 
19
  from typing import List, Optional, Tuple, Union
20
 
21
  import torch
 
22
  from torch import Tensor
23
  import torch.nn.functional as F
24
  import torch.utils.checkpoint
 
74
  def topkgating(logits: Tensor, topk: int):
75
  logits = logits.float()
76
  gates = F.softmax(logits, dim=1)
77
+ # expert_capacity = topk * gates.shape[0]
78
+ expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])
79
  num_experts = int(gates.shape[1])
80
  # Top-k router probability and corresponding expert indices for each token.
81
  # Shape: [tokens_per_group, num_selected_experts].
 
265
  self.max_position_embeddings = max_position_embeddings
266
  self.base = base
267
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
268
+ # inv_freq = inv_freq.bfloat16()
269
  self.register_buffer("inv_freq", inv_freq, persistent=False)
270
 
271
  # Build here to make `torch.jit.trace` work.
 
277
  self.max_seq_len_cached = seq_len
278
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
279
 
280
+ self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
281
  freqs = torch.outer(t, self.inv_freq)
282
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
283
  emb = torch.cat((freqs, freqs), dim=-1).float()
 
286
 
287
  def forward(self, x, seq_len=None):
288
  # x: [bs, num_attention_heads, seq_len, head_size]
289
+ if seq_len > self.max_seq_len_cached or self.inv_freq.dtype != torch.float32:
290
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
291
 
292
  return (
 
411
  self.layer_idx = layer_idx
412
  self.hidden_size = config.hidden_size
413
  if is_shared_mlp:
414
+ self.intermediate_size = config.intermediate_size * config.num_shared_expert[0]
415
  else:
416
  self.intermediate_size = config.intermediate_size
417
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
 
462
  if self.moe_topk == 1:
463
  gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
464
  else:
465
+ gate_output = topkgating(logits, self.moe_topk[0])
466
 
467
  return gate_output
468
 
469
 
470
  class HunYuanMoE(nn.Module):
 
 
 
 
 
 
 
471
  def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
472
  super().__init__()
473
  self.config = config
474
  self.layer_idx = layer_idx
475
  self.moe_topk = config.moe_topk
476
  self.num_experts = config.num_experts
 
 
477
  if config.use_mixed_mlp_moe:
478
  self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
 
 
479
  self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
 
 
480
  self.experts = nn.ModuleList(
481
  [HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
482
  )
483
 
484
+ def forward(self, hidden_states):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  bsz, seq_len, hidden_size = hidden_states.shape
486
 
 
487
  if self.config.use_mixed_mlp_moe:
488
  hidden_states_mlp = self.shared_mlp(hidden_states)
489
 
490
+ l_moe, combine_weights, dispatch_mask, exp_counts = self.gate(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
+ reshaped_input = hidden_states.reshape(-1, hidden_size)
 
493
 
494
+ dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
 
 
 
495
 
496
+ chunks = dispatched_input.chunk(self.num_experts, dim=0)
497
+ expert_outputs = []
498
+ for chunk, expert in zip(chunks, self.experts):
499
+ expert_outputs.append(expert(chunk))
 
 
500
 
501
+ expert_output = torch.cat(expert_outputs, dim=0)
502
+ combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
503
+ combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
 
 
504
 
505
+ if self.config.use_mixed_mlp_moe:
506
+ output = hidden_states_mlp + combined_output
507
+ else:
508
+ output = combined_output
509
 
510
+ return output
 
 
 
 
511
 
 
 
512
 
513
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
514
+ """
515
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
516
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
517
+ """
518
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
519
+ if n_rep == 1:
520
+ return hidden_states
521
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
522
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
523
+
524
 
 
525
  class HunYuanAttention(nn.Module):
526
  """Multi-headed attention from 'Attention Is All You Need' paper"""
527
 
 
1064
  kv_states: Optional[Tuple[torch.Tensor]] = None,
1065
  **kwargs,
1066
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1067
+ """
1068
+ Args:
1069
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1070
+ attention_mask (`torch.FloatTensor`, *optional*):
1071
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1072
+ query_sequence_length, key_sequence_length)` if default attention is used.
1073
+ output_attentions (`bool`, *optional*):
1074
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1075
+ returned tensors for more detail.
1076
+ use_cache (`bool`, *optional*):
1077
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1078
+ (see `past_key_values`).
1079
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1080
+ kv_states (`Tuple(torch.FloatTensor)`, *optional*): Used when CLA is enabled,
1081
+ key and value states from past attention blocks
1082
+ """
1083
  if "padding_mask" in kwargs:
1084
  warnings.warn(
1085
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
1086
  "`attention_mask` instead.`"
1087
  )
1088
+
 
1089
  residual = hidden_states
1090
+
1091
  hidden_states = self.input_layernorm(hidden_states)
1092
+
1093
+ # Self Attention
1094
  hidden_states, self_attn_weights, present_key_value, kv_states = self.self_attn(
1095
  hidden_states=hidden_states,
1096
  attention_mask=attention_mask,
 
1102
  **kwargs,
1103
  )
1104
  hidden_states = residual + hidden_states
1105
+
1106
+ # Fully Connected
1107
  residual = hidden_states
1108
  hidden_states = self.post_attention_layernorm(hidden_states)
1109
+ hidden_states = self.mlp(hidden_states)
1110
+ hidden_states = residual + hidden_states
1111
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1112
  outputs = (hidden_states,)
1113
+
1114
  if output_attentions:
1115
  outputs += (self_attn_weights,)
1116
+
1117
  if use_cache:
1118
  outputs += (present_key_value,)
1119
+
1120
  outputs += (kv_states,)
1121
+
1122
  return outputs
1123
+
1124
+
1125
+ HUNYUAN_START_DOCSTRING = r"""
1126
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1127
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1128
+ etc.)
1129
+
1130
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1131
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1132
+ and behavior.
1133
+
1134
+ Parameters:
1135
+ config ([`HunYuanConfig`]):
1136
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1137
+ load the weights associated with the model, only the configuration. Check out the
1138
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1139
+ """
1140
+
1141
+
1142
+ @add_start_docstrings(
1143
+ "The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
1144
+ HUNYUAN_START_DOCSTRING,
1145
+ )
1146
  class HunYuanPreTrainedModel(PreTrainedModel):
1147
  config_class = HunYuanConfig
1148
  base_model_prefix = "model"
 
1235
  """
1236
 
1237
 
1238
+ @add_start_docstrings(
1239
+ "The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
1240
+ HUNYUAN_START_DOCSTRING,
1241
+ )
1242
  class HunYuanModel(HunYuanPreTrainedModel):
1243
  """
1244
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`]
 
1417
  )
1418
 
1419
 
1420
+ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
1421
  _tied_weights_keys = ["lm_head.weight"]
1422
 
1423
  def __init__(self, config: HunYuanConfig):
 
1547
  if isinstance(past_key_values, Cache):
1548
  cache_length = past_key_values.get_seq_length()
1549
  past_length = past_key_values.seen_tokens
1550
+ max_cache_length = past_key_values.get_max_cache_shape()
1551
  else:
1552
  cache_length = past_length = past_key_values[0][0].shape[2]
1553
  max_cache_length = None
 
1606
  return reordered_past
1607
 
1608
 
1609
+ @add_start_docstrings(
1610
+ """
1611
+ The HunYuan Model transformer with a sequence classification head on top (linear layer).
1612
+
1613
+ [`HunYuanForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1614
+ (e.g. GPT-2) do.
1615
+
1616
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1617
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1618
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1619
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1620
+ each row of the batch).
1621
+ """,
1622
+ HUNYUAN_START_DOCSTRING,
1623
+ )
1624
  class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
1625
  def __init__(self, config):
1626
  super().__init__(config)
 
1725
  past_key_values=transformer_outputs.past_key_values,
1726
  hidden_states=transformer_outputs.hidden_states,
1727
  attentions=transformer_outputs.attentions,
1728
+ )
1729
+
1730
+
1731
+ # ================================================================
1732
+ # Dense/Sparse MoE utilities
1733
+ # These enable dense training (all experts active) with sparse inference,
1734
+ # by fusing the per‑expert parameters into single linear layers and
1735
+ # applying a straight‑through estimator (STE) – gradients flow through the
1736
+ # dense path, while the forward pass matches sparse routing behaviour.
1737
+ # ================================================================
1738
+
1739
+ import types
1740
+
1741
+ class HunYuanDenseMoE(nn.Module):
1742
+ """Dense counterpart of :class:`HunYuanMoE`.
1743
+
1744
+ * The per‑expert linear layers (``gate_proj``, ``up_proj``, ``down_proj``)
1745
+ are concatenated to form three *fused* linear layers.
1746
+ * Forward pass:
1747
+ 1. **Dense path** – every expert contributes,
1748
+ weighted by the softmax router probabilities.
1749
+ 2. **Sparse path** – only the Top‑K experts (identical to the
1750
+ original sparse MoE) are evaluated.
1751
+ 3. **STE** – ``output = dense + (sparse – dense).detach()`` –
1752
+ identical values to sparse inference, dense gradients.
1753
+ """
1754
+
1755
+ def __init__(self, moe: "HunYuanMoE"):
1756
+ super().__init__()
1757
+ self.config = moe.config
1758
+ self.layer_idx = moe.layer_idx
1759
+ self.num_experts = moe.num_experts
1760
+ # All experts share the same hidden/intermediate sizes
1761
+ self.hidden_size = moe.experts[0].hidden_size
1762
+ self.intermediate_size = moe.experts[0].intermediate_size
1763
+
1764
+ # Router is reused directly
1765
+ self.gate = moe.gate
1766
+
1767
+ # ------------------------------------------------------------------
1768
+ # Fuse per‑expert parameters
1769
+ # ------------------------------------------------------------------
1770
+ with torch.no_grad():
1771
+ fused_gate_w = torch.cat([exp.gate_proj.weight for exp in moe.experts], dim=0).clone()
1772
+ fused_up_w = torch.cat([exp.up_proj.weight for exp in moe.experts], dim=0).clone()
1773
+ # down_proj weights are shaped (hidden, intermediate)
1774
+ fused_down_w = torch.cat([exp.down_proj.weight for exp in moe.experts], dim=1).clone()
1775
+
1776
+ self.fused_gate_proj = nn.Linear(self.hidden_size,
1777
+ self.intermediate_size * self.num_experts,
1778
+ bias=False)
1779
+ self.fused_up_proj = nn.Linear(self.hidden_size,
1780
+ self.intermediate_size * self.num_experts,
1781
+ bias=False)
1782
+ self.fused_down_proj = nn.Linear(self.intermediate_size * self.num_experts,
1783
+ self.hidden_size,
1784
+ bias=False)
1785
+
1786
+ # Load weights
1787
+ self.fused_gate_proj.weight.data.copy_(fused_gate_w)
1788
+ self.fused_up_proj.weight.data.copy_(fused_up_w)
1789
+ self.fused_down_proj.weight.data.copy_(fused_down_w)
1790
+
1791
+ self.act_fn = moe.experts[0].act_fn
1792
+ self.topk = self.gate.moe_topk[0] if isinstance(self.gate.moe_topk, (list, tuple)) else self.gate.moe_topk
1793
+
1794
+ def _dense_path(self, x, probs):
1795
+ """Compute dense mixture – every expert active."""
1796
+ gate_out = self.fused_gate_proj(x) # (T, I*E)
1797
+ up_out = self.fused_up_proj(x) # (T, I*E)
1798
+ interm = self.act_fn(gate_out) * up_out # (T, I*E)
1799
+
1800
+ # Reshape to (T, E, I)
1801
+ interm = interm.view(-1, self.num_experts, self.intermediate_size)
1802
+
1803
+ # Weight by softmax router probabilities
1804
+ interm_weighted = interm * probs.unsqueeze(-1) # (T, E, I)
1805
+
1806
+ # Collapse experts back to vector and project down
1807
+ dense_flat = interm_weighted.reshape(-1, self.intermediate_size * self.num_experts)
1808
+ return self.fused_down_proj(dense_flat) # (T, H)
1809
+
1810
+ def _sparse_path(self, x, probs):
1811
+ """Compute sparse Top‑K mixture (matches original inference)."""
1812
+ # Pre‑compute per‑expert activations to avoid repeated forward calls
1813
+ gate_out = self.fused_gate_proj(x) # (T, I*E)
1814
+ up_out = self.fused_up_proj(x) # (T, I*E)
1815
+ interm = self.act_fn(gate_out) * up_out # (T, I*E)
1816
+ interm = interm.view(-1, self.num_experts, self.intermediate_size) # (T,E,I)
1817
+
1818
+ # Top‑K experts per token
1819
+ values, indices = torch.topk(probs, self.topk, dim=1) # (T,K)
1820
+
1821
+ # Gather corresponding intermediate activations
1822
+ gathered = interm.gather(1, indices.unsqueeze(-1).expand(-1, -1, self.intermediate_size)) # (T,K,I)
1823
+ gathered = gathered * values.unsqueeze(-1) # weight by router prob
1824
+
1825
+ # Gather matching down-proj weights
1826
+ # fused_down_proj.weight: (H, I*E) -> reshape to (E,I,H)
1827
+ down_w = self.fused_down_proj.weight.view(self.hidden_size,
1828
+ self.num_experts,
1829
+ self.intermediate_size).permute(1,2,0).contiguous() # (E,I,H)
1830
+ selected_w = down_w.index_select(0, indices.reshape(-1)).view(indices.size(0), self.topk,
1831
+ self.intermediate_size, self.hidden_size)
1832
+ # (T,K,I,H)
1833
+
1834
+ # Compute output: batch matmul over I
1835
+ sparse_out = torch.einsum('t k i, t k i h -> t h', gathered, selected_w)
1836
+ return sparse_out # (T, H)
1837
+
1838
+ def forward(self, hidden_states):
1839
+ bsz, seq_len, _ = hidden_states.shape
1840
+ x = hidden_states.reshape(-1, self.hidden_size) # T x H
1841
+ logits = self.gate.wg(x) # (T,E)
1842
+ probs = torch.softmax(logits, dim=1) # (T,E)
1843
+
1844
+ dense_out = self._dense_path(x, probs)
1845
+ sparse_out = self._sparse_path(x, probs)
1846
+
1847
+ out = dense_out + (sparse_out - dense_out).detach() # STE
1848
+ return out.view(bsz, seq_len, self.hidden_size)
1849
+
1850
+ # -----------------------------------------------------------------------
1851
+ # Helper for module replacement
1852
+ # -----------------------------------------------------------------------
1853
+ def _replace_submodule(root: nn.Module, target: str, new_module: nn.Module):
1854
+ """Replace a (possibly nested) sub‑module.
1855
+
1856
+ ``target`` is the dotted path returned by ``model.named_modules()``.
1857
+ """
1858
+ parts = target.split('.')
1859
+ parent = root
1860
+ for p in parts[:-1]:
1861
+ parent = getattr(parent, p)
1862
+ setattr(parent, parts[-1], new_module)
1863
+
1864
+ # -----------------------------------------------------------------------
1865
+ # Public APIs
1866
+ # -----------------------------------------------------------------------
1867
+ def densify(model: nn.Module):
1868
+ """Convert all :class:`HunYuanMoE` modules under *model* to
1869
+ :class:`HunYuanDenseMoE`. Operates **in‑place**."""
1870
+ replacements = []
1871
+ for name, module in model.named_modules():
1872
+ if isinstance(module, HunYuanMoE):
1873
+ replacements.append((name, module))
1874
+ for name, sparse_moe in replacements:
1875
+ dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
1876
+ _replace_submodule(model, name, dense_moe)
1877
+ return model
1878
+
1879
+
1880
+ def sparsify(model: nn.Module):
1881
+ """Rebuild standard sparse :class:`HunYuanMoE` modules from their
1882
+ fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
1883
+ replacements = []
1884
+ for name, module in model.named_modules():
1885
+ if isinstance(module, HunYuanDenseMoE):
1886
+ replacements.append((name, module))
1887
+ for name, dense_moe in replacements:
1888
+ cfg = dense_moe.config
1889
+ sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
1890
+
1891
+ # Copy router
1892
+ sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
1893
+
1894
+ # Slice fused weights back to per‑expert
1895
+ for idx, expert in enumerate(sparse_moe.experts):
1896
+ start = idx * dense_moe.intermediate_size
1897
+ end = (idx + 1) * dense_moe.intermediate_size
1898
+
1899
+ expert.gate_proj.weight.data.copy_(
1900
+ dense_moe.fused_gate_proj.weight.data[start:end]
1901
+ )
1902
+ expert.up_proj.weight.data.copy_(
1903
+ dense_moe.fused_up_proj.weight.data[start:end]
1904
+ )
1905
+ expert.down_proj.weight.data.copy_(
1906
+ dense_moe.fused_down_proj.weight.data[:, start:end]
1907
+ )
1908
+
1909
+ _replace_submodule(model, name, sparse_moe)
1910
+ return model