from transformers import PretrainedConfig class IsoformerConfig(PretrainedConfig): model_type = "isoformer" def __init__( self, esm_vocab_size=None, esm_mask_token_id=None, esm_pad_token_id=None, esm_hidden_size=768, esm_num_hidden_layers=12, esm_num_attention_heads=12, esm_intermediate_size=3072, esm_hidden_dropout_prob=0.1, esm_attention_probs_dropout_prob=0.1, esm_max_position_embeddings=1026, esm_position_embedding_type="absolute", esm_use_cache=True, esm_emb_layer_norm_before=None, esm_token_dropout=False, esm_add_bias_fnn=True, esm_tie_word_embeddings=0, nt_vocab_size=None, nt_mask_token_id=None, nt_pad_token_id=None, nt_hidden_size=768, nt_num_hidden_layers=12, nt_num_attention_heads=12, nt_intermediate_size=3072, nt_hidden_dropout_prob=0.1, nt_attention_probs_dropout_prob=0.1, nt_max_position_embeddings=1026, nt_position_embedding_type="absolute", nt_use_cache=True, nt_emb_layer_norm_before=None, nt_token_dropout=False, nt_add_bias_fnn=True, nt_tie_word_embeddings=0, enformer_dim=1536, enformer_depth=11, enformer_heads=8, enformer_output_heads=0, enformer_target_length=896, enformer_attn_dim_key=64, enformer_dropout_rate=0.4, enformer_attn_dropout=0.05, enformer_pos_dropout=0.01, enformer_use_checkpointing=False, enformer_use_convnext=False, enformer_num_downsamples=7, # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution enformer_dim_divisible_by=128, enformer_use_tf_gamma=False, num_heads_omics_cross_attention=8, num_tokens_per_seq_nuctf=2048, num_tokens_per_seq_nuctf_rna=2048, num_protein_tokens_per_seq=2048, **kwargs, ): self.esm_vocab_size = esm_vocab_size self.esm_mask_token_id = esm_mask_token_id self.esm_pad_token_id = esm_pad_token_id self.esm_hidden_size = esm_hidden_size self.esm_num_hidden_layers = esm_num_hidden_layers self.esm_num_attention_heads = esm_num_attention_heads self.esm_intermediate_size = esm_intermediate_size self.esm_max_position_embeddings = esm_max_position_embeddings self.esm_token_dropout = esm_token_dropout self.esm_emb_layer_norm_before = esm_emb_layer_norm_before self.esm_attention_probs_dropout_prob = esm_attention_probs_dropout_prob self.esm_hidden_dropout_prob = esm_hidden_dropout_prob self.esm_use_cache = esm_use_cache self.esm_add_bias_fnn = esm_add_bias_fnn self.esm_position_embedding_type = esm_position_embedding_type self.esm_tie_word_embeddings = esm_tie_word_embeddings self.nt_vocab_size = nt_vocab_size self.nt_mask_token_id = nt_mask_token_id self.nt_pad_token_id = nt_pad_token_id self.nt_hidden_size = nt_hidden_size self.nt_num_hidden_layers = nt_num_hidden_layers self.nt_num_attention_heads = nt_num_attention_heads self.nt_intermediate_size = nt_intermediate_size self.nt_max_position_embeddings = nt_max_position_embeddings self.nt_token_dropout = nt_token_dropout self.nt_emb_layer_norm_before = nt_emb_layer_norm_before self.nt_attention_probs_dropout_prob = nt_attention_probs_dropout_prob self.nt_hidden_dropout_prob = nt_hidden_dropout_prob self.nt_use_cache = nt_use_cache self.nt_add_bias_fnn = nt_add_bias_fnn self.nt_position_embedding_type = nt_position_embedding_type self.nt_tie_word_embeddings = nt_tie_word_embeddings self.enformer_dim = enformer_dim self.enformer_depth = enformer_depth self.enformer_heads = enformer_heads self.enformer_output_heads = enformer_output_heads self.enformer_target_length = enformer_target_length self.enformer_attn_dim_key = enformer_attn_dim_key self.enformer_dropout_rate = enformer_dropout_rate self.enformer_attn_dropout = enformer_attn_dropout self.enformer_pos_dropout = enformer_pos_dropout self.enformer_use_checkpointing = enformer_use_checkpointing self.enformer_use_convnext = enformer_use_convnext self.enformer_num_downsamples = enformer_num_downsamples self.enformer_dim_divisible_by = enformer_dim_divisible_by self.enformer_use_tf_gamma = enformer_use_tf_gamma self.num_heads_omics_cross_attention = num_heads_omics_cross_attention self.num_tokens_per_seq_nuctf = num_tokens_per_seq_nuctf self.num_tokens_per_seq_nuctf_rna = num_tokens_per_seq_nuctf_rna self.num_protein_tokens_per_seq = num_protein_tokens_per_seq super().__init__(**kwargs)