manaestras commited on
Commit
03c370c
·
verified ·
1 Parent(s): 0226108

Upload modeling_hunyuan.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_hunyuan.py +3 -3
modeling_hunyuan.py CHANGED
@@ -74,7 +74,7 @@ _CONFIG_FOR_DOC = "HunYuanConfig"
74
  def topkgating(logits: Tensor, topk: int):
75
  logits = logits.float()
76
  gates = F.softmax(logits, dim=1)
77
- # expert_capacity = topk * gates.shape[0]
78
  expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])
79
  num_experts = int(gates.shape[1])
80
  # Top-k router probability and corresponding expert indices for each token.
@@ -1417,7 +1417,7 @@ class HunYuanModel(HunYuanPreTrainedModel):
1417
  )
1418
 
1419
 
1420
- class HunYuanForCausalLM(HunYuanPreTrainedModel):
1421
  _tied_weights_keys = ["lm_head.weight"]
1422
 
1423
  def __init__(self, config: HunYuanConfig):
@@ -1547,7 +1547,7 @@ class HunYuanForCausalLM(HunYuanPreTrainedModel):
1547
  if isinstance(past_key_values, Cache):
1548
  cache_length = past_key_values.get_seq_length()
1549
  past_length = past_key_values.seen_tokens
1550
- max_cache_length = past_key_values.get_max_length()
1551
  else:
1552
  cache_length = past_length = past_key_values[0][0].shape[2]
1553
  max_cache_length = None
 
74
  def topkgating(logits: Tensor, topk: int):
75
  logits = logits.float()
76
  gates = F.softmax(logits, dim=1)
77
+ # expert_capacity = topk * gates.shape[0]
78
  expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])
79
  num_experts = int(gates.shape[1])
80
  # Top-k router probability and corresponding expert indices for each token.
 
1417
  )
1418
 
1419
 
1420
+ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
1421
  _tied_weights_keys = ["lm_head.weight"]
1422
 
1423
  def __init__(self, config: HunYuanConfig):
 
1547
  if isinstance(past_key_values, Cache):
1548
  cache_length = past_key_values.get_seq_length()
1549
  past_length = past_key_values.seen_tokens
1550
+ max_cache_length = past_key_values.get_max_cache_shape()
1551
  else:
1552
  cache_length = past_length = past_key_values[0][0].shape[2]
1553
  max_cache_length = None