osbm commited on
Commit
74b8fa1
·
1 Parent(s): 9f5b1d1

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +38 -58
utils.py CHANGED
@@ -5,7 +5,7 @@ from rdkit.Chem import AllChem
5
  from rdkit.Chem import Draw
6
  import os
7
  import numpy as np
8
- import seaborn as sns
9
  import matplotlib.pyplot as plt
10
  from matplotlib.lines import Line2D
11
  from rdkit import RDLogger
@@ -46,6 +46,7 @@ class Metrics(object):
46
 
47
  return (np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean()
48
 
 
49
  def sim_reward(mol_gen, fps_r):
50
 
51
  gen_scaf = []
@@ -152,6 +153,7 @@ def sample_z_edge(batch_size, vertexes, edges):
152
 
153
  return np.random.normal(0,1, size=(batch_size, vertexes, vertexes, edges)) # 128, 9, 9, 5
154
 
 
155
  def sample_z( batch_size, z_dim):
156
 
157
  ''' Random noise. '''
@@ -176,10 +178,7 @@ def mol_sample(sample_directory, model_name, mol, edges, nodes, idx, i):
176
  print("Valid matrices and smiles are saved")
177
 
178
 
179
-
180
-
181
-
182
- def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, save_path):
183
 
184
  gen_smiles = []
185
  for line in mols:
@@ -222,20 +221,20 @@ def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, sav
222
  #m1 =all_scores_chem(fake_mol, mols, vert, norm=True)
223
  #m0.update(m1)
224
 
225
- #maxlen = MolecularMetrics.max_component(mols, 45)
 
 
226
 
227
  #m0 = {k: np.array(v).mean() for k, v in m0.items()}
228
  #loss.update(m0)
229
  loss.update({'Valid': valid})
230
- loss.update({'Unique@{}'.format(k): unique})
231
  loss.update({'Novel': novel})
232
  #loss.update({'QED': statistics.mean(qed)})
233
  #loss.update({'SA': statistics.mean(sa)})
234
  #loss.update({'LogP': statistics.mean(logp)})
235
  #loss.update({'IntDiv': IntDiv})
236
 
237
- #wandb.log({"maxlen": maxlen})
238
-
239
  for tag, value in loss.items():
240
 
241
  log += ", {}: {:.4f}".format(tag, value)
@@ -246,24 +245,23 @@ def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, sav
246
  print("\n")
247
 
248
 
249
-
250
- def plot_attn(dataset_name, heads,attn_w, model, iter, epoch):
251
-
252
- cols = 4
253
- rows = int(heads/cols)
254
-
255
- fig, axes = plt.subplots( rows,cols, figsize = (30, 14))
256
- axes = axes.flat
257
- attentions_pos = attn_w[0]
258
- attentions_pos = attentions_pos.cpu().detach().numpy()
259
- for i,att in enumerate(attentions_pos):
260
-
261
- #im = axes[i].imshow(att, cmap='gray')
262
- sns.heatmap(att,vmin = 0, vmax = 1,ax = axes[i])
263
- axes[i].set_title(f'head - {i} ')
264
- axes[i].set_ylabel('layers')
265
- pltsavedir = "/home/atabey/attn/second"
266
- plt.savefig(os.path.join(pltsavedir, "attn" + model + "_" + dataset_name + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
267
 
268
 
269
  def plot_grad_flow(named_parameters, model, iter, epoch):
@@ -298,36 +296,8 @@ def plot_grad_flow(named_parameters, model, iter, epoch):
298
  Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
299
  pltsavedir = "/home/atabey/gradients/tryout"
300
  plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
301
-
302
- """
303
- def _genDegree():
304
 
305
- ''' Generates the Degree distribution tensor for PNA, should be used everytime a different
306
- dataset is used.
307
- Can be called without arguments and saves the tensor for later use. If tensor was created
308
- before, it just loads the degree tensor.
309
- '''
310
-
311
- degree_path = os.path.join(self.degree_dir, self.dataset_name + '-degree.pt')
312
- if not os.path.exists(degree_path):
313
-
314
-
315
- max_degree = -1
316
- for data in self.dataset:
317
- d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
318
- max_degree = max(max_degree, int(d.max()))
319
-
320
- # Compute the in-degree histogram tensor
321
- deg = torch.zeros(max_degree + 1, dtype=torch.long)
322
- for data in self.dataset:
323
- d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
324
- deg += torch.bincount(d, minlength=deg.numel())
325
- torch.save(deg, 'DrugGEN/data/' + self.dataset_name + '-degree.pt')
326
- else:
327
- deg = torch.load(degree_path, map_location=lambda storage, loc: storage)
328
-
329
- return deg
330
- """
331
  def get_mol(smiles_or_mol):
332
  '''
333
  Loads SMILES/molecule into RDKit's object
@@ -345,6 +315,7 @@ def get_mol(smiles_or_mol):
345
  return mol
346
  return smiles_or_mol
347
 
 
348
  def mapper(n_jobs):
349
  '''
350
  Returns function for map call.
@@ -369,6 +340,8 @@ def mapper(n_jobs):
369
 
370
  return _mapper
371
  return n_jobs.map
 
 
372
  def remove_invalid(gen, canonize=True, n_jobs=1):
373
  """
374
  Removes invalid molecules from the dataset
@@ -378,6 +351,8 @@ def remove_invalid(gen, canonize=True, n_jobs=1):
378
  return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
379
  return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
380
  x is not None]
 
 
381
  def fraction_valid(gen, n_jobs=1):
382
  """
383
  Computes a number of valid molecules
@@ -387,11 +362,15 @@ def fraction_valid(gen, n_jobs=1):
387
  """
388
  gen = mapper(n_jobs)(get_mol, gen)
389
  return 1 - gen.count(None) / len(gen)
 
 
390
  def canonic_smiles(smiles_or_mol):
391
  mol = get_mol(smiles_or_mol)
392
  if mol is None:
393
  return None
394
  return Chem.MolToSmiles(mol)
 
 
395
  def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
396
  """
397
  Computes a number of unique molecules
@@ -410,9 +389,11 @@ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
410
  gen = gen[:k]
411
  canonic = set(mapper(n_jobs)(canonic_smiles, gen))
412
  if None in canonic and check_validity:
413
- raise ValueError("Invalid molecule passed to unique@k")
 
414
  return 0 if len(gen) == 0 else len(canonic) / len(gen)
415
 
 
416
  def novelty(gen, train, n_jobs=1):
417
  gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
418
  gen_smiles_set = set(gen_smiles) - {None}
@@ -420,7 +401,6 @@ def novelty(gen, train, n_jobs=1):
420
  return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
421
 
422
 
423
-
424
  def average_agg_tanimoto(stock_vecs, gen_vecs,
425
  batch_size=5000, agg='max',
426
  device='cpu', p=1):
 
5
  from rdkit.Chem import Draw
6
  import os
7
  import numpy as np
8
+ #import seaborn as sns
9
  import matplotlib.pyplot as plt
10
  from matplotlib.lines import Line2D
11
  from rdkit import RDLogger
 
46
 
47
  return (np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean()
48
 
49
+
50
  def sim_reward(mol_gen, fps_r):
51
 
52
  gen_scaf = []
 
153
 
154
  return np.random.normal(0,1, size=(batch_size, vertexes, vertexes, edges)) # 128, 9, 9, 5
155
 
156
+
157
  def sample_z( batch_size, z_dim):
158
 
159
  ''' Random noise. '''
 
178
  print("Valid matrices and smiles are saved")
179
 
180
 
181
+ def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, save_path, get_maxlen=False):
 
 
 
182
 
183
  gen_smiles = []
184
  for line in mols:
 
221
  #m1 =all_scores_chem(fake_mol, mols, vert, norm=True)
222
  #m0.update(m1)
223
 
224
+ if get_maxlen:
225
+ maxlen = Metrics.max_component(mols, 45)
226
+ loss.update({"MaxLen": maxlen})
227
 
228
  #m0 = {k: np.array(v).mean() for k, v in m0.items()}
229
  #loss.update(m0)
230
  loss.update({'Valid': valid})
231
+ loss.update({'Unique': unique})
232
  loss.update({'Novel': novel})
233
  #loss.update({'QED': statistics.mean(qed)})
234
  #loss.update({'SA': statistics.mean(sa)})
235
  #loss.update({'LogP': statistics.mean(logp)})
236
  #loss.update({'IntDiv': IntDiv})
237
 
 
 
238
  for tag, value in loss.items():
239
 
240
  log += ", {}: {:.4f}".format(tag, value)
 
245
  print("\n")
246
 
247
 
248
+ #def plot_attn(dataset_name, heads,attn_w, model, iter, epoch):
249
+ #
250
+ # cols = 4
251
+ # rows = int(heads/cols)
252
+ #
253
+ # fig, axes = plt.subplots( rows,cols, figsize = (30, 14))
254
+ # axes = axes.flat
255
+ # attentions_pos = attn_w[0]
256
+ # attentions_pos = attentions_pos.cpu().detach().numpy()
257
+ # for i,att in enumerate(attentions_pos):
258
+ #
259
+ # #im = axes[i].imshow(att, cmap='gray')
260
+ # sns.heatmap(att,vmin = 0, vmax = 1,ax = axes[i])
261
+ # axes[i].set_title(f'head - {i} ')
262
+ # axes[i].set_ylabel('layers')
263
+ # pltsavedir = "/home/atabey/attn/second"
264
+ # plt.savefig(os.path.join(pltsavedir, "attn" + model + "_" + dataset_name + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
 
265
 
266
 
267
  def plot_grad_flow(named_parameters, model, iter, epoch):
 
296
  Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
297
  pltsavedir = "/home/atabey/gradients/tryout"
298
  plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
 
 
 
299
 
300
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  def get_mol(smiles_or_mol):
302
  '''
303
  Loads SMILES/molecule into RDKit's object
 
315
  return mol
316
  return smiles_or_mol
317
 
318
+
319
  def mapper(n_jobs):
320
  '''
321
  Returns function for map call.
 
340
 
341
  return _mapper
342
  return n_jobs.map
343
+
344
+
345
  def remove_invalid(gen, canonize=True, n_jobs=1):
346
  """
347
  Removes invalid molecules from the dataset
 
351
  return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
352
  return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
353
  x is not None]
354
+
355
+
356
  def fraction_valid(gen, n_jobs=1):
357
  """
358
  Computes a number of valid molecules
 
362
  """
363
  gen = mapper(n_jobs)(get_mol, gen)
364
  return 1 - gen.count(None) / len(gen)
365
+
366
+
367
  def canonic_smiles(smiles_or_mol):
368
  mol = get_mol(smiles_or_mol)
369
  if mol is None:
370
  return None
371
  return Chem.MolToSmiles(mol)
372
+
373
+
374
  def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
375
  """
376
  Computes a number of unique molecules
 
389
  gen = gen[:k]
390
  canonic = set(mapper(n_jobs)(canonic_smiles, gen))
391
  if None in canonic and check_validity:
392
+ canonic = [i for i in canonic if i is not None]
393
+ #raise ValueError("Invalid molecule passed to unique@k")
394
  return 0 if len(gen) == 0 else len(canonic) / len(gen)
395
 
396
+
397
  def novelty(gen, train, n_jobs=1):
398
  gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
399
  gen_smiles_set = set(gen_smiles) - {None}
 
401
  return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
402
 
403
 
 
404
  def average_agg_tanimoto(stock_vecs, gen_vecs,
405
  batch_size=5000, agg='max',
406
  device='cpu', p=1):