isoformer-anonymous
commited on
Commit
•
bed38a1
1
Parent(s):
12d0972
Upload Isoformer
Browse files- config.json +4 -0
- isoformer_config.py +111 -0
- modeling_isoformer.py +168 -0
- pytorch_model.bin +1 -1
config.json
CHANGED
@@ -2,6 +2,10 @@
|
|
2 |
"architectures": [
|
3 |
"Isoformer"
|
4 |
],
|
|
|
|
|
|
|
|
|
5 |
"enformer_attn_dim_key": 64,
|
6 |
"enformer_attn_dropout": 0.05,
|
7 |
"enformer_depth": 11,
|
|
|
2 |
"architectures": [
|
3 |
"Isoformer"
|
4 |
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "isoformer_config.IsoformerConfig",
|
7 |
+
"AutoModel": "modeling_isoformer.Isoformer"
|
8 |
+
},
|
9 |
"enformer_attn_dim_key": 64,
|
10 |
"enformer_attn_dropout": 0.05,
|
11 |
"enformer_depth": 11,
|
isoformer_config.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
class IsoformerConfig(PretrainedConfig):
|
4 |
+
model_type = "isoformer"
|
5 |
+
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
esm_vocab_size=None,
|
9 |
+
esm_mask_token_id=None,
|
10 |
+
esm_pad_token_id=None,
|
11 |
+
esm_hidden_size=768,
|
12 |
+
esm_num_hidden_layers=12,
|
13 |
+
esm_num_attention_heads=12,
|
14 |
+
esm_intermediate_size=3072,
|
15 |
+
esm_hidden_dropout_prob=0.1,
|
16 |
+
esm_attention_probs_dropout_prob=0.1,
|
17 |
+
esm_max_position_embeddings=1026,
|
18 |
+
esm_position_embedding_type="absolute",
|
19 |
+
esm_use_cache=True,
|
20 |
+
esm_emb_layer_norm_before=None,
|
21 |
+
esm_token_dropout=False,
|
22 |
+
esm_add_bias_fnn=True,
|
23 |
+
esm_tie_word_embeddings=0,
|
24 |
+
nt_vocab_size=None,
|
25 |
+
nt_mask_token_id=None,
|
26 |
+
nt_pad_token_id=None,
|
27 |
+
nt_hidden_size=768,
|
28 |
+
nt_num_hidden_layers=12,
|
29 |
+
nt_num_attention_heads=12,
|
30 |
+
nt_intermediate_size=3072,
|
31 |
+
nt_hidden_dropout_prob=0.1,
|
32 |
+
nt_attention_probs_dropout_prob=0.1,
|
33 |
+
nt_max_position_embeddings=1026,
|
34 |
+
nt_position_embedding_type="absolute",
|
35 |
+
nt_use_cache=True,
|
36 |
+
nt_emb_layer_norm_before=None,
|
37 |
+
nt_token_dropout=False,
|
38 |
+
nt_add_bias_fnn=True,
|
39 |
+
nt_tie_word_embeddings=0,
|
40 |
+
enformer_dim=1536,
|
41 |
+
enformer_depth=11,
|
42 |
+
enformer_heads=8,
|
43 |
+
enformer_output_heads=0,
|
44 |
+
enformer_target_length=896,
|
45 |
+
enformer_attn_dim_key=64,
|
46 |
+
enformer_dropout_rate=0.4,
|
47 |
+
enformer_attn_dropout=0.05,
|
48 |
+
enformer_pos_dropout=0.01,
|
49 |
+
enformer_use_checkpointing=False,
|
50 |
+
enformer_use_convnext=False,
|
51 |
+
enformer_num_downsamples=7, # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
|
52 |
+
enformer_dim_divisible_by=128,
|
53 |
+
enformer_use_tf_gamma=False,
|
54 |
+
num_heads_omics_cross_attention=8,
|
55 |
+
num_tokens_per_seq_nuctf=2048,
|
56 |
+
num_tokens_per_seq_nuctf_rna=2048,
|
57 |
+
num_protein_tokens_per_seq=2048,
|
58 |
+
**kwargs,
|
59 |
+
):
|
60 |
+
self.esm_vocab_size = esm_vocab_size
|
61 |
+
self.esm_mask_token_id = esm_mask_token_id
|
62 |
+
self.esm_pad_token_id = esm_pad_token_id
|
63 |
+
self.esm_hidden_size = esm_hidden_size
|
64 |
+
self.esm_num_hidden_layers = esm_num_hidden_layers
|
65 |
+
self.esm_num_attention_heads = esm_num_attention_heads
|
66 |
+
self.esm_intermediate_size = esm_intermediate_size
|
67 |
+
self.esm_max_position_embeddings = esm_max_position_embeddings
|
68 |
+
self.esm_token_dropout = esm_token_dropout
|
69 |
+
self.esm_emb_layer_norm_before = esm_emb_layer_norm_before
|
70 |
+
self.esm_attention_probs_dropout_prob = esm_attention_probs_dropout_prob
|
71 |
+
self.esm_hidden_dropout_prob = esm_hidden_dropout_prob
|
72 |
+
self.esm_use_cache = esm_use_cache
|
73 |
+
self.esm_add_bias_fnn = esm_add_bias_fnn
|
74 |
+
self.esm_position_embedding_type = esm_position_embedding_type
|
75 |
+
self.esm_tie_word_embeddings = esm_tie_word_embeddings
|
76 |
+
self.nt_vocab_size = nt_vocab_size
|
77 |
+
self.nt_mask_token_id = nt_mask_token_id
|
78 |
+
self.nt_pad_token_id = nt_pad_token_id
|
79 |
+
self.nt_hidden_size = nt_hidden_size
|
80 |
+
self.nt_num_hidden_layers = nt_num_hidden_layers
|
81 |
+
self.nt_num_attention_heads = nt_num_attention_heads
|
82 |
+
self.nt_intermediate_size = nt_intermediate_size
|
83 |
+
self.nt_max_position_embeddings = nt_max_position_embeddings
|
84 |
+
self.nt_token_dropout = nt_token_dropout
|
85 |
+
self.nt_emb_layer_norm_before = nt_emb_layer_norm_before
|
86 |
+
self.nt_attention_probs_dropout_prob = nt_attention_probs_dropout_prob
|
87 |
+
self.nt_hidden_dropout_prob = nt_hidden_dropout_prob
|
88 |
+
self.nt_use_cache = nt_use_cache
|
89 |
+
self.nt_add_bias_fnn = nt_add_bias_fnn
|
90 |
+
self.nt_position_embedding_type = nt_position_embedding_type
|
91 |
+
self.nt_tie_word_embeddings = nt_tie_word_embeddings
|
92 |
+
self.enformer_dim = enformer_dim
|
93 |
+
self.enformer_depth = enformer_depth
|
94 |
+
self.enformer_heads = enformer_heads
|
95 |
+
self.enformer_output_heads = enformer_output_heads
|
96 |
+
self.enformer_target_length = enformer_target_length
|
97 |
+
self.enformer_attn_dim_key = enformer_attn_dim_key
|
98 |
+
self.enformer_dropout_rate = enformer_dropout_rate
|
99 |
+
self.enformer_attn_dropout = enformer_attn_dropout
|
100 |
+
self.enformer_pos_dropout = enformer_pos_dropout
|
101 |
+
self.enformer_use_checkpointing = enformer_use_checkpointing
|
102 |
+
self.enformer_use_convnext = enformer_use_convnext
|
103 |
+
self.enformer_num_downsamples = enformer_num_downsamples
|
104 |
+
self.enformer_dim_divisible_by = enformer_dim_divisible_by
|
105 |
+
self.enformer_use_tf_gamma = enformer_use_tf_gamma
|
106 |
+
self.num_heads_omics_cross_attention = num_heads_omics_cross_attention
|
107 |
+
self.num_tokens_per_seq_nuctf = num_tokens_per_seq_nuctf
|
108 |
+
self.num_tokens_per_seq_nuctf_rna = num_tokens_per_seq_nuctf_rna
|
109 |
+
self.num_protein_tokens_per_seq = num_protein_tokens_per_seq
|
110 |
+
|
111 |
+
super().__init__(**kwargs)
|
modeling_isoformer.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
#from genomics_research.biobrain_p2.huggingface.modeling_enformer import Enformer
|
3 |
+
from genomics_research.biobrain_p2.huggingface.modeling_esm import NTForMaskedLM, MultiHeadAttention
|
4 |
+
from genomics_research.biobrain_p2.huggingface.isoformer_config import IsoformerConfig
|
5 |
+
#from genomics_research.biobrain_p2.huggingface.enformer_config import EnformerConfig
|
6 |
+
from genomics_research.biobrain_p2.huggingface.esm_config import NTConfig
|
7 |
+
from genomics_research.biobrain_p2.huggingface.modeling_esm_original import EsmForMaskedLM
|
8 |
+
from transformers.models.esm.configuration_esm import EsmConfig
|
9 |
+
from enformer_pytorch import Enformer, str_to_one_hot, EnformerConfig
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
class Isoformer(PreTrainedModel):
|
14 |
+
config_class = IsoformerConfig
|
15 |
+
|
16 |
+
def __init__(self, config):
|
17 |
+
super().__init__(config)
|
18 |
+
|
19 |
+
|
20 |
+
self.esm_config = EsmConfig(
|
21 |
+
vocab_size=config.esm_vocab_size,
|
22 |
+
mask_token_id=config.esm_mask_token_id,
|
23 |
+
pad_token_id=config.esm_pad_token_id,
|
24 |
+
hidden_size=config.esm_hidden_size,
|
25 |
+
num_hidden_layers=config.esm_num_hidden_layers,
|
26 |
+
num_attention_heads=config.esm_num_attention_heads,
|
27 |
+
intermediate_size=config.esm_intermediate_size,
|
28 |
+
max_position_embeddings=config.esm_max_position_embeddings,
|
29 |
+
token_dropout=config.esm_token_dropout,
|
30 |
+
emb_layer_norm_before=config.esm_emb_layer_norm_before,
|
31 |
+
attention_probs_dropout_prob=0.0,
|
32 |
+
hidden_dropout_prob=0.0,
|
33 |
+
use_cache=False,
|
34 |
+
add_bias_fnn=config.esm_add_bias_fnn,
|
35 |
+
position_embedding_type="rotary",
|
36 |
+
tie_word_embeddings=False,
|
37 |
+
)
|
38 |
+
|
39 |
+
self.nt_config = NTConfig(
|
40 |
+
vocab_size=config.nt_vocab_size,
|
41 |
+
mask_token_id=config.nt_mask_token_id,
|
42 |
+
pad_token_id=config.nt_pad_token_id,
|
43 |
+
hidden_size=config.nt_hidden_size,
|
44 |
+
num_hidden_layers=config.nt_num_hidden_layers,
|
45 |
+
num_attention_heads=config.nt_num_attention_heads,
|
46 |
+
intermediate_size=config.nt_intermediate_size,
|
47 |
+
max_position_embeddings=config.nt_max_position_embeddings,
|
48 |
+
token_dropout=config.nt_token_dropout,
|
49 |
+
emb_layer_norm_before=config.nt_emb_layer_norm_before,
|
50 |
+
attention_probs_dropout_prob=0.0,
|
51 |
+
hidden_dropout_prob=0.0,
|
52 |
+
use_cache=False,
|
53 |
+
add_bias_fnn=config.nt_add_bias_fnn,
|
54 |
+
position_embedding_type="rotary",
|
55 |
+
tie_word_embeddings=False,
|
56 |
+
)
|
57 |
+
self.config = config
|
58 |
+
|
59 |
+
# self.enformer_config = EnformerConfig(
|
60 |
+
# dim=config.enformer_dim,
|
61 |
+
# depth=config.enformer_depth,
|
62 |
+
# heads=config.enformer_heads,
|
63 |
+
# output_heads=dict(
|
64 |
+
# human=1,
|
65 |
+
# mouse=1 # TODO CHANGE
|
66 |
+
# ),
|
67 |
+
# target_length=config.enformer_target_length, # 896,
|
68 |
+
# attn_dim_key=config.enformer_attn_dim_key,
|
69 |
+
# dropout_rate=0.4,
|
70 |
+
# attn_dropout=0.05,
|
71 |
+
# pos_dropout=0.01,
|
72 |
+
# use_checkpointing=config.enformer_use_checkpointing,
|
73 |
+
# use_convnext=config.enformer_use_convnext,
|
74 |
+
# num_downsamples=config.enformer_num_downsamples,
|
75 |
+
# # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
|
76 |
+
# dim_divisible_by=config.enformer_dim_divisible_by,
|
77 |
+
# use_tf_gamma=False,
|
78 |
+
# )
|
79 |
+
|
80 |
+
self.esm_model = EsmForMaskedLM(self.esm_config) # protein encoder
|
81 |
+
self.nt_model = NTForMaskedLM(self.nt_config) # rna encoder
|
82 |
+
#self.enformer_model = Enformer(self.enformer_config) # dna encoder
|
83 |
+
self.enformer_model = Enformer.from_pretrained("EleutherAI/enformer-official-rough")
|
84 |
+
|
85 |
+
self.cross_attention_layer_rna = MultiHeadAttention(
|
86 |
+
config=EsmConfig(
|
87 |
+
num_attention_heads=config.num_heads_omics_cross_attention,
|
88 |
+
attention_head_size=3072 // config.num_heads_omics_cross_attention,
|
89 |
+
hidden_size=3072,
|
90 |
+
attention_probs_dropout_prob=0,
|
91 |
+
max_position_embeddings=0
|
92 |
+
),
|
93 |
+
omics_of_interest_size=3072,
|
94 |
+
other_omic_size=768
|
95 |
+
)
|
96 |
+
self.cross_attention_layer_protein = MultiHeadAttention(
|
97 |
+
config=EsmConfig(
|
98 |
+
num_attention_heads=config.num_heads_omics_cross_attention,
|
99 |
+
attention_head_size=3072 // config.num_heads_omics_cross_attention,
|
100 |
+
hidden_size=3072,
|
101 |
+
attention_probs_dropout_prob=0,
|
102 |
+
max_position_embeddings=0
|
103 |
+
),
|
104 |
+
omics_of_interest_size=3072,
|
105 |
+
other_omic_size=640
|
106 |
+
)
|
107 |
+
|
108 |
+
self.head_layer_1 = nn.Linear(3072, 2 * 3072)
|
109 |
+
self.head_layer_2 = nn.Linear(2 * 3072, 30)
|
110 |
+
|
111 |
+
def forward(
|
112 |
+
self,
|
113 |
+
tensor_dna,
|
114 |
+
tensor_rna,
|
115 |
+
tensor_protein,
|
116 |
+
attention_mask_dna,
|
117 |
+
attention_mask_rna,
|
118 |
+
attention_mask_protein
|
119 |
+
):
|
120 |
+
tensor_dna = tensor_dna[:, 1:] # remove CLS
|
121 |
+
dna_embedding = self.enformer_model(
|
122 |
+
tensor_dna,
|
123 |
+
return_only_embeddings=True
|
124 |
+
# attention_mask=attention_mask_dna,
|
125 |
+
# encoder_attention_mask=attention_mask_dna,
|
126 |
+
# output_hidden_states=True
|
127 |
+
)
|
128 |
+
protein_embedding = self.esm_model(
|
129 |
+
tensor_protein,
|
130 |
+
attention_mask=attention_mask_protein,
|
131 |
+
encoder_attention_mask=attention_mask_protein,
|
132 |
+
output_hidden_states=True
|
133 |
+
)
|
134 |
+
rna_embedding = self.nt_model(
|
135 |
+
tensor_rna,
|
136 |
+
attention_mask=attention_mask_rna,
|
137 |
+
encoder_attention_mask=attention_mask_rna,
|
138 |
+
output_hidden_states=True
|
139 |
+
)
|
140 |
+
|
141 |
+
encoder_attention_mask = torch.unsqueeze(torch.unsqueeze(tensor_rna != 1, 0),0).repeat(1,1,dna_embedding.shape[1],1)
|
142 |
+
rna_to_dna = self.cross_attention_layer_rna.forward(
|
143 |
+
hidden_states=dna_embedding,
|
144 |
+
encoder_hidden_states=rna_embedding["hidden_states"][-1],
|
145 |
+
encoder_attention_mask=encoder_attention_mask
|
146 |
+
)
|
147 |
+
|
148 |
+
final_dna_embeddings = self.cross_attention_layer_protein.forward(
|
149 |
+
hidden_states=rna_to_dna["embeddings"],
|
150 |
+
encoder_hidden_states=protein_embedding["hidden_states"][-1],
|
151 |
+
)["embeddings"]
|
152 |
+
|
153 |
+
sequence_mask = torch.zeros(final_dna_embeddings.shape[1])
|
154 |
+
sequence_mask[self.config.pool_window_start:self.config.pool_window_end] = 1
|
155 |
+
x = torch.sum(torch.einsum('ijk,j->ijk', final_dna_embeddings, sequence_mask),axis=1)/torch.sum(sequence_mask)
|
156 |
+
x = self.head_layer_1(x)
|
157 |
+
x = torch.nn.functional.softplus(x)
|
158 |
+
x = self.head_layer_2(x)
|
159 |
+
|
160 |
+
|
161 |
+
return {
|
162 |
+
"gene_expression_predictions":x,
|
163 |
+
"rna_to_dna": rna_to_dna,
|
164 |
+
"final_embeddings": final_dna_embeddings,
|
165 |
+
"dna_embedding": dna_embedding,
|
166 |
+
"rna_embedding": rna_embedding,
|
167 |
+
"protein_embedding": protein_embedding
|
168 |
+
}
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 2803153818
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:862bbda2ad7efe88014c38c046c517728295f2a038080c00071570dd20c9c7ac
|
3 |
size 2803153818
|