NeptuniaNep commited on
Commit
6f48e1f
·
1 Parent(s): 1a65f10

Upload 2 files

Browse files
Files changed (2) hide show
  1. best_model.pth +3 -0
  2. proteinbind_new.py +282 -0
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f04226a7bfc5cbb097348fa4f721a1d0da1b3aa248062ddef43136ff4ece1673
3
+ size 52399787
proteinbind_new.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+
3
+ import pandas as pd
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+ ModalityType = SimpleNamespace(
10
+ AA="aa",
11
+ DNA="dna",
12
+ PDB="pdb",
13
+ GO="go",
14
+ MSA="msa",
15
+ TEXT="text",
16
+ )
17
+
18
+ class Normalize(nn.Module):
19
+ def __init__(self, dim: int) -> None:
20
+ super().__init__()
21
+ self.dim = dim
22
+
23
+ def forward(self, x):
24
+ return torch.nn.functional.normalize(x, dim=self.dim, p=2)
25
+
26
+ class EmbeddingDataset(Dataset):
27
+ """
28
+ The main class for turning any modality to a torch Dataset that can be passed to
29
+ a torch dataloader. Any modality that doesn't fit into the __getitem__
30
+ method can subclass this and modify the __getitem__ method.
31
+ """
32
+ def __init__(self, sequence_file_path, embeddings_file_path, modality):
33
+ self.sequence = pd.read_csv(sequence_file_path)
34
+ self.embedding = torch.load(embeddings_file_path)
35
+ self.modality = modality
36
+
37
+ def __len__(self):
38
+ return len(self.sequence)
39
+
40
+ def __getitem__(self, idx):
41
+ sequence = self.sequence.iloc[idx, 0]
42
+ embedding = self.embedding[idx]
43
+ return {"aa": sequence, self.modality: embedding}
44
+
45
+ class DualEmbeddingDataset(Dataset):
46
+ """
47
+ The main class for turning any modality to a torch Dataset that can be passed to
48
+ a torch dataloader. Any modality that doesn't fit into the __getitem__
49
+ method can subclass this and modify the __getitem__ method.
50
+ """
51
+ def __init__(self, sequence_embeddings_file_path, embeddings_file_path, modality):
52
+ self.sequence_embedding = torch.load(sequence_embeddings_file_path)
53
+ self.embedding = torch.load(embeddings_file_path)
54
+ self.modality = modality
55
+
56
+ def __len__(self):
57
+ return len(self.sequence_embedding)
58
+
59
+ def __getitem__(self, idx):
60
+ sequence_embedding = self.sequence_embedding[idx]
61
+ embedding = self.embedding[idx]
62
+ return {"aa": sequence_embedding, self.modality: embedding}
63
+
64
+ class ProteinBindModel(nn.Module):
65
+
66
+ def __init__(
67
+ self,
68
+ aa_embed_dim,
69
+ dna_embed_dim,
70
+ pdb_embed_dim,
71
+ go_embed_dim,
72
+ msa_embed_dim,
73
+ text_embed_dim,
74
+ in_embed_dim,
75
+ out_embed_dim
76
+ ):
77
+ super().__init__()
78
+ self.modality_trunks = self._create_modality_trunk(
79
+ aa_embed_dim,
80
+ dna_embed_dim,
81
+ pdb_embed_dim,
82
+ go_embed_dim,
83
+ msa_embed_dim,
84
+ text_embed_dim,
85
+ out_embed_dim
86
+ )
87
+ self.modality_heads = self._create_modality_head(
88
+ in_embed_dim,
89
+ out_embed_dim,
90
+ )
91
+ self.modality_postprocessors = self._create_modality_postprocessors(
92
+ out_embed_dim
93
+ )
94
+
95
+
96
+ def _create_modality_trunk(
97
+ self,
98
+ aa_embed_dim,
99
+ dna_embed_dim,
100
+ pdb_embed_dim,
101
+ go_embed_dim,
102
+ msa_embed_dim,
103
+ text_embed_dim,
104
+ in_embed_dim
105
+ ):
106
+ """
107
+ The current layers are just a proof of concept
108
+ and are subject to the opinion of others.
109
+ :param aa_embed_dim:
110
+ :param dna_embed_dim:
111
+ :param pdb_embed_dim:
112
+ :param go_embed_dim:
113
+ :param msa_embed_dim:
114
+ :param text_embed_dim:
115
+ :param in_embed_dim:
116
+ :return:
117
+ """
118
+ modality_trunks = {}
119
+
120
+ modality_trunks[ModalityType.AA] = nn.Sequential(
121
+ nn.Linear(aa_embed_dim, 512),
122
+ nn.ReLU(),
123
+ nn.Linear(512, 512),
124
+ nn.ReLU(),
125
+ nn.Linear(512, in_embed_dim),
126
+ )
127
+
128
+ modality_trunks[ModalityType.DNA] = nn.Sequential(
129
+ nn.Linear(dna_embed_dim, 512),
130
+ nn.ReLU(),
131
+ nn.Linear(512, 512),
132
+ nn.ReLU(),
133
+ nn.Linear(512, in_embed_dim),
134
+ )
135
+
136
+ modality_trunks[ModalityType.PDB] = nn.Sequential(
137
+ nn.Linear(pdb_embed_dim, 512),
138
+ nn.ReLU(),
139
+ nn.Linear(512, 512),
140
+ nn.ReLU(),
141
+ nn.Linear(512, in_embed_dim),
142
+ )
143
+
144
+ modality_trunks[ModalityType.GO] = nn.Sequential(
145
+ nn.Linear(go_embed_dim, 512),
146
+ nn.ReLU(),
147
+ nn.Linear(512, 512),
148
+ nn.ReLU(),
149
+ nn.Linear(512, in_embed_dim),
150
+ )
151
+
152
+ modality_trunks[ModalityType.MSA] = nn.Sequential(
153
+ nn.Linear(msa_embed_dim, 512),
154
+ nn.ReLU(),
155
+ nn.Linear(512, 512),
156
+ nn.ReLU(),
157
+ nn.Linear(512, in_embed_dim),
158
+ )
159
+
160
+ modality_trunks[ModalityType.TEXT] = nn.Sequential(
161
+ nn.Linear(text_embed_dim, 512),
162
+ nn.ReLU(),
163
+ nn.Linear(512, 512),
164
+ nn.ReLU(),
165
+ nn.Linear(512, in_embed_dim),
166
+ )
167
+
168
+ return nn.ModuleDict(modality_trunks)
169
+
170
+ def _create_modality_head(
171
+ self,
172
+ in_embed_dim,
173
+ out_embed_dim
174
+ ):
175
+ modality_heads = {}
176
+
177
+ modality_heads[ModalityType.AA] = nn.Sequential(
178
+ nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6),
179
+ nn.Dropout(p=0.5),
180
+ nn.Linear(in_embed_dim, out_embed_dim, bias=False),
181
+ )
182
+
183
+ modality_heads[ModalityType.DNA] = nn.Sequential(
184
+ nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6),
185
+ nn.Dropout(p=0.5),
186
+ nn.Linear(in_embed_dim, out_embed_dim, bias=False),
187
+ )
188
+
189
+ modality_heads[ModalityType.PDB] = nn.Sequential(
190
+ nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6),
191
+ nn.Dropout(p=0.5),
192
+ nn.Linear(in_embed_dim, out_embed_dim, bias=False),
193
+ )
194
+
195
+ modality_heads[ModalityType.GO] = nn.Sequential(
196
+ nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6),
197
+ nn.Dropout(p=0.5),
198
+ nn.Linear(in_embed_dim, out_embed_dim, bias=False),
199
+ )
200
+
201
+ modality_heads[ModalityType.MSA] = nn.Sequential(
202
+ nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6),
203
+ nn.Dropout(p=0.5),
204
+ nn.Linear(in_embed_dim, out_embed_dim, bias=False),
205
+ )
206
+
207
+ modality_heads[ModalityType.TEXT] = nn.Sequential(
208
+ nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6),
209
+ nn.Dropout(p=0.5),
210
+ nn.Linear(in_embed_dim, out_embed_dim, bias=False),
211
+ )
212
+ return nn.ModuleDict(modality_heads)
213
+
214
+ def _create_modality_postprocessors(self, out_embed_dim):
215
+ modality_postprocessors = {}
216
+ modality_postprocessors[ModalityType.AA] = Normalize(dim=-1)
217
+ modality_postprocessors[ModalityType.DNA] = Normalize(dim=-1)
218
+ modality_postprocessors[ModalityType.PDB] = Normalize(dim=-1)
219
+ modality_postprocessors[ModalityType.TEXT] = Normalize(dim=-1)
220
+ modality_postprocessors[ModalityType.GO] = Normalize(dim=-1)
221
+ modality_postprocessors[ModalityType.MSA] = Normalize(dim=-1)
222
+
223
+
224
+ return nn.ModuleDict(modality_postprocessors)
225
+
226
+ def forward(self, inputs):
227
+ """
228
+ input = {k_1: [v],k_n: [v]}
229
+ for key in input
230
+ get trunk for key
231
+ forward pass of value in trunk
232
+ get projection head of key
233
+ forward pass of value in projection head
234
+ append output in output dict
235
+ return { k_1, [o], k_n: [o]}
236
+ """
237
+
238
+ outputs = {}
239
+
240
+ for modality_key, modality_value in inputs.items():
241
+
242
+
243
+ modality_value = self.modality_trunks[modality_key](
244
+ modality_value
245
+ )
246
+
247
+ modality_value = self.modality_heads[modality_key](
248
+ modality_value
249
+ )
250
+
251
+ modality_value = self.modality_postprocessors[modality_key](
252
+ modality_value
253
+ )
254
+ outputs[modality_key] = modality_value
255
+
256
+ return outputs
257
+
258
+
259
+ def create_proteinbind(pretrained=False):
260
+ """
261
+ The embedding dimensions here are dummy
262
+ :param pretrained:
263
+ :return:
264
+ """
265
+ model = ProteinBindModel(
266
+ aa_embed_dim=480,
267
+ dna_embed_dim=1280,
268
+ pdb_embed_dim=128,
269
+ go_embed_dim=600,
270
+ msa_embed_dim=768,
271
+ text_embed_dim=768,
272
+ in_embed_dim=1024,
273
+ out_embed_dim=1024
274
+ )
275
+
276
+ if pretrained:
277
+ #get path from config
278
+ PATH = 'best_model.pth'
279
+
280
+ model.load_state_dict(torch.load(PATH))
281
+
282
+ return model