isoformer-anonymous commited on
Commit
bed38a1
1 Parent(s): 12d0972

Upload Isoformer

Browse files
Files changed (4) hide show
  1. config.json +4 -0
  2. isoformer_config.py +111 -0
  3. modeling_isoformer.py +168 -0
  4. pytorch_model.bin +1 -1
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "Isoformer"
4
  ],
 
 
 
 
5
  "enformer_attn_dim_key": 64,
6
  "enformer_attn_dropout": 0.05,
7
  "enformer_depth": 11,
 
2
  "architectures": [
3
  "Isoformer"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "isoformer_config.IsoformerConfig",
7
+ "AutoModel": "modeling_isoformer.Isoformer"
8
+ },
9
  "enformer_attn_dim_key": 64,
10
  "enformer_attn_dropout": 0.05,
11
  "enformer_depth": 11,
isoformer_config.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class IsoformerConfig(PretrainedConfig):
4
+ model_type = "isoformer"
5
+
6
+ def __init__(
7
+ self,
8
+ esm_vocab_size=None,
9
+ esm_mask_token_id=None,
10
+ esm_pad_token_id=None,
11
+ esm_hidden_size=768,
12
+ esm_num_hidden_layers=12,
13
+ esm_num_attention_heads=12,
14
+ esm_intermediate_size=3072,
15
+ esm_hidden_dropout_prob=0.1,
16
+ esm_attention_probs_dropout_prob=0.1,
17
+ esm_max_position_embeddings=1026,
18
+ esm_position_embedding_type="absolute",
19
+ esm_use_cache=True,
20
+ esm_emb_layer_norm_before=None,
21
+ esm_token_dropout=False,
22
+ esm_add_bias_fnn=True,
23
+ esm_tie_word_embeddings=0,
24
+ nt_vocab_size=None,
25
+ nt_mask_token_id=None,
26
+ nt_pad_token_id=None,
27
+ nt_hidden_size=768,
28
+ nt_num_hidden_layers=12,
29
+ nt_num_attention_heads=12,
30
+ nt_intermediate_size=3072,
31
+ nt_hidden_dropout_prob=0.1,
32
+ nt_attention_probs_dropout_prob=0.1,
33
+ nt_max_position_embeddings=1026,
34
+ nt_position_embedding_type="absolute",
35
+ nt_use_cache=True,
36
+ nt_emb_layer_norm_before=None,
37
+ nt_token_dropout=False,
38
+ nt_add_bias_fnn=True,
39
+ nt_tie_word_embeddings=0,
40
+ enformer_dim=1536,
41
+ enformer_depth=11,
42
+ enformer_heads=8,
43
+ enformer_output_heads=0,
44
+ enformer_target_length=896,
45
+ enformer_attn_dim_key=64,
46
+ enformer_dropout_rate=0.4,
47
+ enformer_attn_dropout=0.05,
48
+ enformer_pos_dropout=0.01,
49
+ enformer_use_checkpointing=False,
50
+ enformer_use_convnext=False,
51
+ enformer_num_downsamples=7, # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
52
+ enformer_dim_divisible_by=128,
53
+ enformer_use_tf_gamma=False,
54
+ num_heads_omics_cross_attention=8,
55
+ num_tokens_per_seq_nuctf=2048,
56
+ num_tokens_per_seq_nuctf_rna=2048,
57
+ num_protein_tokens_per_seq=2048,
58
+ **kwargs,
59
+ ):
60
+ self.esm_vocab_size = esm_vocab_size
61
+ self.esm_mask_token_id = esm_mask_token_id
62
+ self.esm_pad_token_id = esm_pad_token_id
63
+ self.esm_hidden_size = esm_hidden_size
64
+ self.esm_num_hidden_layers = esm_num_hidden_layers
65
+ self.esm_num_attention_heads = esm_num_attention_heads
66
+ self.esm_intermediate_size = esm_intermediate_size
67
+ self.esm_max_position_embeddings = esm_max_position_embeddings
68
+ self.esm_token_dropout = esm_token_dropout
69
+ self.esm_emb_layer_norm_before = esm_emb_layer_norm_before
70
+ self.esm_attention_probs_dropout_prob = esm_attention_probs_dropout_prob
71
+ self.esm_hidden_dropout_prob = esm_hidden_dropout_prob
72
+ self.esm_use_cache = esm_use_cache
73
+ self.esm_add_bias_fnn = esm_add_bias_fnn
74
+ self.esm_position_embedding_type = esm_position_embedding_type
75
+ self.esm_tie_word_embeddings = esm_tie_word_embeddings
76
+ self.nt_vocab_size = nt_vocab_size
77
+ self.nt_mask_token_id = nt_mask_token_id
78
+ self.nt_pad_token_id = nt_pad_token_id
79
+ self.nt_hidden_size = nt_hidden_size
80
+ self.nt_num_hidden_layers = nt_num_hidden_layers
81
+ self.nt_num_attention_heads = nt_num_attention_heads
82
+ self.nt_intermediate_size = nt_intermediate_size
83
+ self.nt_max_position_embeddings = nt_max_position_embeddings
84
+ self.nt_token_dropout = nt_token_dropout
85
+ self.nt_emb_layer_norm_before = nt_emb_layer_norm_before
86
+ self.nt_attention_probs_dropout_prob = nt_attention_probs_dropout_prob
87
+ self.nt_hidden_dropout_prob = nt_hidden_dropout_prob
88
+ self.nt_use_cache = nt_use_cache
89
+ self.nt_add_bias_fnn = nt_add_bias_fnn
90
+ self.nt_position_embedding_type = nt_position_embedding_type
91
+ self.nt_tie_word_embeddings = nt_tie_word_embeddings
92
+ self.enformer_dim = enformer_dim
93
+ self.enformer_depth = enformer_depth
94
+ self.enformer_heads = enformer_heads
95
+ self.enformer_output_heads = enformer_output_heads
96
+ self.enformer_target_length = enformer_target_length
97
+ self.enformer_attn_dim_key = enformer_attn_dim_key
98
+ self.enformer_dropout_rate = enformer_dropout_rate
99
+ self.enformer_attn_dropout = enformer_attn_dropout
100
+ self.enformer_pos_dropout = enformer_pos_dropout
101
+ self.enformer_use_checkpointing = enformer_use_checkpointing
102
+ self.enformer_use_convnext = enformer_use_convnext
103
+ self.enformer_num_downsamples = enformer_num_downsamples
104
+ self.enformer_dim_divisible_by = enformer_dim_divisible_by
105
+ self.enformer_use_tf_gamma = enformer_use_tf_gamma
106
+ self.num_heads_omics_cross_attention = num_heads_omics_cross_attention
107
+ self.num_tokens_per_seq_nuctf = num_tokens_per_seq_nuctf
108
+ self.num_tokens_per_seq_nuctf_rna = num_tokens_per_seq_nuctf_rna
109
+ self.num_protein_tokens_per_seq = num_protein_tokens_per_seq
110
+
111
+ super().__init__(**kwargs)
modeling_isoformer.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ #from genomics_research.biobrain_p2.huggingface.modeling_enformer import Enformer
3
+ from genomics_research.biobrain_p2.huggingface.modeling_esm import NTForMaskedLM, MultiHeadAttention
4
+ from genomics_research.biobrain_p2.huggingface.isoformer_config import IsoformerConfig
5
+ #from genomics_research.biobrain_p2.huggingface.enformer_config import EnformerConfig
6
+ from genomics_research.biobrain_p2.huggingface.esm_config import NTConfig
7
+ from genomics_research.biobrain_p2.huggingface.modeling_esm_original import EsmForMaskedLM
8
+ from transformers.models.esm.configuration_esm import EsmConfig
9
+ from enformer_pytorch import Enformer, str_to_one_hot, EnformerConfig
10
+ import torch
11
+ from torch import nn
12
+
13
+ class Isoformer(PreTrainedModel):
14
+ config_class = IsoformerConfig
15
+
16
+ def __init__(self, config):
17
+ super().__init__(config)
18
+
19
+
20
+ self.esm_config = EsmConfig(
21
+ vocab_size=config.esm_vocab_size,
22
+ mask_token_id=config.esm_mask_token_id,
23
+ pad_token_id=config.esm_pad_token_id,
24
+ hidden_size=config.esm_hidden_size,
25
+ num_hidden_layers=config.esm_num_hidden_layers,
26
+ num_attention_heads=config.esm_num_attention_heads,
27
+ intermediate_size=config.esm_intermediate_size,
28
+ max_position_embeddings=config.esm_max_position_embeddings,
29
+ token_dropout=config.esm_token_dropout,
30
+ emb_layer_norm_before=config.esm_emb_layer_norm_before,
31
+ attention_probs_dropout_prob=0.0,
32
+ hidden_dropout_prob=0.0,
33
+ use_cache=False,
34
+ add_bias_fnn=config.esm_add_bias_fnn,
35
+ position_embedding_type="rotary",
36
+ tie_word_embeddings=False,
37
+ )
38
+
39
+ self.nt_config = NTConfig(
40
+ vocab_size=config.nt_vocab_size,
41
+ mask_token_id=config.nt_mask_token_id,
42
+ pad_token_id=config.nt_pad_token_id,
43
+ hidden_size=config.nt_hidden_size,
44
+ num_hidden_layers=config.nt_num_hidden_layers,
45
+ num_attention_heads=config.nt_num_attention_heads,
46
+ intermediate_size=config.nt_intermediate_size,
47
+ max_position_embeddings=config.nt_max_position_embeddings,
48
+ token_dropout=config.nt_token_dropout,
49
+ emb_layer_norm_before=config.nt_emb_layer_norm_before,
50
+ attention_probs_dropout_prob=0.0,
51
+ hidden_dropout_prob=0.0,
52
+ use_cache=False,
53
+ add_bias_fnn=config.nt_add_bias_fnn,
54
+ position_embedding_type="rotary",
55
+ tie_word_embeddings=False,
56
+ )
57
+ self.config = config
58
+
59
+ # self.enformer_config = EnformerConfig(
60
+ # dim=config.enformer_dim,
61
+ # depth=config.enformer_depth,
62
+ # heads=config.enformer_heads,
63
+ # output_heads=dict(
64
+ # human=1,
65
+ # mouse=1 # TODO CHANGE
66
+ # ),
67
+ # target_length=config.enformer_target_length, # 896,
68
+ # attn_dim_key=config.enformer_attn_dim_key,
69
+ # dropout_rate=0.4,
70
+ # attn_dropout=0.05,
71
+ # pos_dropout=0.01,
72
+ # use_checkpointing=config.enformer_use_checkpointing,
73
+ # use_convnext=config.enformer_use_convnext,
74
+ # num_downsamples=config.enformer_num_downsamples,
75
+ # # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
76
+ # dim_divisible_by=config.enformer_dim_divisible_by,
77
+ # use_tf_gamma=False,
78
+ # )
79
+
80
+ self.esm_model = EsmForMaskedLM(self.esm_config) # protein encoder
81
+ self.nt_model = NTForMaskedLM(self.nt_config) # rna encoder
82
+ #self.enformer_model = Enformer(self.enformer_config) # dna encoder
83
+ self.enformer_model = Enformer.from_pretrained("EleutherAI/enformer-official-rough")
84
+
85
+ self.cross_attention_layer_rna = MultiHeadAttention(
86
+ config=EsmConfig(
87
+ num_attention_heads=config.num_heads_omics_cross_attention,
88
+ attention_head_size=3072 // config.num_heads_omics_cross_attention,
89
+ hidden_size=3072,
90
+ attention_probs_dropout_prob=0,
91
+ max_position_embeddings=0
92
+ ),
93
+ omics_of_interest_size=3072,
94
+ other_omic_size=768
95
+ )
96
+ self.cross_attention_layer_protein = MultiHeadAttention(
97
+ config=EsmConfig(
98
+ num_attention_heads=config.num_heads_omics_cross_attention,
99
+ attention_head_size=3072 // config.num_heads_omics_cross_attention,
100
+ hidden_size=3072,
101
+ attention_probs_dropout_prob=0,
102
+ max_position_embeddings=0
103
+ ),
104
+ omics_of_interest_size=3072,
105
+ other_omic_size=640
106
+ )
107
+
108
+ self.head_layer_1 = nn.Linear(3072, 2 * 3072)
109
+ self.head_layer_2 = nn.Linear(2 * 3072, 30)
110
+
111
+ def forward(
112
+ self,
113
+ tensor_dna,
114
+ tensor_rna,
115
+ tensor_protein,
116
+ attention_mask_dna,
117
+ attention_mask_rna,
118
+ attention_mask_protein
119
+ ):
120
+ tensor_dna = tensor_dna[:, 1:] # remove CLS
121
+ dna_embedding = self.enformer_model(
122
+ tensor_dna,
123
+ return_only_embeddings=True
124
+ # attention_mask=attention_mask_dna,
125
+ # encoder_attention_mask=attention_mask_dna,
126
+ # output_hidden_states=True
127
+ )
128
+ protein_embedding = self.esm_model(
129
+ tensor_protein,
130
+ attention_mask=attention_mask_protein,
131
+ encoder_attention_mask=attention_mask_protein,
132
+ output_hidden_states=True
133
+ )
134
+ rna_embedding = self.nt_model(
135
+ tensor_rna,
136
+ attention_mask=attention_mask_rna,
137
+ encoder_attention_mask=attention_mask_rna,
138
+ output_hidden_states=True
139
+ )
140
+
141
+ encoder_attention_mask = torch.unsqueeze(torch.unsqueeze(tensor_rna != 1, 0),0).repeat(1,1,dna_embedding.shape[1],1)
142
+ rna_to_dna = self.cross_attention_layer_rna.forward(
143
+ hidden_states=dna_embedding,
144
+ encoder_hidden_states=rna_embedding["hidden_states"][-1],
145
+ encoder_attention_mask=encoder_attention_mask
146
+ )
147
+
148
+ final_dna_embeddings = self.cross_attention_layer_protein.forward(
149
+ hidden_states=rna_to_dna["embeddings"],
150
+ encoder_hidden_states=protein_embedding["hidden_states"][-1],
151
+ )["embeddings"]
152
+
153
+ sequence_mask = torch.zeros(final_dna_embeddings.shape[1])
154
+ sequence_mask[self.config.pool_window_start:self.config.pool_window_end] = 1
155
+ x = torch.sum(torch.einsum('ijk,j->ijk', final_dna_embeddings, sequence_mask),axis=1)/torch.sum(sequence_mask)
156
+ x = self.head_layer_1(x)
157
+ x = torch.nn.functional.softplus(x)
158
+ x = self.head_layer_2(x)
159
+
160
+
161
+ return {
162
+ "gene_expression_predictions":x,
163
+ "rna_to_dna": rna_to_dna,
164
+ "final_embeddings": final_dna_embeddings,
165
+ "dna_embedding": dna_embedding,
166
+ "rna_embedding": rna_embedding,
167
+ "protein_embedding": protein_embedding
168
+ }
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cf803dfa3d135f58e9deb3b9a4958cca369a7959ab11043e21232bb994f35f36
3
  size 2803153818
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:862bbda2ad7efe88014c38c046c517728295f2a038080c00071570dd20c9c7ac
3
  size 2803153818