osbm commited on
Commit
ace5f83
·
1 Parent(s): 0280c50

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +6 -188
models.py CHANGED
@@ -169,7 +169,11 @@ class Generator2(nn.Module):
169
  #pos_enc = self.pos_enc(lap)
170
  #drug_n = drug_n + pos_enc
171
 
172
- nodes_logits,akt1_annot, edges_logits, akt1_adj = self.TransformerDecoder(nodes_logits,akt1_annot,edges_logits,akt1_adj)
 
 
 
 
173
 
174
  edges_logits = self.edges_output_layer(edges_logits)
175
  nodes_logits = self.nodes_output_layer(nodes_logits)
@@ -203,190 +207,4 @@ class simple_disc(nn.Module):
203
 
204
  #prediction = F.softmax(prediction,dim=-1)
205
 
206
- return prediction
207
-
208
- """class Discriminator(nn.Module):
209
-
210
- def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
211
- super(Discriminator, self).__init__()
212
- self.degree = deg
213
- self.aggregators = agg
214
- self.scalers = sca
215
- self.pna_in_channels = pna_in_ch
216
- self.pna_out_channels = pna_out_ch
217
- self.edge_dimension = edge_dim
218
- self.towers = towers
219
- self.pre_layers_num = pre_lay
220
- self.post_layers_num = post_lay
221
- self.pna_layer_num = pna_layer_num
222
- self.graph_add = graph_add
223
- self.PNA_layer = PNA(deg=self.degree, agg =self.aggregators,sca = self.scalers,
224
- pna_in_ch= self.pna_in_channels, pna_out_ch = self.pna_out_channels, edge_dim = self.edge_dimension,
225
- towers = self.towers, pre_lay = self.pre_layers_num, post_lay = self.post_layers_num,
226
- pna_layer_num = self.pna_layer_num, graph_add = self.graph_add)
227
-
228
- def forward(self, x, edge_index, edge_attr, batch, activation=None):
229
-
230
- h = self.PNA_layer(x, edge_index, edge_attr, batch)
231
-
232
- h = activation(h) if activation is not None else h
233
-
234
- return h"""
235
-
236
- """class Discriminator2(nn.Module):
237
-
238
- def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
239
- super(Discriminator2, self).__init__()
240
- self.degree = deg
241
- self.aggregators = agg
242
- self.scalers = sca
243
- self.pna_in_channels = pna_in_ch
244
- self.pna_out_channels = pna_out_ch
245
- self.edge_dimension = edge_dim
246
- self.towers = towers
247
- self.pre_layers_num = pre_lay
248
- self.post_layers_num = post_lay
249
- self.pna_layer_num = pna_layer_num
250
- self.graph_add = graph_add
251
- self.PNA_layer = PNA(deg=self.degree, agg =self.aggregators,sca = self.scalers,
252
- pna_in_ch= self.pna_in_channels, pna_out_ch = self.pna_out_channels, edge_dim = self.edge_dimension,
253
- towers = self.towers, pre_lay = self.pre_layers_num, post_lay = self.post_layers_num,
254
- pna_layer_num = self.pna_layer_num, graph_add = self.graph_add)
255
-
256
- def forward(self, x, edge_index, edge_attr, batch, activation=None):
257
-
258
- h = self.PNA_layer(x, edge_index, edge_attr, batch)
259
-
260
- h = activation(h) if activation is not None else h
261
-
262
- return h"""
263
-
264
-
265
- """class Discriminator_old(nn.Module):
266
-
267
- def __init__(self, conv_dim, m_dim, b_dim, dropout, gcn_depth):
268
- super(Discriminator_old, self).__init__()
269
-
270
- graph_conv_dim, aux_dim, linear_dim = conv_dim
271
-
272
- # discriminator
273
- self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout,gcn_depth)
274
- self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
275
-
276
- # multi dense layer
277
- layers = []
278
- for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
279
- layers.append(nn.Linear(c0,c1))
280
- layers.append(nn.Dropout(dropout))
281
- self.linear_layer = nn.Sequential(*layers)
282
-
283
- self.output_layer = nn.Linear(linear_dim[-1], 1)
284
-
285
- def forward(self, adj, hidden, node, activation=None):
286
-
287
- adj = adj[:,:,:,1:].permute(0,3,1,2)
288
-
289
- annotations = torch.cat((hidden, node), -1) if hidden is not None else node
290
-
291
- h = self.gcn_layer(annotations, adj)
292
- annotations = torch.cat((h, hidden, node) if hidden is not None\
293
- else (h, node), -1)
294
-
295
- h = self.agg_layer(annotations, torch.tanh)
296
- h = self.linear_layer(h)
297
-
298
- # Need to implement batch discriminator #
299
- #########################################
300
-
301
- output = self.output_layer(h)
302
- output = activation(output) if activation is not None else output
303
-
304
- return output, h"""
305
-
306
- """class Discriminator_old2(nn.Module):
307
-
308
- def __init__(self, conv_dim, m_dim, b_dim, dropout, gcn_depth):
309
- super(Discriminator_old2, self).__init__()
310
-
311
- graph_conv_dim, aux_dim, linear_dim = conv_dim
312
-
313
- # discriminator
314
- self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout, gcn_depth)
315
- self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
316
-
317
- # multi dense layer
318
- layers = []
319
- for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
320
- layers.append(nn.Linear(c0,c1))
321
- layers.append(nn.Dropout(dropout))
322
- self.linear_layer = nn.Sequential(*layers)
323
-
324
- self.output_layer = nn.Linear(linear_dim[-1], 1)
325
-
326
- def forward(self, adj, hidden, node, activation=None):
327
-
328
- adj = adj[:,:,:,1:].permute(0,3,1,2)
329
-
330
- annotations = torch.cat((hidden, node), -1) if hidden is not None else node
331
-
332
- h = self.gcn_layer(annotations, adj)
333
- annotations = torch.cat((h, hidden, node) if hidden is not None\
334
- else (h, node), -1)
335
-
336
- h = self.agg_layer(annotations, torch.tanh)
337
- h = self.linear_layer(h)
338
-
339
- # Need to implement batch discriminator #
340
- #########################################
341
-
342
- output = self.output_layer(h)
343
- output = activation(output) if activation is not None else output
344
-
345
- return output, h"""
346
-
347
- """class Discriminator3(nn.Module):
348
-
349
- def __init__(self,in_ch):
350
- super(Discriminator3, self).__init__()
351
- self.dim = in_ch
352
-
353
-
354
- self.TraConv_layer = TransformerConv(in_channels = self.dim,out_channels = self.dim//4,edge_dim = self.dim)
355
- self.mlp = torch.nn.Sequential(torch.nn.Tanh(), torch.nn.Linear(self.dim//4,1))
356
- def forward(self, x, edge_index, edge_attr, batch, activation=None):
357
-
358
- h = self.TraConv_layer(x, edge_index, edge_attr)
359
- h = global_add_pool(h,batch)
360
- h = self.mlp(h)
361
- h = activation(h) if activation is not None else h
362
-
363
- return h"""
364
-
365
-
366
- """class PNA_Net(nn.Module):
367
- def __init__(self,deg):
368
- super().__init__()
369
-
370
-
371
-
372
- self.convs = nn.ModuleList()
373
-
374
- self.lin = nn.Linear(5, 128)
375
- for _ in range(1):
376
- conv = DenseGCNConv(128, 128, improved=False, bias=True)
377
- self.convs.append(conv)
378
-
379
- self.agg_layer = GraphAggregation(128, 128, 0, dropout=0.1)
380
- self.mlp = nn.Sequential(nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 32), nn.Tanh(),
381
- nn.Linear(32, 1))
382
-
383
- def forward(self, x, adj,mask=None):
384
- x = self.lin(x)
385
-
386
- for conv in self.convs:
387
- x = F.relu(conv(x, adj,mask=None))
388
-
389
- x = self.agg_layer(x,torch.tanh)
390
-
391
- return self.mlp(x) """
392
-
 
169
  #pos_enc = self.pos_enc(lap)
170
  #drug_n = drug_n + pos_enc
171
 
172
+ if self.submodel == "Ligand" or self.submodel == "RL" :
173
+ nodes_logits,akt1_annot, edges_logits, akt1_adj = self.TransformerDecoder(akt1_annot,nodes_logits,akt1_adj,edges_logits)
174
+
175
+ else:
176
+ nodes_logits,akt1_annot, edges_logits, akt1_adj = self.TransformerDecoder(nodes_logits,akt1_annot,edges_logits,akt1_adj)
177
 
178
  edges_logits = self.edges_output_layer(edges_logits)
179
  nodes_logits = self.nodes_output_layer(nodes_logits)
 
207
 
208
  #prediction = F.softmax(prediction,dim=-1)
209
 
210
+ return prediction