osbm commited on
Commit
fd37f66
·
1 Parent(s): e08fb79

Upload layers.py

Browse files
Files changed (1) hide show
  1. layers.py +435 -0
layers.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.modules.module import Module
4
+ from torch.nn import functional as F
5
+ from torch.nn import Embedding, ModuleList
6
+ from torch_geometric.nn import PNAConv, global_add_pool, Set2Set, GraphMultisetTransformer
7
+ import math
8
+
9
+ class MLP(nn.Module):
10
+ def __init__(self, act, in_feat, hid_feat=None, out_feat=None,
11
+ dropout=0.):
12
+ super().__init__()
13
+ if not hid_feat:
14
+ hid_feat = in_feat
15
+ if not out_feat:
16
+ out_feat = in_feat
17
+ self.fc1 = nn.Linear(in_feat, hid_feat)
18
+ self.act = torch.nn.ReLU()
19
+ self.fc2 = nn.Linear(hid_feat,out_feat)
20
+ self.droprateout = nn.Dropout(dropout)
21
+
22
+ def forward(self, x):
23
+ x = self.fc1(x)
24
+ x = self.act(x)
25
+ x = self.fc2(x)
26
+ return self.droprateout(x)
27
+
28
+ class Attention_new(nn.Module):
29
+ def __init__(self, dim, heads, act, attention_dropout=0., proj_dropout=0.):
30
+ super().__init__()
31
+ assert dim % heads == 0
32
+ self.heads = heads
33
+ self.scale = 1./dim**0.5
34
+
35
+ self.q = nn.Linear(dim, dim)
36
+ self.k = nn.Linear(dim, dim)
37
+ self.v = nn.Linear(dim, dim)
38
+ self.e = nn.Linear(dim, dim)
39
+ #self.attention_dropout = nn.Dropout(attention_dropout)
40
+
41
+ self.d_k = dim // heads
42
+ self.heads = heads
43
+ self.out_e = nn.Linear(dim,dim)
44
+ self.out_n = nn.Linear(dim, dim)
45
+
46
+
47
+ def forward(self, node, edge):
48
+ b, n, c = node.shape
49
+
50
+
51
+ q_embed = self.q(node).view(-1, n, self.heads, c//self.heads)
52
+ k_embed = self.k(node).view(-1, n, self.heads, c//self.heads)
53
+ v_embed = self.v(node).view(-1, n, self.heads, c//self.heads)
54
+
55
+ e_embed = self.e(edge).view(-1, n, n, self.heads, c//self.heads)
56
+
57
+ q_embed = q_embed.unsqueeze(2)
58
+ k_embed = k_embed.unsqueeze(1)
59
+
60
+ attn = q_embed * k_embed
61
+
62
+ attn = attn/ math.sqrt(self.d_k)
63
+
64
+
65
+ attn = attn * (e_embed + 1) * e_embed
66
+
67
+ edge = self.out_e(attn.flatten(3))
68
+
69
+ attn = F.softmax(attn, dim=2)
70
+
71
+ v_embed = v_embed.unsqueeze(1)
72
+
73
+ v_embed = attn * v_embed
74
+
75
+ v_embed = v_embed.sum(dim=2).flatten(2)
76
+
77
+ node = self.out_n(v_embed)
78
+
79
+ return node, edge
80
+
81
+ class Encoder_Block(nn.Module):
82
+ def __init__(self, dim, heads,act, mlp_ratio=4, drop_rate=0., ):
83
+ super().__init__()
84
+ self.ln1 = nn.LayerNorm(dim)
85
+
86
+ self.attn = Attention_new(dim, heads, act, drop_rate, drop_rate)
87
+ self.ln3 = nn.LayerNorm(dim)
88
+ self.ln4 = nn.LayerNorm(dim)
89
+ self.mlp = MLP(act,dim,dim*mlp_ratio, dim, dropout=drop_rate)
90
+ self.mlp2 = MLP(act,dim,dim*mlp_ratio, dim, dropout=drop_rate)
91
+ self.ln5 = nn.LayerNorm(dim)
92
+ self.ln6 = nn.LayerNorm(dim)
93
+
94
+ def forward(self, x,y):
95
+ x1 = self.ln1(x)
96
+ x2,y1 = self.attn(x1,y)
97
+ x2 = x1 + x2
98
+ y2 = y1 + y
99
+ x2 = self.ln3(x2)
100
+ y2 = self.ln4(y2)
101
+
102
+ x = self.ln5(x2 + self.mlp(x2))
103
+ y = self.ln6(y2 + self.mlp2(y2))
104
+ return x, y
105
+
106
+
107
+ class TransformerEncoder(nn.Module):
108
+ def __init__(self, dim, depth, heads, act, mlp_ratio=4, drop_rate=0.1):
109
+ super().__init__()
110
+
111
+ self.Encoder_Blocks = nn.ModuleList([
112
+ Encoder_Block(dim, heads, act, mlp_ratio, drop_rate)
113
+ for i in range(depth)])
114
+
115
+ def forward(self, x,y):
116
+
117
+ for Encoder_Block in self.Encoder_Blocks:
118
+ x, y = Encoder_Block(x,y)
119
+
120
+ return x, y
121
+
122
+ class enc_dec_attention(nn.Module):
123
+ def __init__(self, dim, heads, attention_dropout=0., proj_dropout=0.):
124
+ super().__init__()
125
+ self.dim = dim
126
+ self.heads = heads
127
+ self.scale = 1./dim**0.5
128
+
129
+
130
+ "query is molecules"
131
+ "key is prot"
132
+ "values is again molecule"
133
+ self.q_mx = nn.Linear(dim,dim)
134
+ self.k_px = nn.Linear(dim,dim)
135
+ self.v_mx = nn.Linear(dim,dim)
136
+
137
+
138
+ self.k_pa = nn.Linear(dim,dim)
139
+ self.v_ma = nn.Linear(dim,dim)
140
+
141
+
142
+
143
+
144
+
145
+ #self.dropout_dec = nn.Dropout(proj_dropout)
146
+ self.out_nd = nn.Linear(dim, dim)
147
+ self.out_ed = nn.Linear(dim,dim)
148
+
149
+ def forward(self, mol_annot, prot_annot, mol_adj, prot_adj):
150
+
151
+ b, n, c = mol_annot.shape
152
+ _, m, _ = prot_annot.shape
153
+
154
+
155
+ query_mol_annot = self.q_mx(mol_annot).view(-1,m, self.heads, c//self.heads)
156
+ key_prot_annot = self.k_px(prot_annot).view(-1,n, self.heads, c//self.heads)
157
+ value_mol_annot = self.v_mx(mol_annot).view(-1,m, self.heads, c//self.heads)
158
+
159
+ mol_e = self.v_ma(mol_adj).view(-1,m,m, self.heads, c//self.heads)
160
+ prot_e = self.k_pa(prot_adj).view(-1,m,m, self.heads, c//self.heads)
161
+
162
+ query_mol_annot = query_mol_annot.unsqueeze(2)
163
+ key_prot_annot = key_prot_annot.unsqueeze(1)
164
+
165
+
166
+
167
+ #attn = torch.einsum('bnchd,bmahd->bnahd', query_mol_annot, key_prot_annot)
168
+
169
+ attn = query_mol_annot * key_prot_annot
170
+
171
+ attn = attn/ math.sqrt(self.dim)
172
+
173
+
174
+ attn = attn * (prot_e + 1) * mol_e
175
+
176
+ mol_e_new = attn.flatten(3)
177
+
178
+ mol_adj = self.out_ed(mol_e_new)
179
+
180
+ attn = F.softmax(attn, dim=2)
181
+
182
+ value_mol_annot = value_mol_annot.unsqueeze(1)
183
+
184
+ value_mol_annot = attn * value_mol_annot
185
+
186
+ value_mol_annot = value_mol_annot.sum(dim=2).flatten(2)
187
+
188
+ mol_annot = self.out_nd(value_mol_annot)
189
+
190
+ return mol_annot, prot_annot, mol_adj, prot_adj
191
+
192
+ class Decoder_Block(nn.Module):
193
+ def __init__(self, dim, heads, mlp_ratio=4, drop_rate=0.):
194
+ super().__init__()
195
+
196
+
197
+ self.ln1_ma = nn.LayerNorm(dim)
198
+ self.ln1_pa = nn.LayerNorm(dim)
199
+ self.ln1_mx = nn.LayerNorm(dim)
200
+ self.ln1_px = nn.LayerNorm(dim)
201
+
202
+ self.attn2 = Attention_new(dim, heads, drop_rate, drop_rate)
203
+
204
+ self.ln2_pa = nn.LayerNorm(dim)
205
+ self.ln2_px = nn.LayerNorm(dim)
206
+
207
+ self.dec_attn = enc_dec_attention(dim, heads, drop_rate, drop_rate)
208
+
209
+ self.ln3_ma = nn.LayerNorm(dim)
210
+ self.ln3_mx = nn.LayerNorm(dim)
211
+
212
+ self.mlp_ma = MLP(dim, dim, dropout=drop_rate)
213
+ self.mlp_mx = MLP(dim, dim, dropout=drop_rate)
214
+
215
+ self.ln4_ma = nn.LayerNorm(dim)
216
+ self.ln4_mx = nn.LayerNorm(dim)
217
+
218
+
219
+ def forward(self,mol_annot, prot_annot, mol_adj, prot_adj):
220
+
221
+ mol_annot = self.ln1_mx(mol_annot)
222
+ mol_adj = self.ln1_ma(mol_adj)
223
+
224
+ prot_annot = self.ln1_px(prot_annot)
225
+ prot_adj = self.ln1_pa(prot_adj)
226
+
227
+ px1, pa1= self.attn2(prot_annot, prot_adj)
228
+
229
+ prot_annot = prot_annot + px1
230
+ prot_adj = prot_adj + pa1
231
+
232
+ prot_annot = self.ln2_px(prot_annot)
233
+ prot_adj = self.ln2_pa(prot_adj)
234
+
235
+ mx1, prot_annot, ma1, prot_adj = self.dec_attn(mol_annot,prot_annot,mol_adj,prot_adj)
236
+
237
+ ma1 = mol_adj + ma1
238
+ mx1 = mol_annot + mx1
239
+
240
+ ma2 = self.ln3_ma(ma1)
241
+ mx2 = self.ln3_mx(mx1)
242
+
243
+ ma3 = self.mlp_ma(ma2)
244
+ mx3 = self.mlp_mx(mx2)
245
+
246
+ ma = ma3 + ma2
247
+ mx = mx3 + mx2
248
+
249
+ mol_adj = self.ln4_ma(ma)
250
+ mol_annot = self.ln4_mx(mx)
251
+
252
+ return mol_annot, prot_annot, mol_adj, prot_adj
253
+
254
+ class TransformerDecoder(nn.Module):
255
+ def __init__(self, dim, depth, heads, mlp_ratio=4, drop_rate=0.):
256
+ super().__init__()
257
+
258
+ self.Decoder_Blocks = nn.ModuleList([
259
+ Decoder_Block(dim, heads, mlp_ratio, drop_rate)
260
+ for i in range(depth)])
261
+
262
+ def forward(self, mol_annot, prot_annot, mol_adj, prot_adj):
263
+
264
+ for Decoder_Block in self.Decoder_Blocks:
265
+ mol_annot, prot_annot, mol_adj, prot_adj = Decoder_Block(mol_annot, prot_annot, mol_adj, prot_adj)
266
+
267
+ return mol_annot, prot_annot,mol_adj, prot_adj
268
+
269
+
270
+
271
+ """class PNA(torch.nn.Module):
272
+ def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
273
+ super(PNA,self).__init__()
274
+
275
+ self.node_emb = Embedding(30, pna_in_ch)
276
+ self.edge_emb = Embedding(30, edge_dim)
277
+ degree = deg
278
+ aggregators = agg.split(",") #["max"] # 'sum', 'min', 'max' 'std', 'var' 'mean', ## buraları değiştirerek bak.
279
+ scalers = sca.split(",") # ['amplification', 'attenuation'] # 'amplification', 'attenuation' , 'linear', 'inverse_linear, 'identity'
280
+ self.graph_add = graph_add
281
+ self.convs = ModuleList()
282
+ self.batch_norms = ModuleList()
283
+
284
+ for _ in range(pna_layer_num): ##### layer sayısını hyperparameter olarak ayarla??
285
+ conv = PNAConv(in_channels=pna_in_ch, out_channels=pna_out_ch,
286
+ aggregators=aggregators, scalers=scalers, deg=degree,
287
+ edge_dim=edge_dim, towers=towers, pre_layers=pre_lay, post_layers=post_lay, ## tower sayısını değiştirerek dene, default - 1
288
+ divide_input=True)
289
+ self.convs.append(conv)
290
+ self.batch_norms.append(nn.LayerNorm(pna_out_ch))
291
+
292
+ #self.graph_multitrans = GraphMultisetTransformer(in_channels=pna_out_ch, hidden_channels= 200,
293
+ #out_channels= pna_out_ch, layer_norm = True)
294
+ if self.graph_add == "set2set":
295
+ self.s2s = Set2Set(in_channels=pna_out_ch, processing_steps=1, num_layers=1)
296
+
297
+ if self.graph_add == "set2set":
298
+ pna_out_ch = pna_out_ch*2
299
+ self.mlp = nn.Sequential(nn.Linear(pna_out_ch,pna_out_ch), nn.Tanh(), nn.Linear(pna_out_ch,25), nn.Tanh(),nn.Linear(25,1))
300
+
301
+ def forward(self, x, edge_index, edge_attr, batch):
302
+
303
+ x = self.node_emb(x.squeeze())
304
+
305
+ edge_attr = self.edge_emb(edge_attr)
306
+
307
+ for conv, batch_norm in zip(self.convs, self.batch_norms):
308
+ x = F.relu(batch_norm(conv(x, edge_index, edge_attr)))
309
+
310
+ if self.graph_add == "global_add":
311
+ x = global_add_pool(x, batch.squeeze())
312
+
313
+ elif self.graph_add == "set2set":
314
+
315
+ x = self.s2s(x, batch.squeeze())
316
+ #elif self.graph_add == "graph_multitrans":
317
+ #x = self.graph_multitrans(x,batch.squeeze(),edge_index)
318
+ x = self.mlp(x)
319
+
320
+ return x"""
321
+
322
+
323
+
324
+
325
+ """class GraphConvolution(nn.Module):
326
+
327
+ def __init__(self, in_features, out_feature_list, b_dim, dropout,gcn_depth):
328
+ super(GraphConvolution, self).__init__()
329
+ self.in_features = in_features
330
+
331
+ self.gcn_depth = gcn_depth
332
+
333
+ self.out_feature_list = out_feature_list
334
+
335
+ self.gcn_in = nn.Sequential(nn.Linear(in_features,out_feature_list[0]),nn.Tanh(),
336
+ nn.Linear(out_feature_list[0],out_feature_list[0]),nn.Tanh(),
337
+ nn.Linear(out_feature_list[0], out_feature_list[0]), nn.Dropout(dropout))
338
+
339
+ self.gcn_convs = nn.ModuleList()
340
+
341
+ for _ in range(gcn_depth):
342
+
343
+ gcn_conv = nn.Sequential(nn.Linear(out_feature_list[0],out_feature_list[0]),nn.Tanh(),
344
+ nn.Linear(out_feature_list[0],out_feature_list[0]),nn.Tanh(),
345
+ nn.Linear(out_feature_list[0], out_feature_list[0]), nn.Dropout(dropout))
346
+
347
+ self.gcn_convs.append(gcn_conv)
348
+
349
+ self.gcn_out = nn.Sequential(nn.Linear(out_feature_list[0],out_feature_list[0]),nn.Tanh(),
350
+ nn.Linear(out_feature_list[0],out_feature_list[0]),nn.Tanh(),
351
+ nn.Linear(out_feature_list[0], out_feature_list[1]), nn.Dropout(dropout))
352
+
353
+ self.dropout = nn.Dropout(dropout)
354
+
355
+ def forward(self, input, adj, activation=None):
356
+ # input : 16x9x9
357
+ # adj : 16x4x9x9
358
+ hidden = torch.stack([self.gcn_in(input) for _ in range(adj.size(1))], 1)
359
+ hidden = torch.einsum('bijk,bikl->bijl', (adj, hidden))
360
+
361
+ hidden = torch.sum(hidden, 1) + self.gcn_in(input)
362
+ hidden = activation(hidden) if activation is not None else hidden
363
+
364
+ for gcn_conv in self.gcn_convs:
365
+ hidden1 = torch.stack([gcn_conv(hidden) for _ in range(adj.size(1))], 1)
366
+ hidden1 = torch.einsum('bijk,bikl->bijl', (adj, hidden1))
367
+ hidden = torch.sum(hidden1, 1) + gcn_conv(hidden)
368
+ hidden = activation(hidden) if activation is not None else hidden
369
+
370
+ output = torch.stack([self.gcn_out(hidden) for _ in range(adj.size(1))], 1)
371
+ output = torch.einsum('bijk,bikl->bijl', (adj, output))
372
+ output = torch.sum(output, 1) + self.gcn_out(hidden)
373
+ output = activation(output) if activation is not None else output
374
+
375
+
376
+ return output
377
+
378
+
379
+ class GraphAggregation(Module):
380
+
381
+ def __init__(self, in_features, out_features, m_dim, dropout):
382
+ super(GraphAggregation, self).__init__()
383
+ self.sigmoid_linear = nn.Sequential(nn.Linear(in_features+m_dim, out_features), nn.Sigmoid())
384
+ self.tanh_linear = nn.Sequential(nn.Linear(in_features+m_dim, out_features), nn.Tanh())
385
+ self.dropout = nn.Dropout(dropout)
386
+
387
+ def forward(self, input, activation):
388
+ i = self.sigmoid_linear(input)
389
+ j = self.tanh_linear(input)
390
+ output = torch.sum(torch.mul(i,j), 1)
391
+ output = activation(output) if activation is not None\
392
+ else output
393
+ output = self.dropout(output)
394
+
395
+ return output"""
396
+
397
+ """class Attention(nn.Module):
398
+ def __init__(self, dim, heads=4, attention_dropout=0., proj_dropout=0.):
399
+ super().__init__()
400
+ self.heads = heads
401
+ self.scale = 1./dim**0.5
402
+ #self.scale = torch.div(1, torch.pow(dim, 0.5)) #1./torch.pow(dim, 0.5) #dim**0.5 torch.div(x, 0.5)
403
+
404
+ self.qkv = nn.Linear(dim, dim*3, bias=False)
405
+
406
+ self.attention_dropout = nn.Dropout(attention_dropout)
407
+ self.out = nn.Sequential(
408
+ nn.Linear(dim, dim),
409
+ nn.Dropout(proj_dropout)
410
+ )
411
+ #self.noise_strength_1 = torch.nn.Parameter(torch.zeros([]))
412
+
413
+ def forward(self, x):
414
+ b, n, c = x.shape
415
+
416
+ #x = x + torch.randn([x.size(0), x.size(1), 1], device=x.device) * self.noise_strength_1
417
+
418
+ qkv = self.qkv(x).reshape(b, n, 3, self.heads, c//self.heads)
419
+
420
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
421
+
422
+ dot = (q @ k.transpose(-2, -1)) * self.scale
423
+
424
+ attn = dot.softmax(dim=-1)
425
+ attn = self.attention_dropout(attn)
426
+
427
+
428
+ x = (attn @ v).transpose(1, 2).reshape(b, n, c)
429
+
430
+ x = self.out(x)
431
+
432
+ return x, attn"""
433
+
434
+
435
+