jiaxianustc commited on
Commit
7d69eaa
·
1 Parent(s): e749e85
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. UltraFlow/commons/__init__.py +0 -5
  2. UltraFlow/commons/dock_utils.py +0 -355
  3. UltraFlow/commons/geomop.py +0 -529
  4. UltraFlow/commons/get_free_gpu.py +0 -78
  5. UltraFlow/commons/loss_weight.pkl +0 -3
  6. UltraFlow/commons/metrics.py +0 -315
  7. UltraFlow/commons/torch_prepare.py +0 -156
  8. UltraFlow/commons/visualize.py +0 -364
  9. UltraFlow/data/INDEX_general_PL_data.2016 +0 -0
  10. UltraFlow/data/INDEX_general_PL_data.2020 +0 -0
  11. UltraFlow/data/INDEX_refined_data.2020 +0 -0
  12. UltraFlow/data/chembl/P49841/P49841_valid_chains.pdb +0 -0
  13. UltraFlow/data/chembl/P49841/P49841_valid_pvalue.smi +0 -0
  14. UltraFlow/data/chembl/P49841/P49841_valid_smiles.smi +0 -0
  15. UltraFlow/data/chembl/P49841/visualize_dir/total_vs.sdf +0 -0
  16. UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_chains.pdb +0 -0
  17. UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_pvalue.smi +0 -0
  18. UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_smiles.smi +0 -0
  19. UltraFlow/data/chembl/Q9Y233/visualize_dir/total_vs.sdf +0 -0
  20. UltraFlow/data/core_set +0 -0
  21. UltraFlow/data/csar_2016 +0 -0
  22. UltraFlow/data/csar_2020 +0 -0
  23. UltraFlow/data/csar_new_2016 +0 -0
  24. UltraFlow/data/horizontal_test.pkl +0 -0
  25. UltraFlow/data/horizontal_train.pkl +0 -0
  26. UltraFlow/data/horizontal_valid.pkl +0 -0
  27. UltraFlow/data/pdb2016_total +0 -0
  28. UltraFlow/data/pdb_after_2016 +0 -0
  29. UltraFlow/data/pdbbind2016_general_gign_train +0 -0
  30. UltraFlow/data/pdbbind2016_general_gign_valid +0 -0
  31. UltraFlow/data/pdbbind2016_general_train +0 -0
  32. UltraFlow/data/pdbbind2016_general_valid +0 -0
  33. UltraFlow/data/pdbbind2016_test +0 -0
  34. UltraFlow/data/pdbbind2016_train +0 -0
  35. UltraFlow/data/pdbbind2016_train_M +0 -0
  36. UltraFlow/data/pdbbind2016_valid +0 -0
  37. UltraFlow/data/pdbbind2016_valid_M +0 -0
  38. UltraFlow/data/pdbbind2020_finetune_test +0 -0
  39. UltraFlow/data/pdbbind2020_finetune_train +0 -0
  40. UltraFlow/data/pdbbind2020_finetune_valid +0 -0
  41. UltraFlow/data/pdbbind2020_vstrain1 +0 -0
  42. UltraFlow/data/pdbbind2020_vstrain2 +0 -0
  43. UltraFlow/data/pdbbind2020_vstrain3 +0 -0
  44. UltraFlow/data/pdbbind2020_vsvalid1 +0 -0
  45. UltraFlow/data/pdbbind2020_vsvalid2 +0 -0
  46. UltraFlow/data/pdbbind2020_vsvalid3 +0 -0
  47. UltraFlow/data/pdbbind_2020_casf_test +0 -0
  48. UltraFlow/data/pdbbind_2020_casf_train +0 -0
  49. UltraFlow/data/pdbbind_2020_casf_valid +0 -0
  50. UltraFlow/data/tankbind_vtrain +0 -0
UltraFlow/commons/__init__.py CHANGED
@@ -1,7 +1,2 @@
1
  from .utils import *
2
- from .torch_prepare import *
3
  from .process_mols import *
4
- from .metrics import *
5
- from .geomop import *
6
- from .visualize import *
7
- from .dock_utils import *
 
1
  from .utils import *
 
2
  from .process_mols import *
 
 
 
 
UltraFlow/commons/dock_utils.py DELETED
@@ -1,355 +0,0 @@
1
- import os
2
- from collections import defaultdict
3
- import numpy as np
4
- import torch
5
- from openbabel import pybel
6
- from statistics import stdev
7
- from time import time
8
- from .utils import pmap_multi
9
- import pandas as pd
10
- from tqdm import tqdm
11
-
12
- MGLTols_PYTHON = '/apdcephfs/private_jiaxianyan/dock/mgltools_x86_64Linux2_1.5.7/bin/python2.7'
13
- Prepare_Ligand = '/apdcephfs/private_jiaxianyan/dock/mgltools_x86_64Linux2_1.5.7/MGLToolsPckgs/AutoDockTools/Utilities24/prepare_ligand4.py'
14
- Prepare_Receptor = '/apdcephfs/private_jiaxianyan/dock/mgltools_x86_64Linux2_1.5.7/MGLToolsPckgs/AutoDockTools/Utilities24/prepare_receptor4.py'
15
- SMINA = '/apdcephfs/private_jiaxianyan/dock/smina'
16
-
17
- def read_matric(matric_file_path):
18
- with open(matric_file_path) as f:
19
- lines = f.read().strip().split('\n')
20
- rmsd, centroid = float(lines[0].split(':')[1]), float(lines[1].split(':')[1])
21
- return rmsd, centroid
22
-
23
- def mol2_add_atom_index_to_atom_name(mol2_file_path):
24
- MOL_list = [x for x in open(mol2_file_path, 'r')]
25
- idx = [i for i, x in enumerate(MOL_list) if x.startswith('@')]
26
- block = MOL_list[idx[1] + 1:idx[2]]
27
- block = [x.split() for x in block]
28
-
29
- block_new = []
30
- atom_count = defaultdict(int)
31
- for i in block:
32
- at = i[5].strip().split('.')[0]
33
- if 'H' not in at:
34
- atom_count[at] += 1
35
- count = atom_count[at]
36
- at_new = at + str(count)
37
- at_new = at_new.rjust(4)
38
- block_new.append([i[0], at_new] + i[2:])
39
- else:
40
- block_new.append(i)
41
-
42
- block_new = ['\t'.join(x) + '\n' for x in block_new]
43
- MOL_list_new = MOL_list[:idx[1] + 1] + block_new + MOL_list[idx[2]:]
44
- with open(mol2_file_path, 'w') as f:
45
- for line in MOL_list_new:
46
- f.write(line)
47
- return
48
-
49
- def prepare_dock_file(pdb_name, config):
50
- visualize_dir = os.path.join(config.train.save_path, 'visualize_dir')
51
- post_align_sdf = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.sdf')
52
- post_align_mol2 = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.mol2')
53
- post_align_pdbqt = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.pdbqt')
54
-
55
- # mgltools preprocess
56
- cmd = f'cd {visualize_dir}'
57
- cmd += f' && obabel -i sdf {post_align_sdf} -o mol2 -O {post_align_mol2}'
58
-
59
- if not os.path.exists(post_align_mol2):
60
- os.system(cmd)
61
- mol2_add_atom_index_to_atom_name(post_align_mol2)
62
-
63
- cmd = f'cd {visualize_dir}'
64
- cmd += f' && {MGLTols_PYTHON} {Prepare_Ligand} -l {post_align_mol2}'
65
-
66
- if not os.path.exists(post_align_pdbqt):
67
- os.system(cmd)
68
- # cmd = f'obabel -i mol2 {post_align_mol2} -o pdbqt -O {post_align_pdbqt}'
69
- # os.system(cmd)
70
-
71
- return
72
-
73
- def get_mol2_atom_name(mol2_file_path):
74
- MOL_list = [x for x in open(mol2_file_path, 'r')]
75
- idx = [i for i, x in enumerate(MOL_list) if x.startswith('@')]
76
- block = MOL_list[idx[1] + 1:idx[2]]
77
- block = [x.split() for x in block]
78
-
79
- atom_names = []
80
-
81
- for i in block:
82
- at = i[1].strip()
83
- atom_names.append(at)
84
- return atom_names
85
-
86
- def align_dock_name_and_target_name(dock_lig_atom_names, target_lig_atom_names):
87
- dock_lig_atom_index_in_target_lig = []
88
- target_atom_name_dict = {}
89
- for index, atom_name in enumerate(target_lig_atom_names):
90
- try:
91
- assert atom_name not in target_atom_name_dict.keys()
92
- except:
93
- raise ValueError(atom_name,'has appeared before')
94
- target_atom_name_dict[atom_name] = index
95
-
96
- dock_lig_atom_name_appears_dict = defaultdict(int)
97
- for atom_name in dock_lig_atom_names:
98
- try:
99
- assert atom_name not in dock_lig_atom_name_appears_dict.keys()
100
- except:
101
- raise ValueError(atom_name,'has appeared before')
102
- dock_lig_atom_name_appears_dict[atom_name] += 1
103
- try:
104
- dock_lig_atom_index_in_target_lig.append(target_atom_name_dict[atom_name])
105
- except:
106
- dock_lig_atom_index_in_target_lig.append(target_atom_name_dict[atom_name+'1'])
107
-
108
- return dock_lig_atom_index_in_target_lig
109
-
110
- def smina_dock_result_rmsd(pdb_name, config):
111
- # target path
112
- target_lig_mol2 = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_ligand.mol2')
113
-
114
- # get target coords
115
- target_m_lig = next(pybel.readfile('mol2', target_lig_mol2))
116
- target_lig_coords = [atom.coords for atom in target_m_lig if atom.atomicnum > 1]
117
- target_lig_coords = np.array(target_lig_coords, dtype=np.float32) # np.array, [n, 3]
118
- target_lig_center = target_lig_coords.mean(axis=0) # np.array, [3]
119
-
120
- # get target atom names
121
- visualize_dir = os.path.join(config.train.save_path, 'visualize_dir')
122
- lig_init_mol2 = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.mol2')
123
- target_atom_name_reference_lig = next(pybel.readfile('mol2', lig_init_mol2))
124
- target_lig_atom_names = get_mol2_atom_name(lig_init_mol2)
125
- target_lig_atom_names_no_h = [atom_name for atom, atom_name in zip(target_atom_name_reference_lig, target_lig_atom_names) if atom.atomicnum > 1]
126
-
127
- # get init coords
128
- coords_before_minimized = [atom.coords for atom in target_atom_name_reference_lig if atom.atomicnum > 1]
129
- coords_before_minimized = np.array(coords_before_minimized, dtype=np.float32) # np.array, [n, 3]
130
-
131
- # get smina minimized coords
132
- dock_lig_mol2_path = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.mol2')
133
- dock_m_lig = next(pybel.readfile('mol2', dock_lig_mol2_path))
134
- dock_lig_coords = [atom.coords for atom in dock_m_lig if atom.atomicnum > 1]
135
- dock_lig_coords = np.array(dock_lig_coords, dtype=np.float32) # np.array, [n, 3]
136
- dock_lig_center = dock_lig_coords.mean(axis=0) # np.array, [3]
137
-
138
- # get atom names
139
- dock_lig_atom_names = get_mol2_atom_name(dock_lig_mol2_path)
140
- dock_lig_atom_names_no_h = [atom_name for atom, atom_name in zip(dock_m_lig, dock_lig_atom_names) if atom.atomicnum > 1]
141
- dock_lig_atom_index_in_target_lig = align_dock_name_and_target_name(dock_lig_atom_names_no_h, target_lig_atom_names_no_h)
142
-
143
- dock_lig_coords_target_align = np.zeros([len(dock_lig_atom_index_in_target_lig),3], dtype=np.float32)
144
- for atom_coords, atom_index_in_target_lig in zip(dock_lig_coords, dock_lig_atom_index_in_target_lig):
145
- dock_lig_coords_target_align[atom_index_in_target_lig] = atom_coords
146
-
147
- # rmsd
148
- error_lig_coords = dock_lig_coords_target_align - target_lig_coords
149
- rmsd = np.sqrt((error_lig_coords ** 2).sum(axis=1, keepdims=True).mean(axis=0))
150
-
151
- # centroid
152
- error_center_coords = dock_lig_center - target_lig_center
153
- centorid_d = np.sqrt( (error_center_coords ** 2).sum() )
154
-
155
- # get rmsd after minimized
156
- error_lig_coords_after_minimized = dock_lig_coords_target_align - coords_before_minimized
157
- rmsd_after_minimized = np.sqrt((error_lig_coords_after_minimized ** 2).sum(axis=1, keepdims=True).mean(axis=0))
158
-
159
- return float(rmsd), float(centorid_d), float(rmsd_after_minimized)
160
-
161
- def get_matric_dict(rmsds, centroids):
162
- rmsd_mean = sum(rmsds)/len(rmsds)
163
- centroid_mean = sum(centroids) / len(centroids)
164
- rmsd_std = stdev(rmsds)
165
- centroid_std = stdev(centroids)
166
-
167
- # rmsd < 2
168
- count = torch.tensor(rmsds) < 2.0
169
- rmsd_less_than_2 = 100 * count.sum().item() / len(count)
170
-
171
- # rmsd < 2
172
- count = torch.tensor(rmsds) < 5.0
173
- rmsd_less_than_5 = 100 * count.sum().item() / len(count)
174
-
175
- # centorid < 2
176
- count = torch.tensor(centroids) < 2.0
177
- centroid_less_than_2 = 100 * count.sum().item() / len(count)
178
-
179
- # centorid < 5
180
- count = torch.tensor(centroids) < 5.0
181
- centroid_less_than_5 = 100 * count.sum().item() / len(count)
182
-
183
- metrics_dict = {'rmsd mean': rmsd_mean, 'rmsd std': rmsd_std, 'centroid mean': centroid_mean, 'centroid std': centroid_std,
184
- 'rmsd less than 2': rmsd_less_than_2, 'rmsd less than 5':rmsd_less_than_5,
185
- 'centroid less than 2': centroid_less_than_2, 'centroid less than 5': centroid_less_than_5}
186
- return metrics_dict
187
-
188
- def run_smina_dock(pdb_name ,config):
189
-
190
- r_pdbqt = os.path.join(config.test_set.dataset_path, pdb_name, f'{pdb_name}_protein_processed.pdbqt')
191
- l_pdbqt = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}.pdbqt')
192
- out_mol2 = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.mol2')
193
- log_file = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.log')
194
- cmd = f'{SMINA}' \
195
- f' --receptor {r_pdbqt}' \
196
- f' --ligand {l_pdbqt}' \
197
- f' --out {out_mol2}' \
198
- f' --log {log_file}' \
199
- f' --minimize'
200
- os.system(cmd)
201
-
202
- return
203
-
204
- def run_score_only(ligand_file, protein_file, out_log_file):
205
- cmd = f'{SMINA}' \
206
- f' --receptor {protein_file}' \
207
- f' --ligand {ligand_file}' \
208
- f' --score_only' \
209
- f' > {out_log_file}'
210
- os.system(cmd)
211
-
212
- with open(out_log_file, 'r') as f:
213
- lines = f.read().strip().split('\n')
214
- affinity_score = float(lines[21].split()[1])
215
-
216
- return affinity_score
217
-
218
- def run_smina_score_after_predict(config):
219
- pdb_name_list = config.test_set.names
220
- smina_score_list = []
221
- for pdb_name in tqdm(pdb_name_list):
222
- ligand_file = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_pred.sdf')
223
- protein_file = os.path.join(config.test_set.dataset_path, pdb_name, f'{pdb_name}_protein_processed.pdbqt')
224
- out_log_file = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_pred_smina_score.out')
225
- smina_score = run_score_only(ligand_file, protein_file, out_log_file)
226
- smina_score_list.append(smina_score)
227
-
228
- result_d = {'pdb_name':pdb_name_list, 'smina_score':smina_score_list}
229
- pd.DataFrame(result_d).to_csv(os.path.join(config.train.save_path, 'visualize_dir', 'pred_smina_score.csv'))
230
- return
231
-
232
- def run_smina_minimize_after_predict(config):
233
- minimize_time = 0
234
-
235
- pdb_name_list = config.test_set.names
236
-
237
- # pmap_multi(prepare_dock_file, zip(pdb_name_list, [config] * len(pdb_name_list)),
238
- # n_jobs=8, desc='mgltools preparing ...')
239
-
240
- rmsds_post_dock, centroids_post_dock = [], []
241
- rmsds_post, centroids_post = [], []
242
- rmsds, centroids = [], []
243
-
244
- rmsds_after_minimized = {'pdb_name':[], 'rmsd':[]}
245
-
246
- valid_pdb_name = []
247
- error_list = []
248
- # for pdb_name in tqdm(pdb_name_list):
249
- # try:
250
- # minimize_begin_time = time()
251
- # run_smina_dock(pdb_name, config)
252
- # minimize_time += time() - minimize_begin_time
253
- # rmsd_post_dock, centroid_post_dock, rmsd_after_minimized = smina_dock_result_rmsd(pdb_name, config)
254
- # rmsds_post_dock.append(rmsd_post_dock)
255
- # centroids_post_dock.append(centroid_post_dock)
256
- #
257
- # rmsds_after_minimized['pdb_name'].append(pdb_name)
258
- # rmsds_after_minimized['rmsd'].append(rmsd_after_minimized)
259
- # print(f'{pdb_name} smina minimized, rmsd: {rmsd_post_dock}, centroid: {centroid_post_dock}')
260
- #
261
- # text_matics = 'rmsd:{}\ncentroid_d:{}\n'.format(rmsd_post_dock, centroid_post_dock)
262
- # post_dock_matric_path = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_matrics_post_{config.train.align_method}_dock.txt')
263
- # with open(post_dock_matric_path, 'w') as f:
264
- # f.write(text_matics)
265
- #
266
- # # read matrics
267
- # post_matric_path = os.path.join(config.train.save_path, 'visualize_dir',
268
- # f'{pdb_name}_matrics_post_{config.train.align_method}.txt')
269
- #
270
- # matric_path = os.path.join(config.train.save_path, 'visualize_dir',
271
- # f'{pdb_name}_matrics.txt')
272
- # rmsd_post, centroid_post = read_matric(post_matric_path)
273
- # rmsds_post.append(rmsd_post)
274
- # centroids_post.append(centroid_post)
275
- #
276
- # rmsd, centroid = read_matric(matric_path)
277
- # rmsds.append(rmsd)
278
- # centroids.append(centroid)
279
- # valid_pdb_name.append(pdb_name)
280
- #
281
- # except:
282
- # print(f'{pdb_name} dock error!')
283
- # error_list.append(pdb_name)
284
- #
285
- dock_score_analysis(pdb_name_list, config)
286
-
287
- pd.DataFrame(rmsds_after_minimized).to_csv(os.path.join(config.train.save_path, 'visualize_dir', f'rmsd_after_smina_minimzed.csv'))
288
-
289
- matric_dict_post_dock = get_matric_dict(rmsds_post_dock, centroids_post_dock)
290
- matric_dict_post = get_matric_dict(rmsds_post, centroids_post)
291
- matric_dict = get_matric_dict(rmsds, centroids)
292
-
293
- matric_dict_post_dock_d = {'pdb_name': valid_pdb_name, 'rmsd': rmsds_post_dock, 'centroid': centroids_post_dock}
294
- pd.DataFrame(matric_dict_post_dock_d).to_csv(
295
- os.path.join(config.train.save_path, 'visualize_dir', 'matric_distribution_after_minimized.csv'))
296
-
297
- matric_str = ''
298
- for key in matric_dict_post_dock.keys():
299
- if key.endswith('mean') or key.endswith('std'):
300
- matric_str += '| post dock {}: {:.4f} '.format(key, matric_dict_post_dock[key])
301
- else:
302
- matric_str += '| post dock {}: {:.4f}% '.format(key, matric_dict_post_dock[key])
303
-
304
- for key in matric_dict_post.keys():
305
- if key.endswith('mean') or key.endswith('std'):
306
- matric_str += '| post {}: {:.4f} '.format(key, matric_dict_post[key])
307
- else:
308
- matric_str += '| post {}: {:.4f}% '.format(key, matric_dict_post[key])
309
-
310
- for key in matric_dict.keys():
311
- if key.endswith('mean') or key.endswith('std'):
312
- matric_str += '| {}: {:.4f} '.format(key, matric_dict[key])
313
- else:
314
- matric_str += '| {}: {:.4f}% '.format(key, matric_dict[key])
315
-
316
- print(f'smina minimize results ========================')
317
- print(matric_str)
318
- print(f'pdb name error list ==========================')
319
- print('\t'.join(error_list))
320
- print(f'smina minimize time: {minimize_time}')
321
-
322
- return
323
-
324
- def get_dock_score(log_path):
325
- with open(log_path, 'r') as f:
326
- lines = f.read().strip().split('\n')
327
-
328
- affinity_score = float(lines[20].split()[1])
329
-
330
- return affinity_score
331
-
332
- def dock_score_analysis(pdb_name_list, config):
333
- dock_score_d = {'name':[], 'score':[]}
334
- error_num = 0
335
- for pdb_name in tqdm(pdb_name_list):
336
- log_path = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.log')
337
- try:
338
- affinity_score = get_dock_score(log_path)
339
- except:
340
- affinity_score = None
341
- dock_score_d['name'].append(pdb_name)
342
- dock_score_d['score'].append(affinity_score)
343
- print('error num,', error_num)
344
- pd.DataFrame(dock_score_d).to_csv(os.path.join(config.train.save_path, 'visualize_dir', f'post_align_{config.train.align_method}_smina_minimize_score.csv'))
345
-
346
-
347
- def structure2score(score_type):
348
- try:
349
- assert score_type in ['vina', 'smina', 'rfscore', 'ign', 'nnscore']
350
- except:
351
- raise ValueError(f'{score_type} if not supported scoring function type')
352
-
353
-
354
-
355
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
UltraFlow/commons/geomop.py DELETED
@@ -1,529 +0,0 @@
1
- import torch
2
- import rdkit.Chem as Chem
3
- import numpy as np
4
- import copy
5
- from rdkit.Chem import AllChem
6
- from rdkit.Chem import rdMolTransforms
7
- from rdkit.Geometry import Point3D
8
- from scipy.optimize import differential_evolution
9
- from .process_mols import read_rdmol
10
- import os
11
- import math
12
- from openbabel import pybel
13
- from tqdm import tqdm
14
-
15
- def get_d_from_pos(pos, edge_index):
16
- return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) # (num_edge)
17
-
18
- def kabsch(coords_A, coords_B, debug=True, device=None):
19
- # rotate and translate coords_A to coords_B pos
20
- coords_A_mean = coords_A.mean(dim=0, keepdim=True) # (1,3)
21
- coords_B_mean = coords_B.mean(dim=0, keepdim=True) # (1,3)
22
-
23
- # A = (coords_A - coords_A_mean).transpose(0, 1) @ (coords_B - coords_B_mean)
24
- A = (coords_A).transpose(0, 1) @ (coords_B )
25
- if torch.isnan(A).any():
26
- print('A Nan encountered')
27
- assert not torch.isnan(A).any()
28
-
29
- if torch.isinf(A).any():
30
- print('inf encountered')
31
- assert not torch.isinf(A).any()
32
-
33
- U, S, Vt = torch.linalg.svd(A)
34
- num_it = 0
35
- while torch.min(S) < 1e-3 or torch.min(
36
- torch.abs((S ** 2).view(1, 3) - (S ** 2).view(3, 1) + torch.eye(3).to(device))) < 1e-2:
37
- if debug: print('S inside loop ', num_it, ' is ', S, ' and A = ', A)
38
- A = A + torch.rand(3, 3).to(device) * torch.eye(3).to(device)
39
- U, S, Vt = torch.linalg.svd(A)
40
- num_it += 1
41
- if num_it > 10: raise Exception('SVD was consitantly unstable')
42
-
43
- corr_mat = torch.diag(torch.tensor([1, 1, torch.sign(torch.det(A))], device=device))
44
- rotation = (U @ corr_mat) @ Vt
45
-
46
- translation = coords_B_mean - torch.t(rotation @ coords_A_mean.t()) # (1,3)
47
-
48
- # new_coords = (rotation @ coords_A.t()).t() + translation
49
-
50
- return rotation, translation
51
-
52
- def rigid_transform_Kabsch_3D(A, B):
53
- assert A.shape[1] == B.shape[1]
54
- num_rows, num_cols = A.shape
55
- if num_rows != 3:
56
- raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
57
- num_rows, num_cols = B.shape
58
- if num_rows != 3:
59
- raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
60
-
61
-
62
- # find mean column wise: 3 x 1
63
- centroid_A = np.mean(A, axis=1, keepdims=True)
64
- centroid_B = np.mean(B, axis=1, keepdims=True)
65
-
66
- # subtract mean
67
- Am = A - centroid_A
68
- Bm = B - centroid_B
69
-
70
- H = Am @ Bm.T
71
-
72
- # find rotation
73
- U, S, Vt = np.linalg.svd(H)
74
-
75
- R = Vt.T @ U.T
76
-
77
- # special reflection case
78
- if np.linalg.det(R) < 0:
79
- # print("det(R) < R, reflection detected!, correcting for it ...")
80
- SS = np.diag([1.,1.,-1.])
81
- R = (Vt.T @ SS) @ U.T
82
- assert math.fabs(np.linalg.det(R) - 1) < 1e-5
83
-
84
- t = -R @ centroid_A + centroid_B
85
- return R, t
86
-
87
- def align_molecule_a_according_molecule_b(molecule_a_path, molecule_b_path, device=None, save=False, kabsch_no_h=True):
88
- m_a = Chem.MolFromMol2File(molecule_a_path, sanitize=False, removeHs=False)
89
- m_b = Chem.MolFromMol2File(molecule_b_path, sanitize=False, removeHs=False)
90
- pos_a = torch.tensor(m_a.GetConformer().GetPositions())
91
- pos_b = torch.tensor(m_b.GetConformer().GetPositions())
92
- m_a_no_h = Chem.RemoveHs(m_a)
93
- m_b_no_h = Chem.RemoveHs(m_b)
94
- pos_a_no_h = torch.tensor(m_a_no_h.GetConformer().GetPositions())
95
- pos_b_no_h = torch.tensor(m_b_no_h.GetConformer().GetPositions())
96
-
97
- if kabsch_no_h:
98
- rotation, translation = kabsch(pos_a_no_h, pos_b_no_h, device=device)
99
- else:
100
- rotation, translation = kabsch(pos_a, pos_b, device=device)
101
- pos_a_new = (rotation @ pos_a.t()).t() + translation
102
- # print(np.sqrt(np.sum((pos_a.numpy() - pos_b.numpy()) ** 2,axis=1).mean()))
103
- # print(np.sqrt(np.sum((pos_a_new.numpy() - pos_b.numpy()) ** 2, axis=1).mean()))
104
-
105
- return pos_a_new, rotation, translation
106
-
107
- def get_principle_axes(xyz,scale_factor=20,pdb_name=None):
108
- #create coordinates array
109
- coord = np.array(xyz, float)
110
- # compute geometric center
111
- center = np.mean(coord, 0)
112
- # print("Coordinates of the geometric center:\n", center)
113
- # center with geometric center
114
- coord = coord - center
115
- # compute principal axis matrix
116
- inertia = np.dot(coord.transpose(), coord)
117
- e_values, e_vectors = np.linalg.eig(inertia)
118
- #--------------------------------------------------------------------------
119
- # order eigen values (and eigen vectors)
120
- #
121
- # axis1 is the principal axis with the biggest eigen value (eval1)
122
- # axis2 is the principal axis with the second biggest eigen value (eval2)
123
- # axis3 is the principal axis with the smallest eigen value (eval3)
124
- #--------------------------------------------------------------------------
125
- order = np.argsort(e_values)
126
- eval3, eval2, eval1 = e_values[order]
127
- axis3, axis2, axis1 = e_vectors[:, order].transpose()
128
-
129
- return np.array([axis1, axis2, axis3]), center
130
-
131
- def get_rotation_and_translation(xyz):
132
- protein_principle_axes_system, system_center = get_principle_axes(xyz)
133
- rotation = protein_principle_axes_system.T
134
- translation = -system_center
135
- return rotation, translation
136
-
137
- def canonical_protein_ligand_orientation(lig_coords, prot_coords):
138
- rotation, translation = get_rotation_and_translation(prot_coords)
139
- lig_canoical_truth_coords = (lig_coords + translation) @ rotation
140
- prot_canonical_truth_coords = (prot_coords + translation) @ rotation
141
- rotation_lig, translation_lig = get_rotation_and_translation(lig_coords)
142
- lig_canonical_init_coords = (lig_coords + translation_lig) @ rotation_lig
143
-
144
- return lig_coords, lig_canoical_truth_coords, lig_canonical_init_coords, \
145
- prot_coords, prot_canonical_truth_coords,\
146
- rotation, translation
147
-
148
- def canonical_single_molecule_orientation(m_coords):
149
- rotation, translation = get_rotation_and_translation(m_coords)
150
- canonical_init_coords = (m_coords + translation) @ rotation
151
- return canonical_init_coords
152
-
153
- # Clockwise dihedral2 from https://stackoverflow.com/questions/20305272/dihedral-torsion-angle-from-four-points-in-cartesian-coordinates-in-python
154
- def GetDihedralFromPointCloud(Z, atom_idx):
155
- p = Z[list(atom_idx)]
156
- b = p[:-1] - p[1:]
157
- b[0] *= -1 #########################
158
- v = np.array( [ v - (v.dot(b[1])/b[1].dot(b[1])) * b[1] for v in [b[0], b[2]] ] )
159
- # Normalize vectors
160
- v /= np.sqrt(np.einsum('...i,...i', v, v)).reshape(-1,1)
161
- b1 = b[1] / np.linalg.norm(b[1])
162
- x = np.dot(v[0], v[1])
163
- m = np.cross(v[0], b1)
164
- y = np.dot(m, v[1])
165
- return np.degrees(np.arctan2( y, x ))
166
-
167
- def A_transpose_matrix(alpha):
168
- return np.array([[np.cos(np.radians(alpha)), np.sin(np.radians(alpha))],
169
- [-np.sin(np.radians(alpha)), np.cos(np.radians(alpha))]], dtype=np.double)
170
-
171
- def S_vec(alpha):
172
- return np.array([[np.cos(np.radians(alpha))],
173
- [np.sin(np.radians(alpha))]], dtype=np.double)
174
-
175
- def get_dihedral_vonMises(mol, conf, atom_idx, Z):
176
- Z = np.array(Z)
177
- v = np.zeros((2,1))
178
- iAtom = mol.GetAtomWithIdx(atom_idx[1])
179
- jAtom = mol.GetAtomWithIdx(atom_idx[2])
180
- k_0 = atom_idx[0]
181
- i = atom_idx[1]
182
- j = atom_idx[2]
183
- l_0 = atom_idx[3]
184
- for b1 in iAtom.GetBonds():
185
- k = b1.GetOtherAtomIdx(i)
186
- if k == j:
187
- continue
188
- for b2 in jAtom.GetBonds():
189
- l = b2.GetOtherAtomIdx(j)
190
- if l == i:
191
- continue
192
- assert k != l
193
- s_star = S_vec(GetDihedralFromPointCloud(Z, (k, i, j, l)))
194
- a_mat = A_transpose_matrix(GetDihedral(conf, (k, i, j, k_0)) + GetDihedral(conf, (l_0, i, j, l)))
195
- v = v + np.matmul(a_mat, s_star)
196
- v = v / np.linalg.norm(v)
197
- v = v.reshape(-1)
198
- return np.degrees(np.arctan2(v[1], v[0]))
199
-
200
- def distance_loss_function(epoch, y_pred, x, protein_nodes_xyz, compound_pair_dis_constraint, LAS_distance_constraint_mask=None, mode=0):
201
- dis = torch.cdist(x, protein_nodes_xyz)
202
- dis_clamp = torch.clamp(dis, max=10)
203
- if mode == 0:
204
- interaction_loss = ((dis_clamp - y_pred).abs()).sum()
205
- elif mode == 1:
206
- interaction_loss = ((dis_clamp - y_pred)**2).sum()
207
- elif mode == 2:
208
- # probably not a good choice. x^0.5 has infinite gradient at x=0. added 1e-5 for numerical stability.
209
- interaction_loss = (((dis_clamp - y_pred).abs() + 1e-5)**0.5).sum()
210
- config_dis = torch.cdist(x, x)
211
- if LAS_distance_constraint_mask is not None:
212
- configuration_loss = 1 * (((config_dis-compound_pair_dis_constraint).abs())[LAS_distance_constraint_mask]).sum()
213
- # basic exlcuded-volume. the distance between compound atoms should be at least 1.22Å
214
- configuration_loss += 2 * ((1.22 - config_dis).relu()).sum()
215
- else:
216
- configuration_loss = 1 * ((config_dis-compound_pair_dis_constraint).abs()).sum()
217
- # if epoch < 500:
218
- # loss = interaction_loss
219
- # else:
220
- # loss = 1 * (interaction_loss + 5e-3 * (epoch - 500) * configuration_loss)
221
- loss = 1 * (interaction_loss + 5e-3 * (epoch + 200) * configuration_loss)
222
- return loss, (interaction_loss.item(), configuration_loss.item())
223
-
224
-
225
- def distance_optimize_compound_coords(coords, y_pred, protein_nodes_xyz,
226
- compound_pair_dis_constraint,total_epoch=1000, loss_function=distance_loss_function, LAS_distance_constraint_mask=None, mode=0, show_progress=False):
227
- # random initialization. center at the protein center.
228
- c_pred = protein_nodes_xyz.mean(axis=0)
229
- x = coords
230
- x.requires_grad = True
231
- optimizer = torch.optim.Adam([x], lr=0.1)
232
- loss_list = []
233
- # optimizer = torch.optim.LBFGS([x], lr=0.01)
234
- if show_progress:
235
- it = tqdm(range(total_epoch))
236
- else:
237
- it = range(total_epoch)
238
- for epoch in it:
239
- optimizer.zero_grad()
240
- loss, (interaction_loss, configuration_loss) = loss_function(epoch, y_pred, x, protein_nodes_xyz,
241
- compound_pair_dis_constraint,
242
- LAS_distance_constraint_mask=LAS_distance_constraint_mask,
243
- mode=mode)
244
- loss.backward()
245
- optimizer.step()
246
- loss_list.append(loss.item())
247
- # break
248
- return x, loss_list
249
-
250
- def tankbind_gen(lig_pred_coords, lig_init_coords, prot_coords, LAS_mask, device='cpu', mode=0):
251
-
252
- pred_prot_lig_inter_distance = torch.cdist(lig_pred_coords, prot_coords)
253
- init_lig_intra_distance = torch.cdist(lig_init_coords, lig_init_coords)
254
- try:
255
- x, loss_list = distance_optimize_compound_coords(lig_pred_coords.to('cpu'),
256
- pred_prot_lig_inter_distance.to('cpu'),
257
- prot_coords.to('cpu'),
258
- init_lig_intra_distance.to('cpu'),
259
- LAS_distance_constraint_mask=LAS_mask.bool(),
260
- mode=mode, show_progress=False)
261
- except:
262
- print('error')
263
-
264
- return x
265
-
266
- def kabsch_align(lig_pred_coords, name, save_path, dataset_path, device='cpu'):
267
- rdkit_init_lig_path_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf')
268
- openbabel_init_m_lig = next(pybel.readfile('sdf', rdkit_init_lig_path_sdf))
269
- rdkit_init_coords = [atom.coords for atom in openbabel_init_m_lig]
270
- rdkit_init_coords = np.array(rdkit_init_coords, dtype=np.float32) # np.array, [n, 3]
271
-
272
- coords_pred = lig_pred_coords.detach().cpu().numpy()
273
-
274
- R, t = rigid_transform_Kabsch_3D(rdkit_init_coords.T, coords_pred.T)
275
- coords_pred_optimized = (R @ (rdkit_init_coords).T).T + t.squeeze()
276
-
277
- opt_ligCoords = torch.tensor(coords_pred_optimized, device=device)
278
- return opt_ligCoords
279
-
280
- def equibind_align(lig_pred_coords, name, save_path, dataset_path, device='cpu'):
281
- lig_path_mol2 = os.path.join(dataset_path, name, f'{name}_ligand.mol2')
282
- lig_path_sdf = os.path.join(dataset_path, name, f'{name}_ligand.sdf')
283
- m_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True)
284
- if m_lig == None: # read mol2 file if sdf file cannot be sanitized
285
- m_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True)
286
-
287
- # load rdkit mol
288
- lig_path_sdf_error = os.path.join(save_path, 'visualize_dir', f'{name}_init')
289
- pred_lig_path_sdf_error = os.path.join(save_path, 'visualize_dir', f'{name}_pred')
290
- pred_lig_path_sdf_true = os.path.join(save_path, 'visualize_dir', f'{name}_pred.sdf')
291
-
292
- rdkit_init_lig_path_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf')
293
-
294
- if not os.path.exists(rdkit_init_lig_path_sdf):
295
- cmd = f'mv {lig_path_sdf_error} {rdkit_init_lig_path_sdf}'
296
- os.system(cmd)
297
- if not os.path.exists(pred_lig_path_sdf_true):
298
- cmd = f'mv {pred_lig_path_sdf_error} {pred_lig_path_sdf_true}'
299
- os.system(cmd)
300
-
301
- openbabel_init_m_lig = next(pybel.readfile('sdf', rdkit_init_lig_path_sdf))
302
- rdkit_init_coords = [atom.coords for atom in openbabel_init_m_lig]
303
- rdkit_init_coords = np.array(rdkit_init_coords, dtype=np.float32) # np.array, [n, 3]
304
- # rdkit_init_m_lig = read_rdmol(rdkit_init_lig_path_sdf, sanitize=True, remove_hs=True)
305
- # rdkit_init_coords = rdkit_init_m_lig.GetConformer().GetPositions()
306
-
307
- rdkit_init_lig = copy.deepcopy(m_lig)
308
- conf = rdkit_init_lig.GetConformer()
309
- for i in range(rdkit_init_lig.GetNumAtoms()):
310
- x, y, z = rdkit_init_coords[i]
311
- conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
312
-
313
- coords_pred = lig_pred_coords.detach().cpu().numpy()
314
- Z_pt_cloud = coords_pred
315
- rotable_bonds = get_torsions([rdkit_init_lig])
316
- new_dihedrals = np.zeros(len(rotable_bonds))
317
-
318
- for idx, r in enumerate(rotable_bonds):
319
- new_dihedrals[idx] = get_dihedral_vonMises(rdkit_init_lig, rdkit_init_lig.GetConformer(), r, Z_pt_cloud)
320
- optimized_mol = apply_changes_equibind(rdkit_init_lig, new_dihedrals, rotable_bonds)
321
-
322
- coords_pred_optimized = optimized_mol.GetConformer().GetPositions()
323
- R, t = rigid_transform_Kabsch_3D(coords_pred_optimized.T, coords_pred.T)
324
- coords_pred_optimized = (R @ (coords_pred_optimized).T).T + t.squeeze()
325
-
326
- opt_ligCoords = torch.tensor(coords_pred_optimized, device=device)
327
- return opt_ligCoords
328
-
329
- def dock_compound(lig_pred_coords, prot_coords, name, save_path,
330
- popsize=150, maxiter=500, seed=None, mutation=(0.5, 1),
331
- recombination=0.8, device='cpu', torsion_num_cut=20):
332
- if seed:
333
- np.random.seed(seed)
334
- torch.cuda.manual_seed_all(seed)
335
- torch.manual_seed(seed)
336
-
337
- # load rdkit mol
338
- lig_path_init_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf')
339
- openbabel_m_lig_init = next(pybel.readfile('sdf', lig_path_init_sdf))
340
- rdkit_init_coords = [atom.coords for atom in openbabel_m_lig_init]
341
-
342
- lig_path_true_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_ligand.sdf')
343
- lig_path_true_mol2 = os.path.join(save_path, 'visualize_dir', f'{name}_ligand.mol2')
344
- m_lig = read_rdmol(lig_path_true_sdf, sanitize=True, remove_hs=True)
345
- if m_lig == None: # read mol2 file if sdf file cannot be sanitized
346
- m_lig = read_rdmol(lig_path_true_mol2, sanitize=True, remove_hs=True)
347
-
348
- atom_num = len(m_lig.GetConformer().GetPositions())
349
- if len(rdkit_init_coords) != atom_num:
350
- rdkit_init_coords = [atom.coords for atom in openbabel_m_lig_init if atom.atomicnum > 1]
351
- lig_pred_coords_no_h_list = [atom_coords for atom,atom_coords in zip(openbabel_m_lig_init, lig_pred_coords.tolist()) if atom.atomicnum > 1]
352
- lig_pred_coords = torch.tensor(lig_pred_coords_no_h_list, device=device)
353
-
354
- rdkit_init_coords = np.array(rdkit_init_coords, dtype=np.float32) # np.array, [n, 3]
355
- print(f'{name} init coords shape: {rdkit_init_coords.shape}')
356
- print(f'{name} true coords shape: {m_lig.GetConformer().GetPositions().shape}')
357
-
358
- rdkit_init_lig = copy.deepcopy(m_lig)
359
- conf = rdkit_init_lig.GetConformer()
360
- for i in range(rdkit_init_lig.GetNumAtoms()):
361
- x, y, z = rdkit_init_coords[i]
362
- conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
363
-
364
- # move m_lig to pred_coords center
365
- pred_coords_center = lig_pred_coords.cpu().numpy().mean(axis=0)
366
- init_coords_center = rdkit_init_lig.GetConformer().GetPositions().mean(axis=0)
367
- # print(f'{name} pred coords shape: {lig_pred_coords.shape}')
368
-
369
- center_rel_vecs = pred_coords_center - init_coords_center
370
- values = np.concatenate([np.array([0,0,0]),center_rel_vecs])
371
- rdMolTransforms.TransformConformer(rdkit_init_lig.GetConformer(), GetTransformationMatrix(values))
372
-
373
- # Set optimization function
374
- opt = optimze_conformation(mol=rdkit_init_lig, target_coords=lig_pred_coords, device=device,
375
- n_particles=1, seed=seed)
376
- if len(opt.rotable_bonds) > torsion_num_cut:
377
- return lig_pred_coords
378
-
379
- # Define bounds for optimization
380
- max_bound = np.concatenate([[np.pi] * 3, prot_coords.cpu().max(0)[0].numpy(), [np.pi] * len(opt.rotable_bonds)], axis=0)
381
- min_bound = np.concatenate([[-np.pi] * 3, prot_coords.cpu().min(0)[0].numpy(), [-np.pi] * len(opt.rotable_bonds)], axis=0)
382
- bounds = (min_bound, max_bound)
383
-
384
- # Optimize conformations
385
- result = differential_evolution(opt.score_conformation, list(zip(bounds[0], bounds[1])), maxiter=maxiter,
386
- popsize=int(np.ceil(popsize / (len(opt.rotable_bonds) + 6))),
387
- mutation=mutation, recombination=recombination, disp=False, seed=seed)
388
-
389
- # Get optimized molecule
390
- starting_mol = opt.mol
391
- opt_mol = apply_changes(starting_mol, result['x'], opt.rotable_bonds)
392
- opt_ligCoords = torch.tensor(opt_mol.GetConformer().GetPositions(), device=device)
393
-
394
- return opt_ligCoords
395
-
396
- class optimze_conformation():
397
- def __init__(self, mol, target_coords, n_particles, save_molecules=False, device='cpu',
398
- seed=None):
399
- super(optimze_conformation, self).__init__()
400
- if seed:
401
- np.random.seed(seed)
402
-
403
- self.targetCoords = torch.stack([target_coords for _ in range(n_particles)]).double()
404
- self.n_particles = n_particles
405
- self.rotable_bonds = get_torsions([mol])
406
- self.save_molecules = save_molecules
407
- self.mol = mol
408
- self.device = device
409
-
410
- def score_conformation(self, values):
411
- """
412
- Parameters
413
- ----------
414
- values : numpy.ndarray
415
- set of inputs of shape :code:`(n_particles, dimensions)`
416
- Returns
417
- -------
418
- numpy.ndarray
419
- computed cost of size :code:`(n_particles, )`
420
- """
421
- if len(values.shape) < 2: values = np.expand_dims(values, axis=0)
422
- mols = [copy.copy(self.mol) for _ in range(self.n_particles)]
423
-
424
- # Apply changes to molecules
425
- # apply rotations
426
- [SetDihedral(mols[m].GetConformer(), self.rotable_bonds[r], values[m, 6 + r]) for r in
427
- range(len(self.rotable_bonds)) for m in range(self.n_particles)]
428
-
429
- # apply transformation matrix
430
- [rdMolTransforms.TransformConformer(mols[m].GetConformer(), GetTransformationMatrix(values[m, :6])) for m in
431
- range(self.n_particles)]
432
-
433
- # Calcualte distances between ligand conformation and pred ligand conformation
434
- ligCoords_list = [torch.tensor(m.GetConformer().GetPositions(), device=self.device) for m in mols] # [n_mols, N, 3]
435
- ligCoords = torch.stack(ligCoords_list).double() # [n_mols, N, 3]
436
-
437
- ligCoords_error = ligCoords - self.targetCoords # [n_mols, N, 3]
438
- ligCoords_rmsd = (ligCoords_error ** 2).sum(dim=-1).mean(dim=-1).sqrt().min().cpu().numpy()
439
-
440
- del ligCoords_error, ligCoords, ligCoords_list, mols
441
-
442
- return ligCoords_rmsd
443
-
444
- def apply_changes(mol, values, rotable_bonds):
445
- opt_mol = copy.copy(mol)
446
-
447
- # apply rotations
448
- [SetDihedral(opt_mol.GetConformer(), rotable_bonds[r], values[6 + r]) for r in range(len(rotable_bonds))]
449
-
450
- # apply transformation matrix
451
- rdMolTransforms.TransformConformer(opt_mol.GetConformer(), GetTransformationMatrix(values[:6]))
452
-
453
- return opt_mol
454
-
455
- def apply_changes_equibind(mol, values, rotable_bonds):
456
- opt_mol = copy.deepcopy(mol)
457
- # opt_mol = add_rdkit_conformer(opt_mol)
458
-
459
- # apply rotations
460
- [SetDihedral(opt_mol.GetConformer(), rotable_bonds[r], values[r]) for r in range(len(rotable_bonds))]
461
-
462
- # # apply transformation matrix
463
- # rdMolTransforms.TransformConformer(opt_mol.GetConformer(), GetTransformationMatrix(values[:6]))
464
-
465
- return opt_mol
466
-
467
- def get_torsions(mol_list):
468
- atom_counter = 0
469
- torsionList = []
470
- dihedralList = []
471
- for m in mol_list:
472
- torsionSmarts = '[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]'
473
- torsionQuery = Chem.MolFromSmarts(torsionSmarts)
474
- matches = m.GetSubstructMatches(torsionQuery)
475
- conf = m.GetConformer()
476
- for match in matches:
477
- idx2 = match[0]
478
- idx3 = match[1]
479
- bond = m.GetBondBetweenAtoms(idx2, idx3)
480
- jAtom = m.GetAtomWithIdx(idx2)
481
- kAtom = m.GetAtomWithIdx(idx3)
482
- for b1 in jAtom.GetBonds():
483
- if (b1.GetIdx() == bond.GetIdx()):
484
- continue
485
- idx1 = b1.GetOtherAtomIdx(idx2)
486
- for b2 in kAtom.GetBonds():
487
- if ((b2.GetIdx() == bond.GetIdx())
488
- or (b2.GetIdx() == b1.GetIdx())):
489
- continue
490
- idx4 = b2.GetOtherAtomIdx(idx3)
491
- # skip 3-membered rings
492
- if (idx4 == idx1):
493
- continue
494
- # skip torsions that include hydrogens
495
- if ((m.GetAtomWithIdx(idx1).GetAtomicNum() == 1)
496
- or (m.GetAtomWithIdx(idx4).GetAtomicNum() == 1)):
497
- continue
498
- if m.GetAtomWithIdx(idx4).IsInRing():
499
- torsionList.append(
500
- (idx4 + atom_counter, idx3 + atom_counter, idx2 + atom_counter, idx1 + atom_counter))
501
- break
502
- else:
503
- torsionList.append(
504
- (idx1 + atom_counter, idx2 + atom_counter, idx3 + atom_counter, idx4 + atom_counter))
505
- break
506
- break
507
-
508
- atom_counter += m.GetNumAtoms()
509
- return torsionList
510
-
511
-
512
- def SetDihedral(conf, atom_idx, new_vale):
513
- rdMolTransforms.SetDihedralRad(conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale)
514
-
515
-
516
- def GetDihedral(conf, atom_idx):
517
- return rdMolTransforms.GetDihedralRad(conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3])
518
-
519
-
520
- def GetTransformationMatrix(transformations):
521
- x, y, z, disp_x, disp_y, disp_z = transformations
522
- transMat = np.array([[np.cos(z) * np.cos(y), (np.cos(z) * np.sin(y) * np.sin(x)) - (np.sin(z) * np.cos(x)),
523
- (np.cos(z) * np.sin(y) * np.cos(x)) + (np.sin(z) * np.sin(x)), disp_x],
524
- [np.sin(z) * np.cos(y), (np.sin(z) * np.sin(y) * np.sin(x)) + (np.cos(z) * np.cos(x)),
525
- (np.sin(z) * np.sin(y) * np.cos(x)) - (np.cos(z) * np.sin(x)), disp_y],
526
- [-np.sin(y), np.cos(y) * np.sin(x), np.cos(y) * np.cos(x), disp_z],
527
- [0, 0, 0, 1]], dtype=np.double)
528
- return transMat
529
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
UltraFlow/commons/get_free_gpu.py DELETED
@@ -1,78 +0,0 @@
1
- import torch
2
- from gpustat import GPUStatCollection
3
- import time
4
- def get_free_gpu(mode="memory", memory_need=10000) -> list:
5
- r"""Get free gpu according to mode (process-free or memory-free).
6
- Args:
7
- mode (str, optional): memory-free or process-free. Defaults to "memory".
8
- memory_need (int): The memory you need, used if mode=='memory'. Defaults to 10000.
9
- Returns:
10
- list: free gpu ids sorting by free memory
11
- """
12
- assert mode in ["memory", "process"], "mode must be 'memory' or 'process'"
13
- if mode == "memory":
14
- assert memory_need is not None, \
15
- "'memory_need' if None, 'memory' mode must give the free memory you want to apply for"
16
- memory_need = int(memory_need)
17
- assert memory_need > 0, "'memory_need' you want must be positive"
18
- gpu_stats = GPUStatCollection.new_query()
19
- gpu_free_id_list = []
20
-
21
- for idx, gpu_stat in enumerate(gpu_stats):
22
- if gpu_check_condition(gpu_stat, mode, memory_need):
23
- gpu_free_id_list.append([idx, gpu_stat.memory_free])
24
- print("gpu[{}]: {}MB".format(idx, gpu_stat.memory_free))
25
-
26
- if gpu_free_id_list:
27
- gpu_free_id_list = sorted(gpu_free_id_list,
28
- key=lambda x: x[1],
29
- reverse=True)
30
- gpu_free_id_list = [i[0] for i in gpu_free_id_list]
31
- return gpu_free_id_list
32
-
33
-
34
- def gpu_check_condition(gpu_stat, mode, memory_need) -> bool:
35
- r"""Check gpu is free or not.
36
- Args:
37
- gpu_stat (gpustat.core): gpustat to check
38
- mode (str): memory-free or process-free.
39
- memory_need (int): The memory you need, used if mode=='memory'
40
- Returns:
41
- bool: gpu is free or not
42
- """
43
- if mode == "memory":
44
- return gpu_stat.memory_free > memory_need
45
- elif mode == "process":
46
- for process in gpu_stat.processes:
47
- if process["command"] == "python":
48
- return False
49
- return True
50
- else:
51
- return False
52
-
53
- def get_device(target_gpu_idx, memory_need=10000):
54
- # check device
55
- # assert torch.cuda.device_count() >= len(target_gpus), 'do you set the gpus in config correctly?'
56
- flag = None
57
-
58
- while True:
59
- # Get the gpu ids which have more than 10000MB memory
60
- free_gpu_ids = get_free_gpu('memory', memory_need)
61
- if len(free_gpu_ids) < 1:
62
- if flag is None:
63
- print("No GPU available now. sleeping 60s ....")
64
- time.sleep(6)
65
- else:
66
-
67
- gpuid = list(set(free_gpu_ids) & set(target_gpu_idx))[0]
68
-
69
- device = torch.device('cuda:'+str(gpuid))
70
- print("Using device %s as main device" % device)
71
- break
72
-
73
- return device
74
-
75
- if __name__ == '__main__':
76
- target_gpu_idx = [0,1,2,3,4,5,6,7,8]
77
- device = get_device(target_gpu_idx)
78
- print(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
UltraFlow/commons/loss_weight.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:42d7d91c0447c79418d9d547d45203612a9dcbf21355e047923237ba36d8765e
3
- size 748
 
 
 
 
UltraFlow/commons/metrics.py DELETED
@@ -1,315 +0,0 @@
1
- from scipy import stats
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- from math import sqrt, ceil
6
- from sklearn.linear_model import LinearRegression
7
- from sklearn.metrics import ndcg_score, recall_score
8
- import os
9
- import pickle
10
- import dgl
11
- from typing import Union, List
12
- from torch import Tensor
13
- from statistics import stdev
14
-
15
- def affinity_loss(affinity_pred,labels,sec_pred,bg_prot,config):
16
- loss = nn.MSELoss(affinity_pred,labels)
17
- if config.model.aux_w != 0:
18
- loss += config.train.aux_w * nn.CrossEntropyLoss(sec_pred,bg_prot.ndata['s'])
19
- return loss
20
-
21
- def Accurate_num(outputs,y):
22
- _, y_pred_label = torch.max(outputs, dim=1)
23
- return torch.sum(y_pred_label == y.data).item()
24
-
25
- def RMSE(y,f):
26
- rmse = sqrt(((y - f)**2).mean(axis=0))
27
- return rmse
28
-
29
- def MAE(y,f):
30
- mae = (np.abs(y-f)).mean()
31
- return mae
32
-
33
- def SD(y,f):
34
- f,y = f.reshape(-1,1),y.reshape(-1,1)
35
- lr = LinearRegression()
36
- lr.fit(f,y)
37
- y_ = lr.predict(f)
38
- sd = (((y - y_) ** 2).sum() / (len(y) - 1)) ** 0.5
39
- return sd
40
-
41
- def Pearson(y,f):
42
- y,f = y.flatten(),f.flatten()
43
- rp = np.corrcoef(y, f)[0,1]
44
- return rp
45
-
46
- def Spearman(y,f):
47
- y, f = y.flatten(), f.flatten()
48
- rp = stats.spearmanr(y, f)
49
- return rp[0]
50
-
51
- def NDCG(y,f,k=None):
52
- y, f = y.flatten(), f.flatten()
53
- return ndcg_score(np.expand_dims(y, axis=0), np.expand_dims(f,axis=0),k=k)
54
-
55
- def Recall(y, f, postive_threshold = 7.5):
56
- y, f = y.flatten(), f.flatten()
57
- y_class = y > postive_threshold
58
- f_class = f > postive_threshold
59
-
60
- return recall_score(y_class, f_class)
61
-
62
- def Enrichment_Factor(y, f, postive_threshold = 7.5, top_percentage = 0.001):
63
- y, f = y.flatten(), f.flatten()
64
- y_class = y > postive_threshold
65
- f_class = f > postive_threshold
66
-
67
- data = list(zip(y_class.tolist(), f_class.tolist()))
68
- data.sort(key=lambda x:x[1], reverse=True)
69
-
70
- y_class, f_class = map(list, zip(*data))
71
-
72
- total_active_rate = sum(y_class) / len(y_class)
73
- top_num = ceil(len(y_class) * top_percentage)
74
- top_active_rate = sum(y_class[:top_num]) / top_num
75
-
76
- er = top_active_rate / total_active_rate
77
-
78
- return er
79
-
80
- def Auxiliary_Weight_Balance(aux_type='Q8'):
81
- if os.path.exists('loss_weight.pkl'):
82
- with open('loss_weight.pkl','rb') as f:
83
- w = pickle.load(f)
84
- return w[aux_type]
85
-
86
- def RMSD(ligs_coords_pred, ligs_coords):
87
- rmsds = []
88
- for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
89
- rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1))).item())
90
- return rmsds
91
-
92
- def KabschRMSD(ligs_coords_pred, ligs_coords):
93
- rmsds = []
94
- for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
95
- lig_coords_pred_mean = lig_coords_pred.mean(dim=0, keepdim=True) # (1,3)
96
- lig_coords_mean = lig_coords.mean(dim=0, keepdim=True) # (1,3)
97
-
98
- A = (lig_coords_pred - lig_coords_pred_mean).transpose(0, 1) @ (lig_coords - lig_coords_mean)
99
-
100
- U, S, Vt = torch.linalg.svd(A)
101
-
102
- corr_mat = torch.diag(torch.tensor([1, 1, torch.sign(torch.det(A))], device=lig_coords_pred.device))
103
- rotation = (U @ corr_mat) @ Vt
104
- translation = lig_coords_pred_mean - torch.t(rotation @ lig_coords_mean.t()) # (1,3)
105
-
106
- lig_coords = (rotation @ lig_coords.t()).t() + translation
107
- rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1))))
108
- return torch.tensor(rmsds).mean()
109
-
110
-
111
- class RMSDmedian(nn.Module):
112
- def __init__(self) -> None:
113
- super(RMSDmedian, self).__init__()
114
-
115
- def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor:
116
- rmsds = []
117
- for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
118
- rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1))))
119
- return torch.median(torch.tensor(rmsds))
120
-
121
-
122
- class RMSDfraction(nn.Module):
123
- def __init__(self, distance) -> None:
124
- super(RMSDfraction, self).__init__()
125
- self.distance = distance
126
-
127
- def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor:
128
- rmsds = []
129
- for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
130
- rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1))))
131
- count = torch.tensor(rmsds) < self.distance
132
- return 100 * count.sum() / len(count)
133
-
134
-
135
- def CentroidDist(ligs_coords_pred, ligs_coords):
136
- distances = []
137
- for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
138
- distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0)).item())
139
- return distances
140
-
141
-
142
- class CentroidDistMedian(nn.Module):
143
- def __init__(self) -> None:
144
- super(CentroidDistMedian, self).__init__()
145
-
146
- def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor:
147
- distances = []
148
- for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
149
- distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0)))
150
- return torch.median(torch.tensor(distances))
151
-
152
-
153
- class CentroidDistFraction(nn.Module):
154
- def __init__(self, distance) -> None:
155
- super(CentroidDistFraction, self).__init__()
156
- self.distance = distance
157
-
158
- def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor:
159
- distances = []
160
- for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
161
- distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0)))
162
- count = torch.tensor(distances) < self.distance
163
- return 100 * count.sum() / len(count)
164
-
165
- class MeanPredictorLoss(nn.Module):
166
-
167
- def __init__(self, loss_func) -> None:
168
- super(MeanPredictorLoss, self).__init__()
169
- self.loss_func = loss_func
170
-
171
- def forward(self, x1: Tensor, targets: Tensor) -> Tensor:
172
- return self.loss_func(torch.full_like(targets, targets.mean()), targets)
173
-
174
-
175
- def compute_mmd(source, target, batch_size=1000, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
176
- """
177
- Calculate the `maximum mean discrepancy distance <https://jmlr.csail.mit.edu/papers/v13/gretton12a.html>`_ between two sample set.
178
- This implementation is based on `this open source code <https://github.com/ZongxianLee/MMD_Loss.Pytorch>`_.
179
- Args:
180
- source (pytorch tensor): the pytorch tensor containing data samples of the source distribution.
181
- target (pytorch tensor): the pytorch tensor containing data samples of the target distribution.
182
- :rtype:
183
- :class:`float`
184
- """
185
- n_source = int(source.size()[0])
186
- n_target = int(target.size()[0])
187
- n_samples = n_source + n_target
188
-
189
- total = torch.cat([source, target], dim=0)
190
- total0 = total.unsqueeze(0)
191
- total1 = total.unsqueeze(1)
192
-
193
- if fix_sigma:
194
- bandwidth = fix_sigma
195
- else:
196
- bandwidth, id = 0.0, 0
197
- while id < n_samples:
198
- bandwidth += torch.sum((total0 - total1[id:id + batch_size]) ** 2)
199
- id += batch_size
200
- bandwidth /= n_samples ** 2 - n_samples
201
-
202
- bandwidth /= kernel_mul ** (kernel_num // 2)
203
- bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
204
-
205
- XX_kernel_val = [0 for _ in range(kernel_num)]
206
- for i in range(kernel_num):
207
- XX_kernel_val[i] += torch.sum(
208
- torch.exp(-((total0[:, :n_source] - total1[:n_source, :]) ** 2) / bandwidth_list[i]))
209
- XX = sum(XX_kernel_val) / (n_source * n_source)
210
-
211
- YY_kernel_val = [0 for _ in range(kernel_num)]
212
- id = n_source
213
- while id < n_samples:
214
- for i in range(kernel_num):
215
- YY_kernel_val[i] += torch.sum(
216
- torch.exp(-((total0[:, n_source:] - total1[id:id + batch_size, :]) ** 2) / bandwidth_list[i]))
217
- id += batch_size
218
- YY = sum(YY_kernel_val) / (n_target * n_target)
219
-
220
- XY_kernel_val = [0 for _ in range(kernel_num)]
221
- id = n_source
222
- while id < n_samples:
223
- for i in range(kernel_num):
224
- XY_kernel_val[i] += torch.sum(
225
- torch.exp(-((total0[:, id:id + batch_size] - total1[:n_source, :]) ** 2) / bandwidth_list[i]))
226
- id += batch_size
227
- XY = sum(XY_kernel_val) / (n_source * n_target)
228
-
229
- return XX.item() + YY.item() - 2 * XY.item()
230
-
231
-
232
- def get_matric_dict(rmsds, centroids, kabsch_rmsds=None):
233
- rmsd_mean = sum(rmsds)/len(rmsds)
234
- centroid_mean = sum(centroids) / len(centroids)
235
- rmsd_std = stdev(rmsds)
236
- centroid_std = stdev(centroids)
237
-
238
- # rmsd < 2
239
- count = torch.tensor(rmsds) < 2.0
240
- rmsd_less_than_2 = 100 * count.sum().item() / len(count)
241
-
242
- # rmsd < 2
243
- count = torch.tensor(rmsds) < 5.0
244
- rmsd_less_than_5 = 100 * count.sum().item() / len(count)
245
-
246
- # centorid < 2
247
- count = torch.tensor(centroids) < 2.0
248
- centroid_less_than_2 = 100 * count.sum().item() / len(count)
249
-
250
- # centorid < 5
251
- count = torch.tensor(centroids) < 5.0
252
- centroid_less_than_5 = 100 * count.sum().item() / len(count)
253
-
254
- rmsd_precentiles = np.percentile(np.array(rmsds), [25, 50, 75]).round(4)
255
- centroid_prcentiles = np.percentile(np.array(centroids), [25, 50, 75]).round(4)
256
-
257
- metrics_dict = {'rmsd mean': rmsd_mean, 'rmsd std': rmsd_std,
258
- 'rmsd 25%': rmsd_precentiles[0], 'rmsd 50%': rmsd_precentiles[1], 'rmsd 75%': rmsd_precentiles[2],
259
- 'centroid mean': centroid_mean, 'centroid std': centroid_std,
260
- 'centroid 25%': centroid_prcentiles[0], 'centroid 50%': centroid_prcentiles[1], 'centroid 75%': centroid_prcentiles[2],
261
- 'rmsd less than 2': rmsd_less_than_2, 'rmsd less than 5':rmsd_less_than_5,
262
- 'centroid less than 2': centroid_less_than_2, 'centroid less than 5': centroid_less_than_5,
263
- }
264
-
265
- if kabsch_rmsds is not None:
266
- kabsch_rmsd_mean = sum(kabsch_rmsds) / len(kabsch_rmsds)
267
- kabsch_rmsd_std = stdev(kabsch_rmsd_mean)
268
- metrics_dict['kabsch rmsd mean'] = kabsch_rmsd_mean
269
- metrics_dict['kabsch rmsd std'] = kabsch_rmsd_std
270
-
271
- return metrics_dict
272
-
273
- def get_sbap_regression_metric_dict(np_y, np_f):
274
- rmse, mae, pearson, spearman, sd_ = RMSE(np_y, np_f), \
275
- MAE(np_y, np_f),\
276
- Pearson(np_y,np_f), \
277
- Spearman(np_y, np_f),\
278
- SD(np_y, np_f)
279
-
280
- metrics_dict = {'RMSE': rmse, 'MAE': mae, 'Pearson': pearson, 'Spearman': spearman, 'SD':sd_}
281
- return metrics_dict
282
-
283
- def get_sbap_matric_dict(np_y, np_f):
284
- rmse, mae, pearson, spearman, sd_ = RMSE(np_y, np_f), \
285
- MAE(np_y, np_f),\
286
- Pearson(np_y,np_f), \
287
- Spearman(np_y, np_f),\
288
- SD(np_y, np_f)
289
-
290
- recall, ndcg = Recall(np_y, np_f), NDCG(np_y, np_f)
291
- enrichment_factor = Enrichment_Factor(np_y, np_f)
292
-
293
- metrics_dict = {'RMSE': rmse, 'MAE': mae, 'Pearson': pearson, 'Spearman': spearman, 'SD':sd_,
294
- 'Recall': recall, 'NDCG': ndcg, 'EF1%':enrichment_factor
295
- }
296
- return metrics_dict
297
-
298
- def get_matric_output_str(matric_dict):
299
- matric_str = ''
300
- for key in matric_dict.keys():
301
- if not 'less than' in key:
302
- matric_str += '| {}: {:.4f} '.format(key, matric_dict[key])
303
- else:
304
- matric_str += '| {}: {:.4f}% '.format(key, matric_dict[key])
305
- return matric_str
306
-
307
- def get_unseen_matric(rmsds, centroids, names, unseen_file_path):
308
- with open(unseen_file_path, 'r') as f:
309
- unseen_names = f.read().strip().split('\n')
310
- unseen_rmsds, unseen_centroids = [], []
311
- for name, rmsd, centroid in zip(names, rmsds, centroids):
312
- if name in unseen_names:
313
- unseen_rmsds.append(rmsd)
314
- unseen_centroids.append(centroid)
315
- return get_matric_dict(unseen_rmsds, unseen_centroids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
UltraFlow/commons/torch_prepare.py DELETED
@@ -1,156 +0,0 @@
1
- import copy
2
- import torch
3
- import torch.nn as nn
4
- import warnings
5
- import dgl
6
- import os
7
- from UltraFlow import runner, dataset
8
- from .utils import get_run_dir
9
-
10
- # customize exp lr scheduler with min lr
11
- class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR):
12
- def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False):
13
- self.gamma = gamma
14
- self.min_lr = min_lr
15
- super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose)
16
-
17
- def get_lr(self):
18
- if not self._get_lr_called_within_step:
19
- warnings.warn("To get the last learning rate computed by the scheduler, "
20
- "please use `get_last_lr()`.", UserWarning)
21
-
22
- if self.last_epoch == 0:
23
- return self.base_lrs
24
- return [max(group['lr'] * self.gamma, self.min_lr)
25
- for group in self.optimizer.param_groups]
26
-
27
- def _get_closed_form_lr(self):
28
- return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr)
29
- for base_lr in self.base_lrs]
30
-
31
-
32
- def get_scheduler(config, optimizer):
33
- if config.type == 'plateau':
34
- return torch.optim.lr_scheduler.ReduceLROnPlateau(
35
- optimizer,
36
- factor=config.factor,
37
- patience=config.patience,
38
- )
39
- elif config.train.scheduler == 'expmin':
40
- return ExponentialLR_with_minLr(
41
- optimizer,
42
- gamma=config.factor,
43
- min_lr=config.min_lr,
44
- )
45
- else:
46
- raise NotImplementedError('Scheduler not supported: %s' % config.type)
47
-
48
- def get_optimizer(config, model):
49
- if config.type == "Adam":
50
- return torch.optim.Adam(
51
- filter(lambda p: p.requires_grad, model.parameters()),
52
- lr=config.lr,
53
- weight_decay=config.weight_decay)
54
- else:
55
- raise NotImplementedError('Optimizer not supported: %s' % config.type)
56
-
57
- def get_optimizer_ablation(config, model, interact_ablation):
58
- if config.type == "Adam":
59
- return torch.optim.Adam(
60
- filter(lambda p: p.requires_grad, list(model.parameters()) + list(interact_ablation.parameters()) ) ,
61
- lr=config.lr,
62
- weight_decay=config.weight_decay)
63
- else:
64
- raise NotImplementedError('Optimizer not supported: %s' % config.type)
65
-
66
- def get_dataset(config, ddp=False):
67
- if config.data.dataset_name == 'chembl_in_pdbbind_smina':
68
- if config.data.split_type == 'assay_specific':
69
- if ddp and config.train.use_memory_efficient_dataset == 'v1':
70
- train_data, val_data = dataset.load_memoryefficient_ChEMBL_Dock(config)
71
- test_data = None
72
- elif config.train.use_memory_efficient_dataset == 'v2':
73
- train_data, val_data = dataset.load_ChEMBL_Dock_v2(config)
74
- test_data = None
75
- else:
76
- train_data, val_data = dataset.load_ChEMBL_Dock(config)
77
- test_data = None
78
-
79
- names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels, IC50_flag, Kd_flag, Ki_flag, K_flag, assay_d\
80
- = dataset.load_complete_dataset(config.data.finetune_total_names, config.data.finetune_dataset_name, config.data.labels_path, config)
81
-
82
- train_names, valid_names, test_names = dataset.split_names(names, config)
83
- finetune_val_data = dataset.select_according_names(valid_names,
84
- names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels,
85
- IC50_flag, Kd_flag, Ki_flag, K_flag,
86
- assay_d, config)
87
-
88
- return train_data, val_data, test_data, finetune_val_data
89
-
90
- def get_finetune_dataset(config):
91
- names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels, IC50_flag, Kd_flag, Ki_flag, K_flag, assay_d\
92
- = dataset.load_complete_dataset(config.data.finetune_total_names, config.data.finetune_dataset_name, config.data.labels_path, config)
93
-
94
- train_names, valid_names, test_names = dataset.split_names(names, config)
95
-
96
- train_data = dataset.select_according_names(train_names,
97
- names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels,
98
- IC50_flag, Kd_flag, Ki_flag, K_flag,
99
- assay_d, config)
100
-
101
- val_data = dataset.select_according_names(valid_names,
102
- names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels,
103
- IC50_flag, Kd_flag, Ki_flag, K_flag,
104
- assay_d, config)
105
-
106
- test_data = dataset.select_according_names(test_names,
107
- names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels,
108
- IC50_flag, Kd_flag, Ki_flag, K_flag,
109
- assay_d, config)
110
-
111
- # train_data = dataset.pdbbind_finetune(config.data.finetune_train_names, config.data.finetune_dataset_name,
112
- # config.data.labels_path, config)
113
- # val_data = dataset.pdbbind_finetune(config.data.finetune_valid_names, config.data.finetune_dataset_name,
114
- # config.data.labels_path, config)
115
- # test_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name,
116
- # config.data.labels_path, config)
117
-
118
- # train_data = dataset.pdbbind_finetune(config.data.finetune_train_names, config.data.finetune_dataset_name,
119
- # config.data.labels_path, config)
120
- # val_data = dataset.pdbbind_finetune(config.data.finetune_valid_names, config.data.finetune_dataset_name,
121
- # config.data.labels_path, config)
122
- # test_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name,
123
- # config.data.labels_path, config)
124
-
125
- generalize_csar_data = dataset.pdbbind_finetune(config.data.generalize_csar_test, config.data.generalize_dataset_name,
126
- config.data.generalize_labels_path, config)
127
-
128
- return train_data, val_data, test_data, generalize_csar_data
129
-
130
- def get_test_dataset(config):
131
- test_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name,
132
- config.data.labels_path, config)
133
-
134
- generalize_csar_data = dataset.pdbbind_finetune(config.data.generalize_csar_test, config.data.generalize_dataset_name,
135
- config.data.generalize_labels_path, config)
136
-
137
- return test_data, generalize_csar_data
138
-
139
- def get_dataset_example(config):
140
- example_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name,
141
- config.data.labels_path, config)
142
-
143
- return example_data
144
-
145
- def get_model(config):
146
- return globals()[config.model.model_type](config).to(config.train.device)
147
-
148
- def repeat_data(data, num_repeat):
149
- datas = [copy.deepcopy(data) for i in range(num_repeat)]
150
- g_ligs, g_prots, g_inters = list(zip(*datas))
151
- return dgl.batch(g_ligs), dgl.batch(g_prots), dgl.batch(g_inters)
152
-
153
- def clip_norm(vec, limit, p=2):
154
- norm = torch.norm(vec, dim=-1, p=2, keepdim=True)
155
- denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))
156
- return vec * denom
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
UltraFlow/commons/visualize.py DELETED
@@ -1,364 +0,0 @@
1
- import os
2
- import pandas as pd
3
- import torch
4
- from prody import writePDB
5
- from rdkit import Chem as Chem
6
- from rdkit.Chem.rdchem import BondType as BT
7
- from openbabel import openbabel, pybel
8
- from io import BytesIO
9
- from .process_mols import read_molecules_crossdock, read_molecules, read_rdmol
10
- from .geomop import canonical_protein_ligand_orientation
11
- from collections import defaultdict
12
- import numpy as np
13
-
14
- BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
15
- BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())}
16
-
17
- def simply_modify_coords(pred_coords,file_path,file_type='mol2',pos_no=None,file_label='pred'):
18
- with open(file_path,'r') as f:
19
- lines = f.read().strip().split('\n')
20
- index = 0
21
- while index < len(lines):
22
- if '@<TRIPOS>ATOM' in lines[index]:
23
- break
24
- index += 1
25
- for coord in pred_coords:
26
- index += 1
27
- new_x = '{:.4f}'.format(coord[0]).rjust(10, ' ')
28
- new_y = '{:.4f}'.format(coord[1]).rjust(10, ' ')
29
- new_z = '{:.4f}'.format(coord[2]).rjust(10, ' ')
30
- new_coord_str = new_x + new_y + new_z
31
- lines[index] = lines[index][:16] + new_coord_str + lines[index][46:]
32
-
33
- if pos_no is not None:
34
- with open('{}_{}_{}.{}'.format(os.path.join(os.path.dirname(file_path),os.path.basename(file_path).split('.')[0]), file_label, pos_no, file_type),'w') as f:
35
- f.write('\n'.join(lines))
36
- else:
37
- with open('{}_{}.{}'.format(os.path.join(os.path.dirname(file_path),os.path.basename(file_path).split('.')[0]), file_label, file_type),'w') as f:
38
- f.write('\n'.join(lines))
39
-
40
- def set_new_coords_for_protein_atom(m_prot, new_coords):
41
- for index,atom in enumerate(m_prot):
42
- atom.setCoords(new_coords[index])
43
- return m_prot
44
-
45
- def save_ligand_file(m_lig, output_path, file_type='mol2'):
46
-
47
- return
48
-
49
- def save_protein_file(m_prot, output_path, file_type='pdb'):
50
- if file_type=='pdb':
51
- writePDB(output_path, m_prot)
52
- return
53
-
54
- def generated_to_xyz(data):
55
- ptable = Chem.GetPeriodicTable()
56
- num_atoms, atom_type, atom_coords = data
57
- xyz = "%d\n\n" % (num_atoms, )
58
- for i in range(num_atoms):
59
- symb = ptable.GetElementSymbol(int(atom_type[i]))
60
- x, y, z = atom_coords[i].clone().cpu().tolist()
61
- xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z)
62
- return xyz
63
-
64
- def generated_to_sdf(data):
65
- xyz = generated_to_xyz(data)
66
- obConversion = openbabel.OBConversion()
67
- obConversion.SetInAndOutFormats("xyz", "sdf")
68
-
69
- mol = openbabel.OBMol()
70
- obConversion.ReadString(mol, xyz)
71
- sdf = obConversion.WriteString(mol)
72
- return sdf
73
-
74
- def sdf_to_rdmol(sdf):
75
- stream = BytesIO(sdf.encode())
76
- suppl = Chem.ForwardSDMolSupplier(stream)
77
- for mol in suppl:
78
- return mol
79
- return None
80
-
81
- def generated_to_rdmol(data):
82
- sdf = generated_to_sdf(data)
83
- return sdf_to_rdmol(sdf)
84
-
85
- def generated_to_rdmol_trajectory(trajectory):
86
- sdf_trajectory = ''
87
- for data in trajectory:
88
- sdf_trajectory += generated_to_sdf(data)
89
- return sdf_trajectory
90
-
91
- def filter_rd_mol(rdmol):
92
- ring_info = rdmol.GetRingInfo()
93
- ring_info.AtomRings()
94
- rings = [set(r) for r in ring_info.AtomRings()]
95
-
96
- # 3-3 ring intersection
97
- for i, ring_a in enumerate(rings):
98
- if len(ring_a) != 3:continue
99
- for j, ring_b in enumerate(rings):
100
- if i <= j: continue
101
- inter = ring_a.intersection(ring_b)
102
- if (len(ring_b) == 3) and (len(inter) > 0):
103
- return False
104
- return True
105
-
106
-
107
- def save_sdf_mol(rdmol, save_path, suffix='test'):
108
- writer = Chem.SDWriter(os.path.join(save_path, 'visualize_dir', f'{suffix}.sdf'))
109
- writer.SetKekulize(False)
110
- try:
111
- writer.write(rdmol, confId=0)
112
- except:
113
- writer.close()
114
- return False
115
- writer.close()
116
- return True
117
-
118
- def sdf_string_save_sdf_file(sdf_string, save_path, suffix='test'):
119
- with open(os.path.join(save_path, 'visualize_dir', f'{suffix}.sdf'), 'w') as f:
120
- f.write(sdf_string)
121
- return
122
-
123
- def visualize_generate_full_trajectory(trajectory, index, dataset, save_path, move_truth=True,
124
- name_suffix='pred_trajectory', canonical_oritentaion=True):
125
- if dataset.dataset_name in ['crossdock2020', 'crossdock2020_test']:
126
- lig_path = index
127
- lig_path_split = lig_path.split('/')
128
- lig_dir, lig_base = lig_path_split[0], lig_path_split[1]
129
- prot_path = os.path.join(lig_dir, lig_base[:10] + '.pdb')
130
-
131
- if not os.path.exists(os.path.join(save_path, 'visualize_dir', lig_dir)):
132
- os.makedirs(os.path.join(save_path, 'visualize_dir', lig_dir))
133
-
134
- name = index[:-4]
135
-
136
- assert prot_path.endswith('_rec.pdb')
137
- molecular_representation = read_molecules_crossdock(lig_path, prot_path, dataset.ligcut, dataset.protcut,
138
- dataset.lig_type, dataset.prot_graph_type,
139
- dataset.dataset_path, dataset.chaincut)
140
-
141
- lig_path_direct = os.path.join(dataset.dataset_path, lig_path)
142
- prot_path_direct = os.path.join(dataset.dataset_path, prot_path)
143
-
144
-
145
- elif dataset.dataset_name in ['pdbbind2020', 'pdbbind2016']:
146
- name = index
147
- molecular_representation = read_molecules(index, dataset.dataset_path, dataset.prot_graph_type,
148
- dataset.ligcut, dataset.protcut, dataset.lig_type,
149
- init_type=None, chain_cut=dataset.chaincut)
150
-
151
- lig_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2')
152
- if os.path.exists(os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')):
153
- prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')
154
- else:
155
- prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein.pdb')
156
-
157
- lig_coords, _, _, lig_node_type, _, prot_coords, _, _, _, _, _, _, _, _, _ = molecular_representation
158
-
159
- if dataset.canonical_oritentaion and canonical_oritentaion:
160
- new_trajectory = []
161
- _, _, _, _, _, rotation, translation = canonical_protein_ligand_orientation(lig_coords, prot_coords)
162
- for coords in trajectory:
163
- new_trajectory.append((coords @ rotation.T) - translation)
164
- trajectory = new_trajectory
165
-
166
- trajectory_data = []
167
- num_atoms = len(coords)
168
- for coords in trajectory:
169
- data = (num_atoms, lig_node_type, coords)
170
- trajectory_data.append(data)
171
- sdf_file_string = generated_to_rdmol_trajectory(trajectory_data)
172
-
173
- if name_suffix is None:
174
- sdf_string_save_sdf_file(sdf_file_string, save_path, suffix=name)
175
- else:
176
- sdf_string_save_sdf_file(sdf_file_string, save_path, suffix=f'{name}_{name_suffix}')
177
-
178
- if move_truth:
179
- output_path = os.path.join(save_path, 'visualize_dir')
180
- cmd = f'cp {prot_path_direct} {output_path}'
181
- cmd += f'&& cp {lig_path_direct} {output_path}'
182
- os.system(cmd)
183
-
184
- return
185
-
186
- def visualize_generated_coordinates(coords, index, dataset, save_path, move_truth=True, name_suffix=None, canonical_oritentaion=True):
187
-
188
- if dataset.dataset_name in ['crossdock2020', 'crossdock2020_test']:
189
- lig_path = index
190
- lig_path_split = lig_path.split('/')
191
- lig_dir, lig_base = lig_path_split[0], lig_path_split[1]
192
- prot_path = os.path.join(lig_dir, lig_base[:10]+'.pdb')
193
-
194
- if not os.path.exists(os.path.join(save_path, 'visualize_dir', lig_dir)):
195
- os.makedirs(os.path.join(save_path, 'visualize_dir', lig_dir))
196
-
197
- name = index[:-4]
198
-
199
- assert prot_path.endswith('_rec.pdb')
200
- molecular_representation = read_molecules_crossdock(lig_path, prot_path, dataset.ligcut, dataset.protcut,
201
- dataset.lig_type, dataset.prot_graph_type, dataset.dataset_path, dataset.chaincut)
202
-
203
- lig_path_direct = os.path.join(dataset.dataset_path, lig_path)
204
- prot_path_direct = os.path.join(dataset.dataset_path, prot_path)
205
-
206
-
207
- elif dataset.dataset_name in ['pdbbind2020','pdbbind2016']:
208
- name = index
209
- molecular_representation = read_molecules(index, dataset.dataset_path, dataset.prot_graph_type,
210
- dataset.ligcut, dataset.protcut, dataset.lig_type,
211
- init_type=None, chain_cut=dataset.chaincut)
212
-
213
- lig_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2')
214
- if os.path.exists(os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')):
215
- prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')
216
- else:
217
- prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein.pdb')
218
-
219
- lig_coords, _, _, lig_node_type, _, prot_coords, _, _, _, _, _, _, _, _, _ = molecular_representation
220
-
221
- if dataset.canonical_oritentaion and canonical_oritentaion:
222
- _, _ , _, _, _, rotation, translation = canonical_protein_ligand_orientation(lig_coords, prot_coords)
223
- coords = (coords @ rotation.T) - translation
224
-
225
- num_atoms = len(coords)
226
-
227
- data = (num_atoms, lig_node_type, coords)
228
- sdf_string = generated_to_sdf(data)
229
-
230
- sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_{name_suffix}.sdf')
231
- with open(sdf_path, 'w') as f:
232
- f.write(sdf_string)
233
-
234
- if move_truth:
235
- lig_path_direct_sdf = os.path.join(dataset.dataset_path, name, f'{name}_ligand.sdf')
236
- output_path = os.path.join(save_path, 'visualize_dir')
237
- cmd = f'cp {prot_path_direct} {output_path}'
238
- cmd += f' && cp {lig_path_direct} {output_path}'
239
- cmd += f' && cp {lig_path_direct_sdf} {output_path}'
240
- os.system(cmd)
241
-
242
- def visualize_predicted_pocket(binding_site_flag, index, dataset, save_path, move_truth=True, name_suffix=None, canonical_oritentaion=True):
243
- if not os.path.exists(os.path.join(save_path, 'visualize_dir')):
244
- os.makedirs(os.path.join(save_path, 'visualize_dir'))
245
-
246
- if dataset.dataset_name in ['crossdock2020', 'crossdock2020_test']:
247
- lig_path = index
248
- lig_path_split = lig_path.split('/')
249
- lig_dir, lig_base = lig_path_split[0], lig_path_split[1]
250
- prot_path = os.path.join(lig_dir, lig_base[:10]+'.pdb')
251
-
252
- if not os.path.exists(os.path.join(save_path, 'visualize_dir', lig_dir)):
253
- os.makedirs(os.path.join(save_path, 'visualize_dir', lig_dir))
254
-
255
- name = index[:-4]
256
-
257
- assert prot_path.endswith('_rec.pdb')
258
- molecular_representation = read_molecules_crossdock(lig_path, prot_path, dataset.ligcut, dataset.protcut,
259
- dataset.lig_type, dataset.prot_graph_type, dataset.dataset_path, dataset.chaincut)
260
-
261
- lig_path_direct = os.path.join(dataset.dataset_path, lig_path)
262
- prot_path_direct = os.path.join(dataset.dataset_path, prot_path)
263
-
264
- elif dataset.dataset_name in ['pdbbind2020','pdbbind2016']:
265
- name = index
266
- molecular_representation = read_molecules(index, dataset.dataset_path, dataset.prot_graph_type,
267
- dataset.ligcut, dataset.protcut, dataset.lig_type,
268
- init_type=None, chain_cut=dataset.chaincut)
269
-
270
- lig_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2')
271
- if os.path.exists(os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')):
272
- prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')
273
- else:
274
- prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein.pdb')
275
-
276
- lig_coords, _, _, lig_node_type, _, prot_coords, _, _, _, _, _, _, _, _, _ = molecular_representation
277
-
278
- coords = torch.from_numpy(prot_coords[binding_site_flag.cpu()])
279
-
280
- num_atoms = len(coords)
281
-
282
- data = (num_atoms, [6] * num_atoms, coords)
283
- sdf_string = generated_to_sdf(data)
284
-
285
- sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_{name_suffix}.sdf')
286
- with open(sdf_path, 'w') as f:
287
- f.write(sdf_string)
288
-
289
- if move_truth:
290
- output_path = os.path.join(save_path, 'visualize_dir')
291
- cmd = f'cp {prot_path_direct} {output_path}'
292
- cmd += f'&& cp {lig_path_direct} {output_path}'
293
- os.system(cmd)
294
-
295
- def visualize_predicted_link_map(pred_prob, true_prob, pdb_name, dataset, save_path):
296
- """
297
- :param pred_prob: [N,M], torch.tensor
298
- :param true_prob: [N,M], torch.tensor
299
- :param pdb_name: string
300
- :param dataset:
301
- :param save_path:
302
- :return:
303
- """
304
- if not os.path.exists(os.path.join(save_path, 'visualize_dir')):
305
- os.makedirs(os.path.join(save_path, 'visualize_dir'))
306
-
307
- pd.DataFrame(pred_prob.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_link_map_pred.csv'))
308
- pd.DataFrame(true_prob.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_link_map_true.csv'))
309
-
310
- def visualize_edge_coef_map(feats_coef, coords_coef, pdb_name, dataset, save_path, layer_index):
311
- if not os.path.exists(os.path.join(save_path, 'visualize_dir')):
312
- os.makedirs(os.path.join(save_path, 'visualize_dir'))
313
-
314
- pd.DataFrame(feats_coef.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_feats_coef_layer_{layer_index}.csv'))
315
- pd.DataFrame(coords_coef.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_coords_coef_layer_{layer_index}.csv'))
316
-
317
- def collect_bond_dists(index, dataset, save_path, name_suffix='pred'):
318
- """
319
- Collect the lengths for each type of chemical bond in given valid molecular geometries.
320
- Args:
321
- mol_dicts (dict): A python dict where the key is the number of atoms, and the value indexed by that key is another python dict storing the atomic
322
- number matrix (indexed by the key '_atomic_numbers') and the coordinate tensor (indexed by the key '_positions') of all generated molecular geometries with that atom number.
323
- valid_list (list): the list of bool values indicating whether each molecular geometry is chemically valid. Note that only the bond lengths of
324
- valid molecular geometries will be collected.
325
- con_mat_list (list): the list of bond order matrices.
326
-
327
- :rtype: :class:`dict` a python dict where the key is the bond type, and the value indexed by that key is the list of all bond lengths of that bond.
328
- """
329
- name = index
330
- bonds_dist = []
331
-
332
- lig_path_mol2 = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2')
333
- lig_path_sdf = os.path.join(dataset.dataset_path, name, f'{name}_ligand.sdf')
334
- rdmol = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True)
335
- if rdmol == None: # read mol2 file if sdf file cannot be sanitized
336
- rdmol = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True)
337
- gd_atom_coords = rdmol.GetConformer().GetPositions()
338
-
339
- pred_sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_{name_suffix}.sdf')
340
- pred_m_lig = next(pybel.readfile('sdf', pred_sdf_path))
341
- pred_atom_coords = np.array([atom.coords for atom in pred_m_lig], dtype=np.float32)
342
- assert len(pred_atom_coords) == len(gd_atom_coords)
343
-
344
- init_sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf')
345
- inti_m_lig = next(pybel.readfile('sdf', init_sdf_path))
346
- init_atom_coords = np.array([atom.coords for atom in inti_m_lig], dtype=np.float32)
347
- assert len(init_atom_coords) == len(gd_atom_coords)
348
-
349
- for bond in rdmol.GetBonds():
350
- start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom()
351
- start_idx, end_idx = start_atom.GetIdx(), end_atom.GetIdx()
352
- if start_idx < end_idx:
353
- continue
354
- start_atom_type, end_atom_type = start_atom.GetAtomicNum(), end_atom.GetAtomicNum()
355
- bond_type = BOND_TYPES[bond.GetBondType()]
356
-
357
- gd_bond_dist = np.linalg.norm(gd_atom_coords[start_idx] - gd_atom_coords[end_idx])
358
- pred_bond_dist = np.linalg.norm(pred_atom_coords[start_idx] - pred_atom_coords[end_idx])
359
- init_bond_dist = np.linalg.norm(init_atom_coords[start_idx] - init_atom_coords[end_idx])
360
-
361
- z1, z2 = min(start_atom_type, end_atom_type), max(start_atom_type, end_atom_type)
362
- bonds_dist.append((z1, z2, bond_type, gd_bond_dist, pred_bond_dist, init_bond_dist))
363
-
364
- return bonds_dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
UltraFlow/data/INDEX_general_PL_data.2016 DELETED
File without changes
UltraFlow/data/INDEX_general_PL_data.2020 DELETED
File without changes
UltraFlow/data/INDEX_refined_data.2020 DELETED
File without changes
UltraFlow/data/chembl/P49841/P49841_valid_chains.pdb DELETED
File without changes
UltraFlow/data/chembl/P49841/P49841_valid_pvalue.smi DELETED
File without changes
UltraFlow/data/chembl/P49841/P49841_valid_smiles.smi DELETED
File without changes
UltraFlow/data/chembl/P49841/visualize_dir/total_vs.sdf DELETED
File without changes
UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_chains.pdb DELETED
File without changes
UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_pvalue.smi DELETED
File without changes
UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_smiles.smi DELETED
File without changes
UltraFlow/data/chembl/Q9Y233/visualize_dir/total_vs.sdf DELETED
File without changes
UltraFlow/data/core_set DELETED
File without changes
UltraFlow/data/csar_2016 DELETED
File without changes
UltraFlow/data/csar_2020 DELETED
File without changes
UltraFlow/data/csar_new_2016 DELETED
File without changes
UltraFlow/data/horizontal_test.pkl DELETED
File without changes
UltraFlow/data/horizontal_train.pkl DELETED
File without changes
UltraFlow/data/horizontal_valid.pkl DELETED
File without changes
UltraFlow/data/pdb2016_total DELETED
File without changes
UltraFlow/data/pdb_after_2016 DELETED
File without changes
UltraFlow/data/pdbbind2016_general_gign_train DELETED
File without changes
UltraFlow/data/pdbbind2016_general_gign_valid DELETED
File without changes
UltraFlow/data/pdbbind2016_general_train DELETED
File without changes
UltraFlow/data/pdbbind2016_general_valid DELETED
File without changes
UltraFlow/data/pdbbind2016_test DELETED
File without changes
UltraFlow/data/pdbbind2016_train DELETED
File without changes
UltraFlow/data/pdbbind2016_train_M DELETED
File without changes
UltraFlow/data/pdbbind2016_valid DELETED
File without changes
UltraFlow/data/pdbbind2016_valid_M DELETED
File without changes
UltraFlow/data/pdbbind2020_finetune_test DELETED
File without changes
UltraFlow/data/pdbbind2020_finetune_train DELETED
File without changes
UltraFlow/data/pdbbind2020_finetune_valid DELETED
File without changes
UltraFlow/data/pdbbind2020_vstrain1 DELETED
File without changes
UltraFlow/data/pdbbind2020_vstrain2 DELETED
File without changes
UltraFlow/data/pdbbind2020_vstrain3 DELETED
File without changes
UltraFlow/data/pdbbind2020_vsvalid1 DELETED
File without changes
UltraFlow/data/pdbbind2020_vsvalid2 DELETED
File without changes
UltraFlow/data/pdbbind2020_vsvalid3 DELETED
File without changes
UltraFlow/data/pdbbind_2020_casf_test DELETED
File without changes
UltraFlow/data/pdbbind_2020_casf_train DELETED
File without changes
UltraFlow/data/pdbbind_2020_casf_valid DELETED
File without changes
UltraFlow/data/tankbind_vtrain DELETED
File without changes