Spaces:
Running
Running
Update models.py
Browse files
models.py
CHANGED
@@ -169,7 +169,11 @@ class Generator2(nn.Module):
|
|
169 |
#pos_enc = self.pos_enc(lap)
|
170 |
#drug_n = drug_n + pos_enc
|
171 |
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
|
174 |
edges_logits = self.edges_output_layer(edges_logits)
|
175 |
nodes_logits = self.nodes_output_layer(nodes_logits)
|
@@ -203,190 +207,4 @@ class simple_disc(nn.Module):
|
|
203 |
|
204 |
#prediction = F.softmax(prediction,dim=-1)
|
205 |
|
206 |
-
return prediction
|
207 |
-
|
208 |
-
"""class Discriminator(nn.Module):
|
209 |
-
|
210 |
-
def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
|
211 |
-
super(Discriminator, self).__init__()
|
212 |
-
self.degree = deg
|
213 |
-
self.aggregators = agg
|
214 |
-
self.scalers = sca
|
215 |
-
self.pna_in_channels = pna_in_ch
|
216 |
-
self.pna_out_channels = pna_out_ch
|
217 |
-
self.edge_dimension = edge_dim
|
218 |
-
self.towers = towers
|
219 |
-
self.pre_layers_num = pre_lay
|
220 |
-
self.post_layers_num = post_lay
|
221 |
-
self.pna_layer_num = pna_layer_num
|
222 |
-
self.graph_add = graph_add
|
223 |
-
self.PNA_layer = PNA(deg=self.degree, agg =self.aggregators,sca = self.scalers,
|
224 |
-
pna_in_ch= self.pna_in_channels, pna_out_ch = self.pna_out_channels, edge_dim = self.edge_dimension,
|
225 |
-
towers = self.towers, pre_lay = self.pre_layers_num, post_lay = self.post_layers_num,
|
226 |
-
pna_layer_num = self.pna_layer_num, graph_add = self.graph_add)
|
227 |
-
|
228 |
-
def forward(self, x, edge_index, edge_attr, batch, activation=None):
|
229 |
-
|
230 |
-
h = self.PNA_layer(x, edge_index, edge_attr, batch)
|
231 |
-
|
232 |
-
h = activation(h) if activation is not None else h
|
233 |
-
|
234 |
-
return h"""
|
235 |
-
|
236 |
-
"""class Discriminator2(nn.Module):
|
237 |
-
|
238 |
-
def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
|
239 |
-
super(Discriminator2, self).__init__()
|
240 |
-
self.degree = deg
|
241 |
-
self.aggregators = agg
|
242 |
-
self.scalers = sca
|
243 |
-
self.pna_in_channels = pna_in_ch
|
244 |
-
self.pna_out_channels = pna_out_ch
|
245 |
-
self.edge_dimension = edge_dim
|
246 |
-
self.towers = towers
|
247 |
-
self.pre_layers_num = pre_lay
|
248 |
-
self.post_layers_num = post_lay
|
249 |
-
self.pna_layer_num = pna_layer_num
|
250 |
-
self.graph_add = graph_add
|
251 |
-
self.PNA_layer = PNA(deg=self.degree, agg =self.aggregators,sca = self.scalers,
|
252 |
-
pna_in_ch= self.pna_in_channels, pna_out_ch = self.pna_out_channels, edge_dim = self.edge_dimension,
|
253 |
-
towers = self.towers, pre_lay = self.pre_layers_num, post_lay = self.post_layers_num,
|
254 |
-
pna_layer_num = self.pna_layer_num, graph_add = self.graph_add)
|
255 |
-
|
256 |
-
def forward(self, x, edge_index, edge_attr, batch, activation=None):
|
257 |
-
|
258 |
-
h = self.PNA_layer(x, edge_index, edge_attr, batch)
|
259 |
-
|
260 |
-
h = activation(h) if activation is not None else h
|
261 |
-
|
262 |
-
return h"""
|
263 |
-
|
264 |
-
|
265 |
-
"""class Discriminator_old(nn.Module):
|
266 |
-
|
267 |
-
def __init__(self, conv_dim, m_dim, b_dim, dropout, gcn_depth):
|
268 |
-
super(Discriminator_old, self).__init__()
|
269 |
-
|
270 |
-
graph_conv_dim, aux_dim, linear_dim = conv_dim
|
271 |
-
|
272 |
-
# discriminator
|
273 |
-
self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout,gcn_depth)
|
274 |
-
self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
|
275 |
-
|
276 |
-
# multi dense layer
|
277 |
-
layers = []
|
278 |
-
for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
|
279 |
-
layers.append(nn.Linear(c0,c1))
|
280 |
-
layers.append(nn.Dropout(dropout))
|
281 |
-
self.linear_layer = nn.Sequential(*layers)
|
282 |
-
|
283 |
-
self.output_layer = nn.Linear(linear_dim[-1], 1)
|
284 |
-
|
285 |
-
def forward(self, adj, hidden, node, activation=None):
|
286 |
-
|
287 |
-
adj = adj[:,:,:,1:].permute(0,3,1,2)
|
288 |
-
|
289 |
-
annotations = torch.cat((hidden, node), -1) if hidden is not None else node
|
290 |
-
|
291 |
-
h = self.gcn_layer(annotations, adj)
|
292 |
-
annotations = torch.cat((h, hidden, node) if hidden is not None\
|
293 |
-
else (h, node), -1)
|
294 |
-
|
295 |
-
h = self.agg_layer(annotations, torch.tanh)
|
296 |
-
h = self.linear_layer(h)
|
297 |
-
|
298 |
-
# Need to implement batch discriminator #
|
299 |
-
#########################################
|
300 |
-
|
301 |
-
output = self.output_layer(h)
|
302 |
-
output = activation(output) if activation is not None else output
|
303 |
-
|
304 |
-
return output, h"""
|
305 |
-
|
306 |
-
"""class Discriminator_old2(nn.Module):
|
307 |
-
|
308 |
-
def __init__(self, conv_dim, m_dim, b_dim, dropout, gcn_depth):
|
309 |
-
super(Discriminator_old2, self).__init__()
|
310 |
-
|
311 |
-
graph_conv_dim, aux_dim, linear_dim = conv_dim
|
312 |
-
|
313 |
-
# discriminator
|
314 |
-
self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout, gcn_depth)
|
315 |
-
self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
|
316 |
-
|
317 |
-
# multi dense layer
|
318 |
-
layers = []
|
319 |
-
for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
|
320 |
-
layers.append(nn.Linear(c0,c1))
|
321 |
-
layers.append(nn.Dropout(dropout))
|
322 |
-
self.linear_layer = nn.Sequential(*layers)
|
323 |
-
|
324 |
-
self.output_layer = nn.Linear(linear_dim[-1], 1)
|
325 |
-
|
326 |
-
def forward(self, adj, hidden, node, activation=None):
|
327 |
-
|
328 |
-
adj = adj[:,:,:,1:].permute(0,3,1,2)
|
329 |
-
|
330 |
-
annotations = torch.cat((hidden, node), -1) if hidden is not None else node
|
331 |
-
|
332 |
-
h = self.gcn_layer(annotations, adj)
|
333 |
-
annotations = torch.cat((h, hidden, node) if hidden is not None\
|
334 |
-
else (h, node), -1)
|
335 |
-
|
336 |
-
h = self.agg_layer(annotations, torch.tanh)
|
337 |
-
h = self.linear_layer(h)
|
338 |
-
|
339 |
-
# Need to implement batch discriminator #
|
340 |
-
#########################################
|
341 |
-
|
342 |
-
output = self.output_layer(h)
|
343 |
-
output = activation(output) if activation is not None else output
|
344 |
-
|
345 |
-
return output, h"""
|
346 |
-
|
347 |
-
"""class Discriminator3(nn.Module):
|
348 |
-
|
349 |
-
def __init__(self,in_ch):
|
350 |
-
super(Discriminator3, self).__init__()
|
351 |
-
self.dim = in_ch
|
352 |
-
|
353 |
-
|
354 |
-
self.TraConv_layer = TransformerConv(in_channels = self.dim,out_channels = self.dim//4,edge_dim = self.dim)
|
355 |
-
self.mlp = torch.nn.Sequential(torch.nn.Tanh(), torch.nn.Linear(self.dim//4,1))
|
356 |
-
def forward(self, x, edge_index, edge_attr, batch, activation=None):
|
357 |
-
|
358 |
-
h = self.TraConv_layer(x, edge_index, edge_attr)
|
359 |
-
h = global_add_pool(h,batch)
|
360 |
-
h = self.mlp(h)
|
361 |
-
h = activation(h) if activation is not None else h
|
362 |
-
|
363 |
-
return h"""
|
364 |
-
|
365 |
-
|
366 |
-
"""class PNA_Net(nn.Module):
|
367 |
-
def __init__(self,deg):
|
368 |
-
super().__init__()
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
self.convs = nn.ModuleList()
|
373 |
-
|
374 |
-
self.lin = nn.Linear(5, 128)
|
375 |
-
for _ in range(1):
|
376 |
-
conv = DenseGCNConv(128, 128, improved=False, bias=True)
|
377 |
-
self.convs.append(conv)
|
378 |
-
|
379 |
-
self.agg_layer = GraphAggregation(128, 128, 0, dropout=0.1)
|
380 |
-
self.mlp = nn.Sequential(nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 32), nn.Tanh(),
|
381 |
-
nn.Linear(32, 1))
|
382 |
-
|
383 |
-
def forward(self, x, adj,mask=None):
|
384 |
-
x = self.lin(x)
|
385 |
-
|
386 |
-
for conv in self.convs:
|
387 |
-
x = F.relu(conv(x, adj,mask=None))
|
388 |
-
|
389 |
-
x = self.agg_layer(x,torch.tanh)
|
390 |
-
|
391 |
-
return self.mlp(x) """
|
392 |
-
|
|
|
169 |
#pos_enc = self.pos_enc(lap)
|
170 |
#drug_n = drug_n + pos_enc
|
171 |
|
172 |
+
if self.submodel == "Ligand" or self.submodel == "RL" :
|
173 |
+
nodes_logits,akt1_annot, edges_logits, akt1_adj = self.TransformerDecoder(akt1_annot,nodes_logits,akt1_adj,edges_logits)
|
174 |
+
|
175 |
+
else:
|
176 |
+
nodes_logits,akt1_annot, edges_logits, akt1_adj = self.TransformerDecoder(nodes_logits,akt1_annot,edges_logits,akt1_adj)
|
177 |
|
178 |
edges_logits = self.edges_output_layer(edges_logits)
|
179 |
nodes_logits = self.nodes_output_layer(nodes_logits)
|
|
|
207 |
|
208 |
#prediction = F.softmax(prediction,dim=-1)
|
209 |
|
210 |
+
return prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|