mgyigit commited on
Commit
4e04e76
·
verified ·
1 Parent(s): b29080a

Upload 7 files

Browse files
Files changed (6) hide show
  1. inference.py +61 -37
  2. layers.py +2 -2
  3. loss.py +39 -15
  4. models.py +56 -6
  5. smiles_cor.py +1291 -0
  6. utils.py +18 -7
inference.py CHANGED
@@ -4,19 +4,20 @@ import pickle
4
  import random
5
  from tqdm import tqdm
6
  import argparse
7
-
8
  import torch
9
  from torch_geometric.loader import DataLoader
10
  import torch.utils.data
11
  from rdkit import RDLogger
12
  torch.set_num_threads(5)
13
  RDLogger.DisableLog('rdApp.*')
14
-
15
  from utils import *
16
  from models import Generator
17
  from new_dataloader import DruggenDataset
18
  from loss import generator_loss
19
  from training_data import load_molecules
 
20
 
21
 
22
  class Inference(object):
@@ -43,6 +44,7 @@ class Inference(object):
43
 
44
  self.inference_model = config.inference_model
45
  self.sample_num = config.sample_num
 
46
 
47
  # Data loader.
48
  self.inf_raw_file = config.inf_raw_file # SMILES containing text file for first dataset.
@@ -103,8 +105,7 @@ class Inference(object):
103
  dim=self.dim,
104
  depth=self.depth,
105
  heads=self.heads,
106
- mlp_ratio=self.mlp_ratio,
107
- submodel = self.submodel)
108
 
109
  self.print_network(self.G, 'G')
110
 
@@ -113,7 +114,7 @@ class Inference(object):
113
 
114
  def decoder_load(self, dictionary_name):
115
  ''' Loading the atom and bond decoders'''
116
- with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
117
  return pickle.load(f)
118
 
119
 
@@ -139,18 +140,25 @@ class Inference(object):
139
  self.restore_model(self.submodel, self.inference_model)
140
 
141
  # smiles data for metrics calculation.
142
- chembl_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
143
- chembl_test = [line for line in open("data/chembl_test.smi", 'r').read().splitlines()]
144
- drug_smiles = [line for line in open("data/akt_inhibitors.smi", 'r').read().splitlines()]
145
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
146
  drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
147
 
148
 
149
  # Make directories if not exist.
150
- if not os.path.exists("experiments/inference/{}".format(self.submodel)):
151
- os.makedirs("experiments/inference/{}".format(self.submodel))
152
-
153
-
 
 
 
 
 
 
 
154
  self.G.eval()
155
 
156
  start_time = time.time()
@@ -158,7 +166,9 @@ class Inference(object):
158
  uniqueness_calc = []
159
  real_smiles_snn = []
160
  nodes_sample = torch.Tensor(size=[1,45,1]).to(self.device)
161
-
 
 
162
  val_counter = 0
163
  none_counter = 0
164
  # Inference mode
@@ -182,7 +192,7 @@ class Inference(object):
182
  g_edges_hat_sample = torch.max(edge_sample, -1)[1]
183
  g_nodes_hat_sample = torch.max(node_sample, -1)[1]
184
 
185
- fake_mol_g = [self.inf_dataset.matrices2mol_drugs(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name)
186
  for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
187
 
188
  a_tensor_sample = torch.max(a_tensor, -1)[1]
@@ -197,34 +207,47 @@ class Inference(object):
197
  if molecules is None:
198
  none_counter += 1
199
 
200
- with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
201
- for molecules in inference_drugs:
202
- if molecules is not None:
203
- molecules = molecules.replace("*", "C")
204
- f.write(molecules)
205
- f.write("\n")
206
- uniqueness_calc.append(molecules)
207
- nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1,45,1)), 0)
208
- pbar.update(1)
209
- metric_calc_dr.append(molecules)
210
-
211
 
 
212
  generation_number = len([x for x in metric_calc_dr if x is not None])
213
  if generation_number == self.sample_num or none_counter == self.sample_num:
214
  break
215
- real_smiles_snn.append(real_mols[0])
216
-
 
 
 
 
 
 
 
 
 
 
217
  et = time.time() - start_time
218
- gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
219
- real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
220
- print("Inference mode is lasted for {:.2f} seconds".format(et))
221
 
222
- print("Metrics calculation started using MOSES.")
223
- # post-process * to Carbon atom in valid molecules
 
 
 
 
 
 
 
224
 
225
  return{
226
  "Runtime (seconds)": f"{et:.2f}",
227
- "Validity": f"{fraction_valid(metric_calc_dr):.2f}",
228
  "Uniqueness": f"{fraction_unique(uniqueness_calc):.2f}",
229
  "Novelty (Train)": f"{novelty(metric_calc_dr, chembl_smiles):.2f}",
230
  "Novelty (Inference)": f"{novelty(metric_calc_dr, chembl_test):.2f}",
@@ -237,13 +260,14 @@ if __name__=="__main__":
237
  # Inference configuration.
238
  parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
239
  parser.add_argument('--inference_model', type=str, help="Path to the model for inference")
240
- parser.add_argument('--sample_num', type=int, default=10000, help='inference samples')
241
-
 
242
  # Data configuration.
243
  parser.add_argument('--inf_dataset_file', type=str, default='chembl45_test.pt')
244
- parser.add_argument('--inf_raw_file', type=str, default='data/chembl_test.smi')
245
  parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
246
- parser.add_argument('--mol_data_dir', type=str, default='data')
247
  parser.add_argument('--features', type=str2bool, default=False, help='features dimension for nodes')
248
 
249
  # Model configuration.
 
4
  import random
5
  from tqdm import tqdm
6
  import argparse
7
+ import pandas as pd
8
  import torch
9
  from torch_geometric.loader import DataLoader
10
  import torch.utils.data
11
  from rdkit import RDLogger
12
  torch.set_num_threads(5)
13
  RDLogger.DisableLog('rdApp.*')
14
+ from rdkit.Chem import QED
15
  from utils import *
16
  from models import Generator
17
  from new_dataloader import DruggenDataset
18
  from loss import generator_loss
19
  from training_data import load_molecules
20
+ from smiles_cor import smi_correct
21
 
22
 
23
  class Inference(object):
 
44
 
45
  self.inference_model = config.inference_model
46
  self.sample_num = config.sample_num
47
+ self.correct = config.correct
48
 
49
  # Data loader.
50
  self.inf_raw_file = config.inf_raw_file # SMILES containing text file for first dataset.
 
105
  dim=self.dim,
106
  depth=self.depth,
107
  heads=self.heads,
108
+ mlp_ratio=self.mlp_ratio)
 
109
 
110
  self.print_network(self.G, 'G')
111
 
 
114
 
115
  def decoder_load(self, dictionary_name):
116
  ''' Loading the atom and bond decoders'''
117
+ with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
118
  return pickle.load(f)
119
 
120
 
 
140
  self.restore_model(self.submodel, self.inference_model)
141
 
142
  # smiles data for metrics calculation.
143
+ chembl_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
144
+ chembl_test = [line for line in open("DrugGEN/data/chembl_test.smi", 'r').read().splitlines()]
145
+ drug_smiles = [line for line in open("DrugGEN/data/akt_inhibitors.smi", 'r').read().splitlines()]
146
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
147
  drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
148
 
149
 
150
  # Make directories if not exist.
151
+ if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
152
+ os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
153
+ if self.correct:
154
+ correct = smi_correct(self.submodel, "DrugGEN_/experiments/inference/{}".format(self.submodel))
155
+ search_res = pd.DataFrame(columns=["submodel", "validity",
156
+ "uniqueness", "novelty",
157
+ "novelty_test", "AKT_novelty",
158
+ "max_len", "mean_atom_type",
159
+ "snn_chembl", "snn_akt", "IntDiv", "qed"])
160
+
161
+
162
  self.G.eval()
163
 
164
  start_time = time.time()
 
166
  uniqueness_calc = []
167
  real_smiles_snn = []
168
  nodes_sample = torch.Tensor(size=[1,45,1]).to(self.device)
169
+ f = open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w")
170
+ f.write("SMILES")
171
+ f.write("\n")
172
  val_counter = 0
173
  none_counter = 0
174
  # Inference mode
 
192
  g_edges_hat_sample = torch.max(edge_sample, -1)[1]
193
  g_nodes_hat_sample = torch.max(node_sample, -1)[1]
194
 
195
+ fake_mol_g = [self.inf_dataset.matrices2mol_drugs(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=False, file_name=self.dataset_name)
196
  for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
197
 
198
  a_tensor_sample = torch.max(a_tensor, -1)[1]
 
207
  if molecules is None:
208
  none_counter += 1
209
 
210
+ for molecules in inference_drugs:
211
+ if molecules is not None:
212
+ molecules = molecules.replace("*", "C")
213
+ f.write(molecules)
214
+ f.write("\n")
215
+ uniqueness_calc.append(molecules)
216
+ nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1,45,1)), 0)
217
+ pbar.update(1)
218
+ metric_calc_dr.append(molecules)
 
 
219
 
220
+ real_smiles_snn.append(real_mols[0])
221
  generation_number = len([x for x in metric_calc_dr if x is not None])
222
  if generation_number == self.sample_num or none_counter == self.sample_num:
223
  break
224
+
225
+
226
+ f.close()
227
+ print("Inference completed, starting metrics calculation.")
228
+ if self.correct:
229
+ corrected = correct.correct("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel))
230
+ gen_smi = corrected["SMILES"].tolist()
231
+
232
+ else:
233
+ gen_smi = pd.read_csv("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel))["SMILES"].tolist()
234
+
235
+
236
  et = time.time() - start_time
 
 
 
237
 
238
+ with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w") as f:
239
+ for i in gen_smi:
240
+ f.write(i)
241
+ f.write("\n")
242
+
243
+ if self.correct:
244
+ val = round(len(gen_smi)/self.sample_num,3)
245
+ else:
246
+ val = round(fraction_valid(gen_smi),3)
247
 
248
  return{
249
  "Runtime (seconds)": f"{et:.2f}",
250
+ "Validity": str(val),
251
  "Uniqueness": f"{fraction_unique(uniqueness_calc):.2f}",
252
  "Novelty (Train)": f"{novelty(metric_calc_dr, chembl_smiles):.2f}",
253
  "Novelty (Inference)": f"{novelty(metric_calc_dr, chembl_test):.2f}",
 
260
  # Inference configuration.
261
  parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
262
  parser.add_argument('--inference_model', type=str, help="Path to the model for inference")
263
+ parser.add_argument('--sample_num', type=int, default=100, help='inference samples')
264
+ parser.add_argument('--correct', type=str2bool, default=False, help='Correct smiles')
265
+
266
  # Data configuration.
267
  parser.add_argument('--inf_dataset_file', type=str, default='chembl45_test.pt')
268
+ parser.add_argument('--inf_raw_file', type=str, default='DrugGEN/data/chembl_test.smi')
269
  parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
270
+ parser.add_argument('--mol_data_dir', type=str, default='DrugGEN/data')
271
  parser.add_argument('--features', type=str2bool, default=False, help='features dimension for nodes')
272
 
273
  # Model configuration.
layers.py CHANGED
@@ -82,7 +82,7 @@ class Encoder_Block(nn.Module):
82
 
83
  def forward(self, x, y):
84
  x1 = self.ln1(x)
85
- x2,y1 = self.attn(x1,y)
86
  x2 = x1 + x2
87
  y2 = y1 + y
88
  x2 = self.ln3(x2)
@@ -102,5 +102,5 @@ class TransformerEncoder(nn.Module):
102
 
103
  def forward(self, x, y):
104
  for Encoder_Block in self.Encoder_Blocks:
105
- x, y = Encoder_Block(x,y)
106
  return x, y
 
82
 
83
  def forward(self, x, y):
84
  x1 = self.ln1(x)
85
+ x2, y1 = self.attn(x1, y)
86
  x2 = x1 + x2
87
  y2 = y1 + y
88
  x2 = self.ln3(x2)
 
102
 
103
  def forward(self, x, y):
104
  for Encoder_Block in self.Encoder_Blocks:
105
+ x, y = Encoder_Block(x, y)
106
  return x, y
loss.py CHANGED
@@ -1,36 +1,60 @@
1
  import torch
2
 
3
 
4
- def discriminator_loss(generator, discriminator, mol_graph, batch_size, device, grad_pen, lambda_gp, z_edge, z_node):
5
  # Compute loss with real molecules.
6
- logits_real_disc = discriminator(mol_graph)
 
 
 
 
7
  prediction_real = - torch.mean(logits_real_disc)
8
 
9
  # Compute loss with fake molecules.
10
  node, edge, node_sample, edge_sample = generator(z_edge, z_node)
11
- graph = torch.cat((node_sample.view(batch_size, -1), edge_sample.view(batch_size, -1)), dim=-1)
12
- logits_fake_disc = discriminator(graph.detach())
 
 
 
 
13
  prediction_fake = torch.mean(logits_fake_disc)
14
 
15
- # Compute gradient loss.
16
- eps = torch.rand(mol_graph.size(0),1).to(device)
17
- x_int0 = (eps * mol_graph + (1. - eps) * graph).requires_grad_(True)
18
- grad0 = discriminator(x_int0)
19
- d_loss_gp = grad_pen(grad0, x_int0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Calculate total loss
22
- d_loss = prediction_fake + prediction_real + d_loss_gp * lambda_gp
23
  return node, edge, d_loss
24
 
25
 
26
- def generator_loss(generator, discriminator, adj, annot, batch_size):
27
  # Compute loss with fake molecules.
28
  node, edge, node_sample, edge_sample = generator(adj, annot)
 
 
 
 
 
29
 
30
- graph = torch.cat((node_sample.view(batch_size, -1), edge_sample.view(batch_size, -1)), dim=-1)
31
-
32
- logits_fake_disc = discriminator(graph)
33
  prediction_fake = - torch.mean(logits_fake_disc)
 
34
  g_loss = prediction_fake
35
 
36
  return g_loss, node, edge, node_sample, edge_sample
 
1
  import torch
2
 
3
 
4
+ def discriminator_loss(generator, discriminator, drug_edge, drug_node, batch_size, device, grad_pen, lambda_gp, z_edge, z_node, submodel):
5
  # Compute loss with real molecules.
6
+ if submodel == "DrugGEN":
7
+ logits_real_disc = discriminator(drug_edge, drug_node)
8
+ else:
9
+ logits_real_disc = discriminator(drug_node)
10
+
11
  prediction_real = - torch.mean(logits_real_disc)
12
 
13
  # Compute loss with fake molecules.
14
  node, edge, node_sample, edge_sample = generator(z_edge, z_node)
15
+ if submodel == "DrugGEN":
16
+ logits_fake_disc = discriminator(edge_sample, node_sample)
17
+ else:
18
+ graph = torch.cat((node_sample.view(batch_size, -1), edge_sample.view(batch_size, -1)), dim=-1)
19
+ logits_fake_disc = discriminator(graph.detach())
20
+
21
  prediction_fake = torch.mean(logits_fake_disc)
22
 
23
+ # Compute gradient penalty.
24
+ eps_edge = torch.rand(batch_size, 1, 1, 1, device=device) # Shape adapted for broadcasting with edges and nodes
25
+ eps_node = torch.rand(batch_size, 1, 1, device=device) # Shape adapted for broadcasting with edges and nodes
26
+ int_node = eps_node * drug_node + (1 - eps_node) * node_sample
27
+ int_edge = eps_edge * drug_edge + (1 - eps_edge) * edge_sample
28
+ int_node.requires_grad_(True)
29
+ int_edge.requires_grad_(True)
30
+
31
+ # Compute discriminator output for interpolated samples
32
+ if submodel == "DrugGEN":
33
+ logits_interpolated = discriminator(int_edge, int_node)
34
+ else:
35
+ graph = torch.cat((int_node.view(batch_size, -1), int_edge.view(batch_size, -1)), dim=-1)
36
+ logits_interpolated = discriminator(graph)
37
+
38
+ # Compute gradient penalty for nodes and edges
39
+ grad_penalty = grad_pen(logits_interpolated, int_node)
40
+
41
+ # Calculate total discriminator loss
42
+ d_loss = prediction_fake + prediction_real + lambda_gp * grad_penalty
43
 
 
 
44
  return node, edge, d_loss
45
 
46
 
47
+ def generator_loss(generator, discriminator, adj, annot, batch_size, submodel):
48
  # Compute loss with fake molecules.
49
  node, edge, node_sample, edge_sample = generator(adj, annot)
50
+ if submodel == "DrugGEN":
51
+ logits_fake_disc = discriminator(edge_sample, node_sample)
52
+ else:
53
+ graph = torch.cat((node_sample.view(batch_size, -1), edge_sample.view(batch_size, -1)), dim=-1)
54
+ logits_fake_disc = discriminator(graph)
55
 
 
 
 
56
  prediction_fake = - torch.mean(logits_fake_disc)
57
+
58
  g_loss = prediction_fake
59
 
60
  return g_loss, node, edge, node_sample, edge_sample
models.py CHANGED
@@ -5,9 +5,8 @@ from layers import TransformerEncoder
5
  class Generator(nn.Module):
6
  """Generator network."""
7
 
8
- def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio, submodel):
9
  super(Generator, self).__init__()
10
- self.submodel = submodel
11
  self.vertexes = vertexes
12
  self.edges = edges
13
  self.nodes = nodes
@@ -30,8 +29,8 @@ class Generator(nn.Module):
30
  self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
31
  self.pos_enc_dim = 5
32
 
33
- self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
34
- self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
35
  self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act,
36
  mlp_ratio=self.mlp_ratio, drop_rate=self.dropout)
37
 
@@ -63,12 +62,61 @@ class Generator(nn.Module):
63
  edge = self.edge_layers(z_e)
64
  edge = (edge + edge.permute(0, 2, 1, 3)) / 2
65
 
66
- node, edge = self.TransformerEncoder(node,edge)
67
 
68
  node_sample = self.readout_n(node)
69
  edge_sample = self.readout_e(edge)
 
70
  return node, edge, node_sample, edge_sample
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  class simple_disc(nn.Module):
74
  def __init__(self, act, m_dim, vertexes, b_dim):
@@ -82,6 +130,8 @@ class simple_disc(nn.Module):
82
  act = nn.Sigmoid()
83
  elif act == "tanh":
84
  act = nn.Tanh()
 
 
85
 
86
  features = vertexes * m_dim + vertexes * vertexes * b_dim
87
  self.predictor = nn.Sequential(nn.Linear(features,256), act, nn.Linear(256,128), act, nn.Linear(128,64), act,
@@ -90,4 +140,4 @@ class simple_disc(nn.Module):
90
 
91
  def forward(self, x):
92
  prediction = self.predictor(x)
93
- return prediction
 
5
  class Generator(nn.Module):
6
  """Generator network."""
7
 
8
+ def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
9
  super(Generator, self).__init__()
 
10
  self.vertexes = vertexes
11
  self.edges = edges
12
  self.nodes = nodes
 
29
  self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
30
  self.pos_enc_dim = 5
31
 
32
+ self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout))
33
+ self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout))
34
  self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act,
35
  mlp_ratio=self.mlp_ratio, drop_rate=self.dropout)
36
 
 
62
  edge = self.edge_layers(z_e)
63
  edge = (edge + edge.permute(0, 2, 1, 3)) / 2
64
 
65
+ node, edge = self.TransformerEncoder(node, edge)
66
 
67
  node_sample = self.readout_n(node)
68
  edge_sample = self.readout_e(edge)
69
+
70
  return node, edge, node_sample, edge_sample
71
 
72
+ class Discriminator(nn.Module):
73
+
74
+ def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
75
+ super(Discriminator, self).__init__()
76
+ self.vertexes = vertexes
77
+ self.edges = edges
78
+ self.nodes = nodes
79
+ self.depth = depth
80
+ self.dim = dim
81
+ self.heads = heads
82
+ self.mlp_ratio = mlp_ratio
83
+ self.dropout = dropout
84
+
85
+ if act == "relu":
86
+ act = nn.ReLU()
87
+ elif act == "leaky":
88
+ act = nn.LeakyReLU()
89
+ elif act == "sigmoid":
90
+ act = nn.Sigmoid()
91
+ elif act == "tanh":
92
+ act = nn.Tanh()
93
+
94
+ self.features = vertexes * vertexes * edges + vertexes * nodes
95
+ self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
96
+
97
+ self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout))
98
+ self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout))
99
+ self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act,
100
+ mlp_ratio=self.mlp_ratio, drop_rate=self.dropout)
101
+ self.node_features = vertexes * dim
102
+ self.edge_features = vertexes * vertexes * dim
103
+ self.node_mlp = nn.Sequential(nn.Linear(self.node_features, 64), act, nn.Linear(64, 32), act, nn.Linear(32, 16), act, nn.Linear(16, 1))
104
+
105
+ def forward(self, z_e, z_n):
106
+ b, n, c = z_n.shape
107
+ _, _, _ , d = z_e.shape
108
+
109
+ node = self.node_layers(z_n)
110
+ edge = self.edge_layers(z_e)
111
+ edge = (edge + edge.permute(0, 2, 1, 3)) / 2
112
+
113
+ node, edge = self.TransformerEncoder(node, edge)
114
+
115
+ node = node.view(b, -1)
116
+
117
+ prediction = self.node_mlp(node)
118
+
119
+ return prediction
120
 
121
  class simple_disc(nn.Module):
122
  def __init__(self, act, m_dim, vertexes, b_dim):
 
130
  act = nn.Sigmoid()
131
  elif act == "tanh":
132
  act = nn.Tanh()
133
+ else:
134
+ raise ValueError("Unsupported activation function: {}".format(act))
135
 
136
  features = vertexes * m_dim + vertexes * vertexes * b_dim
137
  self.predictor = nn.Sequential(nn.Linear(features,256), act, nn.Linear(256,128), act, nn.Linear(128,64), act,
 
140
 
141
  def forward(self, x):
142
  prediction = self.predictor(x)
143
+ return prediction
smiles_cor.py ADDED
@@ -0,0 +1,1291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import pandas as pd
4
+ import random
5
+ from chembl_structure_pipeline import standardizer
6
+ from rdkit.Chem import MolStandardize
7
+ from rdkit import Chem
8
+ import time
9
+ import torch
10
+ import torch.nn as nn
11
+ from torchtext.data import TabularDataset, Field, BucketIterator, Iterator
12
+ import random
13
+ import os
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.utils.data import DataLoader
17
+ import random
18
+ from torch import optim
19
+ import numpy as np
20
+ import itertools
21
+ import time
22
+ import statistics
23
+ from rdkit.Chem import GraphDescriptors, Lipinski, AllChem
24
+ from rdkit.Chem.rdSLNParse import MolFromSLN
25
+ from rdkit.Chem.rdmolfiles import MolFromSmiles
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.optim as optim
29
+ import pandas as pd
30
+ import numpy as np
31
+ from rdkit import rdBase, Chem
32
+ import re
33
+ from rdkit import RDLogger
34
+ RDLogger.DisableLog('rdApp.*')
35
+
36
+ SEED = 42
37
+ random.seed(SEED)
38
+ torch.manual_seed(SEED)
39
+ torch.backends.cudnn.deterministic = True
40
+
41
+ ##################################################################################################
42
+ ##################################################################################################
43
+ # #
44
+ #  THIS SCRIPT IS DIRECTLY ADAPTED FROM https://github.com/LindeSchoenmaker/SMILES-corrector #
45
+ # #
46
+ ##################################################################################################
47
+ ##################################################################################################
48
+ def is_smiles(array,
49
+ TRG,
50
+ reverse: bool,
51
+ return_output=False,
52
+ src=None,
53
+ src_field=None):
54
+ """Turns predicted tokens within batch into smiles and evaluates their validity
55
+ Arguments:
56
+ array: Tensor with most probable token for each location for each sequence in batch
57
+ [trg len, batch size]
58
+ TRG: target field for getting tokens from vocab
59
+ reverse (bool): True if the target sequence is reversed
60
+ return_output (bool): True if output sequences and their validity should be saved
61
+ Returns:
62
+ df: dataframe with correct and incorrect sequences
63
+ valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
64
+ smiless: list of the predicted smiles
65
+ """
66
+ trg_field = TRG
67
+ valids = []
68
+ smiless = []
69
+ if return_output:
70
+ df = pd.DataFrame()
71
+ else:
72
+ df = None
73
+ batch_size = array.size(1)
74
+ # check if the first token should be removed, first token is zero because
75
+ # outputs initaliazed to all be zeros
76
+ if int((array[0, 0]).tolist()) == 0:
77
+ start = 1
78
+ else:
79
+ start = 0
80
+ # for each sequence in the batch
81
+ for i in range(0, batch_size):
82
+ # turns sequence from tensor to list skipps first row as this is not
83
+ # filled in in forward
84
+ sequence = (array[start:, i]).tolist()
85
+ # goes from embedded to tokens
86
+ trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
87
+ # print(trg_tokens)
88
+ # takes all tokens untill eos token, model would be faster if did this
89
+ # one step earlier, but then changes in vocab order would disrupt.
90
+ rev_tokens = list(
91
+ itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
92
+ if reverse:
93
+ rev_tokens = rev_tokens[::-1]
94
+ smiles = "".join(rev_tokens)
95
+ # determine how many valid smiles are made
96
+ valid = True if MolFromSmiles(smiles) else False
97
+ valids.append(valid)
98
+ smiless.append(smiles)
99
+ if return_output:
100
+ if valid:
101
+ df.loc[i, "CORRECT"] = smiles
102
+ else:
103
+ df.loc[i, "INCORRECT"] = smiles
104
+
105
+ # add the original drugex outputs to the _de dataframe
106
+ if return_output and src is not None:
107
+ for i in range(0, batch_size):
108
+ # turns sequence from tensor to list skipps first row as this is
109
+ # <sos> for src
110
+ sequence = (src[1:, i]).tolist()
111
+ # goes from embedded to tokens
112
+ src_tokens = [src_field.vocab.itos[int(t)] for t in sequence]
113
+ # takes all tokens untill eos token, model would be faster if did
114
+ # this one step earlier, but then changes in vocab order would
115
+ # disrupt.
116
+ rev_tokens = list(
117
+ itertools.takewhile(lambda x: x != "<eos>", src_tokens))
118
+ smiles = "".join(rev_tokens)
119
+ df.loc[i, "ORIGINAL"] = smiles
120
+
121
+ return df, valids, smiless
122
+
123
+
124
+ def is_unchanged(array,
125
+ TRG,
126
+ reverse: bool,
127
+ return_output=False,
128
+ src=None,
129
+ src_field=None):
130
+ """Checks is output is different from input
131
+ Arguments:
132
+ array: Tensor with most probable token for each location for each sequence in batch
133
+ [trg len, batch size]
134
+ TRG: target field for getting tokens from vocab
135
+ reverse (bool): True if the target sequence is reversed
136
+ return_output (bool): True if output sequences and their validity should be saved
137
+ Returns:
138
+ df: dataframe with correct and incorrect sequences
139
+ valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
140
+ smiless: list of the predicted smiles
141
+ """
142
+ trg_field = TRG
143
+ sources = []
144
+ batch_size = array.size(1)
145
+ unchanged = 0
146
+
147
+ # check if the first token should be removed, first token is zero because
148
+ # outputs initaliazed to all be zeros
149
+ if int((array[0, 0]).tolist()) == 0:
150
+ start = 1
151
+ else:
152
+ start = 0
153
+
154
+ for i in range(0, batch_size):
155
+ # turns sequence from tensor to list skipps first row as this is <sos>
156
+ # for src
157
+ sequence = (src[1:, i]).tolist()
158
+ # goes from embedded to tokens
159
+ src_tokens = [src_field.vocab.itos[int(t)] for t in sequence]
160
+ # takes all tokens untill eos token, model would be faster if did this
161
+ # one step earlier, but then changes in vocab order would disrupt.
162
+ rev_tokens = list(
163
+ itertools.takewhile(lambda x: x != "<eos>", src_tokens))
164
+ smiles = "".join(rev_tokens)
165
+ sources.append(smiles)
166
+
167
+ # for each sequence in the batch
168
+ for i in range(0, batch_size):
169
+ # turns sequence from tensor to list skipps first row as this is not
170
+ # filled in in forward
171
+ sequence = (array[start:, i]).tolist()
172
+ # goes from embedded to tokens
173
+ trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
174
+ # print(trg_tokens)
175
+ # takes all tokens untill eos token, model would be faster if did this
176
+ # one step earlier, but then changes in vocab order would disrupt.
177
+ rev_tokens = list(
178
+ itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
179
+ if reverse:
180
+ rev_tokens = rev_tokens[::-1]
181
+ smiles = "".join(rev_tokens)
182
+ # determine how many valid smiles are made
183
+ valid = True if MolFromSmiles(smiles) else False
184
+ if not valid:
185
+ if smiles == sources[i]:
186
+ unchanged += 1
187
+
188
+ return unchanged
189
+
190
+
191
+ def molecule_reconstruction(array, TRG, reverse: bool, outputs):
192
+ """Turns target tokens within batch into smiles and compares them to predicted output smiles
193
+ Arguments:
194
+ array: Tensor with target's token for each location for each sequence in batch
195
+ [trg len, batch size]
196
+ TRG: target field for getting tokens from vocab
197
+ reverse (bool): True if the target sequence is reversed
198
+ outputs: list of predicted SMILES sequences
199
+ Returns:
200
+ matches(int): number of total right molecules
201
+ """
202
+ trg_field = TRG
203
+ matches = 0
204
+ targets = []
205
+ batch_size = array.size(1)
206
+ # for each sequence in the batch
207
+ for i in range(0, batch_size):
208
+ # turns sequence from tensor to list skipps first row as this is not
209
+ # filled in in forward
210
+ sequence = (array[1:, i]).tolist()
211
+ # goes from embedded to tokens
212
+ trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
213
+ # takes all tokens untill eos token, model would be faster if did this
214
+ # one step earlier, but then changes in vocab order would disrupt.
215
+ rev_tokens = list(
216
+ itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
217
+ if reverse:
218
+ rev_tokens = rev_tokens[::-1]
219
+ smiles = "".join(rev_tokens)
220
+ targets.append(smiles)
221
+ for i in range(0, batch_size):
222
+ m = MolFromSmiles(targets[i])
223
+ p = MolFromSmiles(outputs[i])
224
+ if p is not None:
225
+ if m.HasSubstructMatch(p) and p.HasSubstructMatch(m):
226
+ matches += 1
227
+ return matches
228
+
229
+
230
+ def complexity_whitlock(mol: Chem.Mol, includeAllDescs=False):
231
+ """
232
+ Complexity as defined in DOI:10.1021/jo9814546
233
+ S: complexity = 4*#rings + 2*#unsat + #hetatm + 2*#chiral
234
+ Other descriptors:
235
+ H: size = #bonds (Hydrogen atoms included)
236
+ G: S + H
237
+ Ratio: S / H
238
+ """
239
+ mol_ = Chem.Mol(mol)
240
+ nrings = Lipinski.RingCount(mol_) - Lipinski.NumAromaticRings(mol_)
241
+ Chem.rdmolops.SetAromaticity(mol_)
242
+ unsat = sum(1 for bond in mol_.GetBonds()
243
+ if bond.GetBondTypeAsDouble() == 2)
244
+ hetatm = len(mol_.GetSubstructMatches(Chem.MolFromSmarts("[!#6]")))
245
+ AllChem.EmbedMolecule(mol_)
246
+ Chem.AssignAtomChiralTagsFromStructure(mol_)
247
+ chiral = len(Chem.FindMolChiralCenters(mol_))
248
+ S = 4 * nrings + 2 * unsat + hetatm + 2 * chiral
249
+ if not includeAllDescs:
250
+ return S
251
+ Chem.rdmolops.Kekulize(mol_)
252
+ mol_ = Chem.AddHs(mol_)
253
+ H = sum(bond.GetBondTypeAsDouble() for bond in mol_.GetBonds())
254
+ G = S + H
255
+ R = S / H
256
+ return {"WhitlockS": S, "WhitlockH": H, "WhitlockG": G, "WhitlockRatio": R}
257
+
258
+
259
+ def complexity_baronechanon(mol: Chem.Mol):
260
+ """
261
+ Complexity as defined in DOI:10.1021/ci000145p
262
+ """
263
+ mol_ = Chem.Mol(mol)
264
+ Chem.Kekulize(mol_)
265
+ Chem.RemoveStereochemistry(mol_)
266
+ mol_ = Chem.RemoveHs(mol_, updateExplicitCount=True)
267
+ degree, counts = 0, 0
268
+ for atom in mol_.GetAtoms():
269
+ degree += 3 * 2**(atom.GetExplicitValence() - atom.GetNumExplicitHs() -
270
+ 1)
271
+ counts += 3 if atom.GetSymbol() == "C" else 6
272
+ ringterm = sum(map(lambda x: 6 * len(x), mol_.GetRingInfo().AtomRings()))
273
+ return degree + counts + ringterm
274
+
275
+
276
+ def calc_complexity(array,
277
+ TRG,
278
+ reverse,
279
+ valids,
280
+ complexity_function=GraphDescriptors.BertzCT):
281
+ """Calculates the complexity of inputs that are not correct.
282
+ Arguments:
283
+ array: Tensor with target's token for each location for each sequence in batch
284
+ [trg len, batch size]
285
+ TRG: target field for getting tokens from vocab
286
+ reverse (bool): True if the target sequence is reversed
287
+ valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
288
+ complexity_function: the type of complexity measure that will be used
289
+ GraphDescriptors.BertzCT
290
+ complexity_whitlock
291
+ complexity_baronechanon
292
+ Returns:
293
+ matches(int): mean of complexity values
294
+ """
295
+ trg_field = TRG
296
+ sources = []
297
+ complexities = []
298
+ loc = torch.BoolTensor(valids)
299
+ # only keeps rows in batch size dimension where valid is false
300
+ array = array[:, loc == False]
301
+ # should check if this still works
302
+ # array = torch.transpose(array, 0, 1)
303
+ array_size = array.size(1)
304
+ for i in range(0, array_size):
305
+ # turns sequence from tensor to list skipps first row as this is not
306
+ # filled in in forward
307
+ sequence = (array[1:, i]).tolist()
308
+ # goes from embedded to tokens
309
+ trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
310
+ # takes all tokens untill eos token, model would be faster if did this
311
+ # one step earlier, but then changes in vocab order would disrupt.
312
+ rev_tokens = list(
313
+ itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
314
+ if reverse:
315
+ rev_tokens = rev_tokens[::-1]
316
+ smiles = "".join(rev_tokens)
317
+ sources.append(smiles)
318
+ for source in sources:
319
+ try:
320
+ m = MolFromSmiles(source)
321
+ except BaseException:
322
+ m = MolFromSLN(source)
323
+ complexities.append(complexity_function(m))
324
+ if len(complexities) > 0:
325
+ mean = statistics.mean(complexities)
326
+ else:
327
+ mean = 0
328
+ return mean
329
+
330
+
331
+ def epoch_time(start_time, end_time):
332
+ elapsed_time = end_time - start_time
333
+ elapsed_mins = int(elapsed_time / 60)
334
+ elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
335
+ return elapsed_mins, elapsed_secs
336
+
337
+
338
+ class Convo:
339
+ """Class for training and evaluating transformer and convolutional neural network
340
+
341
+ Methods
342
+ -------
343
+ train_model()
344
+ train model for initialized number of epochs
345
+ evaluate(return_output)
346
+ use model with validation loader (& optionally drugex loader) to get test loss & other metrics
347
+ translate(loader)
348
+ translate inputs from loader (different from evaluate in that no target sequence is used)
349
+ """
350
+
351
+ def train_model(self):
352
+ optimizer = optim.Adam(self.parameters(), lr=self.lr)
353
+ log = open(f"{self.out}.log", "a")
354
+ best_error = np.inf
355
+ for epoch in range(self.epochs):
356
+ self.train()
357
+ start_time = time.time()
358
+ loss_train = 0
359
+ for i, batch in enumerate(self.loader_train):
360
+ optimizer.zero_grad()
361
+ # changed src,trg call to match with bentrevett
362
+ # src, trg = batch['src'], batch['trg']
363
+ trg = batch.trg
364
+ src = batch.src
365
+ output, attention = self(src, trg[:, :-1])
366
+ # feed the source and target into def forward to get the output
367
+ # Xuhan uses forward for this, with istrain = true
368
+ output_dim = output.shape[-1]
369
+ # changed
370
+ output = output.contiguous().view(-1, output_dim)
371
+ trg = trg[:, 1:].contiguous().view(-1)
372
+ # output = output[:,:,0]#.view(-1)
373
+ # output = output[1:].view(-1, output.shape[-1])
374
+ # trg = trg[1:].view(-1)
375
+ loss = nn.CrossEntropyLoss(
376
+ ignore_index=self.TRG.vocab.stoi[self.TRG.pad_token])
377
+ a, b = output.view(-1), trg.to(self.device).view(-1)
378
+ # changed
379
+ # loss = loss(output.view(0), trg.view(0).to(device))
380
+ loss = loss(output, trg)
381
+ loss.backward()
382
+ torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip)
383
+ optimizer.step()
384
+ loss_train += loss.item()
385
+ # turned off for now, as not using voc so won't work, output is a tensor
386
+ # output = [(trg len - 1) * batch size, output dim]
387
+ # smiles, valid = is_valid_smiles(output, reversed)
388
+ # if valid:
389
+ # valids += 1
390
+ # smiless.append(smiles)
391
+ # added .dataset becaue len(iterator) gives len(self.dataset) /
392
+ # self.batch_size)
393
+ loss_train /= len(self.loader_train)
394
+ info = f"Epoch: {epoch+1:02} step: {i} loss_train: {loss_train:.4g}"
395
+ # model is used to generate trg based on src from the validation set to assess performance
396
+ # similar to Xuhan, although he doesn't use the if loop
397
+ if self.loader_valid is not None:
398
+ return_output = False
399
+ if epoch + 1 == self.epochs:
400
+ return_output = True
401
+ (
402
+ valids,
403
+ loss_valid,
404
+ valids_de,
405
+ df_output,
406
+ df_output_de,
407
+ right_molecules,
408
+ complexity,
409
+ unchanged,
410
+ unchanged_de,
411
+ ) = self.evaluate(return_output)
412
+ reconstruction_error = 1 - right_molecules / len(
413
+ self.loader_valid.dataset)
414
+ error = 1 - valids / len(self.loader_valid.dataset)
415
+ complexity = complexity / len(self.loader_valid)
416
+ unchan = unchanged / (len(self.loader_valid.dataset) - valids)
417
+ info += f" loss_valid: {loss_valid:.4g} error_rate: {error:.4g} molecule_reconstruction_error_rate: {reconstruction_error:.4g} unchanged: {unchan:.4g} invalid_target_complexity: {complexity:.4g}"
418
+ if self.loader_drugex is not None:
419
+ error_de = 1 - valids_de / len(self.loader_drugex.dataset)
420
+ unchan_de = unchanged_de / (
421
+ len(self.loader_drugex.dataset) - valids_de)
422
+ info += f" error_rate_drugex: {error_de:.4g} unchanged_drugex: {unchan_de:.4g}"
423
+
424
+ if reconstruction_error < best_error:
425
+ torch.save(self.state_dict(), f"{self.out}.pkg")
426
+ best_error = reconstruction_error
427
+ last_save = epoch
428
+ else:
429
+ if epoch - last_save >= 10 and best_error != 1:
430
+ torch.save(self.state_dict(), f"{self.out}_last.pkg")
431
+ (
432
+ valids,
433
+ loss_valid,
434
+ valids_de,
435
+ df_output,
436
+ df_output_de,
437
+ right_molecules,
438
+ complexity,
439
+ unchanged,
440
+ unchanged_de,
441
+ ) = self.evaluate(True)
442
+ end_time = time.time()
443
+ epoch_mins, epoch_secs = epoch_time(
444
+ start_time, end_time)
445
+ info += f" Time: {epoch_mins}m {epoch_secs}s"
446
+
447
+ break
448
+ elif error < best_error:
449
+ torch.save(self.state_dict(), f"{self.out}.pkg")
450
+ best_error = error
451
+ end_time = time.time()
452
+ epoch_mins, epoch_secs = epoch_time(start_time, end_time)
453
+ info += f" Time: {epoch_mins}m {epoch_secs}s"
454
+
455
+
456
+ torch.save(self.state_dict(), f"{self.out}_last.pkg")
457
+ log.close()
458
+ self.load_state_dict(torch.load(f"{self.out}.pkg"))
459
+ df_output.to_csv(f"{self.out}.csv", index=False)
460
+ df_output_de.to_csv(f"{self.out}_de.csv", index=False)
461
+
462
+ def evaluate(self, return_output):
463
+ self.eval()
464
+ test_loss = 0
465
+ df_output = pd.DataFrame()
466
+ df_output_de = pd.DataFrame()
467
+ valids = 0
468
+ valids_de = 0
469
+ unchanged = 0
470
+ unchanged_de = 0
471
+ right_molecules = 0
472
+ complexity = 0
473
+ with torch.no_grad():
474
+ for _, batch in enumerate(self.loader_valid):
475
+ trg = batch.trg
476
+ src = batch.src
477
+ output, attention = self.forward(src, trg[:, :-1])
478
+ pred_token = output.argmax(2)
479
+ array = torch.transpose(pred_token, 0, 1)
480
+ trg_trans = torch.transpose(trg, 0, 1)
481
+ output_dim = output.shape[-1]
482
+ output = output.contiguous().view(-1, output_dim)
483
+ trg = trg[:, 1:].contiguous().view(-1)
484
+ src_trans = torch.transpose(src, 0, 1)
485
+ df_batch, valid, smiless = is_smiles(
486
+ array, self.TRG, reverse=True, return_output=return_output)
487
+ unchanged += is_unchanged(
488
+ array,
489
+ self.TRG,
490
+ reverse=True,
491
+ return_output=return_output,
492
+ src=src_trans,
493
+ src_field=self.SRC,
494
+ )
495
+ matches = molecule_reconstruction(trg_trans,
496
+ self.TRG,
497
+ reverse=True,
498
+ outputs=smiless)
499
+ complexity += calc_complexity(trg_trans,
500
+ self.TRG,
501
+ reverse=True,
502
+ valids=valid)
503
+ if df_batch is not None:
504
+ df_output = pd.concat([df_output, df_batch],
505
+ ignore_index=True)
506
+ right_molecules += matches
507
+ valids += sum(valid)
508
+ # trg = trg[1:].view(-1)
509
+ # output, trg = output[1:].view(-1, output.shape[-1]), trg[1:].view(-1)
510
+ loss = nn.CrossEntropyLoss(
511
+ ignore_index=self.TRG.vocab.stoi[self.TRG.pad_token])
512
+ loss = loss(output, trg)
513
+ test_loss += loss.item()
514
+ if self.loader_drugex is not None:
515
+ for _, batch in enumerate(self.loader_drugex):
516
+ src = batch.src
517
+ output = self.translate_sentence(src, self.TRG,
518
+ self.device)
519
+ # checks the number of valid smiles
520
+ pred_token = output.argmax(2)
521
+ array = torch.transpose(pred_token, 0, 1)
522
+ src_trans = torch.transpose(src, 0, 1)
523
+ df_batch, valid, smiless = is_smiles(
524
+ array,
525
+ self.TRG,
526
+ reverse=True,
527
+ return_output=return_output,
528
+ src=src_trans,
529
+ src_field=self.SRC,
530
+ )
531
+ unchanged_de += is_unchanged(
532
+ array,
533
+ self.TRG,
534
+ reverse=True,
535
+ return_output=return_output,
536
+ src=src_trans,
537
+ src_field=self.SRC,
538
+ )
539
+ if df_batch is not None:
540
+ df_output_de = pd.concat([df_output_de, df_batch],
541
+ ignore_index=True)
542
+ valids_de += sum(valid)
543
+ return (
544
+ valids,
545
+ test_loss / len(self.loader_valid),
546
+ valids_de,
547
+ df_output,
548
+ df_output_de,
549
+ right_molecules,
550
+ complexity,
551
+ unchanged,
552
+ unchanged_de,
553
+ )
554
+
555
+ def translate(self, loader):
556
+ self.eval()
557
+ df_output_de = pd.DataFrame()
558
+ valids_de = 0
559
+ with torch.no_grad():
560
+ for _, batch in enumerate(loader):
561
+ src = batch.src
562
+ output = self.translate_sentence(src, self.TRG, self.device)
563
+ # checks the number of valid smiles
564
+ pred_token = output.argmax(2)
565
+ array = torch.transpose(pred_token, 0, 1)
566
+ src_trans = torch.transpose(src, 0, 1)
567
+ df_batch, valid, smiless = is_smiles(
568
+ array,
569
+ self.TRG,
570
+ reverse=True,
571
+ return_output=True,
572
+ src=src_trans,
573
+ src_field=self.SRC,
574
+ )
575
+ if df_batch is not None:
576
+ df_output_de = pd.concat([df_output_de, df_batch],
577
+ ignore_index=True)
578
+ valids_de += sum(valid)
579
+ return valids_de, df_output_de
580
+
581
+
582
+ class Encoder(nn.Module):
583
+
584
+ def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout,
585
+ max_length, device):
586
+ super().__init__()
587
+ self.device = device
588
+ self.tok_embedding = nn.Embedding(input_dim, hid_dim)
589
+ self.pos_embedding = nn.Embedding(max_length, hid_dim)
590
+ self.layers = nn.ModuleList([
591
+ EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
592
+ for _ in range(n_layers)
593
+ ])
594
+
595
+ self.dropout = nn.Dropout(dropout)
596
+ self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
597
+
598
+ def forward(self, src, src_mask):
599
+ # src = [batch size, src len]
600
+ # src_mask = [batch size, src len]
601
+ batch_size = src.shape[0]
602
+ src_len = src.shape[1]
603
+ pos = (torch.arange(0, src_len).unsqueeze(0).repeat(batch_size,
604
+ 1).to(self.device))
605
+ # pos = [batch size, src len]
606
+ src = self.dropout((self.tok_embedding(src) * self.scale) +
607
+ self.pos_embedding(pos))
608
+ # src = [batch size, src len, hid dim]
609
+ for layer in self.layers:
610
+ src = layer(src, src_mask)
611
+ # src = [batch size, src len, hid dim]
612
+ return src
613
+
614
+
615
+ class EncoderLayer(nn.Module):
616
+
617
+ def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
618
+ super().__init__()
619
+
620
+ self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
621
+ self.ff_layer_norm = nn.LayerNorm(hid_dim)
622
+ self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads,
623
+ dropout, device)
624
+ self.positionwise_feedforward = PositionwiseFeedforwardLayer(
625
+ hid_dim, pf_dim, dropout)
626
+ self.dropout = nn.Dropout(dropout)
627
+
628
+ def forward(self, src, src_mask):
629
+ # src = [batch size, src len, hid dim]
630
+ # src_mask = [batch size, src len]
631
+ # self attention
632
+ _src, _ = self.self_attention(src, src, src, src_mask)
633
+ # dropout, residual connection and layer norm
634
+ src = self.self_attn_layer_norm(src + self.dropout(_src))
635
+ # src = [batch size, src len, hid dim]
636
+ # positionwise feedforward
637
+ _src = self.positionwise_feedforward(src)
638
+ # dropout, residual and layer norm
639
+ src = self.ff_layer_norm(src + self.dropout(_src))
640
+ # src = [batch size, src len, hid dim]
641
+
642
+ return src
643
+
644
+
645
+ class MultiHeadAttentionLayer(nn.Module):
646
+
647
+ def __init__(self, hid_dim, n_heads, dropout, device):
648
+ super().__init__()
649
+ assert hid_dim % n_heads == 0
650
+ self.hid_dim = hid_dim
651
+ self.n_heads = n_heads
652
+ self.head_dim = hid_dim // n_heads
653
+ self.fc_q = nn.Linear(hid_dim, hid_dim)
654
+ self.fc_k = nn.Linear(hid_dim, hid_dim)
655
+ self.fc_v = nn.Linear(hid_dim, hid_dim)
656
+ self.fc_o = nn.Linear(hid_dim, hid_dim)
657
+ self.dropout = nn.Dropout(dropout)
658
+ self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
659
+
660
+ def forward(self, query, key, value, mask=None):
661
+ batch_size = query.shape[0]
662
+ # query = [batch size, query len, hid dim]
663
+ # key = [batch size, key len, hid dim]
664
+ # value = [batch size, value len, hid dim]
665
+ Q = self.fc_q(query)
666
+ K = self.fc_k(key)
667
+ V = self.fc_v(value)
668
+ # Q = [batch size, query len, hid dim]
669
+ # K = [batch size, key len, hid dim]
670
+ # V = [batch size, value len, hid dim]
671
+ Q = Q.view(batch_size, -1, self.n_heads,
672
+ self.head_dim).permute(0, 2, 1, 3)
673
+ K = K.view(batch_size, -1, self.n_heads,
674
+ self.head_dim).permute(0, 2, 1, 3)
675
+ V = V.view(batch_size, -1, self.n_heads,
676
+ self.head_dim).permute(0, 2, 1, 3)
677
+ # Q = [batch size, n heads, query len, head dim]
678
+ # K = [batch size, n heads, key len, head dim]
679
+ # V = [batch size, n heads, value len, head dim]
680
+ energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
681
+ # energy = [batch size, n heads, query len, key len]
682
+ if mask is not None:
683
+ energy = energy.masked_fill(mask == 0, -1e10)
684
+ attention = torch.softmax(energy, dim=-1)
685
+ # attention = [batch size, n heads, query len, key len]
686
+ x = torch.matmul(self.dropout(attention), V)
687
+ # x = [batch size, n heads, query len, head dim]
688
+ x = x.permute(0, 2, 1, 3).contiguous()
689
+ # x = [batch size, query len, n heads, head dim]
690
+ x = x.view(batch_size, -1, self.hid_dim)
691
+ # x = [batch size, query len, hid dim]
692
+ x = self.fc_o(x)
693
+ # x = [batch size, query len, hid dim]
694
+ return x, attention
695
+
696
+
697
+ class PositionwiseFeedforwardLayer(nn.Module):
698
+
699
+ def __init__(self, hid_dim, pf_dim, dropout):
700
+ super().__init__()
701
+ self.fc_1 = nn.Linear(hid_dim, pf_dim)
702
+ self.fc_2 = nn.Linear(pf_dim, hid_dim)
703
+ self.dropout = nn.Dropout(dropout)
704
+
705
+ def forward(self, x):
706
+ # x = [batch size, seq len, hid dim]
707
+ x = self.dropout(torch.relu(self.fc_1(x)))
708
+ # x = [batch size, seq len, pf dim]
709
+ x = self.fc_2(x)
710
+ # x = [batch size, seq len, hid dim]
711
+
712
+ return x
713
+
714
+
715
+ class Decoder(nn.Module):
716
+
717
+ def __init__(
718
+ self,
719
+ output_dim,
720
+ hid_dim,
721
+ n_layers,
722
+ n_heads,
723
+ pf_dim,
724
+ dropout,
725
+ max_length,
726
+ device,
727
+ ):
728
+ super().__init__()
729
+ self.device = device
730
+ self.tok_embedding = nn.Embedding(output_dim, hid_dim)
731
+ self.pos_embedding = nn.Embedding(max_length, hid_dim)
732
+ self.layers = nn.ModuleList([
733
+ DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
734
+ for _ in range(n_layers)
735
+ ])
736
+ self.fc_out = nn.Linear(hid_dim, output_dim)
737
+ self.dropout = nn.Dropout(dropout)
738
+ self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
739
+
740
+ def forward(self, trg, enc_src, trg_mask, src_mask):
741
+ # trg = [batch size, trg len]
742
+ # enc_src = [batch size, src len, hid dim]
743
+ # trg_mask = [batch size, trg len]
744
+ # src_mask = [batch size, src len]
745
+ batch_size = trg.shape[0]
746
+ trg_len = trg.shape[1]
747
+ pos = (torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size,
748
+ 1).to(self.device))
749
+ # pos = [batch size, trg len]
750
+ trg = self.dropout((self.tok_embedding(trg) * self.scale) +
751
+ self.pos_embedding(pos))
752
+ # trg = [batch size, trg len, hid dim]
753
+ for layer in self.layers:
754
+ trg, attention = layer(trg, enc_src, trg_mask, src_mask)
755
+ # trg = [batch size, trg len, hid dim]
756
+ # attention = [batch size, n heads, trg len, src len]
757
+ output = self.fc_out(trg)
758
+ # output = [batch size, trg len, output dim]
759
+ return output, attention
760
+
761
+
762
+ class DecoderLayer(nn.Module):
763
+
764
+ def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
765
+ super().__init__()
766
+ self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
767
+ self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
768
+ self.ff_layer_norm = nn.LayerNorm(hid_dim)
769
+ self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads,
770
+ dropout, device)
771
+ self.encoder_attention = MultiHeadAttentionLayer(
772
+ hid_dim, n_heads, dropout, device)
773
+ self.positionwise_feedforward = PositionwiseFeedforwardLayer(
774
+ hid_dim, pf_dim, dropout)
775
+ self.dropout = nn.Dropout(dropout)
776
+
777
+ def forward(self, trg, enc_src, trg_mask, src_mask):
778
+ # trg = [batch size, trg len, hid dim]
779
+ # enc_src = [batch size, src len, hid dim]
780
+ # trg_mask = [batch size, trg len]
781
+ # src_mask = [batch size, src len]
782
+ # self attention
783
+ _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
784
+ # dropout, residual connection and layer norm
785
+ trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
786
+ # trg = [batch size, trg len, hid dim]
787
+ # encoder attention
788
+ _trg, attention = self.encoder_attention(trg, enc_src, enc_src,
789
+ src_mask)
790
+ # dropout, residual connection and layer norm
791
+ trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
792
+ # trg = [batch size, trg len, hid dim]
793
+ # positionwise feedforward
794
+ _trg = self.positionwise_feedforward(trg)
795
+ # dropout, residual and layer norm
796
+ trg = self.ff_layer_norm(trg + self.dropout(_trg))
797
+ # trg = [batch size, trg len, hid dim]
798
+ # attention = [batch size, n heads, trg len, src len]
799
+ return trg, attention
800
+
801
+
802
+ class Seq2Seq(nn.Module, Convo):
803
+
804
+ def __init__(
805
+ self,
806
+ encoder,
807
+ decoder,
808
+ src_pad_idx,
809
+ trg_pad_idx,
810
+ device,
811
+ loader_train: DataLoader,
812
+ out: str,
813
+ loader_valid=None,
814
+ loader_drugex=None,
815
+ epochs=100,
816
+ lr=0.0005,
817
+ clip=0.1,
818
+ reverse=True,
819
+ TRG=None,
820
+ SRC=None,
821
+ ):
822
+ super().__init__()
823
+ self.encoder = encoder
824
+ self.decoder = decoder
825
+ self.src_pad_idx = src_pad_idx
826
+ self.trg_pad_idx = trg_pad_idx
827
+ self.device = device
828
+ self.loader_train = loader_train
829
+ self.out = out
830
+ self.loader_valid = loader_valid
831
+ self.loader_drugex = loader_drugex
832
+ self.epochs = epochs
833
+ self.lr = lr
834
+ self.clip = clip
835
+ self.reverse = reverse
836
+ self.TRG = TRG
837
+ self.SRC = SRC
838
+
839
+ def make_src_mask(self, src):
840
+ # src = [batch size, src len]
841
+ src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
842
+ # src_mask = [batch size, 1, 1, src len]
843
+ return src_mask
844
+
845
+ def make_trg_mask(self, trg):
846
+ # trg = [batch size, trg len]
847
+ trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
848
+ # trg_pad_mask = [batch size, 1, 1, trg len]
849
+ trg_len = trg.shape[1]
850
+ trg_sub_mask = torch.tril(
851
+ torch.ones((trg_len, trg_len), device=self.device)).bool()
852
+ # trg_sub_mask = [trg len, trg len]
853
+ trg_mask = trg_pad_mask & trg_sub_mask
854
+ # trg_mask = [batch size, 1, trg len, trg len]
855
+ return trg_mask
856
+
857
+ def forward(self, src, trg):
858
+ # src = [batch size, src len]
859
+ # trg = [batch size, trg len]
860
+ src_mask = self.make_src_mask(src)
861
+ trg_mask = self.make_trg_mask(trg)
862
+ # src_mask = [batch size, 1, 1, src len]
863
+ # trg_mask = [batch size, 1, trg len, trg len]
864
+ enc_src = self.encoder(src, src_mask)
865
+ # enc_src = [batch size, src len, hid dim]
866
+ output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
867
+ # output = [batch size, trg len, output dim]
868
+ # attention = [batch size, n heads, trg len, src len]
869
+ return output, attention
870
+
871
+ def translate_sentence(self, src, trg_field, device, max_len=202):
872
+ self.eval()
873
+ src_mask = self.make_src_mask(src)
874
+ with torch.no_grad():
875
+ enc_src = self.encoder(src, src_mask)
876
+ trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
877
+ batch_size = src.shape[0]
878
+ trg = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
879
+ trg = trg.repeat(batch_size, 1)
880
+ for i in range(max_len):
881
+ # turned model into self.
882
+ trg_mask = self.make_trg_mask(trg)
883
+ with torch.no_grad():
884
+ output, attention = self.decoder(trg, enc_src, trg_mask,
885
+ src_mask)
886
+ pred_tokens = output.argmax(2)[:, -1].unsqueeze(1)
887
+ trg = torch.cat((trg, pred_tokens), 1)
888
+
889
+ return output
890
+
891
+
892
+ def remove_floats(df: pd.DataFrame, subset: str):
893
+ """Preprocessing step to remove any entries that are not strings"""
894
+ df_subset = df[subset]
895
+ df[subset] = df[subset].astype(str)
896
+ # only keep entries that stayed the same after applying astype str
897
+ df = df[df[subset] == df_subset].copy()
898
+
899
+ return df
900
+
901
+
902
+ def smi_tokenizer(smi: str, reverse=False) -> list:
903
+ """
904
+ Tokenize a SMILES molecule
905
+ """
906
+ pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
907
+ regex = re.compile(pattern)
908
+ # tokens = ['<sos>'] + [token for token in regex.findall(smi)] + ['<eos>']
909
+ tokens = [token for token in regex.findall(smi)]
910
+ # assert smi == ''.join(tokens[1:-1])
911
+ assert smi == "".join(tokens[:])
912
+ # try:
913
+ # assert smi == "".join(tokens[:])
914
+ # except:
915
+ # print(smi)
916
+ # print("".join(tokens[:]))
917
+ if reverse:
918
+ return tokens[::-1]
919
+ return tokens
920
+
921
+
922
+ def init_weights(m: nn.Module):
923
+ if hasattr(m, "weight") and m.weight.dim() > 1:
924
+ nn.init.xavier_uniform_(m.weight.data)
925
+
926
+
927
+ def count_parameters(model: nn.Module):
928
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
929
+
930
+
931
+ def epoch_time(start_time, end_time):
932
+ elapsed_time = end_time - start_time
933
+ elapsed_mins = int(elapsed_time / 60)
934
+ elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
935
+ return elapsed_mins, elapsed_secs
936
+
937
+
938
+ def initialize_model(folder_out: str,
939
+ data_source: str,
940
+ error_source: str,
941
+ device: torch.device,
942
+ threshold: int,
943
+ epochs: int,
944
+ layers: int = 3,
945
+ batch_size: int = 16,
946
+ invalid_type: str = "all",
947
+ num_errors: int = 1,
948
+ validation_step=False):
949
+ """Create encoder decoder models for specified model (currently only translator) & type of invalid SMILES
950
+
951
+ param data: collection of invalid, valid SMILES pairs
952
+ param invalid_smiles_path: path to previously generated invalid SMILES
953
+ param invalid_type: type of errors introduced into invalid SMILES
954
+
955
+ return:
956
+
957
+ """
958
+
959
+ # set fields
960
+ SRC = Field(
961
+ tokenize=lambda x: smi_tokenizer(x),
962
+ init_token="<sos>",
963
+ eos_token="<eos>",
964
+ batch_first=True,
965
+ )
966
+ TRG = Field(
967
+ tokenize=lambda x: smi_tokenizer(x, reverse=True),
968
+ init_token="<sos>",
969
+ eos_token="<eos>",
970
+ batch_first=True,
971
+ )
972
+
973
+ if validation_step:
974
+ train, val = TabularDataset.splits(
975
+ path=f'{folder_out}errors/split/',
976
+ train=f"{data_source}_{invalid_type}_{num_errors}_errors_train.csv",
977
+ validation=
978
+ f"{data_source}_{invalid_type}_{num_errors}_errors_dev.csv",
979
+ format="CSV",
980
+ skip_header=False,
981
+ fields={
982
+ "ERROR": ("src", SRC),
983
+ "STD_SMILES": ("trg", TRG)
984
+ },
985
+ )
986
+ SRC.build_vocab(train, val, max_size=1000)
987
+ TRG.build_vocab(train, val, max_size=1000)
988
+ else:
989
+ train = TabularDataset(
990
+ path=
991
+ f'{folder_out}{data_source}_{invalid_type}_{num_errors}_errors.csv',
992
+ format="CSV",
993
+ skip_header=False,
994
+ fields={
995
+ "ERROR": ("src", SRC),
996
+ "STD_SMILES": ("trg", TRG)
997
+ },
998
+ )
999
+ SRC.build_vocab(train, max_size=1000)
1000
+ TRG.build_vocab(train, max_size=1000)
1001
+
1002
+ drugex = TabularDataset(
1003
+ path=error_source,
1004
+ format="csv",
1005
+ skip_header=False,
1006
+ fields={
1007
+ "SMILES": ("src", SRC),
1008
+ "SMILES_TARGET": ("trg", TRG)
1009
+ },
1010
+ )
1011
+
1012
+
1013
+ #SRC.vocab = torch.load('vocab_src.pth')
1014
+ #TRG.vocab = torch.load('vocab_trg.pth')
1015
+
1016
+ # model parameters
1017
+ EPOCHS = epochs
1018
+ BATCH_SIZE = batch_size
1019
+ INPUT_DIM = len(SRC.vocab)
1020
+ OUTPUT_DIM = len(TRG.vocab)
1021
+ HID_DIM = 256
1022
+ ENC_LAYERS = layers
1023
+ DEC_LAYERS = layers
1024
+ ENC_HEADS = 8
1025
+ DEC_HEADS = 8
1026
+ ENC_PF_DIM = 512
1027
+ DEC_PF_DIM = 512
1028
+ ENC_DROPOUT = 0.1
1029
+ DEC_DROPOUT = 0.1
1030
+ SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
1031
+ TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
1032
+ # add 2 to length for start and stop tokens
1033
+ MAX_LENGTH = threshold + 2
1034
+
1035
+ # model name
1036
+ MODEL_OUT_FOLDER = f"{folder_out}"
1037
+
1038
+ MODEL_NAME = "transformer_%s_%s_%s_%s_%s" % (
1039
+ invalid_type, num_errors, data_source, BATCH_SIZE, layers)
1040
+ if not os.path.exists(MODEL_OUT_FOLDER):
1041
+ os.mkdir(MODEL_OUT_FOLDER)
1042
+
1043
+ out = os.path.join(MODEL_OUT_FOLDER, MODEL_NAME)
1044
+
1045
+ torch.save(SRC.vocab, f'{out}_vocab_src.pth')
1046
+ torch.save(TRG.vocab, f'{out}_vocab_trg.pth')
1047
+
1048
+ # iterator is a dataloader
1049
+ # iterator to pass to the same length and create batches in which the
1050
+ # amount of padding is minimized
1051
+ if validation_step:
1052
+ train_iter, val_iter = BucketIterator.splits(
1053
+ (train, val),
1054
+ batch_sizes=(BATCH_SIZE, 256),
1055
+ sort_within_batch=True,
1056
+ shuffle=True,
1057
+ # the BucketIterator needs to be told what function it should use to
1058
+ # group the data.
1059
+ sort_key=lambda x: len(x.src),
1060
+ device=device,
1061
+ )
1062
+ else:
1063
+ train_iter = BucketIterator(
1064
+ train,
1065
+ batch_size=BATCH_SIZE,
1066
+ sort_within_batch=True,
1067
+ shuffle=True,
1068
+ # the BucketIterator needs to be told what function it should use to
1069
+ # group the data.
1070
+ sort_key=lambda x: len(x.src),
1071
+ device=device,
1072
+ )
1073
+ val_iter = None
1074
+
1075
+ drugex_iter = Iterator(
1076
+ drugex,
1077
+ batch_size=64,
1078
+ device=device,
1079
+ sort=False,
1080
+ sort_within_batch=True,
1081
+ sort_key=lambda x: len(x.src),
1082
+ repeat=False,
1083
+ )
1084
+
1085
+
1086
+ # model initialization
1087
+
1088
+ enc = Encoder(
1089
+ INPUT_DIM,
1090
+ HID_DIM,
1091
+ ENC_LAYERS,
1092
+ ENC_HEADS,
1093
+ ENC_PF_DIM,
1094
+ ENC_DROPOUT,
1095
+ MAX_LENGTH,
1096
+ device,
1097
+ )
1098
+ dec = Decoder(
1099
+ OUTPUT_DIM,
1100
+ HID_DIM,
1101
+ DEC_LAYERS,
1102
+ DEC_HEADS,
1103
+ DEC_PF_DIM,
1104
+ DEC_DROPOUT,
1105
+ MAX_LENGTH,
1106
+ device,
1107
+ )
1108
+
1109
+ model = Seq2Seq(
1110
+ enc,
1111
+ dec,
1112
+ SRC_PAD_IDX,
1113
+ TRG_PAD_IDX,
1114
+ device,
1115
+ train_iter,
1116
+ out=out,
1117
+ loader_valid=val_iter,
1118
+ loader_drugex=drugex_iter,
1119
+ epochs=EPOCHS,
1120
+ TRG=TRG,
1121
+ SRC=SRC,
1122
+ ).to(device)
1123
+
1124
+
1125
+
1126
+
1127
+ return model, out, SRC
1128
+
1129
+
1130
+ def train_model(model, out, assess):
1131
+ """Apply given weights (& assess performance or train further) or start training new model
1132
+
1133
+ Args:
1134
+ model: initialized model
1135
+ out: .pkg file with model parameters
1136
+ asses: bool
1137
+
1138
+ Returns:
1139
+ model with (new) weights
1140
+ """
1141
+
1142
+ if os.path.exists(f"{out}.pkg") and assess:
1143
+
1144
+
1145
+ model.load_state_dict(torch.load(f=out + ".pkg"))
1146
+ (
1147
+ valids,
1148
+ loss_valid,
1149
+ valids_de,
1150
+ df_output,
1151
+ df_output_de,
1152
+ right_molecules,
1153
+ complexity,
1154
+ unchanged,
1155
+ unchanged_de,
1156
+ ) = model.evaluate(True)
1157
+
1158
+
1159
+ # log = open('unchanged.log', 'a')
1160
+ # info = f'type: comb unchanged: {unchan:.4g} unchanged_drugex: {unchan_de:.4g}'
1161
+ # print(info, file=log, flush = True)
1162
+ # print(valids_de)
1163
+ # print(unchanged_de)
1164
+
1165
+ # print(unchan)
1166
+ # print(unchan_de)
1167
+ # df_output_de.to_csv(f'{out}_de_new.csv', index = False)
1168
+
1169
+ # error_de = 1 - valids_de / len(drugex_iter.dataset)
1170
+ # print(error_de)
1171
+ # df_output.to_csv(f'{out}_par.csv', index = False)
1172
+
1173
+ elif os.path.exists(f"{out}.pkg"):
1174
+
1175
+ # starts from the model after the last epoch, not the best epoch
1176
+ model.load_state_dict(torch.load(f=out + "_last.pkg"))
1177
+ # need to change how log file names epochs
1178
+ model.train_model()
1179
+ else:
1180
+
1181
+ model = model.apply(init_weights)
1182
+ model.train_model()
1183
+
1184
+ return model
1185
+
1186
+
1187
+ def correct_SMILES(model, out, error_source, device, SRC):
1188
+ """Model that is given corrects SMILES and return number of correct ouputs and dataframe containing all outputs
1189
+ Args:
1190
+ model: initialized model
1191
+ out: .pkg file with model parameters
1192
+ asses: bool
1193
+
1194
+ Returns:
1195
+ valids: number of fixed outputs
1196
+ df_output: dataframe containing output (either correct or incorrect) & original input
1197
+ """
1198
+ ## account for tokens that are not yet in SRC without changing existing SRC token embeddings
1199
+ errors = TabularDataset(
1200
+ path=error_source,
1201
+ format="csv",
1202
+ skip_header=False,
1203
+ fields={"SMILES": ("src", SRC)},
1204
+ )
1205
+
1206
+ errors_loader = Iterator(
1207
+ errors,
1208
+ batch_size=64,
1209
+ device=device,
1210
+ sort=False,
1211
+ sort_within_batch=True,
1212
+ sort_key=lambda x: len(x.src),
1213
+ repeat=False,
1214
+ )
1215
+ model.load_state_dict(torch.load(f=out + ".pkg",map_location=torch.device('cpu')))
1216
+ # add option to use different iterator maybe?
1217
+
1218
+ valids, df_output = model.translate(errors_loader)
1219
+ #df_output.to_csv(f"{error_source}_fixed.csv", index=False)
1220
+
1221
+
1222
+ return valids, df_output
1223
+
1224
+
1225
+
1226
+ class smi_correct(object):
1227
+ def __init__(self, model_name, trans_file_path):
1228
+ # set random seed, used for error generation & initiation transformer
1229
+
1230
+ self.SEED = 42
1231
+ random.seed(self.SEED)
1232
+ self.model_name = model_name
1233
+ self.folder_out = "DrugGEN/data/"
1234
+
1235
+ self.trans_file_path = trans_file_path
1236
+
1237
+ if not os.path.exists(self.folder_out):
1238
+ os.makedirs(self.folder_out)
1239
+
1240
+ self.invalid_type = 'multiple'
1241
+ self.num_errors = 12
1242
+ self.threshold = 200
1243
+ self.data_source = f"PAPYRUS_{self.threshold}"
1244
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
1245
+ self.initialize_source = 'DrugGEN/data/papyrus_rnn_S.csv' # change this path
1246
+
1247
+ def standardization_pipeline(self, smile):
1248
+ desalter = MolStandardize.fragment.LargestFragmentChooser()
1249
+ std_smile = None
1250
+ if not isinstance(smile, str): return None
1251
+ m = Chem.MolFromSmiles(smile)
1252
+ # skips smiles for which no mol file could be generated
1253
+ if m is not None:
1254
+ # standardizes
1255
+ std_m = standardizer.standardize_mol(m)
1256
+ # strips salts
1257
+ std_m_p, exclude = standardizer.get_parent_mol(std_m)
1258
+ if not exclude:
1259
+ # choose largest fragment for rare cases where chembl structure
1260
+ # pipeline leaves 2 fragments
1261
+ std_m_p_d = desalter.choose(std_m_p)
1262
+ std_smile = Chem.MolToSmiles(std_m_p_d)
1263
+ return std_smile
1264
+
1265
+ def remove_smiles_duplicates(self, dataframe: pd.DataFrame,
1266
+ subset: str) -> pd.DataFrame:
1267
+ return dataframe.drop_duplicates(subset=subset)
1268
+
1269
+ def correct(self, smi):
1270
+
1271
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1272
+
1273
+ model, out, SRC = initialize_model(self.folder_out,
1274
+ self.data_source,
1275
+ error_source=self.initialize_source,
1276
+ device=device,
1277
+ threshold=self.threshold,
1278
+ epochs=30,
1279
+ layers=3,
1280
+ batch_size=16,
1281
+ invalid_type=self.invalid_type,
1282
+ num_errors=self.num_errors)
1283
+
1284
+ valids, df_output = correct_SMILES(model, out, smi, device,
1285
+ SRC)
1286
+
1287
+ df_output["SMILES"] = df_output.apply(lambda row: self.standardization_pipeline(row["CORRECT"]), axis=1)
1288
+
1289
+ df_output = self.remove_smiles_duplicates(df_output, subset="SMILES").drop(columns=["CORRECT", "INCORRECT", "ORIGINAL"]).dropna()
1290
+
1291
+ return df_output
utils.py CHANGED
@@ -42,7 +42,15 @@ class Metrics(object):
42
 
43
  @staticmethod
44
  def max_component(data, max_len):
 
 
45
 
 
 
 
 
 
 
46
  return ((np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean())
47
 
48
  @staticmethod
@@ -347,7 +355,7 @@ def canonic_smiles(smiles_or_mol):
347
  if mol is None:
348
  return None
349
  return Chem.MolToSmiles(mol)
350
- def fraction_unique(gen, k=None, n_jobs=1, check_validity=False):
351
  """
352
  Computes a number of unique molecules
353
  Parameters:
@@ -363,11 +371,13 @@ def fraction_unique(gen, k=None, n_jobs=1, check_validity=False):
363
  "gen contains only {} molecules".format(len(gen))
364
  )
365
  gen = gen[:k]
366
- canonic = set(mapper(n_jobs)(canonic_smiles, gen))
367
- if None in canonic and check_validity:
368
- #canonic = [i for i in canonic if i is not None]
369
- raise ValueError("Invalid molecule passed to unique@k")
370
- return 0 if len(gen) == 0 else len(canonic) / len(gen)
 
 
371
 
372
  def novelty(gen, train, n_jobs=1):
373
  gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
@@ -375,7 +385,8 @@ def novelty(gen, train, n_jobs=1):
375
  train_set = set(train)
376
  return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
377
 
378
-
 
379
 
380
  def average_agg_tanimoto(stock_vecs, gen_vecs,
381
  batch_size=5000, agg='max',
 
42
 
43
  @staticmethod
44
  def max_component(data, max_len):
45
+
46
+ # There will be a name change for this function to better reflect what it does
47
 
48
+ """Returns the average length of the molecules in the dataset normalized by the maximum length.
49
+
50
+ Returns:
51
+ array: normalized average length of the molecules in the dataset
52
+ """
53
+
54
  return ((np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean())
55
 
56
  @staticmethod
 
355
  if mol is None:
356
  return None
357
  return Chem.MolToSmiles(mol)
358
+ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
359
  """
360
  Computes a number of unique molecules
361
  Parameters:
 
371
  "gen contains only {} molecules".format(len(gen))
372
  )
373
  gen = gen[:k]
374
+ if check_validity:
375
+
376
+ canonic = list(mapper(n_jobs)(canonic_smiles, gen))
377
+ canonic = [i for i in canonic if i is not None]
378
+ set_cannonic = set(canonic)
379
+ #raise ValueError("Invalid molecule passed to unique@k")
380
+ return 0 if len(canonic) == 0 else len(set_cannonic) / len(canonic)
381
 
382
  def novelty(gen, train, n_jobs=1):
383
  gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
 
385
  train_set = set(train)
386
  return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
387
 
388
+ def internal_diversity(gen):
389
+ return 1 - average_agg_tanimoto(gen, gen, agg="mean")
390
 
391
  def average_agg_tanimoto(stock_vecs, gen_vecs,
392
  batch_size=5000, agg='max',