osbm commited on
Commit
1a3cfaf
·
1 Parent(s): 0b7b562
Files changed (6) hide show
  1. models.py +392 -0
  2. new_dataloader.py +349 -0
  3. requirements.txt +8 -0
  4. trainer.py +892 -0
  5. training_data.py +50 -0
  6. utils.py +462 -0
models.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from layers import TransformerEncoder, TransformerDecoder
5
+
6
+ class Generator(nn.Module):
7
+ """Generator network."""
8
+ def __init__(self, z_dim, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio, submodel):
9
+ super(Generator, self).__init__()
10
+
11
+ self.submodel = submodel
12
+ self.vertexes = vertexes
13
+ self.edges = edges
14
+ self.nodes = nodes
15
+ self.depth = depth
16
+ self.dim = dim
17
+ self.heads = heads
18
+ self.mlp_ratio = mlp_ratio
19
+
20
+ self.dropout = dropout
21
+ self.z_dim = z_dim
22
+
23
+ if act == "relu":
24
+ act = nn.ReLU()
25
+ elif act == "leaky":
26
+ act = nn.LeakyReLU()
27
+ elif act == "sigmoid":
28
+ act = nn.Sigmoid()
29
+ elif act == "tanh":
30
+ act = nn.Tanh()
31
+ self.features = vertexes * vertexes * edges + vertexes * nodes
32
+ self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
33
+ self.pos_enc_dim = 5
34
+ #self.pos_enc = nn.Linear(self.pos_enc_dim, self.dim)
35
+
36
+ self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
37
+ self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
38
+
39
+ self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act,
40
+ mlp_ratio=self.mlp_ratio, drop_rate=self.dropout)
41
+
42
+ self.readout_e = nn.Linear(self.dim, edges)
43
+ self.readout_n = nn.Linear(self.dim, nodes)
44
+ self.softmax = nn.Softmax(dim = -1)
45
+
46
+ def _generate_square_subsequent_mask(self, sz):
47
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
48
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
49
+ return mask
50
+
51
+ def laplacian_positional_enc(self, adj):
52
+
53
+ A = adj
54
+ D = torch.diag(torch.count_nonzero(A, dim=-1))
55
+ L = torch.eye(A.shape[0], device=A.device) - D * A * D
56
+
57
+ EigVal, EigVec = torch.linalg.eig(L)
58
+
59
+ idx = torch.argsort(torch.real(EigVal))
60
+ EigVal, EigVec = EigVal[idx], torch.real(EigVec[:,idx])
61
+ pos_enc = EigVec[:,1:self.pos_enc_dim + 1]
62
+
63
+ return pos_enc
64
+
65
+ def forward(self, z_e, z_n):
66
+ b, n, c = z_n.shape
67
+ _, _, _ , d = z_e.shape
68
+ #random_mask_e = torch.randint(low=0,high=2,size=(b,n,n,d)).to(z_e.device).float()
69
+ #random_mask_n = torch.randint(low=0,high=2,size=(b,n,c)).to(z_n.device).float()
70
+ #z_e = F.relu(z_e - random_mask_e)
71
+ #z_n = F.relu(z_n - random_mask_n)
72
+
73
+ #mask = self._generate_square_subsequent_mask(self.vertexes).to(z_e.device)
74
+
75
+ node = self.node_layers(z_n)
76
+
77
+ edge = self.edge_layers(z_e)
78
+
79
+ edge = (edge + edge.permute(0,2,1,3))/2
80
+
81
+ #lap = [self.laplacian_positional_enc(torch.max(x,-1)[1]) for x in edge]
82
+
83
+ #lap = torch.stack(lap).to(node.device)
84
+
85
+ #pos_enc = self.pos_enc(lap)
86
+
87
+ #node = node + pos_enc
88
+
89
+ node, edge = self.TransformerEncoder(node,edge)
90
+
91
+ node_sample = self.softmax(self.readout_n(node))
92
+
93
+ edge_sample = self.softmax(self.readout_e(edge))
94
+
95
+ return node, edge, node_sample, edge_sample
96
+
97
+
98
+
99
+ class Generator2(nn.Module):
100
+ def __init__(self, dim, dec_dim, depth, heads, mlp_ratio, drop_rate, drugs_m_dim, drugs_b_dim, submodel):
101
+ super().__init__()
102
+ self.submodel = submodel
103
+ self.depth = depth
104
+ self.dim = dim
105
+ self.mlp_ratio = mlp_ratio
106
+ self.heads = heads
107
+ self.dropout_rate = drop_rate
108
+ self.drugs_m_dim = drugs_m_dim
109
+ self.drugs_b_dim = drugs_b_dim
110
+
111
+ self.pos_enc_dim = 5
112
+
113
+
114
+ if self.submodel == "Prot":
115
+ self.prot_n = torch.nn.Linear(3822, 45) ## exact dimension of protein features
116
+ self.prot_e = torch.nn.Linear(298116, 2025) ## exact dimension of protein features
117
+
118
+ self.protn_dim = torch.nn.Linear(1, dec_dim)
119
+ self.prote_dim = torch.nn.Linear(1, dec_dim)
120
+
121
+
122
+ self.mol_nodes = nn.Linear(dim, dec_dim)
123
+ self.mol_edges = nn.Linear(dim, dec_dim)
124
+
125
+ self.drug_nodes = nn.Linear(self.drugs_m_dim, dec_dim)
126
+ self.drug_edges = nn.Linear(self.drugs_b_dim, dec_dim)
127
+
128
+ self.TransformerDecoder = TransformerDecoder(dec_dim, depth, heads, mlp_ratio, drop_rate=self.dropout_rate)
129
+
130
+ self.nodes_output_layer = nn.Linear(dec_dim, self.drugs_m_dim)
131
+ self.edges_output_layer = nn.Linear(dec_dim, self.drugs_b_dim)
132
+ self.softmax = nn.Softmax(dim=-1)
133
+
134
+ def laplacian_positional_enc(self, adj):
135
+
136
+ A = adj
137
+ D = torch.diag(torch.count_nonzero(A, dim=-1))
138
+ L = torch.eye(A.shape[0], device=A.device) - D * A * D
139
+
140
+ EigVal, EigVec = torch.linalg.eig(L)
141
+
142
+ idx = torch.argsort(torch.real(EigVal))
143
+ EigVal, EigVec = EigVal[idx], torch.real(EigVec[:,idx])
144
+ pos_enc = EigVec[:,1:self.pos_enc_dim + 1]
145
+
146
+ return pos_enc
147
+
148
+ def _generate_square_subsequent_mask(self, sz):
149
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
150
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
151
+ return mask
152
+
153
+ def forward(self, edges_logits, nodes_logits ,akt1_adj,akt1_annot):
154
+
155
+ edges_logits = self.mol_edges(edges_logits)
156
+ nodes_logits = self.mol_nodes(nodes_logits)
157
+
158
+ if self.submodel != "Prot":
159
+ akt1_annot = self.drug_nodes(akt1_annot)
160
+ akt1_adj = self.drug_edges(akt1_adj)
161
+
162
+ else:
163
+ akt1_adj = self.prote_dim(self.prot_e(akt1_adj).view(1,45,45,1))
164
+ akt1_annot = self.protn_dim(self.prot_n(akt1_annot).view(1,45,1))
165
+
166
+
167
+ #lap = [self.laplacian_positional_enc(torch.max(x,-1)[1]) for x in drug_e]
168
+ #lap = torch.stack(lap).to(drug_e.device)
169
+ #pos_enc = self.pos_enc(lap)
170
+ #drug_n = drug_n + pos_enc
171
+
172
+ nodes_logits,akt1_annot, edges_logits, akt1_adj = self.TransformerDecoder(nodes_logits,akt1_annot,edges_logits,akt1_adj)
173
+
174
+ edges_logits = self.edges_output_layer(edges_logits)
175
+ nodes_logits = self.nodes_output_layer(nodes_logits)
176
+
177
+ edges_logits = self.softmax(edges_logits)
178
+ nodes_logits = self.softmax(nodes_logits)
179
+
180
+ return edges_logits, nodes_logits
181
+
182
+
183
+ class simple_disc(nn.Module):
184
+ def __init__(self, act, m_dim, vertexes, b_dim):
185
+ super().__init__()
186
+ if act == "relu":
187
+ act = nn.ReLU()
188
+ elif act == "leaky":
189
+ act = nn.LeakyReLU()
190
+ elif act == "sigmoid":
191
+ act = nn.Sigmoid()
192
+ elif act == "tanh":
193
+ act = nn.Tanh()
194
+ features = vertexes * m_dim + vertexes * vertexes * b_dim
195
+
196
+ self.predictor = nn.Sequential(nn.Linear(features,256), act, nn.Linear(256,128), act, nn.Linear(128,64), act,
197
+ nn.Linear(64,32), act, nn.Linear(32,16), act,
198
+ nn.Linear(16,1))
199
+
200
+ def forward(self, x):
201
+
202
+ prediction = self.predictor(x)
203
+
204
+ #prediction = F.softmax(prediction,dim=-1)
205
+
206
+ return prediction
207
+
208
+ """class Discriminator(nn.Module):
209
+
210
+ def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
211
+ super(Discriminator, self).__init__()
212
+ self.degree = deg
213
+ self.aggregators = agg
214
+ self.scalers = sca
215
+ self.pna_in_channels = pna_in_ch
216
+ self.pna_out_channels = pna_out_ch
217
+ self.edge_dimension = edge_dim
218
+ self.towers = towers
219
+ self.pre_layers_num = pre_lay
220
+ self.post_layers_num = post_lay
221
+ self.pna_layer_num = pna_layer_num
222
+ self.graph_add = graph_add
223
+ self.PNA_layer = PNA(deg=self.degree, agg =self.aggregators,sca = self.scalers,
224
+ pna_in_ch= self.pna_in_channels, pna_out_ch = self.pna_out_channels, edge_dim = self.edge_dimension,
225
+ towers = self.towers, pre_lay = self.pre_layers_num, post_lay = self.post_layers_num,
226
+ pna_layer_num = self.pna_layer_num, graph_add = self.graph_add)
227
+
228
+ def forward(self, x, edge_index, edge_attr, batch, activation=None):
229
+
230
+ h = self.PNA_layer(x, edge_index, edge_attr, batch)
231
+
232
+ h = activation(h) if activation is not None else h
233
+
234
+ return h"""
235
+
236
+ """class Discriminator2(nn.Module):
237
+
238
+ def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
239
+ super(Discriminator2, self).__init__()
240
+ self.degree = deg
241
+ self.aggregators = agg
242
+ self.scalers = sca
243
+ self.pna_in_channels = pna_in_ch
244
+ self.pna_out_channels = pna_out_ch
245
+ self.edge_dimension = edge_dim
246
+ self.towers = towers
247
+ self.pre_layers_num = pre_lay
248
+ self.post_layers_num = post_lay
249
+ self.pna_layer_num = pna_layer_num
250
+ self.graph_add = graph_add
251
+ self.PNA_layer = PNA(deg=self.degree, agg =self.aggregators,sca = self.scalers,
252
+ pna_in_ch= self.pna_in_channels, pna_out_ch = self.pna_out_channels, edge_dim = self.edge_dimension,
253
+ towers = self.towers, pre_lay = self.pre_layers_num, post_lay = self.post_layers_num,
254
+ pna_layer_num = self.pna_layer_num, graph_add = self.graph_add)
255
+
256
+ def forward(self, x, edge_index, edge_attr, batch, activation=None):
257
+
258
+ h = self.PNA_layer(x, edge_index, edge_attr, batch)
259
+
260
+ h = activation(h) if activation is not None else h
261
+
262
+ return h"""
263
+
264
+
265
+ """class Discriminator_old(nn.Module):
266
+
267
+ def __init__(self, conv_dim, m_dim, b_dim, dropout, gcn_depth):
268
+ super(Discriminator_old, self).__init__()
269
+
270
+ graph_conv_dim, aux_dim, linear_dim = conv_dim
271
+
272
+ # discriminator
273
+ self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout,gcn_depth)
274
+ self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
275
+
276
+ # multi dense layer
277
+ layers = []
278
+ for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
279
+ layers.append(nn.Linear(c0,c1))
280
+ layers.append(nn.Dropout(dropout))
281
+ self.linear_layer = nn.Sequential(*layers)
282
+
283
+ self.output_layer = nn.Linear(linear_dim[-1], 1)
284
+
285
+ def forward(self, adj, hidden, node, activation=None):
286
+
287
+ adj = adj[:,:,:,1:].permute(0,3,1,2)
288
+
289
+ annotations = torch.cat((hidden, node), -1) if hidden is not None else node
290
+
291
+ h = self.gcn_layer(annotations, adj)
292
+ annotations = torch.cat((h, hidden, node) if hidden is not None\
293
+ else (h, node), -1)
294
+
295
+ h = self.agg_layer(annotations, torch.tanh)
296
+ h = self.linear_layer(h)
297
+
298
+ # Need to implement batch discriminator #
299
+ #########################################
300
+
301
+ output = self.output_layer(h)
302
+ output = activation(output) if activation is not None else output
303
+
304
+ return output, h"""
305
+
306
+ """class Discriminator_old2(nn.Module):
307
+
308
+ def __init__(self, conv_dim, m_dim, b_dim, dropout, gcn_depth):
309
+ super(Discriminator_old2, self).__init__()
310
+
311
+ graph_conv_dim, aux_dim, linear_dim = conv_dim
312
+
313
+ # discriminator
314
+ self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout, gcn_depth)
315
+ self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
316
+
317
+ # multi dense layer
318
+ layers = []
319
+ for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
320
+ layers.append(nn.Linear(c0,c1))
321
+ layers.append(nn.Dropout(dropout))
322
+ self.linear_layer = nn.Sequential(*layers)
323
+
324
+ self.output_layer = nn.Linear(linear_dim[-1], 1)
325
+
326
+ def forward(self, adj, hidden, node, activation=None):
327
+
328
+ adj = adj[:,:,:,1:].permute(0,3,1,2)
329
+
330
+ annotations = torch.cat((hidden, node), -1) if hidden is not None else node
331
+
332
+ h = self.gcn_layer(annotations, adj)
333
+ annotations = torch.cat((h, hidden, node) if hidden is not None\
334
+ else (h, node), -1)
335
+
336
+ h = self.agg_layer(annotations, torch.tanh)
337
+ h = self.linear_layer(h)
338
+
339
+ # Need to implement batch discriminator #
340
+ #########################################
341
+
342
+ output = self.output_layer(h)
343
+ output = activation(output) if activation is not None else output
344
+
345
+ return output, h"""
346
+
347
+ """class Discriminator3(nn.Module):
348
+
349
+ def __init__(self,in_ch):
350
+ super(Discriminator3, self).__init__()
351
+ self.dim = in_ch
352
+
353
+
354
+ self.TraConv_layer = TransformerConv(in_channels = self.dim,out_channels = self.dim//4,edge_dim = self.dim)
355
+ self.mlp = torch.nn.Sequential(torch.nn.Tanh(), torch.nn.Linear(self.dim//4,1))
356
+ def forward(self, x, edge_index, edge_attr, batch, activation=None):
357
+
358
+ h = self.TraConv_layer(x, edge_index, edge_attr)
359
+ h = global_add_pool(h,batch)
360
+ h = self.mlp(h)
361
+ h = activation(h) if activation is not None else h
362
+
363
+ return h"""
364
+
365
+
366
+ """class PNA_Net(nn.Module):
367
+ def __init__(self,deg):
368
+ super().__init__()
369
+
370
+
371
+
372
+ self.convs = nn.ModuleList()
373
+
374
+ self.lin = nn.Linear(5, 128)
375
+ for _ in range(1):
376
+ conv = DenseGCNConv(128, 128, improved=False, bias=True)
377
+ self.convs.append(conv)
378
+
379
+ self.agg_layer = GraphAggregation(128, 128, 0, dropout=0.1)
380
+ self.mlp = nn.Sequential(nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 32), nn.Tanh(),
381
+ nn.Linear(32, 1))
382
+
383
+ def forward(self, x, adj,mask=None):
384
+ x = self.lin(x)
385
+
386
+ for conv in self.convs:
387
+ x = F.relu(conv(x, adj,mask=None))
388
+
389
+ x = self.agg_layer(x,torch.tanh)
390
+
391
+ return self.mlp(x) """
392
+
new_dataloader.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os.path as osp
3
+ import re
4
+
5
+ import torch
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from rdkit import Chem
9
+ from rdkit import RDLogger
10
+ from torch_geometric.data import (Data, InMemoryDataset)
11
+
12
+ RDLogger.DisableLog('rdApp.*')
13
+ class DruggenDataset(InMemoryDataset):
14
+
15
+ def __init__(self, root, dataset_file, raw_files, max_atom, features, transform=None, pre_transform=None, pre_filter=None):
16
+ self.dataset_name = dataset_file.split(".")[0]
17
+ self.dataset_file = dataset_file
18
+ self.raw_files = raw_files
19
+ self.max_atom = max_atom
20
+ self.features = features
21
+
22
+ super().__init__(root, transform, pre_transform, pre_filter)
23
+ self.data, self.slices = torch.load(osp.join(root, dataset_file))
24
+
25
+
26
+ @property
27
+ def raw_file_names(self):
28
+ return self.raw_files
29
+
30
+ @property
31
+ def processed_file_names(self):
32
+ '''
33
+ Return the processed file names. If these names are not present, they will be automatically processed using process function of this class.
34
+ '''
35
+ return self.dataset_file
36
+
37
+ def _generate_encoders_decoders(self, data):
38
+ """
39
+ Generates the encoders and decoders for the atoms and bonds.
40
+ """
41
+ self.data = data
42
+ print('Creating atoms encoder and decoder..')
43
+
44
+ atom_labels = set()
45
+ # bond_labels = set()
46
+ self.max_atom_size_in_data = 0
47
+
48
+ for smile in data:
49
+ mol = Chem.MolFromSmiles(smile)
50
+ atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
51
+ # bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
52
+ self.max_atom_size_in_data = max(self.max_atom_size_in_data, mol.GetNumAtoms())
53
+ atom_labels.update([0]) # add PAD symbol (for unknown atoms)
54
+ atom_labels = sorted(atom_labels) # turn set into list and sort it
55
+
56
+ # atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0]))
57
+ self.atom_encoder_m = {l: i for i, l in enumerate(atom_labels)}
58
+ self.atom_decoder_m = {i: l for i, l in enumerate(atom_labels)}
59
+ self.atom_num_types = len(atom_labels)
60
+ print(f'Created atoms encoder and decoder with {self.atom_num_types - 1} atom types and 1 PAD symbol!')
61
+ print("atom_labels", atom_labels)
62
+ print('Creating bonds encoder and decoder..')
63
+ # bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType()
64
+ # for mol in self.data
65
+ # for bond in mol.GetBonds())))
66
+ bond_labels = [
67
+ Chem.rdchem.BondType.ZERO,
68
+ Chem.rdchem.BondType.SINGLE,
69
+ Chem.rdchem.BondType.DOUBLE,
70
+ Chem.rdchem.BondType.TRIPLE,
71
+ Chem.rdchem.BondType.AROMATIC,
72
+ ]
73
+
74
+ print("bond labels", bond_labels)
75
+ self.bond_encoder_m = {l: i for i, l in enumerate(bond_labels)}
76
+ self.bond_decoder_m = {i: l for i, l in enumerate(bond_labels)}
77
+ self.bond_num_types = len(bond_labels)
78
+ print(f'Created bonds encoder and decoder with {self.bond_num_types - 1} bond types and 1 PAD symbol!')
79
+ #dataset_names = str(self.dataset_name)
80
+ with open("DrugGEN/data/encoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_encoders:
81
+ pickle.dump(self.atom_encoder_m,atom_encoders)
82
+
83
+
84
+ with open("DrugGEN/data/decoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_decoders:
85
+ pickle.dump(self.atom_decoder_m,atom_decoders)
86
+
87
+
88
+ with open("DrugGEN/data/encoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_encoders:
89
+ pickle.dump(self.bond_encoder_m,bond_encoders)
90
+
91
+
92
+ with open("DrugGEN/data/decoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_decoders:
93
+ pickle.dump(self.bond_decoder_m,bond_decoders)
94
+
95
+
96
+
97
+ def generate_adjacency_matrix(self, mol, connected=True, max_length=None):
98
+ """
99
+ Generates the adjacency matrix for a molecule.
100
+
101
+ Args:
102
+ mol (Molecule): The molecule object.
103
+ connected (bool): Whether to check for connectivity in the molecule. Defaults to True.
104
+ max_length (int): The maximum length of the adjacency matrix. Defaults to the number of atoms in the molecule.
105
+
106
+ Returns:
107
+ numpy.ndarray or None: The adjacency matrix if connected and all atoms have a degree greater than 0,
108
+ otherwise None.
109
+ """
110
+ max_length = max_length if max_length is not None else mol.GetNumAtoms()
111
+
112
+ A = np.zeros(shape=(max_length, max_length))
113
+
114
+ begin, end = [b.GetBeginAtomIdx() for b in mol.GetBonds()], [b.GetEndAtomIdx() for b in mol.GetBonds()]
115
+ bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
116
+
117
+ A[begin, end] = bond_type
118
+ A[end, begin] = bond_type
119
+
120
+ degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
121
+
122
+ return A if connected and (degree > 0).all() else None
123
+
124
+ def generate_node_features(self, mol, max_length=None):
125
+ """
126
+ Generates the node features for a molecule.
127
+
128
+ Args:
129
+ mol (Molecule): The molecule object.
130
+ max_length (int): The maximum length of the node features. Defaults to the number of atoms in the molecule.
131
+
132
+ Returns:
133
+ numpy.ndarray: The node features matrix.
134
+ """
135
+ max_length = max_length if max_length is not None else mol.GetNumAtoms()
136
+
137
+ return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] + [0] * (
138
+ max_length - mol.GetNumAtoms()))
139
+
140
+ def generate_additional_features(self, mol, max_length=None):
141
+ """
142
+ Generates additional features for a molecule.
143
+
144
+ Args:
145
+ mol (Molecule): The molecule object.
146
+ max_length (int): The maximum length of the additional features. Defaults to the number of atoms in the molecule.
147
+
148
+ Returns:
149
+ numpy.ndarray: The additional features matrix.
150
+ """
151
+ max_length = max_length if max_length is not None else mol.GetNumAtoms()
152
+
153
+ features = np.array([[*[a.GetDegree() == i for i in range(5)],
154
+ *[a.GetExplicitValence() == i for i in range(9)],
155
+ *[int(a.GetHybridization()) == i for i in range(1, 7)],
156
+ *[a.GetImplicitValence() == i for i in range(9)],
157
+ a.GetIsAromatic(),
158
+ a.GetNoImplicit(),
159
+ *[a.GetNumExplicitHs() == i for i in range(5)],
160
+ *[a.GetNumImplicitHs() == i for i in range(5)],
161
+ *[a.GetNumRadicalElectrons() == i for i in range(5)],
162
+ a.IsInRing(),
163
+ *[a.IsInRingSize(i) for i in range(2, 9)]] for a in mol.GetAtoms()], dtype=np.int32)
164
+
165
+ return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
166
+
167
+ def decoder_load(self, dictionary_name):
168
+ with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
169
+ return pickle.load(f)
170
+
171
+ def drugs_decoder_load(self, dictionary_name):
172
+ with open("DrugGEN/data/decoders/" + dictionary_name +'.pkl', 'rb') as f:
173
+ return pickle.load(f)
174
+
175
+ def matrices2mol(self, node_labels, edge_labels, strict=True):
176
+ mol = Chem.RWMol()
177
+ RDLogger.DisableLog('rdApp.*')
178
+ atom_decoders = self.decoder_load("atom")
179
+ bond_decoders = self.decoder_load("bond")
180
+
181
+ for node_label in node_labels:
182
+ mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
183
+
184
+ for start, end in zip(*np.nonzero(edge_labels)):
185
+ if start > end:
186
+ mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
187
+ mol = self.correct_mol(mol)
188
+ if strict:
189
+ try:
190
+
191
+ Chem.SanitizeMol(mol)
192
+ except:
193
+ mol = None
194
+
195
+ return mol
196
+
197
+ def drug_decoder_load(self, dictionary_name):
198
+
199
+ ''' Loading the atom and bond decoders '''
200
+
201
+ with open("DrugGEN/data/decoders/" + dictionary_name +"_" + "akt_train" +'.pkl', 'rb') as f:
202
+
203
+ return pickle.load(f)
204
+ def matrices2mol_drugs(self, node_labels, edge_labels, strict=True):
205
+ mol = Chem.RWMol()
206
+ RDLogger.DisableLog('rdApp.*')
207
+ atom_decoders = self.drug_decoder_load("atom")
208
+ bond_decoders = self.drug_decoder_load("bond")
209
+
210
+ for node_label in node_labels:
211
+
212
+ mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
213
+
214
+ for start, end in zip(*np.nonzero(edge_labels)):
215
+ if start > end:
216
+ mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
217
+ mol = self.correct_mol(mol)
218
+ if strict:
219
+ try:
220
+ Chem.SanitizeMol(mol)
221
+ except:
222
+ mol = None
223
+
224
+ return mol
225
+ def check_valency(self,mol):
226
+ """
227
+ Checks that no atoms in the mol have exceeded their possible
228
+ valency
229
+ :return: True if no valency issues, False otherwise
230
+ """
231
+ try:
232
+ Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
233
+ return True, None
234
+ except ValueError as e:
235
+ e = str(e)
236
+ p = e.find('#')
237
+ e_sub = e[p:]
238
+ atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
239
+ return False, atomid_valence
240
+
241
+
242
+ def correct_mol(self,x):
243
+ # xsm = Chem.MolToSmiles(x, isomericSmiles=True)
244
+ mol = x
245
+ while True:
246
+ flag, atomid_valence = self.check_valency(mol)
247
+ if flag:
248
+ break
249
+ else:
250
+ assert len (atomid_valence) == 2
251
+ idx = atomid_valence[0]
252
+ v = atomid_valence[1]
253
+ queue = []
254
+ for b in mol.GetAtomWithIdx(idx).GetBonds():
255
+ queue.append(
256
+ (b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx())
257
+ )
258
+ queue.sort(key=lambda tup: tup[1], reverse=True)
259
+ if len(queue) > 0:
260
+ start = queue[0][2]
261
+ end = queue[0][3]
262
+ t = queue[0][1] - 1
263
+ mol.RemoveBond(start, end)
264
+
265
+ #if t >= 1:
266
+
267
+ #mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
268
+ # if '.' in Chem.MolToSmiles(mol, isomericSmiles=True):
269
+ # mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
270
+ # print(tt)
271
+ # print(Chem.MolToSmiles(mol, isomericSmiles=True))
272
+
273
+ return mol
274
+
275
+
276
+
277
+ def label2onehot(self, labels, dim):
278
+
279
+ """Convert label indices to one-hot vectors."""
280
+
281
+ out = torch.zeros(list(labels.size())+[dim])
282
+ out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
283
+
284
+ return out.float()
285
+
286
+ def process(self, size= None):
287
+ '''
288
+ Process the dataset. This function will be only run if processed_file_names does not exist in the data folder already.
289
+ '''
290
+ # mols = [Chem.MolFromSmiles(line) for line in open(self.raw_files, 'r').readlines()]
291
+ # mols = list(filter(lambda x: x.GetNumAtoms() <= self.max_atom, mols))
292
+ # mols = mols[:size] # i
293
+ # indices = range(len(mols))
294
+
295
+ smiles = pd.read_csv(self.raw_files, header=None)[0].tolist()
296
+ self._generate_encoders_decoders(smiles)
297
+
298
+ # pbar.set_description(f'Processing chembl dataset')
299
+ # max_length = max(mol.GetNumAtoms() for mol in mols)
300
+ data_list = []
301
+ max_length = min(self.max_atom_size_in_data, self.max_atom)
302
+ self.m_dim = len(self.atom_decoder_m)
303
+ # for idx in indices:
304
+ for smiles in tqdm(smiles, desc='Processing chembl dataset', total=len(smiles)):
305
+ # mol = mols[idx]
306
+
307
+ mol = Chem.MolFromSmiles(smile)
308
+
309
+ # filter by max atom size
310
+ if mol.GetNumAtoms() > max_length:
311
+ continue
312
+
313
+ A = self.generate_adjacency_matrix(mol, connected=True, max_length=max_length)
314
+ if A is not None:
315
+
316
+
317
+ x = torch.from_numpy(self.generate_node_features(mol, max_length=max_length)).to(torch.long).view(1, -1)
318
+
319
+ x = self.label2onehot(x,self.m_dim).squeeze()
320
+ if self.features:
321
+ f = torch.from_numpy(self.generate_additional_features(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
322
+ x = torch.concat((x,f), dim=-1)
323
+
324
+ adjacency = torch.from_numpy(A)
325
+
326
+ edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
327
+ edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
328
+
329
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
330
+
331
+ if self.pre_filter is not None and not self.pre_filter(data):
332
+ continue
333
+
334
+ if self.pre_transform is not None:
335
+ data = self.pre_transform(data)
336
+
337
+ data_list.append(data)
338
+ # pbar.update(1)
339
+
340
+ # pbar.close()
341
+
342
+ torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
343
+
344
+
345
+
346
+
347
+ if __name__ == '__main__':
348
+ data = DruggenDataset("DrugGEN/data")
349
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ rdkit-pypi
3
+ tqdm
4
+ numpy
5
+ seaborn
6
+ matplotlib
7
+ pandas
8
+ torch_geometric
trainer.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch.nn
4
+ import torch
5
+
6
+ from utils import *
7
+ from models import Generator, Generator2, simple_disc
8
+ import torch_geometric.utils as geoutils
9
+ #import #wandb
10
+ import re
11
+ from torch_geometric.loader import DataLoader
12
+ from new_dataloader import DruggenDataset
13
+ import torch.utils.data
14
+ from rdkit import RDLogger
15
+ import pickle
16
+ from rdkit.Chem.Scaffolds import MurckoScaffold
17
+ torch.set_num_threads(5)
18
+ RDLogger.DisableLog('rdApp.*')
19
+ from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
20
+ from training_data import load_data
21
+ import random
22
+
23
+
24
+ class Trainer(object):
25
+
26
+ """Trainer for training and testing DrugGEN."""
27
+
28
+ def __init__(self, config):
29
+
30
+ self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
31
+ """Initialize configurations."""
32
+ self.submodel = config.submodel
33
+ self.inference_model = config.inference_model
34
+ # Data loader.
35
+ self.raw_file = config.raw_file # SMILES containing text file for first dataset.
36
+ # Write the full path to file.
37
+
38
+ self.drug_raw_file = config.drug_raw_file # SMILES containing text file for second dataset.
39
+ # Write the full path to file.
40
+
41
+
42
+ self.dataset_file = config.dataset_file # Dataset file name for the first GAN.
43
+ # Contains large number of molecules.
44
+
45
+ self.drugs_dataset_file = config.drug_dataset_file # Drug dataset file name for the second GAN.
46
+ # Contains drug molecules only. (In this case AKT1 inhibitors.)
47
+
48
+ self.inf_raw_file = config.inf_raw_file # SMILES containing text file for first dataset.
49
+ # Write the full path to file.
50
+
51
+ self.inf_drug_raw_file = config.inf_drug_raw_file # SMILES containing text file for second dataset.
52
+ # Write the full path to file.
53
+
54
+
55
+ self.inf_dataset_file = config.inf_dataset_file # Dataset file name for the first GAN.
56
+ # Contains large number of molecules.
57
+
58
+ self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN.
59
+ # Contains drug molecules only. (In this case AKT1 inhibitors.)
60
+
61
+ self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
62
+
63
+ self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
64
+
65
+ self.dataset_name = self.dataset_file.split(".")[0]
66
+ self.drugs_name = self.drugs_dataset_file.split(".")[0]
67
+
68
+ self.max_atom = config.max_atom # Model is based on one-shot generation.
69
+ # Max atom number for molecules must be specified.
70
+
71
+ self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
72
+ # Additional node features can be added. Please check new_dataloarder.py Line 102.
73
+
74
+
75
+ self.batch_size = config.batch_size # Batch size for training.
76
+
77
+ self.dataset = DruggenDataset(self.mol_data_dir,
78
+ self.dataset_file,
79
+ self.raw_file,
80
+ self.max_atom,
81
+ self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class.
82
+ # Can create any molecular graph dataset given smiles string.
83
+ # Nonisomeric SMILES are suggested but not necessary.
84
+ # Uses sparse matrix representation for graphs,
85
+ # For computational and speed efficiency.
86
+
87
+ self.loader = DataLoader(self.dataset,
88
+ shuffle=True,
89
+ batch_size=self.batch_size,
90
+ drop_last=True) # PyG dataloader for the first GAN.
91
+
92
+ self.drugs = DruggenDataset(self.drug_data_dir,
93
+ self.drugs_dataset_file,
94
+ self.drug_raw_file,
95
+ self.max_atom,
96
+ self.features) # Dataset for the second GAN. Custom dataset class from PyG parent class.
97
+ # Can create any molecular graph dataset given smiles string.
98
+ # Nonisomeric SMILES are suggested but not necessary.
99
+ # Uses sparse matrix representation for graphs,
100
+ # For computational and speed efficiency.
101
+
102
+ self.drugs_loader = DataLoader(self.drugs,
103
+ shuffle=True,
104
+ batch_size=self.batch_size,
105
+ drop_last=True) # PyG dataloader for the second GAN.
106
+
107
+ # Atom and bond type dimensions for the construction of the model.
108
+
109
+ self.atom_decoders = self.decoder_load("atom") # Atom type decoders for first GAN.
110
+ # eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F)
111
+
112
+ self.bond_decoders = self.decoder_load("bond") # Bond type decoders for first GAN.
113
+ # eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic)
114
+
115
+ self.m_dim = len(self.atom_decoders) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension.
116
+
117
+ self.b_dim = len(self.bond_decoders) # Bond type dimension.
118
+
119
+ self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
120
+
121
+ self.drugs_atom_decoders = self.drug_decoder_load("atom") # Atom type decoders for second GAN.
122
+ # eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F)
123
+
124
+ self.drugs_bond_decoders = self.drug_decoder_load("bond") # Bond type decoders for second GAN.
125
+ # eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic)
126
+
127
+ self.drugs_m_dim = len(self.drugs_atom_decoders) if not self.features else int(self.drugs_loader.dataset[0].x.shape[1]) # Atom type dimension.
128
+
129
+ self.drugs_b_dim = len(self.drugs_bond_decoders) # Bond type dimension.
130
+
131
+ self.drug_vertexes = int(self.drugs_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
132
+
133
+ # Transformer and Convolution configurations.
134
+
135
+ self.act = config.act
136
+
137
+ self.z_dim = config.z_dim
138
+
139
+ self.lambda_gp = config.lambda_gp
140
+
141
+ self.dim = config.dim
142
+
143
+ self.depth = config.depth
144
+
145
+ self.heads = config.heads
146
+
147
+ self.mlp_ratio = config.mlp_ratio
148
+
149
+ self.dec_depth = config.dec_depth
150
+
151
+ self.dec_heads = config.dec_heads
152
+
153
+ self.dec_dim = config.dec_dim
154
+
155
+ self.dis_select = config.dis_select
156
+
157
+ """self.la = config.la
158
+ self.la2 = config.la2
159
+ self.gcn_depth = config.gcn_depth
160
+ self.g_conv_dim = config.g_conv_dim
161
+ self.d_conv_dim = config.d_conv_dim"""
162
+ """# PNA config
163
+
164
+ self.agg = config.aggregators
165
+ self.sca = config.scalers
166
+ self.pna_in_ch = config.pna_in_ch
167
+ self.pna_out_ch = config.pna_out_ch
168
+ self.edge_dim = config.edge_dim
169
+ self.towers = config.towers
170
+ self.pre_lay = config.pre_lay
171
+ self.post_lay = config.post_lay
172
+ self.pna_layer_num = config.pna_layer_num
173
+ self.graph_add = config.graph_add"""
174
+
175
+ # Training configurations.
176
+
177
+ self.epoch = config.epoch
178
+
179
+ self.g_lr = config.g_lr
180
+
181
+ self.d_lr = config.d_lr
182
+
183
+ self.g2_lr = config.g2_lr
184
+
185
+ self.d2_lr = config.d2_lr
186
+
187
+ self.dropout = config.dropout
188
+
189
+ self.dec_dropout = config.dec_dropout
190
+
191
+ self.n_critic = config.n_critic
192
+
193
+ self.beta1 = config.beta1
194
+
195
+ self.beta2 = config.beta2
196
+
197
+ self.resume_iters = config.resume_iters
198
+
199
+ self.warm_up_steps = config.warm_up_steps
200
+
201
+ # Test configurations.
202
+
203
+ self.num_test_epoch = config.num_test_epoch
204
+
205
+ self.test_iters = config.test_iters
206
+
207
+ self.inference_sample_num = config.inference_sample_num
208
+
209
+ # Directories.
210
+
211
+ self.log_dir = config.log_dir
212
+ self.sample_dir = config.sample_dir
213
+ self.model_save_dir = config.model_save_dir
214
+ self.result_dir = config.result_dir
215
+
216
+ # Step size.
217
+
218
+ self.log_step = config.log_sample_step
219
+ self.clipping_value = config.clipping_value
220
+ # Miscellaneous.
221
+
222
+ self.mode = config.mode
223
+
224
+ self.noise_strength_0 = torch.nn.Parameter(torch.zeros([]))
225
+ self.noise_strength_1 = torch.nn.Parameter(torch.zeros([]))
226
+ self.noise_strength_2 = torch.nn.Parameter(torch.zeros([]))
227
+ self.noise_strength_3 = torch.nn.Parameter(torch.zeros([]))
228
+
229
+ self.init_type = config.init_type
230
+ self.build_model()
231
+
232
+
233
+
234
+ def build_model(self):
235
+ """Create generators and discriminators."""
236
+
237
+ ''' Generator is based on Transformer Encoder:
238
+
239
+ @ g_conv_dim: Dimensions for first MLP layers before Transformer Encoder
240
+ @ vertexes: maximum length of generated molecules (atom length)
241
+ @ b_dim: number of bond types
242
+ @ m_dim: number of atom types (or number of features used)
243
+ @ dropout: dropout possibility
244
+ @ dim: Hidden dimension of Transformer Encoder
245
+ @ depth: Transformer layer number
246
+ @ heads: Number of multihead-attention heads
247
+ @ mlp_ratio: Read-out layer dimension of Transformer
248
+ @ drop_rate: depricated
249
+ @ tra_conv: Whether module creates output for TransformerConv discriminator
250
+ '''
251
+
252
+ self.G = Generator(self.z_dim,
253
+ self.act,
254
+ self.vertexes,
255
+ self.b_dim,
256
+ self.m_dim,
257
+ self.dropout,
258
+ dim=self.dim,
259
+ depth=self.depth,
260
+ heads=self.heads,
261
+ mlp_ratio=self.mlp_ratio,
262
+ submodel = self.submodel)
263
+
264
+ self.G2 = Generator2(self.dim,
265
+ self.dec_dim,
266
+ self.dec_depth,
267
+ self.dec_heads,
268
+ self.mlp_ratio,
269
+ self.dec_dropout,
270
+ self.drugs_m_dim,
271
+ self.drugs_b_dim,
272
+ self.submodel)
273
+
274
+
275
+
276
+ ''' Discriminator implementation with PNA:
277
+
278
+ @ deg: Degree distribution based on used data. (Created with _genDegree() function)
279
+ @ agg: aggregators used in PNA
280
+ @ sca: scalers used in PNA
281
+ @ pna_in_ch: First PNA hidden dimension
282
+ @ pna_out_ch: Last PNA hidden dimension
283
+ @ edge_dim: Edge hidden dimension
284
+ @ towers: Number of towers (Splitting the hidden dimension to multiple parallel processes)
285
+ @ pre_lay: Pre-transformation layer
286
+ @ post_lay: Post-transformation layer
287
+ @ pna_layer_num: number of PNA layers
288
+ @ graph_add: global pooling layer selection
289
+ '''
290
+
291
+
292
+ ''' Discriminator implementation with Graph Convolution:
293
+
294
+ @ d_conv_dim: convolution dimensions for GCN
295
+ @ m_dim: number of atom types (or number of features used)
296
+ @ b_dim: number of bond types
297
+ @ dropout: dropout possibility
298
+ '''
299
+
300
+ ''' Discriminator implementation with MLP:
301
+
302
+ @ act: Activation function for MLP
303
+ @ m_dim: number of atom types (or number of features used)
304
+ @ b_dim: number of bond types
305
+ @ dropout: dropout possibility
306
+ @ vertexes: maximum length of generated molecules (molecule length)
307
+ '''
308
+
309
+ #self.D = Discriminator_old(self.d_conv_dim, self.m_dim , self.b_dim, self.dropout, self.gcn_depth)
310
+ self.D2 = simple_disc("tanh", self.drugs_m_dim, self.drug_vertexes, self.drugs_b_dim)
311
+ self.D = simple_disc("tanh", self.m_dim, self.vertexes, self.b_dim)
312
+ self.V = simple_disc("tanh", self.m_dim, self.vertexes, self.b_dim)
313
+ self.V2 = simple_disc("tanh", self.drugs_m_dim, self.drug_vertexes, self.drugs_b_dim)
314
+
315
+ ''' Optimizers for G1, G2, D1, and D2:
316
+
317
+ Adam Optimizer is used and different beta1 and beta2s are used for GAN1 and GAN2
318
+ '''
319
+
320
+ self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
321
+ self.g2_optimizer = torch.optim.AdamW(self.G2.parameters(), self.g2_lr, [self.beta1, self.beta2])
322
+
323
+ self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
324
+ self.d2_optimizer = torch.optim.AdamW(self.D2.parameters(), self.d2_lr, [self.beta1, self.beta2])
325
+
326
+
327
+
328
+ self.v_optimizer = torch.optim.AdamW(self.V.parameters(), self.d_lr, [self.beta1, self.beta2])
329
+ self.v2_optimizer = torch.optim.AdamW(self.V2.parameters(), self.d2_lr, [self.beta1, self.beta2])
330
+ ''' Learning rate scheduler:
331
+
332
+ Changes learning rate based on loss.
333
+ '''
334
+
335
+ #self.scheduler_g = ReduceLROnPlateau(self.g_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
336
+
337
+
338
+ #self.scheduler_d = ReduceLROnPlateau(self.d_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
339
+
340
+ #self.scheduler_v = ReduceLROnPlateau(self.v_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
341
+ #self.scheduler_g2 = ReduceLROnPlateau(self.g2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
342
+ #self.scheduler_d2 = ReduceLROnPlateau(self.d2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
343
+ #self.scheduler_v2 = ReduceLROnPlateau(self.v2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
344
+ self.print_network(self.G, 'G')
345
+ self.print_network(self.D, 'D')
346
+
347
+ self.print_network(self.G2, 'G2')
348
+ self.print_network(self.D2, 'D2')
349
+
350
+ self.G.to(self.device)
351
+ self.D.to(self.device)
352
+
353
+ self.V.to(self.device)
354
+ self.V2.to(self.device)
355
+ self.G2.to(self.device)
356
+ self.D2.to(self.device)
357
+
358
+ #self.V2.to(self.device)
359
+ #self.modules_of_the_model = (self.G, self.D, self.G2, self.D2)
360
+ """for p in self.G.parameters():
361
+ if p.dim() > 1:
362
+ if self.init_type == 'uniform':
363
+ torch.nn.init.xavier_uniform_(p)
364
+ elif self.init_type == 'normal':
365
+ torch.nn.init.xavier_normal_(p)
366
+ elif self.init_type == 'random_normal':
367
+ torch.nn.init.normal_(p, 0.0, 0.02)
368
+ for p in self.G2.parameters():
369
+ if p.dim() > 1:
370
+ if self.init_type == 'uniform':
371
+ torch.nn.init.xavier_uniform_(p)
372
+ elif self.init_type == 'normal':
373
+ torch.nn.init.xavier_normal_(p)
374
+ elif self.init_type == 'random_normal':
375
+ torch.nn.init.normal_(p, 0.0, 0.02)
376
+ if self.dis_select == "conv":
377
+ for p in self.D.parameters():
378
+ if p.dim() > 1:
379
+ if self.init_type == 'uniform':
380
+ torch.nn.init.xavier_uniform_(p)
381
+ elif self.init_type == 'normal':
382
+ torch.nn.init.xavier_normal_(p)
383
+ elif self.init_type == 'random_normal':
384
+ torch.nn.init.normal_(p, 0.0, 0.02)
385
+
386
+ if self.dis_select == "conv":
387
+ for p in self.D2.parameters():
388
+ if p.dim() > 1:
389
+ if self.init_type == 'uniform':
390
+ torch.nn.init.xavier_uniform_(p)
391
+ elif self.init_type == 'normal':
392
+ torch.nn.init.xavier_normal_(p)
393
+ elif self.init_type == 'random_normal':
394
+ torch.nn.init.normal_(p, 0.0, 0.02)"""
395
+
396
+
397
+ def decoder_load(self, dictionary_name):
398
+
399
+ ''' Loading the atom and bond decoders'''
400
+
401
+ with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
402
+
403
+ return pickle.load(f)
404
+
405
+ def drug_decoder_load(self, dictionary_name):
406
+
407
+ ''' Loading the atom and bond decoders'''
408
+
409
+ with open("DrugGEN/data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
410
+
411
+ return pickle.load(f)
412
+
413
+ def print_network(self, model, name):
414
+
415
+ """Print out the network information."""
416
+
417
+ num_params = 0
418
+ for p in model.parameters():
419
+ num_params += p.numel()
420
+ print(model)
421
+ print(name)
422
+ print("The number of parameters: {}".format(num_params))
423
+
424
+
425
+ def restore_model(self, epoch, iteration, model_directory):
426
+
427
+ """Restore the trained generator and discriminator."""
428
+
429
+ print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
430
+
431
+ G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
432
+ #D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
433
+
434
+ self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
435
+ #self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
436
+
437
+
438
+ G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration))
439
+ #D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(epoch, iteration))
440
+
441
+ self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
442
+ #self.D2.load_state_dict(torch.load(D2_path, map_location=lambda storage, loc: storage))
443
+
444
+
445
+ def save_model(self, model_directory, idx,i):
446
+ G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1))
447
+ D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1))
448
+ torch.save(self.G.state_dict(), G_path)
449
+ torch.save(self.D.state_dict(), D_path)
450
+
451
+ if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
452
+ G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(idx+1,i+1))
453
+ D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(idx+1,i+1))
454
+
455
+ torch.save(self.G2.state_dict(), G2_path)
456
+ torch.save(self.D2.state_dict(), D2_path)
457
+
458
+ def reset_grad(self):
459
+
460
+ """Reset the gradient buffers."""
461
+
462
+ self.g_optimizer.zero_grad()
463
+ self.v_optimizer.zero_grad()
464
+ self.g2_optimizer.zero_grad()
465
+ self.v2_optimizer.zero_grad()
466
+
467
+ self.d_optimizer.zero_grad()
468
+ self.d2_optimizer.zero_grad()
469
+
470
+ def gradient_penalty(self, y, x):
471
+
472
+ """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
473
+
474
+ weight = torch.ones(y.size(),requires_grad=False).to(self.device)
475
+ dydx = torch.autograd.grad(outputs=y,
476
+ inputs=x,
477
+ grad_outputs=weight,
478
+ retain_graph=True,
479
+ create_graph=True,
480
+ only_inputs=True)[0]
481
+
482
+ dydx = dydx.view(dydx.size(0), -1)
483
+ gradient_penalty = ((dydx.norm(2, dim=1) - 1) ** 2).mean()
484
+
485
+ return gradient_penalty
486
+
487
+ def train(self):
488
+
489
+ ''' Training Script starts from here'''
490
+
491
+ #wandb.config = {'beta2': 0.999}
492
+ #wandb.init(project="DrugGEN2", entity="atabeyunlu")
493
+
494
+ # Defining sampling paths and creating logger
495
+
496
+ self.arguments = "{}_glr{}_dlr{}_g2lr{}_d2lr{}_dim{}_depth{}_heads{}_decdepth{}_decheads{}_ncritic{}_batch{}_epoch{}_warmup{}_dataset{}_dropout{}".format(self.submodel,self.g_lr,self.d_lr,self.g2_lr,self.d2_lr,self.dim,self.depth,self.heads,self.dec_depth,self.dec_heads,self.n_critic,self.batch_size,self.epoch,self.warm_up_steps,self.dataset_name,self.dropout)
497
+
498
+ self.model_directory= os.path.join(self.model_save_dir,self.arguments)
499
+ self.sample_directory=os.path.join(self.sample_dir,self.arguments)
500
+ self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments))
501
+ if not os.path.exists(self.model_directory):
502
+ os.makedirs(self.model_directory)
503
+ if not os.path.exists(self.sample_directory):
504
+ os.makedirs(self.sample_directory)
505
+
506
+ # Learning rate cache for decaying.
507
+
508
+
509
+ # protein data
510
+ full_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
511
+ drug_smiles = [line for line in open("DrugGEN/data/akt_train.smi", 'r').read().splitlines()]
512
+
513
+ drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
514
+ drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
515
+ fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
516
+
517
+ akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
518
+ akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
519
+
520
+ # Start training.
521
+
522
+ print('Start training...')
523
+ self.start_time = time.time()
524
+ for idx in range(self.epoch):
525
+
526
+ # =================================================================================== #
527
+ # 1. Preprocess input data #
528
+ # =================================================================================== #
529
+
530
+ # Load the data
531
+
532
+ dataloader_iterator = iter(self.drugs_loader)
533
+
534
+ for i, data in enumerate(self.loader):
535
+ try:
536
+ drugs = next(dataloader_iterator)
537
+ except StopIteration:
538
+ dataloader_iterator = iter(self.drugs_loader)
539
+ drugs = next(dataloader_iterator)
540
+
541
+ # Preprocess both dataset
542
+
543
+ bulk_data = load_data(data,
544
+ drugs,
545
+ self.batch_size,
546
+ self.device,
547
+ self.b_dim,
548
+ self.m_dim,
549
+ self.drugs_b_dim,
550
+ self.drugs_m_dim,
551
+ self.z_dim,
552
+ self.vertexes)
553
+
554
+ drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
555
+
556
+ if self.submodel == "CrossLoss":
557
+ GAN1_input_e = drugs_a_tensor
558
+ GAN1_input_x = drugs_x_tensor
559
+ GAN1_disc_e = a_tensor
560
+ GAN1_disc_x = x_tensor
561
+ elif self.submodel == "Ligand":
562
+ GAN1_input_e = a_tensor
563
+ GAN1_input_x = x_tensor
564
+ GAN1_disc_e = a_tensor
565
+ GAN1_disc_x = x_tensor
566
+ GAN2_input_e = drugs_a_tensor
567
+ GAN2_input_x = drugs_x_tensor
568
+ GAN2_disc_e = drugs_a_tensor
569
+ GAN2_disc_x = drugs_x_tensor
570
+ elif self.submodel == "Prot":
571
+ GAN1_input_e = a_tensor
572
+ GAN1_input_x = x_tensor
573
+ GAN1_disc_e = a_tensor
574
+ GAN1_disc_x = x_tensor
575
+ GAN2_input_e = akt1_human_adj
576
+ GAN2_input_x = akt1_human_annot
577
+ GAN2_disc_e = drugs_a_tensor
578
+ GAN2_disc_x = drugs_x_tensor
579
+ elif self.submodel == "RL":
580
+ GAN1_input_e = z_edge
581
+ GAN1_input_x = z_node
582
+ GAN1_disc_e = a_tensor
583
+ GAN1_disc_x = x_tensor
584
+ GAN2_input_e = drugs_a_tensor
585
+ GAN2_input_x = drugs_x_tensor
586
+ GAN2_disc_e = drugs_a_tensor
587
+ GAN2_disc_x = drugs_x_tensor
588
+ elif self.submodel == "NoTarget":
589
+ GAN1_input_e = z_edge
590
+ GAN1_input_x = z_node
591
+ GAN1_disc_e = a_tensor
592
+ GAN1_disc_x = x_tensor
593
+
594
+ # =================================================================================== #
595
+ # 2. Train the discriminator #
596
+ # =================================================================================== #
597
+ loss = {}
598
+ self.reset_grad()
599
+
600
+ # Compute discriminator loss.
601
+
602
+ node, edge, d_loss = discriminator_loss(self.G,
603
+ self.D,
604
+ real_graphs,
605
+ GAN1_disc_e,
606
+ GAN1_disc_x,
607
+ self.batch_size,
608
+ self.device,
609
+ self.gradient_penalty,
610
+ self.lambda_gp,
611
+ GAN1_input_e,
612
+ GAN1_input_x)
613
+
614
+ d_total = d_loss
615
+ if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
616
+ d2_loss = discriminator2_loss(self.G2,
617
+ self.D2,
618
+ drug_graphs,
619
+ edge,
620
+ node,
621
+ self.batch_size,
622
+ self.device,
623
+ self.gradient_penalty,
624
+ self.lambda_gp,
625
+ GAN2_input_e,
626
+ GAN2_input_x)
627
+ d_total = d_loss + d2_loss
628
+
629
+ loss["d_total"] = d_total.item()
630
+ d_total.backward()
631
+ self.d_optimizer.step()
632
+ if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
633
+ self.d2_optimizer.step()
634
+ self.reset_grad()
635
+ generator_output = generator_loss(self.G,
636
+ self.D,
637
+ self.V,
638
+ GAN1_input_e,
639
+ GAN1_input_x,
640
+ self.batch_size,
641
+ sim_reward,
642
+ self.dataset.matrices2mol_drugs,
643
+ fps_r,
644
+ self.submodel)
645
+
646
+ g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output
647
+
648
+ self.reset_grad()
649
+ g_total = g_loss
650
+ if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
651
+ output = generator2_loss(self.G2,
652
+ self.D2,
653
+ self.V2,
654
+ edge,
655
+ node,
656
+ self.batch_size,
657
+ sim_reward,
658
+ self.dataset.matrices2mol_drugs,
659
+ fps_r,
660
+ GAN2_input_e,
661
+ GAN2_input_x,
662
+ self.submodel)
663
+
664
+ g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output
665
+
666
+ g_total = g_loss + g2_loss
667
+
668
+ loss["g_total"] = g_total.item()
669
+ g_total.backward()
670
+ self.g_optimizer.step()
671
+ if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
672
+ self.g2_optimizer.step()
673
+
674
+ if self.submodel == "RL":
675
+ self.v_optimizer.step()
676
+ self.v2_optimizer.step()
677
+
678
+
679
+ if (i+1) % self.log_step == 0:
680
+
681
+ logging(self.log_path, self.start_time, fake_mol, full_smiles, i, idx, loss, 1,self.sample_directory)
682
+ mol_sample(self.sample_directory,"GAN1",fake_mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), idx, i)
683
+ if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
684
+ logging(self.log_path, self.start_time, fake_mol_g, drug_smiles, i, idx, loss, 2,self.sample_directory)
685
+ mol_sample(self.sample_directory,"GAN2",fake_mol_g, dr_g_edges_hat_sample.detach(), dr_g_nodes_hat_sample.detach(), idx, i)
686
+
687
+
688
+ if (idx+1) % 10 == 0:
689
+ self.save_model(self.model_directory,idx,i)
690
+ print("model saved at epoch {} and iteration {}".format(idx,i))
691
+
692
+
693
+
694
+ def inference(self):
695
+
696
+ # Load the trained generator.
697
+ self.G.to(self.device)
698
+ #self.D.to(self.device)
699
+ self.G2.to(self.device)
700
+ #self.D2.to(self.device)
701
+
702
+ G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel))
703
+ self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
704
+ G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
705
+ self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
706
+
707
+
708
+ drug_smiles = [line for line in open("DrugGEN/data/akt_test.smi", 'r').read().splitlines()]
709
+
710
+ drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
711
+ drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
712
+ fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
713
+
714
+ akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
715
+ akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
716
+
717
+ self.G.eval()
718
+ #self.D.eval()
719
+ self.G2.eval()
720
+ #self.D2.eval()
721
+
722
+ self.inf_batch_size =256
723
+ self.inf_dataset = DruggenDataset(self.mol_data_dir,
724
+ self.inf_dataset_file,
725
+ self.inf_raw_file,
726
+ self.max_atom,
727
+ self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class.
728
+ # Can create any molecular graph dataset given smiles string.
729
+ # Nonisomeric SMILES are suggested but not necessary.
730
+ # Uses sparse matrix representation for graphs,
731
+ # For computational and speed efficiency.
732
+
733
+ self.inf_loader = DataLoader(self.inf_dataset,
734
+ shuffle=True,
735
+ batch_size=self.inf_batch_size,
736
+ drop_last=True) # PyG dataloader for the first GAN.
737
+
738
+ self.inf_drugs = DruggenDataset(self.drug_data_dir,
739
+ self.inf_drugs_dataset_file,
740
+ self.inf_drug_raw_file,
741
+ self.max_atom,
742
+ self.features) # Dataset for the second GAN. Custom dataset class from PyG parent class.
743
+ # Can create any molecular graph dataset given smiles string.
744
+ # Nonisomeric SMILES are suggested but not necessary.
745
+ # Uses sparse matrix representation for graphs,
746
+ # For computational and speed efficiency.
747
+
748
+ self.inf_drugs_loader = DataLoader(self.inf_drugs,
749
+ shuffle=True,
750
+ batch_size=self.inf_batch_size,
751
+ drop_last=True) # PyG dataloader for the second GAN.
752
+ start_time = time.time()
753
+ #metric_calc_mol = []
754
+ metric_calc_dr = []
755
+ date = time.time()
756
+ if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
757
+ os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
758
+ with torch.inference_mode():
759
+
760
+ dataloader_iterator = iter(self.drugs_loader)
761
+
762
+ for i, data in enumerate(self.loader):
763
+ try:
764
+ drugs = next(dataloader_iterator)
765
+ except StopIteration:
766
+ dataloader_iterator = iter(self.drugs_loader)
767
+ drugs = next(dataloader_iterator)
768
+
769
+ # Preprocess both dataset
770
+
771
+ bulk_data = load_data(data,
772
+ drugs,
773
+ self.batch_size,
774
+ self.device,
775
+ self.b_dim,
776
+ self.m_dim,
777
+ self.drugs_b_dim,
778
+ self.drugs_m_dim,
779
+ self.z_dim,
780
+ self.vertexes)
781
+
782
+ drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
783
+
784
+ if self.submodel == "CrossLoss":
785
+ GAN1_input_e = a_tensor
786
+ GAN1_input_x = x_tensor
787
+ GAN1_disc_e = drugs_a_tensor
788
+ GAN1_disc_x = drugs_x_tensor
789
+ GAN2_input_e = drugs_a_tensor
790
+ GAN2_input_x = drugs_x_tensor
791
+ GAN2_disc_e = a_tensor
792
+ GAN2_disc_x = x_tensor
793
+ elif self.submodel == "Ligand":
794
+ GAN1_input_e = a_tensor
795
+ GAN1_input_x = x_tensor
796
+ GAN1_disc_e = a_tensor
797
+ GAN1_disc_x = x_tensor
798
+ GAN2_input_e = drugs_a_tensor
799
+ GAN2_input_x = drugs_x_tensor
800
+ GAN2_disc_e = drugs_a_tensor
801
+ GAN2_disc_x = drugs_x_tensor
802
+ elif self.submodel == "Prot":
803
+ GAN1_input_e = a_tensor
804
+ GAN1_input_x = x_tensor
805
+ GAN1_disc_e = a_tensor
806
+ GAN1_disc_x = x_tensor
807
+ GAN2_input_e = akt1_human_adj
808
+ GAN2_input_x = akt1_human_annot
809
+ GAN2_disc_e = drugs_a_tensor
810
+ GAN2_disc_x = drugs_x_tensor
811
+ elif self.submodel == "RL":
812
+ GAN1_input_e = z_edge
813
+ GAN1_input_x = z_node
814
+ GAN1_disc_e = a_tensor
815
+ GAN1_disc_x = x_tensor
816
+ GAN2_input_e = drugs_a_tensor
817
+ GAN2_input_x = drugs_x_tensor
818
+ GAN2_disc_e = drugs_a_tensor
819
+ GAN2_disc_x = drugs_x_tensor
820
+ elif self.submodel == "NoTarget":
821
+ GAN1_input_e = z_edge
822
+ GAN1_input_x = z_node
823
+ GAN1_disc_e = a_tensor
824
+ GAN1_disc_x = x_tensor
825
+ # =================================================================================== #
826
+ # 2. GAN1 Inference #
827
+ # =================================================================================== #
828
+ generator_output = generator_loss(self.G,
829
+ self.D,
830
+ self.V,
831
+ GAN1_input_e,
832
+ GAN1_input_x,
833
+ self.batch_size,
834
+ sim_reward,
835
+ self.dataset.matrices2mol_drugs,
836
+ fps_r,
837
+ self.submodel)
838
+
839
+ _, fake_mol, _, _, node, edge = generator_output
840
+
841
+ # =================================================================================== #
842
+ # 3. GAN2 Inference #
843
+ # =================================================================================== #
844
+
845
+ output = generator2_loss(self.G2,
846
+ self.D2,
847
+ self.V2,
848
+ edge,
849
+ node,
850
+ self.batch_size,
851
+ sim_reward,
852
+ self.dataset.matrices2mol_drugs,
853
+ fps_r,
854
+ GAN2_input_e,
855
+ GAN2_input_x,
856
+ self.submodel)
857
+
858
+ _, fake_mol_g, _, _ = output
859
+
860
+ inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
861
+
862
+
863
+
864
+ #inference_smiles = [Chem.MolToSmiles(line) for line in fake_mol]
865
+
866
+
867
+
868
+ print("molecule batch {} inferred".format(i))
869
+
870
+ with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
871
+ for molecules in inference_drugs:
872
+
873
+ f.write(molecules)
874
+ f.write("\n")
875
+ metric_calc_dr.append(molecules)
876
+
877
+
878
+
879
+ if i == 120:
880
+ break
881
+
882
+ et = time.time() - start_time
883
+
884
+ print("Inference mode is lasted for {:.2f} seconds".format(et))
885
+
886
+ print("Metrics calculation started using MOSES.")
887
+
888
+ print("Validity: ", fraction_valid(inference_drugs), "\n")
889
+ print("Uniqueness: ", fraction_unique(inference_drugs), "\n")
890
+ print("Validity: ", novelty(inference_drugs, drug_smiles), "\n")
891
+
892
+ print("Metrics are calculated.")
training_data.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch_geometric.utils as geoutils
3
+ from utils import *
4
+
5
+ def load_data(data, drugs, batch_size, device, b_dim, m_dim, drugs_b_dim, drugs_m_dim,z_dim,vertexes):
6
+
7
+ z = sample_z(batch_size, z_dim) # (batch,max_len)
8
+
9
+ z = torch.from_numpy(z).to(device).float().requires_grad_(True)
10
+ data = data.to(device)
11
+ drugs = drugs.to(device)
12
+ z_e = sample_z_edge(batch_size,vertexes,b_dim) # (batch,max_len,max_len)
13
+ z_n = sample_z_node(batch_size,vertexes,m_dim) # (batch,max_len)
14
+ z_edge = torch.from_numpy(z_e).to(device).float().requires_grad_(True) # Edge noise.(batch,max_len,max_len)
15
+ z_node = torch.from_numpy(z_n).to(device).float().requires_grad_(True) # Node noise.(batch,max_len)
16
+ a = geoutils.to_dense_adj(edge_index = data.edge_index,batch=data.batch,edge_attr=data.edge_attr, max_num_nodes=int(data.batch.shape[0]/batch_size))
17
+ x = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
18
+
19
+ a_tensor = label2onehot(a, b_dim, device)
20
+ #x_tensor = label2onehot(x, m_dim)
21
+ x_tensor = x
22
+
23
+ a_tensor = a_tensor #+ torch.randn([a_tensor.size(0), a_tensor.size(1), a_tensor.size(2),1], device=a_tensor.device) * noise_strength_0
24
+ x_tensor = x_tensor #+ torch.randn([x_tensor.size(0), x_tensor.size(1),1], device=x_tensor.device) * noise_strength_1
25
+
26
+ drugs_a = geoutils.to_dense_adj(edge_index = drugs.edge_index,batch=drugs.batch,edge_attr=drugs.edge_attr, max_num_nodes=int(drugs.batch.shape[0]/batch_size))
27
+
28
+ drugs_x = drugs.x.view(batch_size,int(drugs.batch.shape[0]/batch_size),-1)
29
+
30
+ drugs_a = drugs_a.to(device).long()
31
+ drugs_x = drugs_x.to(device)
32
+ drugs_a_tensor = label2onehot(drugs_a, drugs_b_dim,device).float()
33
+ drugs_x_tensor = drugs_x
34
+
35
+ drugs_a_tensor = drugs_a_tensor #+ torch.randn([drugs_a_tensor.size(0), drugs_a_tensor.size(1), drugs_a_tensor.size(2),1], device=drugs_a_tensor.device) * noise_strength_2
36
+ drugs_x_tensor = drugs_x_tensor #+ torch.randn([drugs_x_tensor.size(0), drugs_x_tensor.size(1),1], device=drugs_x_tensor.device) * noise_strength_3
37
+ #prot_n = akt1_human_annot[None,:].to(device).float()
38
+ #prot_e = akt1_human_adj[None,None,:].view(1,546,546,1).to(device).float()
39
+
40
+
41
+
42
+ a_tensor_vec = a_tensor.reshape(batch_size,-1)
43
+ x_tensor_vec = x_tensor.reshape(batch_size,-1)
44
+ real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
45
+
46
+ a_drug_vec = drugs_a_tensor.reshape(batch_size,-1)
47
+ x_drug_vec = drugs_x_tensor.reshape(batch_size,-1)
48
+ drug_graphs = torch.concat((x_drug_vec,a_drug_vec),dim=-1)
49
+
50
+ return drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node
utils.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from statistics import mean
2
+ from rdkit import DataStructs
3
+ from rdkit import Chem
4
+ from rdkit.Chem import AllChem
5
+ from rdkit.Chem import Draw
6
+ import os
7
+ import numpy as np
8
+ import seaborn as sns
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.lines import Line2D
11
+ from rdkit import RDLogger
12
+ import torch
13
+ from rdkit.Chem.Scaffolds import MurckoScaffold
14
+ import math
15
+ import time
16
+ import datetime
17
+ import re
18
+ RDLogger.DisableLog('rdApp.*')
19
+ import warnings
20
+ from multiprocessing import Pool
21
+ class Metrics(object):
22
+
23
+ @staticmethod
24
+ def valid(x):
25
+ return x is not None and Chem.MolToSmiles(x) != ''
26
+
27
+ @staticmethod
28
+ def tanimoto_sim_1v2(data1, data2):
29
+ min_len = data1.size if data1.size > data2.size else data2
30
+ sims = []
31
+ for i in range(min_len):
32
+ sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
33
+ sims.append(sim)
34
+ mean_sim = mean(sim)
35
+ return mean_sim
36
+
37
+ @staticmethod
38
+ def mol_length(x):
39
+ if x is not None:
40
+ return len([char for char in max(Chem.MolToSmiles(x).split(sep =".")).upper() if char.isalpha()])
41
+ else:
42
+ return 0
43
+
44
+ @staticmethod
45
+ def max_component(data, max_len):
46
+
47
+ return (np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean()
48
+
49
+ def sim_reward(mol_gen, fps_r):
50
+
51
+ gen_scaf = []
52
+
53
+ for x in mol_gen:
54
+ if x is not None:
55
+ try:
56
+
57
+ gen_scaf.append(MurckoScaffold.GetScaffoldForMol(x))
58
+ except:
59
+ pass
60
+
61
+ if len(gen_scaf) == 0:
62
+
63
+ rew = 1
64
+ else:
65
+ fps = [Chem.RDKFingerprint(x) for x in gen_scaf]
66
+
67
+
68
+ fps = np.array(fps)
69
+ fps_r = np.array(fps_r)
70
+
71
+ rew = average_agg_tanimoto(fps_r,fps)[0]
72
+ if math.isnan(rew):
73
+ rew = 1
74
+
75
+ return rew ## change this to penalty
76
+
77
+ ##########################################
78
+ ##########################################
79
+ ##########################################
80
+
81
+ def mols2grid_image(mols,path):
82
+ mols = [e if e is not None else Chem.RWMol() for e in mols]
83
+
84
+ for i in range(len(mols)):
85
+ if Metrics.valid(mols[i]):
86
+ #if Chem.MolToSmiles(mols[i]) != '':
87
+ AllChem.Compute2DCoords(mols[i])
88
+ Draw.MolToFile(mols[i], os.path.join(path,"{}.png".format(i+1)), size=(1200,1200))
89
+ else:
90
+ continue
91
+
92
+ def save_smiles_matrices(mols,edges_hard, nodes_hard,path,data_source = None):
93
+ mols = [e if e is not None else Chem.RWMol() for e in mols]
94
+
95
+ for i in range(len(mols)):
96
+ if Metrics.valid(mols[i]):
97
+ #m0= all_scores_for_print(mols[i], data_source, norm=False)
98
+ #if Chem.MolToSmiles(mols[i]) != '':
99
+ save_path = os.path.join(path,"{}.txt".format(i+1))
100
+ with open(save_path, "a") as f:
101
+ np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n",fmt='%1.2f')
102
+ f.write("\n")
103
+ np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:",fmt='%1.2f')
104
+ f.write("\n")
105
+ #f.write(m0)
106
+ f.write("\n")
107
+
108
+
109
+ print(Chem.MolToSmiles(mols[i]), file=open(save_path,"a"))
110
+ else:
111
+ continue
112
+
113
+ ##########################################
114
+ ##########################################
115
+ ##########################################
116
+
117
+ def dense_to_sparse_with_attr(adj):
118
+ ###
119
+ assert adj.dim() >= 2 and adj.dim() <= 3
120
+ assert adj.size(-1) == adj.size(-2)
121
+
122
+ index = adj.nonzero(as_tuple=True)
123
+ edge_attr = adj[index]
124
+
125
+ if len(index) == 3:
126
+ batch = index[0] * adj.size(-1)
127
+ index = (batch + index[1], batch + index[2])
128
+ #index = torch.stack(index, dim=0)
129
+ return index, edge_attr
130
+
131
+
132
+ def label2onehot(labels, dim, device):
133
+
134
+ """Convert label indices to one-hot vectors."""
135
+
136
+ out = torch.zeros(list(labels.size())+[dim]).to(device)
137
+ out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
138
+
139
+ return out.float()
140
+
141
+
142
+ def sample_z_node(batch_size, vertexes, nodes):
143
+
144
+ ''' Random noise for nodes logits. '''
145
+
146
+ return np.random.normal(0,1, size=(batch_size,vertexes, nodes)) # 128, 9, 5
147
+
148
+
149
+ def sample_z_edge(batch_size, vertexes, edges):
150
+
151
+ ''' Random noise for edges logits. '''
152
+
153
+ return np.random.normal(0,1, size=(batch_size, vertexes, vertexes, edges)) # 128, 9, 9, 5
154
+
155
+ def sample_z( batch_size, z_dim):
156
+
157
+ ''' Random noise. '''
158
+
159
+ return np.random.normal(0,1, size=(batch_size,z_dim)) # 128, 9, 5
160
+
161
+
162
+ def mol_sample(sample_directory, model_name, mol, edges, nodes, idx, i):
163
+ sample_path = os.path.join(sample_directory,"{}-{}_{}-epoch_iteration".format(model_name,idx+1, i+1))
164
+
165
+ if not os.path.exists(sample_path):
166
+ os.makedirs(sample_path)
167
+
168
+ mols2grid_image(mol,sample_path)
169
+
170
+ save_smiles_matrices(mol,edges.detach(), nodes.detach(), sample_path)
171
+
172
+ if len(os.listdir(sample_path)) == 0:
173
+ os.rmdir(sample_path)
174
+
175
+ print("Valid molecules are saved.")
176
+ print("Valid matrices and smiles are saved")
177
+
178
+
179
+
180
+
181
+
182
+ def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, save_path):
183
+
184
+ gen_smiles = []
185
+ for line in mols:
186
+ if line is not None:
187
+ gen_smiles.append(Chem.MolToSmiles(line))
188
+ elif line is None:
189
+ gen_smiles.append(None)
190
+
191
+ #gen_smiles_saves = [None if x is None else re.sub('\*', '', x) for x in gen_smiles]
192
+ #gen_smiles_saves = [None if x is None else re.sub('\.', '', x) for x in gen_smiles_saves]
193
+ gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
194
+
195
+ sample_save_dir = os.path.join(save_path, "samples-GAN{}.txt".format(model_num))
196
+ with open(sample_save_dir, "a") as f:
197
+ for idxs in range(len(gen_smiles_saves)):
198
+ if gen_smiles_saves[idxs] is not None:
199
+
200
+ f.write(gen_smiles_saves[idxs])
201
+ f.write("\n")
202
+
203
+ k = len(set(gen_smiles_saves) - {None})
204
+
205
+
206
+ et = time.time() - start_time
207
+ et = str(datetime.timedelta(seconds=et))[:-7]
208
+ log = "Elapsed [{}], Epoch/Iteration [{}/{}] for GAN{}".format(et, idx, i+1, model_num)
209
+
210
+ # Log update
211
+ #m0 = get_all_metrics(gen = gen_smiles, train = train_smiles, batch_size=batch_size, k = valid_mol_num, device=self.device)
212
+ valid = fraction_valid(gen_smiles_saves)
213
+ unique = fraction_unique(gen_smiles_saves, k, check_validity=False)
214
+ novel = novelty(gen_smiles_saves, train_smiles)
215
+
216
+ #qed = [QED(mol) for mol in mols if mol is not None]
217
+ #sa = [SA(mol) for mol in mols if mol is not None]
218
+ #logp = [logP(mol) for mol in mols if mol is not None]
219
+
220
+ #IntDiv = internal_diversity(gen_smiles)
221
+ #m0= all_scores_val(fake_mol, mols, full_mols, full_smiles, vert, norm=True) # 'mols' is output of Fake Reward
222
+ #m1 =all_scores_chem(fake_mol, mols, vert, norm=True)
223
+ #m0.update(m1)
224
+
225
+ #maxlen = MolecularMetrics.max_component(mols, 45)
226
+
227
+ #m0 = {k: np.array(v).mean() for k, v in m0.items()}
228
+ #loss.update(m0)
229
+ loss.update({'Valid': valid})
230
+ loss.update({'Unique@{}'.format(k): unique})
231
+ loss.update({'Novel': novel})
232
+ #loss.update({'QED': statistics.mean(qed)})
233
+ #loss.update({'SA': statistics.mean(sa)})
234
+ #loss.update({'LogP': statistics.mean(logp)})
235
+ #loss.update({'IntDiv': IntDiv})
236
+
237
+ #wandb.log({"maxlen": maxlen})
238
+
239
+ for tag, value in loss.items():
240
+
241
+ log += ", {}: {:.4f}".format(tag, value)
242
+ with open(log_path, "a") as f:
243
+ f.write(log)
244
+ f.write("\n")
245
+ print(log)
246
+ print("\n")
247
+
248
+
249
+
250
+ def plot_attn(dataset_name, heads,attn_w, model, iter, epoch):
251
+
252
+ cols = 4
253
+ rows = int(heads/cols)
254
+
255
+ fig, axes = plt.subplots( rows,cols, figsize = (30, 14))
256
+ axes = axes.flat
257
+ attentions_pos = attn_w[0]
258
+ attentions_pos = attentions_pos.cpu().detach().numpy()
259
+ for i,att in enumerate(attentions_pos):
260
+
261
+ #im = axes[i].imshow(att, cmap='gray')
262
+ sns.heatmap(att,vmin = 0, vmax = 1,ax = axes[i])
263
+ axes[i].set_title(f'head - {i} ')
264
+ axes[i].set_ylabel('layers')
265
+ pltsavedir = "/home/atabey/attn/second"
266
+ plt.savefig(os.path.join(pltsavedir, "attn" + model + "_" + dataset_name + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
267
+
268
+
269
+ def plot_grad_flow(named_parameters, model, iter, epoch):
270
+
271
+ # Based on https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10
272
+ '''Plots the gradients flowing through different layers in the net during training.
273
+ Can be used for checking for possible gradient vanishing / exploding problems.
274
+
275
+ Usage: Plug this function in Trainer class after loss.backwards() as
276
+ "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
277
+ ave_grads = []
278
+ max_grads= []
279
+ layers = []
280
+ for n, p in named_parameters:
281
+ if(p.requires_grad) and ("bias" not in n):
282
+ print(p.grad,n)
283
+ layers.append(n)
284
+ ave_grads.append(p.grad.abs().mean().cpu())
285
+ max_grads.append(p.grad.abs().max().cpu())
286
+ plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
287
+ plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
288
+ plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
289
+ plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
290
+ plt.xlim(left=0, right=len(ave_grads))
291
+ plt.ylim(bottom = -0.001, top=1) # zoom in on the lower gradient regions
292
+ plt.xlabel("Layers")
293
+ plt.ylabel("average gradient")
294
+ plt.title("Gradient flow")
295
+ plt.grid(True)
296
+ plt.legend([Line2D([0], [0], color="c", lw=4),
297
+ Line2D([0], [0], color="b", lw=4),
298
+ Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
299
+ pltsavedir = "/home/atabey/gradients/tryout"
300
+ plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
301
+
302
+ """
303
+ def _genDegree():
304
+
305
+ ''' Generates the Degree distribution tensor for PNA, should be used everytime a different
306
+ dataset is used.
307
+ Can be called without arguments and saves the tensor for later use. If tensor was created
308
+ before, it just loads the degree tensor.
309
+ '''
310
+
311
+ degree_path = os.path.join(self.degree_dir, self.dataset_name + '-degree.pt')
312
+ if not os.path.exists(degree_path):
313
+
314
+
315
+ max_degree = -1
316
+ for data in self.dataset:
317
+ d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
318
+ max_degree = max(max_degree, int(d.max()))
319
+
320
+ # Compute the in-degree histogram tensor
321
+ deg = torch.zeros(max_degree + 1, dtype=torch.long)
322
+ for data in self.dataset:
323
+ d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
324
+ deg += torch.bincount(d, minlength=deg.numel())
325
+ torch.save(deg, 'DrugGEN/data/' + self.dataset_name + '-degree.pt')
326
+ else:
327
+ deg = torch.load(degree_path, map_location=lambda storage, loc: storage)
328
+
329
+ return deg
330
+ """
331
+ def get_mol(smiles_or_mol):
332
+ '''
333
+ Loads SMILES/molecule into RDKit's object
334
+ '''
335
+ if isinstance(smiles_or_mol, str):
336
+ if len(smiles_or_mol) == 0:
337
+ return None
338
+ mol = Chem.MolFromSmiles(smiles_or_mol)
339
+ if mol is None:
340
+ return None
341
+ try:
342
+ Chem.SanitizeMol(mol)
343
+ except ValueError:
344
+ return None
345
+ return mol
346
+ return smiles_or_mol
347
+
348
+ def mapper(n_jobs):
349
+ '''
350
+ Returns function for map call.
351
+ If n_jobs == 1, will use standard map
352
+ If n_jobs > 1, will use multiprocessing pool
353
+ If n_jobs is a pool object, will return its map function
354
+ '''
355
+ if n_jobs == 1:
356
+ def _mapper(*args, **kwargs):
357
+ return list(map(*args, **kwargs))
358
+
359
+ return _mapper
360
+ if isinstance(n_jobs, int):
361
+ pool = Pool(n_jobs)
362
+
363
+ def _mapper(*args, **kwargs):
364
+ try:
365
+ result = pool.map(*args, **kwargs)
366
+ finally:
367
+ pool.terminate()
368
+ return result
369
+
370
+ return _mapper
371
+ return n_jobs.map
372
+ def remove_invalid(gen, canonize=True, n_jobs=1):
373
+ """
374
+ Removes invalid molecules from the dataset
375
+ """
376
+ if not canonize:
377
+ mols = mapper(n_jobs)(get_mol, gen)
378
+ return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
379
+ return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
380
+ x is not None]
381
+ def fraction_valid(gen, n_jobs=1):
382
+ """
383
+ Computes a number of valid molecules
384
+ Parameters:
385
+ gen: list of SMILES
386
+ n_jobs: number of threads for calculation
387
+ """
388
+ gen = mapper(n_jobs)(get_mol, gen)
389
+ return 1 - gen.count(None) / len(gen)
390
+ def canonic_smiles(smiles_or_mol):
391
+ mol = get_mol(smiles_or_mol)
392
+ if mol is None:
393
+ return None
394
+ return Chem.MolToSmiles(mol)
395
+ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
396
+ """
397
+ Computes a number of unique molecules
398
+ Parameters:
399
+ gen: list of SMILES
400
+ k: compute unique@k
401
+ n_jobs: number of threads for calculation
402
+ check_validity: raises ValueError if invalid molecules are present
403
+ """
404
+ if k is not None:
405
+ if len(gen) < k:
406
+ warnings.warn(
407
+ "Can't compute unique@{}.".format(k) +
408
+ "gen contains only {} molecules".format(len(gen))
409
+ )
410
+ gen = gen[:k]
411
+ canonic = set(mapper(n_jobs)(canonic_smiles, gen))
412
+ if None in canonic and check_validity:
413
+ raise ValueError("Invalid molecule passed to unique@k")
414
+ return 0 if len(gen) == 0 else len(canonic) / len(gen)
415
+
416
+ def novelty(gen, train, n_jobs=1):
417
+ gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
418
+ gen_smiles_set = set(gen_smiles) - {None}
419
+ train_set = set(train)
420
+ return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
421
+
422
+
423
+
424
+ def average_agg_tanimoto(stock_vecs, gen_vecs,
425
+ batch_size=5000, agg='max',
426
+ device='cpu', p=1):
427
+ """
428
+ For each molecule in gen_vecs finds closest molecule in stock_vecs.
429
+ Returns average tanimoto score for between these molecules
430
+
431
+ Parameters:
432
+ stock_vecs: numpy array <n_vectors x dim>
433
+ gen_vecs: numpy array <n_vectors' x dim>
434
+ agg: max or mean
435
+ p: power for averaging: (mean x^p)^(1/p)
436
+ """
437
+ assert agg in ['max', 'mean'], "Can aggregate only max or mean"
438
+ agg_tanimoto = np.zeros(len(gen_vecs))
439
+ total = np.zeros(len(gen_vecs))
440
+ for j in range(0, stock_vecs.shape[0], batch_size):
441
+ x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
442
+ for i in range(0, gen_vecs.shape[0], batch_size):
443
+
444
+ y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
445
+ y_gen = y_gen.transpose(0, 1)
446
+ tp = torch.mm(x_stock, y_gen)
447
+ jac = (tp / (x_stock.sum(1, keepdim=True) +
448
+ y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
449
+ jac[np.isnan(jac)] = 1
450
+ if p != 1:
451
+ jac = jac**p
452
+ if agg == 'max':
453
+ agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
454
+ agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
455
+ elif agg == 'mean':
456
+ agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
457
+ total[i:i + y_gen.shape[1]] += jac.shape[0]
458
+ if agg == 'mean':
459
+ agg_tanimoto /= total
460
+ if p != 1:
461
+ agg_tanimoto = (agg_tanimoto)**(1/p)
462
+ return np.mean(agg_tanimoto)