from transformers import PretrainedConfig class DITConfig(PretrainedConfig): model_type = "dit" def __init__( self, vocab_size: int = 50258, max_seq_len: int = 1024, hidden_size: int = 768, timestep_cond_dim: int = 128, num_hidden_layers: int = 12, num_attention_heads: int = 12, attention_dropout: float = 0.0, p_uniform: float = 0.0, t_eps: float = 1e-4, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.max_seq_len = max_seq_len self.hidden_size = hidden_size self.timestep_cond_dim = timestep_cond_dim self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.attention_dropout = attention_dropout self.p_uniform = p_uniform self.t_eps = t_eps