Spaces:
Running
Running
Update utils.py
Browse files
utils.py
CHANGED
@@ -5,7 +5,7 @@ from rdkit.Chem import AllChem
|
|
5 |
from rdkit.Chem import Draw
|
6 |
import os
|
7 |
import numpy as np
|
8 |
-
import seaborn as sns
|
9 |
import matplotlib.pyplot as plt
|
10 |
from matplotlib.lines import Line2D
|
11 |
from rdkit import RDLogger
|
@@ -46,6 +46,7 @@ class Metrics(object):
|
|
46 |
|
47 |
return (np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean()
|
48 |
|
|
|
49 |
def sim_reward(mol_gen, fps_r):
|
50 |
|
51 |
gen_scaf = []
|
@@ -152,6 +153,7 @@ def sample_z_edge(batch_size, vertexes, edges):
|
|
152 |
|
153 |
return np.random.normal(0,1, size=(batch_size, vertexes, vertexes, edges)) # 128, 9, 9, 5
|
154 |
|
|
|
155 |
def sample_z( batch_size, z_dim):
|
156 |
|
157 |
''' Random noise. '''
|
@@ -176,10 +178,7 @@ def mol_sample(sample_directory, model_name, mol, edges, nodes, idx, i):
|
|
176 |
print("Valid matrices and smiles are saved")
|
177 |
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, save_path):
|
183 |
|
184 |
gen_smiles = []
|
185 |
for line in mols:
|
@@ -222,20 +221,20 @@ def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, sav
|
|
222 |
#m1 =all_scores_chem(fake_mol, mols, vert, norm=True)
|
223 |
#m0.update(m1)
|
224 |
|
225 |
-
|
|
|
|
|
226 |
|
227 |
#m0 = {k: np.array(v).mean() for k, v in m0.items()}
|
228 |
#loss.update(m0)
|
229 |
loss.update({'Valid': valid})
|
230 |
-
loss.update({'Unique
|
231 |
loss.update({'Novel': novel})
|
232 |
#loss.update({'QED': statistics.mean(qed)})
|
233 |
#loss.update({'SA': statistics.mean(sa)})
|
234 |
#loss.update({'LogP': statistics.mean(logp)})
|
235 |
#loss.update({'IntDiv': IntDiv})
|
236 |
|
237 |
-
#wandb.log({"maxlen": maxlen})
|
238 |
-
|
239 |
for tag, value in loss.items():
|
240 |
|
241 |
log += ", {}: {:.4f}".format(tag, value)
|
@@ -246,24 +245,23 @@ def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, sav
|
|
246 |
print("\n")
|
247 |
|
248 |
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
attentions_pos =
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
axes[i].
|
264 |
-
|
265 |
-
pltsavedir
|
266 |
-
plt.savefig(os.path.join(pltsavedir, "attn" + model + "_" + dataset_name + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
267 |
|
268 |
|
269 |
def plot_grad_flow(named_parameters, model, iter, epoch):
|
@@ -298,36 +296,8 @@ def plot_grad_flow(named_parameters, model, iter, epoch):
|
|
298 |
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
|
299 |
pltsavedir = "/home/atabey/gradients/tryout"
|
300 |
plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
301 |
-
|
302 |
-
"""
|
303 |
-
def _genDegree():
|
304 |
|
305 |
-
|
306 |
-
dataset is used.
|
307 |
-
Can be called without arguments and saves the tensor for later use. If tensor was created
|
308 |
-
before, it just loads the degree tensor.
|
309 |
-
'''
|
310 |
-
|
311 |
-
degree_path = os.path.join(self.degree_dir, self.dataset_name + '-degree.pt')
|
312 |
-
if not os.path.exists(degree_path):
|
313 |
-
|
314 |
-
|
315 |
-
max_degree = -1
|
316 |
-
for data in self.dataset:
|
317 |
-
d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
|
318 |
-
max_degree = max(max_degree, int(d.max()))
|
319 |
-
|
320 |
-
# Compute the in-degree histogram tensor
|
321 |
-
deg = torch.zeros(max_degree + 1, dtype=torch.long)
|
322 |
-
for data in self.dataset:
|
323 |
-
d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
|
324 |
-
deg += torch.bincount(d, minlength=deg.numel())
|
325 |
-
torch.save(deg, 'DrugGEN/data/' + self.dataset_name + '-degree.pt')
|
326 |
-
else:
|
327 |
-
deg = torch.load(degree_path, map_location=lambda storage, loc: storage)
|
328 |
-
|
329 |
-
return deg
|
330 |
-
"""
|
331 |
def get_mol(smiles_or_mol):
|
332 |
'''
|
333 |
Loads SMILES/molecule into RDKit's object
|
@@ -345,6 +315,7 @@ def get_mol(smiles_or_mol):
|
|
345 |
return mol
|
346 |
return smiles_or_mol
|
347 |
|
|
|
348 |
def mapper(n_jobs):
|
349 |
'''
|
350 |
Returns function for map call.
|
@@ -369,6 +340,8 @@ def mapper(n_jobs):
|
|
369 |
|
370 |
return _mapper
|
371 |
return n_jobs.map
|
|
|
|
|
372 |
def remove_invalid(gen, canonize=True, n_jobs=1):
|
373 |
"""
|
374 |
Removes invalid molecules from the dataset
|
@@ -378,6 +351,8 @@ def remove_invalid(gen, canonize=True, n_jobs=1):
|
|
378 |
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
|
379 |
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
|
380 |
x is not None]
|
|
|
|
|
381 |
def fraction_valid(gen, n_jobs=1):
|
382 |
"""
|
383 |
Computes a number of valid molecules
|
@@ -387,11 +362,15 @@ def fraction_valid(gen, n_jobs=1):
|
|
387 |
"""
|
388 |
gen = mapper(n_jobs)(get_mol, gen)
|
389 |
return 1 - gen.count(None) / len(gen)
|
|
|
|
|
390 |
def canonic_smiles(smiles_or_mol):
|
391 |
mol = get_mol(smiles_or_mol)
|
392 |
if mol is None:
|
393 |
return None
|
394 |
return Chem.MolToSmiles(mol)
|
|
|
|
|
395 |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
396 |
"""
|
397 |
Computes a number of unique molecules
|
@@ -410,9 +389,11 @@ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
|
410 |
gen = gen[:k]
|
411 |
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
|
412 |
if None in canonic and check_validity:
|
413 |
-
|
|
|
414 |
return 0 if len(gen) == 0 else len(canonic) / len(gen)
|
415 |
|
|
|
416 |
def novelty(gen, train, n_jobs=1):
|
417 |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
418 |
gen_smiles_set = set(gen_smiles) - {None}
|
@@ -420,7 +401,6 @@ def novelty(gen, train, n_jobs=1):
|
|
420 |
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
421 |
|
422 |
|
423 |
-
|
424 |
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
425 |
batch_size=5000, agg='max',
|
426 |
device='cpu', p=1):
|
|
|
5 |
from rdkit.Chem import Draw
|
6 |
import os
|
7 |
import numpy as np
|
8 |
+
#import seaborn as sns
|
9 |
import matplotlib.pyplot as plt
|
10 |
from matplotlib.lines import Line2D
|
11 |
from rdkit import RDLogger
|
|
|
46 |
|
47 |
return (np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean()
|
48 |
|
49 |
+
|
50 |
def sim_reward(mol_gen, fps_r):
|
51 |
|
52 |
gen_scaf = []
|
|
|
153 |
|
154 |
return np.random.normal(0,1, size=(batch_size, vertexes, vertexes, edges)) # 128, 9, 9, 5
|
155 |
|
156 |
+
|
157 |
def sample_z( batch_size, z_dim):
|
158 |
|
159 |
''' Random noise. '''
|
|
|
178 |
print("Valid matrices and smiles are saved")
|
179 |
|
180 |
|
181 |
+
def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, save_path, get_maxlen=False):
|
|
|
|
|
|
|
182 |
|
183 |
gen_smiles = []
|
184 |
for line in mols:
|
|
|
221 |
#m1 =all_scores_chem(fake_mol, mols, vert, norm=True)
|
222 |
#m0.update(m1)
|
223 |
|
224 |
+
if get_maxlen:
|
225 |
+
maxlen = Metrics.max_component(mols, 45)
|
226 |
+
loss.update({"MaxLen": maxlen})
|
227 |
|
228 |
#m0 = {k: np.array(v).mean() for k, v in m0.items()}
|
229 |
#loss.update(m0)
|
230 |
loss.update({'Valid': valid})
|
231 |
+
loss.update({'Unique': unique})
|
232 |
loss.update({'Novel': novel})
|
233 |
#loss.update({'QED': statistics.mean(qed)})
|
234 |
#loss.update({'SA': statistics.mean(sa)})
|
235 |
#loss.update({'LogP': statistics.mean(logp)})
|
236 |
#loss.update({'IntDiv': IntDiv})
|
237 |
|
|
|
|
|
238 |
for tag, value in loss.items():
|
239 |
|
240 |
log += ", {}: {:.4f}".format(tag, value)
|
|
|
245 |
print("\n")
|
246 |
|
247 |
|
248 |
+
#def plot_attn(dataset_name, heads,attn_w, model, iter, epoch):
|
249 |
+
#
|
250 |
+
# cols = 4
|
251 |
+
# rows = int(heads/cols)
|
252 |
+
#
|
253 |
+
# fig, axes = plt.subplots( rows,cols, figsize = (30, 14))
|
254 |
+
# axes = axes.flat
|
255 |
+
# attentions_pos = attn_w[0]
|
256 |
+
# attentions_pos = attentions_pos.cpu().detach().numpy()
|
257 |
+
# for i,att in enumerate(attentions_pos):
|
258 |
+
#
|
259 |
+
# #im = axes[i].imshow(att, cmap='gray')
|
260 |
+
# sns.heatmap(att,vmin = 0, vmax = 1,ax = axes[i])
|
261 |
+
# axes[i].set_title(f'head - {i} ')
|
262 |
+
# axes[i].set_ylabel('layers')
|
263 |
+
# pltsavedir = "/home/atabey/attn/second"
|
264 |
+
# plt.savefig(os.path.join(pltsavedir, "attn" + model + "_" + dataset_name + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
|
|
265 |
|
266 |
|
267 |
def plot_grad_flow(named_parameters, model, iter, epoch):
|
|
|
296 |
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
|
297 |
pltsavedir = "/home/atabey/gradients/tryout"
|
298 |
plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
|
|
|
|
|
|
299 |
|
300 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
def get_mol(smiles_or_mol):
|
302 |
'''
|
303 |
Loads SMILES/molecule into RDKit's object
|
|
|
315 |
return mol
|
316 |
return smiles_or_mol
|
317 |
|
318 |
+
|
319 |
def mapper(n_jobs):
|
320 |
'''
|
321 |
Returns function for map call.
|
|
|
340 |
|
341 |
return _mapper
|
342 |
return n_jobs.map
|
343 |
+
|
344 |
+
|
345 |
def remove_invalid(gen, canonize=True, n_jobs=1):
|
346 |
"""
|
347 |
Removes invalid molecules from the dataset
|
|
|
351 |
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
|
352 |
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
|
353 |
x is not None]
|
354 |
+
|
355 |
+
|
356 |
def fraction_valid(gen, n_jobs=1):
|
357 |
"""
|
358 |
Computes a number of valid molecules
|
|
|
362 |
"""
|
363 |
gen = mapper(n_jobs)(get_mol, gen)
|
364 |
return 1 - gen.count(None) / len(gen)
|
365 |
+
|
366 |
+
|
367 |
def canonic_smiles(smiles_or_mol):
|
368 |
mol = get_mol(smiles_or_mol)
|
369 |
if mol is None:
|
370 |
return None
|
371 |
return Chem.MolToSmiles(mol)
|
372 |
+
|
373 |
+
|
374 |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
375 |
"""
|
376 |
Computes a number of unique molecules
|
|
|
389 |
gen = gen[:k]
|
390 |
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
|
391 |
if None in canonic and check_validity:
|
392 |
+
canonic = [i for i in canonic if i is not None]
|
393 |
+
#raise ValueError("Invalid molecule passed to unique@k")
|
394 |
return 0 if len(gen) == 0 else len(canonic) / len(gen)
|
395 |
|
396 |
+
|
397 |
def novelty(gen, train, n_jobs=1):
|
398 |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
399 |
gen_smiles_set = set(gen_smiles) - {None}
|
|
|
401 |
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
402 |
|
403 |
|
|
|
404 |
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
405 |
batch_size=5000, agg='max',
|
406 |
device='cpu', p=1):
|