manaestras commited on
Commit
1cb4ead
·
verified ·
1 Parent(s): 7032224

Upload hunyuan.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hunyuan.py +3 -3
hunyuan.py CHANGED
@@ -358,7 +358,7 @@ class HunYuanModel(HunYuanPreTrainedModel):
358
  )
359
 
360
 
361
- class HunYuanForCausalLM(HunYuanPreTrainedModel):
362
  _tied_weights_keys = ["lm_head.weight"]
363
 
364
  def __init__(self, config: HunYuanConfig):
@@ -527,7 +527,7 @@ class HunYuanForCausalLM(HunYuanPreTrainedModel):
527
  if isinstance(past_key_values, Cache):
528
  cache_length = past_key_values.get_seq_length()
529
  past_length = past_key_values.seen_tokens
530
- max_cache_length = past_key_values.get_max_length()
531
  else:
532
  cache_length = past_length = past_key_values[0][0].shape[2]
533
  max_cache_length = None
@@ -586,7 +586,7 @@ class HunYuanForCausalLM(HunYuanPreTrainedModel):
586
  return reordered_past
587
 
588
 
589
- class MultimodelHunYuanForCausalLM(HunYuanForCausalLM):
590
  _tied_weights_keys = ["lm_head.weight"]
591
 
592
  def __init__(self, config: HunYuanConfig):
 
358
  )
359
 
360
 
361
+ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
362
  _tied_weights_keys = ["lm_head.weight"]
363
 
364
  def __init__(self, config: HunYuanConfig):
 
527
  if isinstance(past_key_values, Cache):
528
  cache_length = past_key_values.get_seq_length()
529
  past_length = past_key_values.seen_tokens
530
+ max_cache_length = past_key_values.get_max_cache_shape()
531
  else:
532
  cache_length = past_length = past_key_values[0][0].shape[2]
533
  max_cache_length = None
 
586
  return reordered_past
587
 
588
 
589
+ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
590
  _tied_weights_keys = ["lm_head.weight"]
591
 
592
  def __init__(self, config: HunYuanConfig):