Isoformer / isoformer_config.py
isoformer-anonymous's picture
Upload Isoformer
bed38a1 verified
raw
history blame
5 kB
from transformers import PretrainedConfig
class IsoformerConfig(PretrainedConfig):
model_type = "isoformer"
def __init__(
self,
esm_vocab_size=None,
esm_mask_token_id=None,
esm_pad_token_id=None,
esm_hidden_size=768,
esm_num_hidden_layers=12,
esm_num_attention_heads=12,
esm_intermediate_size=3072,
esm_hidden_dropout_prob=0.1,
esm_attention_probs_dropout_prob=0.1,
esm_max_position_embeddings=1026,
esm_position_embedding_type="absolute",
esm_use_cache=True,
esm_emb_layer_norm_before=None,
esm_token_dropout=False,
esm_add_bias_fnn=True,
esm_tie_word_embeddings=0,
nt_vocab_size=None,
nt_mask_token_id=None,
nt_pad_token_id=None,
nt_hidden_size=768,
nt_num_hidden_layers=12,
nt_num_attention_heads=12,
nt_intermediate_size=3072,
nt_hidden_dropout_prob=0.1,
nt_attention_probs_dropout_prob=0.1,
nt_max_position_embeddings=1026,
nt_position_embedding_type="absolute",
nt_use_cache=True,
nt_emb_layer_norm_before=None,
nt_token_dropout=False,
nt_add_bias_fnn=True,
nt_tie_word_embeddings=0,
enformer_dim=1536,
enformer_depth=11,
enformer_heads=8,
enformer_output_heads=0,
enformer_target_length=896,
enformer_attn_dim_key=64,
enformer_dropout_rate=0.4,
enformer_attn_dropout=0.05,
enformer_pos_dropout=0.01,
enformer_use_checkpointing=False,
enformer_use_convnext=False,
enformer_num_downsamples=7, # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
enformer_dim_divisible_by=128,
enformer_use_tf_gamma=False,
num_heads_omics_cross_attention=8,
num_tokens_per_seq_nuctf=2048,
num_tokens_per_seq_nuctf_rna=2048,
num_protein_tokens_per_seq=2048,
**kwargs,
):
self.esm_vocab_size = esm_vocab_size
self.esm_mask_token_id = esm_mask_token_id
self.esm_pad_token_id = esm_pad_token_id
self.esm_hidden_size = esm_hidden_size
self.esm_num_hidden_layers = esm_num_hidden_layers
self.esm_num_attention_heads = esm_num_attention_heads
self.esm_intermediate_size = esm_intermediate_size
self.esm_max_position_embeddings = esm_max_position_embeddings
self.esm_token_dropout = esm_token_dropout
self.esm_emb_layer_norm_before = esm_emb_layer_norm_before
self.esm_attention_probs_dropout_prob = esm_attention_probs_dropout_prob
self.esm_hidden_dropout_prob = esm_hidden_dropout_prob
self.esm_use_cache = esm_use_cache
self.esm_add_bias_fnn = esm_add_bias_fnn
self.esm_position_embedding_type = esm_position_embedding_type
self.esm_tie_word_embeddings = esm_tie_word_embeddings
self.nt_vocab_size = nt_vocab_size
self.nt_mask_token_id = nt_mask_token_id
self.nt_pad_token_id = nt_pad_token_id
self.nt_hidden_size = nt_hidden_size
self.nt_num_hidden_layers = nt_num_hidden_layers
self.nt_num_attention_heads = nt_num_attention_heads
self.nt_intermediate_size = nt_intermediate_size
self.nt_max_position_embeddings = nt_max_position_embeddings
self.nt_token_dropout = nt_token_dropout
self.nt_emb_layer_norm_before = nt_emb_layer_norm_before
self.nt_attention_probs_dropout_prob = nt_attention_probs_dropout_prob
self.nt_hidden_dropout_prob = nt_hidden_dropout_prob
self.nt_use_cache = nt_use_cache
self.nt_add_bias_fnn = nt_add_bias_fnn
self.nt_position_embedding_type = nt_position_embedding_type
self.nt_tie_word_embeddings = nt_tie_word_embeddings
self.enformer_dim = enformer_dim
self.enformer_depth = enformer_depth
self.enformer_heads = enformer_heads
self.enformer_output_heads = enformer_output_heads
self.enformer_target_length = enformer_target_length
self.enformer_attn_dim_key = enformer_attn_dim_key
self.enformer_dropout_rate = enformer_dropout_rate
self.enformer_attn_dropout = enformer_attn_dropout
self.enformer_pos_dropout = enformer_pos_dropout
self.enformer_use_checkpointing = enformer_use_checkpointing
self.enformer_use_convnext = enformer_use_convnext
self.enformer_num_downsamples = enformer_num_downsamples
self.enformer_dim_divisible_by = enformer_dim_divisible_by
self.enformer_use_tf_gamma = enformer_use_tf_gamma
self.num_heads_omics_cross_attention = num_heads_omics_cross_attention
self.num_tokens_per_seq_nuctf = num_tokens_per_seq_nuctf
self.num_tokens_per_seq_nuctf_rna = num_tokens_per_seq_nuctf_rna
self.num_protein_tokens_per_seq = num_protein_tokens_per_seq
super().__init__(**kwargs)