manaestras commited on
Commit
261e790
·
verified ·
1 Parent(s): bfcb1f7

Upload ./modeling_hunyuan.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_hunyuan.py +3 -2
modeling_hunyuan.py CHANGED
@@ -74,7 +74,8 @@ _CONFIG_FOR_DOC = "HunYuanConfig"
74
  def topkgating(logits: Tensor, topk: int):
75
  logits = logits.float()
76
  gates = F.softmax(logits, dim=1)
77
- expert_capacity = topk * gates.shape[0]
 
78
  num_experts = int(gates.shape[1])
79
  # Top-k router probability and corresponding expert indices for each token.
80
  # Shape: [tokens_per_group, num_selected_experts].
@@ -1546,7 +1547,7 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
1546
  if isinstance(past_key_values, Cache):
1547
  cache_length = past_key_values.get_seq_length()
1548
  past_length = past_key_values.seen_tokens
1549
- max_cache_length = past_key_values.get_max_length()
1550
  else:
1551
  cache_length = past_length = past_key_values[0][0].shape[2]
1552
  max_cache_length = None
 
74
  def topkgating(logits: Tensor, topk: int):
75
  logits = logits.float()
76
  gates = F.softmax(logits, dim=1)
77
+ # expert_capacity = topk * gates.shape[0]
78
+ expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])
79
  num_experts = int(gates.shape[1])
80
  # Top-k router probability and corresponding expert indices for each token.
81
  # Shape: [tokens_per_group, num_selected_experts].
 
1547
  if isinstance(past_key_values, Cache):
1548
  cache_length = past_key_values.get_seq_length()
1549
  past_length = past_key_values.seen_tokens
1550
+ max_cache_length = past_key_values.get_max_cache_shape()
1551
  else:
1552
  cache_length = past_length = past_key_values[0][0].shape[2]
1553
  max_cache_length = None