from typing import Optional, List, Any, Dict from transformers.configuration_utils import PretrainedConfig class Step1Config(PretrainedConfig): model_type = "step1" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, hidden_size: int = 5120, intermediate_size: int = 13312, num_attention_heads: int = 40, num_attention_groups: int = 8, num_hidden_layers: int = 48, max_seq_len: int = 4096, vocab_size: int = 65536, rms_norm_eps: float = 1e-5, bos_token_id: int = 1, eos_token_id: int = 3, pad_token_id: int = 0, **kwargs, ) -> None: self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.num_attention_groups = num_attention_groups self.num_hidden_layers = num_hidden_layers self.max_seq_len = max_seq_len self.vocab_size = vocab_size self.rms_norm_eps = rms_norm_eps super().__init__( bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_id=eos_token_id, **kwargs ) __all__ = ["Step1Config"]