File size: 5,002 Bytes
bed38a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from transformers import PretrainedConfig

class IsoformerConfig(PretrainedConfig):
    model_type = "isoformer"

    def __init__(
        self,
        esm_vocab_size=None,
        esm_mask_token_id=None,
        esm_pad_token_id=None,
        esm_hidden_size=768,
        esm_num_hidden_layers=12,
        esm_num_attention_heads=12,
        esm_intermediate_size=3072,
        esm_hidden_dropout_prob=0.1,
        esm_attention_probs_dropout_prob=0.1,
        esm_max_position_embeddings=1026,
        esm_position_embedding_type="absolute",
        esm_use_cache=True,
        esm_emb_layer_norm_before=None,
        esm_token_dropout=False,
        esm_add_bias_fnn=True,
        esm_tie_word_embeddings=0,
        nt_vocab_size=None,
        nt_mask_token_id=None,
        nt_pad_token_id=None,
        nt_hidden_size=768,
        nt_num_hidden_layers=12,
        nt_num_attention_heads=12,
        nt_intermediate_size=3072,
        nt_hidden_dropout_prob=0.1,
        nt_attention_probs_dropout_prob=0.1,
        nt_max_position_embeddings=1026,
        nt_position_embedding_type="absolute",
        nt_use_cache=True,
        nt_emb_layer_norm_before=None,
        nt_token_dropout=False,
        nt_add_bias_fnn=True,
        nt_tie_word_embeddings=0,
        enformer_dim=1536,
        enformer_depth=11,
        enformer_heads=8,
        enformer_output_heads=0,
        enformer_target_length=896,
        enformer_attn_dim_key=64,
        enformer_dropout_rate=0.4,
        enformer_attn_dropout=0.05,
        enformer_pos_dropout=0.01,
        enformer_use_checkpointing=False,
        enformer_use_convnext=False,
        enformer_num_downsamples=7,  # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
        enformer_dim_divisible_by=128,
        enformer_use_tf_gamma=False,
        num_heads_omics_cross_attention=8,
        num_tokens_per_seq_nuctf=2048,
        num_tokens_per_seq_nuctf_rna=2048,
        num_protein_tokens_per_seq=2048,
        **kwargs,
    ):
        self.esm_vocab_size = esm_vocab_size
        self.esm_mask_token_id = esm_mask_token_id
        self.esm_pad_token_id = esm_pad_token_id
        self.esm_hidden_size = esm_hidden_size
        self.esm_num_hidden_layers = esm_num_hidden_layers
        self.esm_num_attention_heads = esm_num_attention_heads
        self.esm_intermediate_size = esm_intermediate_size
        self.esm_max_position_embeddings = esm_max_position_embeddings
        self.esm_token_dropout = esm_token_dropout
        self.esm_emb_layer_norm_before = esm_emb_layer_norm_before
        self.esm_attention_probs_dropout_prob = esm_attention_probs_dropout_prob
        self.esm_hidden_dropout_prob = esm_hidden_dropout_prob
        self.esm_use_cache = esm_use_cache
        self.esm_add_bias_fnn = esm_add_bias_fnn
        self.esm_position_embedding_type = esm_position_embedding_type
        self.esm_tie_word_embeddings = esm_tie_word_embeddings
        self.nt_vocab_size = nt_vocab_size
        self.nt_mask_token_id = nt_mask_token_id
        self.nt_pad_token_id = nt_pad_token_id
        self.nt_hidden_size = nt_hidden_size
        self.nt_num_hidden_layers = nt_num_hidden_layers
        self.nt_num_attention_heads = nt_num_attention_heads
        self.nt_intermediate_size = nt_intermediate_size
        self.nt_max_position_embeddings = nt_max_position_embeddings
        self.nt_token_dropout = nt_token_dropout
        self.nt_emb_layer_norm_before = nt_emb_layer_norm_before
        self.nt_attention_probs_dropout_prob = nt_attention_probs_dropout_prob
        self.nt_hidden_dropout_prob = nt_hidden_dropout_prob
        self.nt_use_cache = nt_use_cache
        self.nt_add_bias_fnn = nt_add_bias_fnn
        self.nt_position_embedding_type = nt_position_embedding_type
        self.nt_tie_word_embeddings = nt_tie_word_embeddings
        self.enformer_dim = enformer_dim
        self.enformer_depth = enformer_depth
        self.enformer_heads = enformer_heads
        self.enformer_output_heads = enformer_output_heads
        self.enformer_target_length = enformer_target_length
        self.enformer_attn_dim_key = enformer_attn_dim_key
        self.enformer_dropout_rate = enformer_dropout_rate
        self.enformer_attn_dropout = enformer_attn_dropout
        self.enformer_pos_dropout = enformer_pos_dropout
        self.enformer_use_checkpointing = enformer_use_checkpointing
        self.enformer_use_convnext = enformer_use_convnext
        self.enformer_num_downsamples = enformer_num_downsamples
        self.enformer_dim_divisible_by = enformer_dim_divisible_by
        self.enformer_use_tf_gamma = enformer_use_tf_gamma
        self.num_heads_omics_cross_attention = num_heads_omics_cross_attention
        self.num_tokens_per_seq_nuctf = num_tokens_per_seq_nuctf
        self.num_tokens_per_seq_nuctf_rna = num_tokens_per_seq_nuctf_rna
        self.num_protein_tokens_per_seq = num_protein_tokens_per_seq

        super().__init__(**kwargs)