Spaces:
Running
Running
Commit
·
7d69eaa
1
Parent(s):
e749e85
test
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- UltraFlow/commons/__init__.py +0 -5
- UltraFlow/commons/dock_utils.py +0 -355
- UltraFlow/commons/geomop.py +0 -529
- UltraFlow/commons/get_free_gpu.py +0 -78
- UltraFlow/commons/loss_weight.pkl +0 -3
- UltraFlow/commons/metrics.py +0 -315
- UltraFlow/commons/torch_prepare.py +0 -156
- UltraFlow/commons/visualize.py +0 -364
- UltraFlow/data/INDEX_general_PL_data.2016 +0 -0
- UltraFlow/data/INDEX_general_PL_data.2020 +0 -0
- UltraFlow/data/INDEX_refined_data.2020 +0 -0
- UltraFlow/data/chembl/P49841/P49841_valid_chains.pdb +0 -0
- UltraFlow/data/chembl/P49841/P49841_valid_pvalue.smi +0 -0
- UltraFlow/data/chembl/P49841/P49841_valid_smiles.smi +0 -0
- UltraFlow/data/chembl/P49841/visualize_dir/total_vs.sdf +0 -0
- UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_chains.pdb +0 -0
- UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_pvalue.smi +0 -0
- UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_smiles.smi +0 -0
- UltraFlow/data/chembl/Q9Y233/visualize_dir/total_vs.sdf +0 -0
- UltraFlow/data/core_set +0 -0
- UltraFlow/data/csar_2016 +0 -0
- UltraFlow/data/csar_2020 +0 -0
- UltraFlow/data/csar_new_2016 +0 -0
- UltraFlow/data/horizontal_test.pkl +0 -0
- UltraFlow/data/horizontal_train.pkl +0 -0
- UltraFlow/data/horizontal_valid.pkl +0 -0
- UltraFlow/data/pdb2016_total +0 -0
- UltraFlow/data/pdb_after_2016 +0 -0
- UltraFlow/data/pdbbind2016_general_gign_train +0 -0
- UltraFlow/data/pdbbind2016_general_gign_valid +0 -0
- UltraFlow/data/pdbbind2016_general_train +0 -0
- UltraFlow/data/pdbbind2016_general_valid +0 -0
- UltraFlow/data/pdbbind2016_test +0 -0
- UltraFlow/data/pdbbind2016_train +0 -0
- UltraFlow/data/pdbbind2016_train_M +0 -0
- UltraFlow/data/pdbbind2016_valid +0 -0
- UltraFlow/data/pdbbind2016_valid_M +0 -0
- UltraFlow/data/pdbbind2020_finetune_test +0 -0
- UltraFlow/data/pdbbind2020_finetune_train +0 -0
- UltraFlow/data/pdbbind2020_finetune_valid +0 -0
- UltraFlow/data/pdbbind2020_vstrain1 +0 -0
- UltraFlow/data/pdbbind2020_vstrain2 +0 -0
- UltraFlow/data/pdbbind2020_vstrain3 +0 -0
- UltraFlow/data/pdbbind2020_vsvalid1 +0 -0
- UltraFlow/data/pdbbind2020_vsvalid2 +0 -0
- UltraFlow/data/pdbbind2020_vsvalid3 +0 -0
- UltraFlow/data/pdbbind_2020_casf_test +0 -0
- UltraFlow/data/pdbbind_2020_casf_train +0 -0
- UltraFlow/data/pdbbind_2020_casf_valid +0 -0
- UltraFlow/data/tankbind_vtrain +0 -0
UltraFlow/commons/__init__.py
CHANGED
@@ -1,7 +1,2 @@
|
|
1 |
from .utils import *
|
2 |
-
from .torch_prepare import *
|
3 |
from .process_mols import *
|
4 |
-
from .metrics import *
|
5 |
-
from .geomop import *
|
6 |
-
from .visualize import *
|
7 |
-
from .dock_utils import *
|
|
|
1 |
from .utils import *
|
|
|
2 |
from .process_mols import *
|
|
|
|
|
|
|
|
UltraFlow/commons/dock_utils.py
DELETED
@@ -1,355 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from collections import defaultdict
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
from openbabel import pybel
|
6 |
-
from statistics import stdev
|
7 |
-
from time import time
|
8 |
-
from .utils import pmap_multi
|
9 |
-
import pandas as pd
|
10 |
-
from tqdm import tqdm
|
11 |
-
|
12 |
-
MGLTols_PYTHON = '/apdcephfs/private_jiaxianyan/dock/mgltools_x86_64Linux2_1.5.7/bin/python2.7'
|
13 |
-
Prepare_Ligand = '/apdcephfs/private_jiaxianyan/dock/mgltools_x86_64Linux2_1.5.7/MGLToolsPckgs/AutoDockTools/Utilities24/prepare_ligand4.py'
|
14 |
-
Prepare_Receptor = '/apdcephfs/private_jiaxianyan/dock/mgltools_x86_64Linux2_1.5.7/MGLToolsPckgs/AutoDockTools/Utilities24/prepare_receptor4.py'
|
15 |
-
SMINA = '/apdcephfs/private_jiaxianyan/dock/smina'
|
16 |
-
|
17 |
-
def read_matric(matric_file_path):
|
18 |
-
with open(matric_file_path) as f:
|
19 |
-
lines = f.read().strip().split('\n')
|
20 |
-
rmsd, centroid = float(lines[0].split(':')[1]), float(lines[1].split(':')[1])
|
21 |
-
return rmsd, centroid
|
22 |
-
|
23 |
-
def mol2_add_atom_index_to_atom_name(mol2_file_path):
|
24 |
-
MOL_list = [x for x in open(mol2_file_path, 'r')]
|
25 |
-
idx = [i for i, x in enumerate(MOL_list) if x.startswith('@')]
|
26 |
-
block = MOL_list[idx[1] + 1:idx[2]]
|
27 |
-
block = [x.split() for x in block]
|
28 |
-
|
29 |
-
block_new = []
|
30 |
-
atom_count = defaultdict(int)
|
31 |
-
for i in block:
|
32 |
-
at = i[5].strip().split('.')[0]
|
33 |
-
if 'H' not in at:
|
34 |
-
atom_count[at] += 1
|
35 |
-
count = atom_count[at]
|
36 |
-
at_new = at + str(count)
|
37 |
-
at_new = at_new.rjust(4)
|
38 |
-
block_new.append([i[0], at_new] + i[2:])
|
39 |
-
else:
|
40 |
-
block_new.append(i)
|
41 |
-
|
42 |
-
block_new = ['\t'.join(x) + '\n' for x in block_new]
|
43 |
-
MOL_list_new = MOL_list[:idx[1] + 1] + block_new + MOL_list[idx[2]:]
|
44 |
-
with open(mol2_file_path, 'w') as f:
|
45 |
-
for line in MOL_list_new:
|
46 |
-
f.write(line)
|
47 |
-
return
|
48 |
-
|
49 |
-
def prepare_dock_file(pdb_name, config):
|
50 |
-
visualize_dir = os.path.join(config.train.save_path, 'visualize_dir')
|
51 |
-
post_align_sdf = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.sdf')
|
52 |
-
post_align_mol2 = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.mol2')
|
53 |
-
post_align_pdbqt = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.pdbqt')
|
54 |
-
|
55 |
-
# mgltools preprocess
|
56 |
-
cmd = f'cd {visualize_dir}'
|
57 |
-
cmd += f' && obabel -i sdf {post_align_sdf} -o mol2 -O {post_align_mol2}'
|
58 |
-
|
59 |
-
if not os.path.exists(post_align_mol2):
|
60 |
-
os.system(cmd)
|
61 |
-
mol2_add_atom_index_to_atom_name(post_align_mol2)
|
62 |
-
|
63 |
-
cmd = f'cd {visualize_dir}'
|
64 |
-
cmd += f' && {MGLTols_PYTHON} {Prepare_Ligand} -l {post_align_mol2}'
|
65 |
-
|
66 |
-
if not os.path.exists(post_align_pdbqt):
|
67 |
-
os.system(cmd)
|
68 |
-
# cmd = f'obabel -i mol2 {post_align_mol2} -o pdbqt -O {post_align_pdbqt}'
|
69 |
-
# os.system(cmd)
|
70 |
-
|
71 |
-
return
|
72 |
-
|
73 |
-
def get_mol2_atom_name(mol2_file_path):
|
74 |
-
MOL_list = [x for x in open(mol2_file_path, 'r')]
|
75 |
-
idx = [i for i, x in enumerate(MOL_list) if x.startswith('@')]
|
76 |
-
block = MOL_list[idx[1] + 1:idx[2]]
|
77 |
-
block = [x.split() for x in block]
|
78 |
-
|
79 |
-
atom_names = []
|
80 |
-
|
81 |
-
for i in block:
|
82 |
-
at = i[1].strip()
|
83 |
-
atom_names.append(at)
|
84 |
-
return atom_names
|
85 |
-
|
86 |
-
def align_dock_name_and_target_name(dock_lig_atom_names, target_lig_atom_names):
|
87 |
-
dock_lig_atom_index_in_target_lig = []
|
88 |
-
target_atom_name_dict = {}
|
89 |
-
for index, atom_name in enumerate(target_lig_atom_names):
|
90 |
-
try:
|
91 |
-
assert atom_name not in target_atom_name_dict.keys()
|
92 |
-
except:
|
93 |
-
raise ValueError(atom_name,'has appeared before')
|
94 |
-
target_atom_name_dict[atom_name] = index
|
95 |
-
|
96 |
-
dock_lig_atom_name_appears_dict = defaultdict(int)
|
97 |
-
for atom_name in dock_lig_atom_names:
|
98 |
-
try:
|
99 |
-
assert atom_name not in dock_lig_atom_name_appears_dict.keys()
|
100 |
-
except:
|
101 |
-
raise ValueError(atom_name,'has appeared before')
|
102 |
-
dock_lig_atom_name_appears_dict[atom_name] += 1
|
103 |
-
try:
|
104 |
-
dock_lig_atom_index_in_target_lig.append(target_atom_name_dict[atom_name])
|
105 |
-
except:
|
106 |
-
dock_lig_atom_index_in_target_lig.append(target_atom_name_dict[atom_name+'1'])
|
107 |
-
|
108 |
-
return dock_lig_atom_index_in_target_lig
|
109 |
-
|
110 |
-
def smina_dock_result_rmsd(pdb_name, config):
|
111 |
-
# target path
|
112 |
-
target_lig_mol2 = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_ligand.mol2')
|
113 |
-
|
114 |
-
# get target coords
|
115 |
-
target_m_lig = next(pybel.readfile('mol2', target_lig_mol2))
|
116 |
-
target_lig_coords = [atom.coords for atom in target_m_lig if atom.atomicnum > 1]
|
117 |
-
target_lig_coords = np.array(target_lig_coords, dtype=np.float32) # np.array, [n, 3]
|
118 |
-
target_lig_center = target_lig_coords.mean(axis=0) # np.array, [3]
|
119 |
-
|
120 |
-
# get target atom names
|
121 |
-
visualize_dir = os.path.join(config.train.save_path, 'visualize_dir')
|
122 |
-
lig_init_mol2 = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.mol2')
|
123 |
-
target_atom_name_reference_lig = next(pybel.readfile('mol2', lig_init_mol2))
|
124 |
-
target_lig_atom_names = get_mol2_atom_name(lig_init_mol2)
|
125 |
-
target_lig_atom_names_no_h = [atom_name for atom, atom_name in zip(target_atom_name_reference_lig, target_lig_atom_names) if atom.atomicnum > 1]
|
126 |
-
|
127 |
-
# get init coords
|
128 |
-
coords_before_minimized = [atom.coords for atom in target_atom_name_reference_lig if atom.atomicnum > 1]
|
129 |
-
coords_before_minimized = np.array(coords_before_minimized, dtype=np.float32) # np.array, [n, 3]
|
130 |
-
|
131 |
-
# get smina minimized coords
|
132 |
-
dock_lig_mol2_path = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.mol2')
|
133 |
-
dock_m_lig = next(pybel.readfile('mol2', dock_lig_mol2_path))
|
134 |
-
dock_lig_coords = [atom.coords for atom in dock_m_lig if atom.atomicnum > 1]
|
135 |
-
dock_lig_coords = np.array(dock_lig_coords, dtype=np.float32) # np.array, [n, 3]
|
136 |
-
dock_lig_center = dock_lig_coords.mean(axis=0) # np.array, [3]
|
137 |
-
|
138 |
-
# get atom names
|
139 |
-
dock_lig_atom_names = get_mol2_atom_name(dock_lig_mol2_path)
|
140 |
-
dock_lig_atom_names_no_h = [atom_name for atom, atom_name in zip(dock_m_lig, dock_lig_atom_names) if atom.atomicnum > 1]
|
141 |
-
dock_lig_atom_index_in_target_lig = align_dock_name_and_target_name(dock_lig_atom_names_no_h, target_lig_atom_names_no_h)
|
142 |
-
|
143 |
-
dock_lig_coords_target_align = np.zeros([len(dock_lig_atom_index_in_target_lig),3], dtype=np.float32)
|
144 |
-
for atom_coords, atom_index_in_target_lig in zip(dock_lig_coords, dock_lig_atom_index_in_target_lig):
|
145 |
-
dock_lig_coords_target_align[atom_index_in_target_lig] = atom_coords
|
146 |
-
|
147 |
-
# rmsd
|
148 |
-
error_lig_coords = dock_lig_coords_target_align - target_lig_coords
|
149 |
-
rmsd = np.sqrt((error_lig_coords ** 2).sum(axis=1, keepdims=True).mean(axis=0))
|
150 |
-
|
151 |
-
# centroid
|
152 |
-
error_center_coords = dock_lig_center - target_lig_center
|
153 |
-
centorid_d = np.sqrt( (error_center_coords ** 2).sum() )
|
154 |
-
|
155 |
-
# get rmsd after minimized
|
156 |
-
error_lig_coords_after_minimized = dock_lig_coords_target_align - coords_before_minimized
|
157 |
-
rmsd_after_minimized = np.sqrt((error_lig_coords_after_minimized ** 2).sum(axis=1, keepdims=True).mean(axis=0))
|
158 |
-
|
159 |
-
return float(rmsd), float(centorid_d), float(rmsd_after_minimized)
|
160 |
-
|
161 |
-
def get_matric_dict(rmsds, centroids):
|
162 |
-
rmsd_mean = sum(rmsds)/len(rmsds)
|
163 |
-
centroid_mean = sum(centroids) / len(centroids)
|
164 |
-
rmsd_std = stdev(rmsds)
|
165 |
-
centroid_std = stdev(centroids)
|
166 |
-
|
167 |
-
# rmsd < 2
|
168 |
-
count = torch.tensor(rmsds) < 2.0
|
169 |
-
rmsd_less_than_2 = 100 * count.sum().item() / len(count)
|
170 |
-
|
171 |
-
# rmsd < 2
|
172 |
-
count = torch.tensor(rmsds) < 5.0
|
173 |
-
rmsd_less_than_5 = 100 * count.sum().item() / len(count)
|
174 |
-
|
175 |
-
# centorid < 2
|
176 |
-
count = torch.tensor(centroids) < 2.0
|
177 |
-
centroid_less_than_2 = 100 * count.sum().item() / len(count)
|
178 |
-
|
179 |
-
# centorid < 5
|
180 |
-
count = torch.tensor(centroids) < 5.0
|
181 |
-
centroid_less_than_5 = 100 * count.sum().item() / len(count)
|
182 |
-
|
183 |
-
metrics_dict = {'rmsd mean': rmsd_mean, 'rmsd std': rmsd_std, 'centroid mean': centroid_mean, 'centroid std': centroid_std,
|
184 |
-
'rmsd less than 2': rmsd_less_than_2, 'rmsd less than 5':rmsd_less_than_5,
|
185 |
-
'centroid less than 2': centroid_less_than_2, 'centroid less than 5': centroid_less_than_5}
|
186 |
-
return metrics_dict
|
187 |
-
|
188 |
-
def run_smina_dock(pdb_name ,config):
|
189 |
-
|
190 |
-
r_pdbqt = os.path.join(config.test_set.dataset_path, pdb_name, f'{pdb_name}_protein_processed.pdbqt')
|
191 |
-
l_pdbqt = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}.pdbqt')
|
192 |
-
out_mol2 = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.mol2')
|
193 |
-
log_file = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.log')
|
194 |
-
cmd = f'{SMINA}' \
|
195 |
-
f' --receptor {r_pdbqt}' \
|
196 |
-
f' --ligand {l_pdbqt}' \
|
197 |
-
f' --out {out_mol2}' \
|
198 |
-
f' --log {log_file}' \
|
199 |
-
f' --minimize'
|
200 |
-
os.system(cmd)
|
201 |
-
|
202 |
-
return
|
203 |
-
|
204 |
-
def run_score_only(ligand_file, protein_file, out_log_file):
|
205 |
-
cmd = f'{SMINA}' \
|
206 |
-
f' --receptor {protein_file}' \
|
207 |
-
f' --ligand {ligand_file}' \
|
208 |
-
f' --score_only' \
|
209 |
-
f' > {out_log_file}'
|
210 |
-
os.system(cmd)
|
211 |
-
|
212 |
-
with open(out_log_file, 'r') as f:
|
213 |
-
lines = f.read().strip().split('\n')
|
214 |
-
affinity_score = float(lines[21].split()[1])
|
215 |
-
|
216 |
-
return affinity_score
|
217 |
-
|
218 |
-
def run_smina_score_after_predict(config):
|
219 |
-
pdb_name_list = config.test_set.names
|
220 |
-
smina_score_list = []
|
221 |
-
for pdb_name in tqdm(pdb_name_list):
|
222 |
-
ligand_file = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_pred.sdf')
|
223 |
-
protein_file = os.path.join(config.test_set.dataset_path, pdb_name, f'{pdb_name}_protein_processed.pdbqt')
|
224 |
-
out_log_file = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_pred_smina_score.out')
|
225 |
-
smina_score = run_score_only(ligand_file, protein_file, out_log_file)
|
226 |
-
smina_score_list.append(smina_score)
|
227 |
-
|
228 |
-
result_d = {'pdb_name':pdb_name_list, 'smina_score':smina_score_list}
|
229 |
-
pd.DataFrame(result_d).to_csv(os.path.join(config.train.save_path, 'visualize_dir', 'pred_smina_score.csv'))
|
230 |
-
return
|
231 |
-
|
232 |
-
def run_smina_minimize_after_predict(config):
|
233 |
-
minimize_time = 0
|
234 |
-
|
235 |
-
pdb_name_list = config.test_set.names
|
236 |
-
|
237 |
-
# pmap_multi(prepare_dock_file, zip(pdb_name_list, [config] * len(pdb_name_list)),
|
238 |
-
# n_jobs=8, desc='mgltools preparing ...')
|
239 |
-
|
240 |
-
rmsds_post_dock, centroids_post_dock = [], []
|
241 |
-
rmsds_post, centroids_post = [], []
|
242 |
-
rmsds, centroids = [], []
|
243 |
-
|
244 |
-
rmsds_after_minimized = {'pdb_name':[], 'rmsd':[]}
|
245 |
-
|
246 |
-
valid_pdb_name = []
|
247 |
-
error_list = []
|
248 |
-
# for pdb_name in tqdm(pdb_name_list):
|
249 |
-
# try:
|
250 |
-
# minimize_begin_time = time()
|
251 |
-
# run_smina_dock(pdb_name, config)
|
252 |
-
# minimize_time += time() - minimize_begin_time
|
253 |
-
# rmsd_post_dock, centroid_post_dock, rmsd_after_minimized = smina_dock_result_rmsd(pdb_name, config)
|
254 |
-
# rmsds_post_dock.append(rmsd_post_dock)
|
255 |
-
# centroids_post_dock.append(centroid_post_dock)
|
256 |
-
#
|
257 |
-
# rmsds_after_minimized['pdb_name'].append(pdb_name)
|
258 |
-
# rmsds_after_minimized['rmsd'].append(rmsd_after_minimized)
|
259 |
-
# print(f'{pdb_name} smina minimized, rmsd: {rmsd_post_dock}, centroid: {centroid_post_dock}')
|
260 |
-
#
|
261 |
-
# text_matics = 'rmsd:{}\ncentroid_d:{}\n'.format(rmsd_post_dock, centroid_post_dock)
|
262 |
-
# post_dock_matric_path = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_matrics_post_{config.train.align_method}_dock.txt')
|
263 |
-
# with open(post_dock_matric_path, 'w') as f:
|
264 |
-
# f.write(text_matics)
|
265 |
-
#
|
266 |
-
# # read matrics
|
267 |
-
# post_matric_path = os.path.join(config.train.save_path, 'visualize_dir',
|
268 |
-
# f'{pdb_name}_matrics_post_{config.train.align_method}.txt')
|
269 |
-
#
|
270 |
-
# matric_path = os.path.join(config.train.save_path, 'visualize_dir',
|
271 |
-
# f'{pdb_name}_matrics.txt')
|
272 |
-
# rmsd_post, centroid_post = read_matric(post_matric_path)
|
273 |
-
# rmsds_post.append(rmsd_post)
|
274 |
-
# centroids_post.append(centroid_post)
|
275 |
-
#
|
276 |
-
# rmsd, centroid = read_matric(matric_path)
|
277 |
-
# rmsds.append(rmsd)
|
278 |
-
# centroids.append(centroid)
|
279 |
-
# valid_pdb_name.append(pdb_name)
|
280 |
-
#
|
281 |
-
# except:
|
282 |
-
# print(f'{pdb_name} dock error!')
|
283 |
-
# error_list.append(pdb_name)
|
284 |
-
#
|
285 |
-
dock_score_analysis(pdb_name_list, config)
|
286 |
-
|
287 |
-
pd.DataFrame(rmsds_after_minimized).to_csv(os.path.join(config.train.save_path, 'visualize_dir', f'rmsd_after_smina_minimzed.csv'))
|
288 |
-
|
289 |
-
matric_dict_post_dock = get_matric_dict(rmsds_post_dock, centroids_post_dock)
|
290 |
-
matric_dict_post = get_matric_dict(rmsds_post, centroids_post)
|
291 |
-
matric_dict = get_matric_dict(rmsds, centroids)
|
292 |
-
|
293 |
-
matric_dict_post_dock_d = {'pdb_name': valid_pdb_name, 'rmsd': rmsds_post_dock, 'centroid': centroids_post_dock}
|
294 |
-
pd.DataFrame(matric_dict_post_dock_d).to_csv(
|
295 |
-
os.path.join(config.train.save_path, 'visualize_dir', 'matric_distribution_after_minimized.csv'))
|
296 |
-
|
297 |
-
matric_str = ''
|
298 |
-
for key in matric_dict_post_dock.keys():
|
299 |
-
if key.endswith('mean') or key.endswith('std'):
|
300 |
-
matric_str += '| post dock {}: {:.4f} '.format(key, matric_dict_post_dock[key])
|
301 |
-
else:
|
302 |
-
matric_str += '| post dock {}: {:.4f}% '.format(key, matric_dict_post_dock[key])
|
303 |
-
|
304 |
-
for key in matric_dict_post.keys():
|
305 |
-
if key.endswith('mean') or key.endswith('std'):
|
306 |
-
matric_str += '| post {}: {:.4f} '.format(key, matric_dict_post[key])
|
307 |
-
else:
|
308 |
-
matric_str += '| post {}: {:.4f}% '.format(key, matric_dict_post[key])
|
309 |
-
|
310 |
-
for key in matric_dict.keys():
|
311 |
-
if key.endswith('mean') or key.endswith('std'):
|
312 |
-
matric_str += '| {}: {:.4f} '.format(key, matric_dict[key])
|
313 |
-
else:
|
314 |
-
matric_str += '| {}: {:.4f}% '.format(key, matric_dict[key])
|
315 |
-
|
316 |
-
print(f'smina minimize results ========================')
|
317 |
-
print(matric_str)
|
318 |
-
print(f'pdb name error list ==========================')
|
319 |
-
print('\t'.join(error_list))
|
320 |
-
print(f'smina minimize time: {minimize_time}')
|
321 |
-
|
322 |
-
return
|
323 |
-
|
324 |
-
def get_dock_score(log_path):
|
325 |
-
with open(log_path, 'r') as f:
|
326 |
-
lines = f.read().strip().split('\n')
|
327 |
-
|
328 |
-
affinity_score = float(lines[20].split()[1])
|
329 |
-
|
330 |
-
return affinity_score
|
331 |
-
|
332 |
-
def dock_score_analysis(pdb_name_list, config):
|
333 |
-
dock_score_d = {'name':[], 'score':[]}
|
334 |
-
error_num = 0
|
335 |
-
for pdb_name in tqdm(pdb_name_list):
|
336 |
-
log_path = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.log')
|
337 |
-
try:
|
338 |
-
affinity_score = get_dock_score(log_path)
|
339 |
-
except:
|
340 |
-
affinity_score = None
|
341 |
-
dock_score_d['name'].append(pdb_name)
|
342 |
-
dock_score_d['score'].append(affinity_score)
|
343 |
-
print('error num,', error_num)
|
344 |
-
pd.DataFrame(dock_score_d).to_csv(os.path.join(config.train.save_path, 'visualize_dir', f'post_align_{config.train.align_method}_smina_minimize_score.csv'))
|
345 |
-
|
346 |
-
|
347 |
-
def structure2score(score_type):
|
348 |
-
try:
|
349 |
-
assert score_type in ['vina', 'smina', 'rfscore', 'ign', 'nnscore']
|
350 |
-
except:
|
351 |
-
raise ValueError(f'{score_type} if not supported scoring function type')
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UltraFlow/commons/geomop.py
DELETED
@@ -1,529 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import rdkit.Chem as Chem
|
3 |
-
import numpy as np
|
4 |
-
import copy
|
5 |
-
from rdkit.Chem import AllChem
|
6 |
-
from rdkit.Chem import rdMolTransforms
|
7 |
-
from rdkit.Geometry import Point3D
|
8 |
-
from scipy.optimize import differential_evolution
|
9 |
-
from .process_mols import read_rdmol
|
10 |
-
import os
|
11 |
-
import math
|
12 |
-
from openbabel import pybel
|
13 |
-
from tqdm import tqdm
|
14 |
-
|
15 |
-
def get_d_from_pos(pos, edge_index):
|
16 |
-
return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) # (num_edge)
|
17 |
-
|
18 |
-
def kabsch(coords_A, coords_B, debug=True, device=None):
|
19 |
-
# rotate and translate coords_A to coords_B pos
|
20 |
-
coords_A_mean = coords_A.mean(dim=0, keepdim=True) # (1,3)
|
21 |
-
coords_B_mean = coords_B.mean(dim=0, keepdim=True) # (1,3)
|
22 |
-
|
23 |
-
# A = (coords_A - coords_A_mean).transpose(0, 1) @ (coords_B - coords_B_mean)
|
24 |
-
A = (coords_A).transpose(0, 1) @ (coords_B )
|
25 |
-
if torch.isnan(A).any():
|
26 |
-
print('A Nan encountered')
|
27 |
-
assert not torch.isnan(A).any()
|
28 |
-
|
29 |
-
if torch.isinf(A).any():
|
30 |
-
print('inf encountered')
|
31 |
-
assert not torch.isinf(A).any()
|
32 |
-
|
33 |
-
U, S, Vt = torch.linalg.svd(A)
|
34 |
-
num_it = 0
|
35 |
-
while torch.min(S) < 1e-3 or torch.min(
|
36 |
-
torch.abs((S ** 2).view(1, 3) - (S ** 2).view(3, 1) + torch.eye(3).to(device))) < 1e-2:
|
37 |
-
if debug: print('S inside loop ', num_it, ' is ', S, ' and A = ', A)
|
38 |
-
A = A + torch.rand(3, 3).to(device) * torch.eye(3).to(device)
|
39 |
-
U, S, Vt = torch.linalg.svd(A)
|
40 |
-
num_it += 1
|
41 |
-
if num_it > 10: raise Exception('SVD was consitantly unstable')
|
42 |
-
|
43 |
-
corr_mat = torch.diag(torch.tensor([1, 1, torch.sign(torch.det(A))], device=device))
|
44 |
-
rotation = (U @ corr_mat) @ Vt
|
45 |
-
|
46 |
-
translation = coords_B_mean - torch.t(rotation @ coords_A_mean.t()) # (1,3)
|
47 |
-
|
48 |
-
# new_coords = (rotation @ coords_A.t()).t() + translation
|
49 |
-
|
50 |
-
return rotation, translation
|
51 |
-
|
52 |
-
def rigid_transform_Kabsch_3D(A, B):
|
53 |
-
assert A.shape[1] == B.shape[1]
|
54 |
-
num_rows, num_cols = A.shape
|
55 |
-
if num_rows != 3:
|
56 |
-
raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
|
57 |
-
num_rows, num_cols = B.shape
|
58 |
-
if num_rows != 3:
|
59 |
-
raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
|
60 |
-
|
61 |
-
|
62 |
-
# find mean column wise: 3 x 1
|
63 |
-
centroid_A = np.mean(A, axis=1, keepdims=True)
|
64 |
-
centroid_B = np.mean(B, axis=1, keepdims=True)
|
65 |
-
|
66 |
-
# subtract mean
|
67 |
-
Am = A - centroid_A
|
68 |
-
Bm = B - centroid_B
|
69 |
-
|
70 |
-
H = Am @ Bm.T
|
71 |
-
|
72 |
-
# find rotation
|
73 |
-
U, S, Vt = np.linalg.svd(H)
|
74 |
-
|
75 |
-
R = Vt.T @ U.T
|
76 |
-
|
77 |
-
# special reflection case
|
78 |
-
if np.linalg.det(R) < 0:
|
79 |
-
# print("det(R) < R, reflection detected!, correcting for it ...")
|
80 |
-
SS = np.diag([1.,1.,-1.])
|
81 |
-
R = (Vt.T @ SS) @ U.T
|
82 |
-
assert math.fabs(np.linalg.det(R) - 1) < 1e-5
|
83 |
-
|
84 |
-
t = -R @ centroid_A + centroid_B
|
85 |
-
return R, t
|
86 |
-
|
87 |
-
def align_molecule_a_according_molecule_b(molecule_a_path, molecule_b_path, device=None, save=False, kabsch_no_h=True):
|
88 |
-
m_a = Chem.MolFromMol2File(molecule_a_path, sanitize=False, removeHs=False)
|
89 |
-
m_b = Chem.MolFromMol2File(molecule_b_path, sanitize=False, removeHs=False)
|
90 |
-
pos_a = torch.tensor(m_a.GetConformer().GetPositions())
|
91 |
-
pos_b = torch.tensor(m_b.GetConformer().GetPositions())
|
92 |
-
m_a_no_h = Chem.RemoveHs(m_a)
|
93 |
-
m_b_no_h = Chem.RemoveHs(m_b)
|
94 |
-
pos_a_no_h = torch.tensor(m_a_no_h.GetConformer().GetPositions())
|
95 |
-
pos_b_no_h = torch.tensor(m_b_no_h.GetConformer().GetPositions())
|
96 |
-
|
97 |
-
if kabsch_no_h:
|
98 |
-
rotation, translation = kabsch(pos_a_no_h, pos_b_no_h, device=device)
|
99 |
-
else:
|
100 |
-
rotation, translation = kabsch(pos_a, pos_b, device=device)
|
101 |
-
pos_a_new = (rotation @ pos_a.t()).t() + translation
|
102 |
-
# print(np.sqrt(np.sum((pos_a.numpy() - pos_b.numpy()) ** 2,axis=1).mean()))
|
103 |
-
# print(np.sqrt(np.sum((pos_a_new.numpy() - pos_b.numpy()) ** 2, axis=1).mean()))
|
104 |
-
|
105 |
-
return pos_a_new, rotation, translation
|
106 |
-
|
107 |
-
def get_principle_axes(xyz,scale_factor=20,pdb_name=None):
|
108 |
-
#create coordinates array
|
109 |
-
coord = np.array(xyz, float)
|
110 |
-
# compute geometric center
|
111 |
-
center = np.mean(coord, 0)
|
112 |
-
# print("Coordinates of the geometric center:\n", center)
|
113 |
-
# center with geometric center
|
114 |
-
coord = coord - center
|
115 |
-
# compute principal axis matrix
|
116 |
-
inertia = np.dot(coord.transpose(), coord)
|
117 |
-
e_values, e_vectors = np.linalg.eig(inertia)
|
118 |
-
#--------------------------------------------------------------------------
|
119 |
-
# order eigen values (and eigen vectors)
|
120 |
-
#
|
121 |
-
# axis1 is the principal axis with the biggest eigen value (eval1)
|
122 |
-
# axis2 is the principal axis with the second biggest eigen value (eval2)
|
123 |
-
# axis3 is the principal axis with the smallest eigen value (eval3)
|
124 |
-
#--------------------------------------------------------------------------
|
125 |
-
order = np.argsort(e_values)
|
126 |
-
eval3, eval2, eval1 = e_values[order]
|
127 |
-
axis3, axis2, axis1 = e_vectors[:, order].transpose()
|
128 |
-
|
129 |
-
return np.array([axis1, axis2, axis3]), center
|
130 |
-
|
131 |
-
def get_rotation_and_translation(xyz):
|
132 |
-
protein_principle_axes_system, system_center = get_principle_axes(xyz)
|
133 |
-
rotation = protein_principle_axes_system.T
|
134 |
-
translation = -system_center
|
135 |
-
return rotation, translation
|
136 |
-
|
137 |
-
def canonical_protein_ligand_orientation(lig_coords, prot_coords):
|
138 |
-
rotation, translation = get_rotation_and_translation(prot_coords)
|
139 |
-
lig_canoical_truth_coords = (lig_coords + translation) @ rotation
|
140 |
-
prot_canonical_truth_coords = (prot_coords + translation) @ rotation
|
141 |
-
rotation_lig, translation_lig = get_rotation_and_translation(lig_coords)
|
142 |
-
lig_canonical_init_coords = (lig_coords + translation_lig) @ rotation_lig
|
143 |
-
|
144 |
-
return lig_coords, lig_canoical_truth_coords, lig_canonical_init_coords, \
|
145 |
-
prot_coords, prot_canonical_truth_coords,\
|
146 |
-
rotation, translation
|
147 |
-
|
148 |
-
def canonical_single_molecule_orientation(m_coords):
|
149 |
-
rotation, translation = get_rotation_and_translation(m_coords)
|
150 |
-
canonical_init_coords = (m_coords + translation) @ rotation
|
151 |
-
return canonical_init_coords
|
152 |
-
|
153 |
-
# Clockwise dihedral2 from https://stackoverflow.com/questions/20305272/dihedral-torsion-angle-from-four-points-in-cartesian-coordinates-in-python
|
154 |
-
def GetDihedralFromPointCloud(Z, atom_idx):
|
155 |
-
p = Z[list(atom_idx)]
|
156 |
-
b = p[:-1] - p[1:]
|
157 |
-
b[0] *= -1 #########################
|
158 |
-
v = np.array( [ v - (v.dot(b[1])/b[1].dot(b[1])) * b[1] for v in [b[0], b[2]] ] )
|
159 |
-
# Normalize vectors
|
160 |
-
v /= np.sqrt(np.einsum('...i,...i', v, v)).reshape(-1,1)
|
161 |
-
b1 = b[1] / np.linalg.norm(b[1])
|
162 |
-
x = np.dot(v[0], v[1])
|
163 |
-
m = np.cross(v[0], b1)
|
164 |
-
y = np.dot(m, v[1])
|
165 |
-
return np.degrees(np.arctan2( y, x ))
|
166 |
-
|
167 |
-
def A_transpose_matrix(alpha):
|
168 |
-
return np.array([[np.cos(np.radians(alpha)), np.sin(np.radians(alpha))],
|
169 |
-
[-np.sin(np.radians(alpha)), np.cos(np.radians(alpha))]], dtype=np.double)
|
170 |
-
|
171 |
-
def S_vec(alpha):
|
172 |
-
return np.array([[np.cos(np.radians(alpha))],
|
173 |
-
[np.sin(np.radians(alpha))]], dtype=np.double)
|
174 |
-
|
175 |
-
def get_dihedral_vonMises(mol, conf, atom_idx, Z):
|
176 |
-
Z = np.array(Z)
|
177 |
-
v = np.zeros((2,1))
|
178 |
-
iAtom = mol.GetAtomWithIdx(atom_idx[1])
|
179 |
-
jAtom = mol.GetAtomWithIdx(atom_idx[2])
|
180 |
-
k_0 = atom_idx[0]
|
181 |
-
i = atom_idx[1]
|
182 |
-
j = atom_idx[2]
|
183 |
-
l_0 = atom_idx[3]
|
184 |
-
for b1 in iAtom.GetBonds():
|
185 |
-
k = b1.GetOtherAtomIdx(i)
|
186 |
-
if k == j:
|
187 |
-
continue
|
188 |
-
for b2 in jAtom.GetBonds():
|
189 |
-
l = b2.GetOtherAtomIdx(j)
|
190 |
-
if l == i:
|
191 |
-
continue
|
192 |
-
assert k != l
|
193 |
-
s_star = S_vec(GetDihedralFromPointCloud(Z, (k, i, j, l)))
|
194 |
-
a_mat = A_transpose_matrix(GetDihedral(conf, (k, i, j, k_0)) + GetDihedral(conf, (l_0, i, j, l)))
|
195 |
-
v = v + np.matmul(a_mat, s_star)
|
196 |
-
v = v / np.linalg.norm(v)
|
197 |
-
v = v.reshape(-1)
|
198 |
-
return np.degrees(np.arctan2(v[1], v[0]))
|
199 |
-
|
200 |
-
def distance_loss_function(epoch, y_pred, x, protein_nodes_xyz, compound_pair_dis_constraint, LAS_distance_constraint_mask=None, mode=0):
|
201 |
-
dis = torch.cdist(x, protein_nodes_xyz)
|
202 |
-
dis_clamp = torch.clamp(dis, max=10)
|
203 |
-
if mode == 0:
|
204 |
-
interaction_loss = ((dis_clamp - y_pred).abs()).sum()
|
205 |
-
elif mode == 1:
|
206 |
-
interaction_loss = ((dis_clamp - y_pred)**2).sum()
|
207 |
-
elif mode == 2:
|
208 |
-
# probably not a good choice. x^0.5 has infinite gradient at x=0. added 1e-5 for numerical stability.
|
209 |
-
interaction_loss = (((dis_clamp - y_pred).abs() + 1e-5)**0.5).sum()
|
210 |
-
config_dis = torch.cdist(x, x)
|
211 |
-
if LAS_distance_constraint_mask is not None:
|
212 |
-
configuration_loss = 1 * (((config_dis-compound_pair_dis_constraint).abs())[LAS_distance_constraint_mask]).sum()
|
213 |
-
# basic exlcuded-volume. the distance between compound atoms should be at least 1.22Å
|
214 |
-
configuration_loss += 2 * ((1.22 - config_dis).relu()).sum()
|
215 |
-
else:
|
216 |
-
configuration_loss = 1 * ((config_dis-compound_pair_dis_constraint).abs()).sum()
|
217 |
-
# if epoch < 500:
|
218 |
-
# loss = interaction_loss
|
219 |
-
# else:
|
220 |
-
# loss = 1 * (interaction_loss + 5e-3 * (epoch - 500) * configuration_loss)
|
221 |
-
loss = 1 * (interaction_loss + 5e-3 * (epoch + 200) * configuration_loss)
|
222 |
-
return loss, (interaction_loss.item(), configuration_loss.item())
|
223 |
-
|
224 |
-
|
225 |
-
def distance_optimize_compound_coords(coords, y_pred, protein_nodes_xyz,
|
226 |
-
compound_pair_dis_constraint,total_epoch=1000, loss_function=distance_loss_function, LAS_distance_constraint_mask=None, mode=0, show_progress=False):
|
227 |
-
# random initialization. center at the protein center.
|
228 |
-
c_pred = protein_nodes_xyz.mean(axis=0)
|
229 |
-
x = coords
|
230 |
-
x.requires_grad = True
|
231 |
-
optimizer = torch.optim.Adam([x], lr=0.1)
|
232 |
-
loss_list = []
|
233 |
-
# optimizer = torch.optim.LBFGS([x], lr=0.01)
|
234 |
-
if show_progress:
|
235 |
-
it = tqdm(range(total_epoch))
|
236 |
-
else:
|
237 |
-
it = range(total_epoch)
|
238 |
-
for epoch in it:
|
239 |
-
optimizer.zero_grad()
|
240 |
-
loss, (interaction_loss, configuration_loss) = loss_function(epoch, y_pred, x, protein_nodes_xyz,
|
241 |
-
compound_pair_dis_constraint,
|
242 |
-
LAS_distance_constraint_mask=LAS_distance_constraint_mask,
|
243 |
-
mode=mode)
|
244 |
-
loss.backward()
|
245 |
-
optimizer.step()
|
246 |
-
loss_list.append(loss.item())
|
247 |
-
# break
|
248 |
-
return x, loss_list
|
249 |
-
|
250 |
-
def tankbind_gen(lig_pred_coords, lig_init_coords, prot_coords, LAS_mask, device='cpu', mode=0):
|
251 |
-
|
252 |
-
pred_prot_lig_inter_distance = torch.cdist(lig_pred_coords, prot_coords)
|
253 |
-
init_lig_intra_distance = torch.cdist(lig_init_coords, lig_init_coords)
|
254 |
-
try:
|
255 |
-
x, loss_list = distance_optimize_compound_coords(lig_pred_coords.to('cpu'),
|
256 |
-
pred_prot_lig_inter_distance.to('cpu'),
|
257 |
-
prot_coords.to('cpu'),
|
258 |
-
init_lig_intra_distance.to('cpu'),
|
259 |
-
LAS_distance_constraint_mask=LAS_mask.bool(),
|
260 |
-
mode=mode, show_progress=False)
|
261 |
-
except:
|
262 |
-
print('error')
|
263 |
-
|
264 |
-
return x
|
265 |
-
|
266 |
-
def kabsch_align(lig_pred_coords, name, save_path, dataset_path, device='cpu'):
|
267 |
-
rdkit_init_lig_path_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf')
|
268 |
-
openbabel_init_m_lig = next(pybel.readfile('sdf', rdkit_init_lig_path_sdf))
|
269 |
-
rdkit_init_coords = [atom.coords for atom in openbabel_init_m_lig]
|
270 |
-
rdkit_init_coords = np.array(rdkit_init_coords, dtype=np.float32) # np.array, [n, 3]
|
271 |
-
|
272 |
-
coords_pred = lig_pred_coords.detach().cpu().numpy()
|
273 |
-
|
274 |
-
R, t = rigid_transform_Kabsch_3D(rdkit_init_coords.T, coords_pred.T)
|
275 |
-
coords_pred_optimized = (R @ (rdkit_init_coords).T).T + t.squeeze()
|
276 |
-
|
277 |
-
opt_ligCoords = torch.tensor(coords_pred_optimized, device=device)
|
278 |
-
return opt_ligCoords
|
279 |
-
|
280 |
-
def equibind_align(lig_pred_coords, name, save_path, dataset_path, device='cpu'):
|
281 |
-
lig_path_mol2 = os.path.join(dataset_path, name, f'{name}_ligand.mol2')
|
282 |
-
lig_path_sdf = os.path.join(dataset_path, name, f'{name}_ligand.sdf')
|
283 |
-
m_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True)
|
284 |
-
if m_lig == None: # read mol2 file if sdf file cannot be sanitized
|
285 |
-
m_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True)
|
286 |
-
|
287 |
-
# load rdkit mol
|
288 |
-
lig_path_sdf_error = os.path.join(save_path, 'visualize_dir', f'{name}_init')
|
289 |
-
pred_lig_path_sdf_error = os.path.join(save_path, 'visualize_dir', f'{name}_pred')
|
290 |
-
pred_lig_path_sdf_true = os.path.join(save_path, 'visualize_dir', f'{name}_pred.sdf')
|
291 |
-
|
292 |
-
rdkit_init_lig_path_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf')
|
293 |
-
|
294 |
-
if not os.path.exists(rdkit_init_lig_path_sdf):
|
295 |
-
cmd = f'mv {lig_path_sdf_error} {rdkit_init_lig_path_sdf}'
|
296 |
-
os.system(cmd)
|
297 |
-
if not os.path.exists(pred_lig_path_sdf_true):
|
298 |
-
cmd = f'mv {pred_lig_path_sdf_error} {pred_lig_path_sdf_true}'
|
299 |
-
os.system(cmd)
|
300 |
-
|
301 |
-
openbabel_init_m_lig = next(pybel.readfile('sdf', rdkit_init_lig_path_sdf))
|
302 |
-
rdkit_init_coords = [atom.coords for atom in openbabel_init_m_lig]
|
303 |
-
rdkit_init_coords = np.array(rdkit_init_coords, dtype=np.float32) # np.array, [n, 3]
|
304 |
-
# rdkit_init_m_lig = read_rdmol(rdkit_init_lig_path_sdf, sanitize=True, remove_hs=True)
|
305 |
-
# rdkit_init_coords = rdkit_init_m_lig.GetConformer().GetPositions()
|
306 |
-
|
307 |
-
rdkit_init_lig = copy.deepcopy(m_lig)
|
308 |
-
conf = rdkit_init_lig.GetConformer()
|
309 |
-
for i in range(rdkit_init_lig.GetNumAtoms()):
|
310 |
-
x, y, z = rdkit_init_coords[i]
|
311 |
-
conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
|
312 |
-
|
313 |
-
coords_pred = lig_pred_coords.detach().cpu().numpy()
|
314 |
-
Z_pt_cloud = coords_pred
|
315 |
-
rotable_bonds = get_torsions([rdkit_init_lig])
|
316 |
-
new_dihedrals = np.zeros(len(rotable_bonds))
|
317 |
-
|
318 |
-
for idx, r in enumerate(rotable_bonds):
|
319 |
-
new_dihedrals[idx] = get_dihedral_vonMises(rdkit_init_lig, rdkit_init_lig.GetConformer(), r, Z_pt_cloud)
|
320 |
-
optimized_mol = apply_changes_equibind(rdkit_init_lig, new_dihedrals, rotable_bonds)
|
321 |
-
|
322 |
-
coords_pred_optimized = optimized_mol.GetConformer().GetPositions()
|
323 |
-
R, t = rigid_transform_Kabsch_3D(coords_pred_optimized.T, coords_pred.T)
|
324 |
-
coords_pred_optimized = (R @ (coords_pred_optimized).T).T + t.squeeze()
|
325 |
-
|
326 |
-
opt_ligCoords = torch.tensor(coords_pred_optimized, device=device)
|
327 |
-
return opt_ligCoords
|
328 |
-
|
329 |
-
def dock_compound(lig_pred_coords, prot_coords, name, save_path,
|
330 |
-
popsize=150, maxiter=500, seed=None, mutation=(0.5, 1),
|
331 |
-
recombination=0.8, device='cpu', torsion_num_cut=20):
|
332 |
-
if seed:
|
333 |
-
np.random.seed(seed)
|
334 |
-
torch.cuda.manual_seed_all(seed)
|
335 |
-
torch.manual_seed(seed)
|
336 |
-
|
337 |
-
# load rdkit mol
|
338 |
-
lig_path_init_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf')
|
339 |
-
openbabel_m_lig_init = next(pybel.readfile('sdf', lig_path_init_sdf))
|
340 |
-
rdkit_init_coords = [atom.coords for atom in openbabel_m_lig_init]
|
341 |
-
|
342 |
-
lig_path_true_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_ligand.sdf')
|
343 |
-
lig_path_true_mol2 = os.path.join(save_path, 'visualize_dir', f'{name}_ligand.mol2')
|
344 |
-
m_lig = read_rdmol(lig_path_true_sdf, sanitize=True, remove_hs=True)
|
345 |
-
if m_lig == None: # read mol2 file if sdf file cannot be sanitized
|
346 |
-
m_lig = read_rdmol(lig_path_true_mol2, sanitize=True, remove_hs=True)
|
347 |
-
|
348 |
-
atom_num = len(m_lig.GetConformer().GetPositions())
|
349 |
-
if len(rdkit_init_coords) != atom_num:
|
350 |
-
rdkit_init_coords = [atom.coords for atom in openbabel_m_lig_init if atom.atomicnum > 1]
|
351 |
-
lig_pred_coords_no_h_list = [atom_coords for atom,atom_coords in zip(openbabel_m_lig_init, lig_pred_coords.tolist()) if atom.atomicnum > 1]
|
352 |
-
lig_pred_coords = torch.tensor(lig_pred_coords_no_h_list, device=device)
|
353 |
-
|
354 |
-
rdkit_init_coords = np.array(rdkit_init_coords, dtype=np.float32) # np.array, [n, 3]
|
355 |
-
print(f'{name} init coords shape: {rdkit_init_coords.shape}')
|
356 |
-
print(f'{name} true coords shape: {m_lig.GetConformer().GetPositions().shape}')
|
357 |
-
|
358 |
-
rdkit_init_lig = copy.deepcopy(m_lig)
|
359 |
-
conf = rdkit_init_lig.GetConformer()
|
360 |
-
for i in range(rdkit_init_lig.GetNumAtoms()):
|
361 |
-
x, y, z = rdkit_init_coords[i]
|
362 |
-
conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
|
363 |
-
|
364 |
-
# move m_lig to pred_coords center
|
365 |
-
pred_coords_center = lig_pred_coords.cpu().numpy().mean(axis=0)
|
366 |
-
init_coords_center = rdkit_init_lig.GetConformer().GetPositions().mean(axis=0)
|
367 |
-
# print(f'{name} pred coords shape: {lig_pred_coords.shape}')
|
368 |
-
|
369 |
-
center_rel_vecs = pred_coords_center - init_coords_center
|
370 |
-
values = np.concatenate([np.array([0,0,0]),center_rel_vecs])
|
371 |
-
rdMolTransforms.TransformConformer(rdkit_init_lig.GetConformer(), GetTransformationMatrix(values))
|
372 |
-
|
373 |
-
# Set optimization function
|
374 |
-
opt = optimze_conformation(mol=rdkit_init_lig, target_coords=lig_pred_coords, device=device,
|
375 |
-
n_particles=1, seed=seed)
|
376 |
-
if len(opt.rotable_bonds) > torsion_num_cut:
|
377 |
-
return lig_pred_coords
|
378 |
-
|
379 |
-
# Define bounds for optimization
|
380 |
-
max_bound = np.concatenate([[np.pi] * 3, prot_coords.cpu().max(0)[0].numpy(), [np.pi] * len(opt.rotable_bonds)], axis=0)
|
381 |
-
min_bound = np.concatenate([[-np.pi] * 3, prot_coords.cpu().min(0)[0].numpy(), [-np.pi] * len(opt.rotable_bonds)], axis=0)
|
382 |
-
bounds = (min_bound, max_bound)
|
383 |
-
|
384 |
-
# Optimize conformations
|
385 |
-
result = differential_evolution(opt.score_conformation, list(zip(bounds[0], bounds[1])), maxiter=maxiter,
|
386 |
-
popsize=int(np.ceil(popsize / (len(opt.rotable_bonds) + 6))),
|
387 |
-
mutation=mutation, recombination=recombination, disp=False, seed=seed)
|
388 |
-
|
389 |
-
# Get optimized molecule
|
390 |
-
starting_mol = opt.mol
|
391 |
-
opt_mol = apply_changes(starting_mol, result['x'], opt.rotable_bonds)
|
392 |
-
opt_ligCoords = torch.tensor(opt_mol.GetConformer().GetPositions(), device=device)
|
393 |
-
|
394 |
-
return opt_ligCoords
|
395 |
-
|
396 |
-
class optimze_conformation():
|
397 |
-
def __init__(self, mol, target_coords, n_particles, save_molecules=False, device='cpu',
|
398 |
-
seed=None):
|
399 |
-
super(optimze_conformation, self).__init__()
|
400 |
-
if seed:
|
401 |
-
np.random.seed(seed)
|
402 |
-
|
403 |
-
self.targetCoords = torch.stack([target_coords for _ in range(n_particles)]).double()
|
404 |
-
self.n_particles = n_particles
|
405 |
-
self.rotable_bonds = get_torsions([mol])
|
406 |
-
self.save_molecules = save_molecules
|
407 |
-
self.mol = mol
|
408 |
-
self.device = device
|
409 |
-
|
410 |
-
def score_conformation(self, values):
|
411 |
-
"""
|
412 |
-
Parameters
|
413 |
-
----------
|
414 |
-
values : numpy.ndarray
|
415 |
-
set of inputs of shape :code:`(n_particles, dimensions)`
|
416 |
-
Returns
|
417 |
-
-------
|
418 |
-
numpy.ndarray
|
419 |
-
computed cost of size :code:`(n_particles, )`
|
420 |
-
"""
|
421 |
-
if len(values.shape) < 2: values = np.expand_dims(values, axis=0)
|
422 |
-
mols = [copy.copy(self.mol) for _ in range(self.n_particles)]
|
423 |
-
|
424 |
-
# Apply changes to molecules
|
425 |
-
# apply rotations
|
426 |
-
[SetDihedral(mols[m].GetConformer(), self.rotable_bonds[r], values[m, 6 + r]) for r in
|
427 |
-
range(len(self.rotable_bonds)) for m in range(self.n_particles)]
|
428 |
-
|
429 |
-
# apply transformation matrix
|
430 |
-
[rdMolTransforms.TransformConformer(mols[m].GetConformer(), GetTransformationMatrix(values[m, :6])) for m in
|
431 |
-
range(self.n_particles)]
|
432 |
-
|
433 |
-
# Calcualte distances between ligand conformation and pred ligand conformation
|
434 |
-
ligCoords_list = [torch.tensor(m.GetConformer().GetPositions(), device=self.device) for m in mols] # [n_mols, N, 3]
|
435 |
-
ligCoords = torch.stack(ligCoords_list).double() # [n_mols, N, 3]
|
436 |
-
|
437 |
-
ligCoords_error = ligCoords - self.targetCoords # [n_mols, N, 3]
|
438 |
-
ligCoords_rmsd = (ligCoords_error ** 2).sum(dim=-1).mean(dim=-1).sqrt().min().cpu().numpy()
|
439 |
-
|
440 |
-
del ligCoords_error, ligCoords, ligCoords_list, mols
|
441 |
-
|
442 |
-
return ligCoords_rmsd
|
443 |
-
|
444 |
-
def apply_changes(mol, values, rotable_bonds):
|
445 |
-
opt_mol = copy.copy(mol)
|
446 |
-
|
447 |
-
# apply rotations
|
448 |
-
[SetDihedral(opt_mol.GetConformer(), rotable_bonds[r], values[6 + r]) for r in range(len(rotable_bonds))]
|
449 |
-
|
450 |
-
# apply transformation matrix
|
451 |
-
rdMolTransforms.TransformConformer(opt_mol.GetConformer(), GetTransformationMatrix(values[:6]))
|
452 |
-
|
453 |
-
return opt_mol
|
454 |
-
|
455 |
-
def apply_changes_equibind(mol, values, rotable_bonds):
|
456 |
-
opt_mol = copy.deepcopy(mol)
|
457 |
-
# opt_mol = add_rdkit_conformer(opt_mol)
|
458 |
-
|
459 |
-
# apply rotations
|
460 |
-
[SetDihedral(opt_mol.GetConformer(), rotable_bonds[r], values[r]) for r in range(len(rotable_bonds))]
|
461 |
-
|
462 |
-
# # apply transformation matrix
|
463 |
-
# rdMolTransforms.TransformConformer(opt_mol.GetConformer(), GetTransformationMatrix(values[:6]))
|
464 |
-
|
465 |
-
return opt_mol
|
466 |
-
|
467 |
-
def get_torsions(mol_list):
|
468 |
-
atom_counter = 0
|
469 |
-
torsionList = []
|
470 |
-
dihedralList = []
|
471 |
-
for m in mol_list:
|
472 |
-
torsionSmarts = '[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]'
|
473 |
-
torsionQuery = Chem.MolFromSmarts(torsionSmarts)
|
474 |
-
matches = m.GetSubstructMatches(torsionQuery)
|
475 |
-
conf = m.GetConformer()
|
476 |
-
for match in matches:
|
477 |
-
idx2 = match[0]
|
478 |
-
idx3 = match[1]
|
479 |
-
bond = m.GetBondBetweenAtoms(idx2, idx3)
|
480 |
-
jAtom = m.GetAtomWithIdx(idx2)
|
481 |
-
kAtom = m.GetAtomWithIdx(idx3)
|
482 |
-
for b1 in jAtom.GetBonds():
|
483 |
-
if (b1.GetIdx() == bond.GetIdx()):
|
484 |
-
continue
|
485 |
-
idx1 = b1.GetOtherAtomIdx(idx2)
|
486 |
-
for b2 in kAtom.GetBonds():
|
487 |
-
if ((b2.GetIdx() == bond.GetIdx())
|
488 |
-
or (b2.GetIdx() == b1.GetIdx())):
|
489 |
-
continue
|
490 |
-
idx4 = b2.GetOtherAtomIdx(idx3)
|
491 |
-
# skip 3-membered rings
|
492 |
-
if (idx4 == idx1):
|
493 |
-
continue
|
494 |
-
# skip torsions that include hydrogens
|
495 |
-
if ((m.GetAtomWithIdx(idx1).GetAtomicNum() == 1)
|
496 |
-
or (m.GetAtomWithIdx(idx4).GetAtomicNum() == 1)):
|
497 |
-
continue
|
498 |
-
if m.GetAtomWithIdx(idx4).IsInRing():
|
499 |
-
torsionList.append(
|
500 |
-
(idx4 + atom_counter, idx3 + atom_counter, idx2 + atom_counter, idx1 + atom_counter))
|
501 |
-
break
|
502 |
-
else:
|
503 |
-
torsionList.append(
|
504 |
-
(idx1 + atom_counter, idx2 + atom_counter, idx3 + atom_counter, idx4 + atom_counter))
|
505 |
-
break
|
506 |
-
break
|
507 |
-
|
508 |
-
atom_counter += m.GetNumAtoms()
|
509 |
-
return torsionList
|
510 |
-
|
511 |
-
|
512 |
-
def SetDihedral(conf, atom_idx, new_vale):
|
513 |
-
rdMolTransforms.SetDihedralRad(conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale)
|
514 |
-
|
515 |
-
|
516 |
-
def GetDihedral(conf, atom_idx):
|
517 |
-
return rdMolTransforms.GetDihedralRad(conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3])
|
518 |
-
|
519 |
-
|
520 |
-
def GetTransformationMatrix(transformations):
|
521 |
-
x, y, z, disp_x, disp_y, disp_z = transformations
|
522 |
-
transMat = np.array([[np.cos(z) * np.cos(y), (np.cos(z) * np.sin(y) * np.sin(x)) - (np.sin(z) * np.cos(x)),
|
523 |
-
(np.cos(z) * np.sin(y) * np.cos(x)) + (np.sin(z) * np.sin(x)), disp_x],
|
524 |
-
[np.sin(z) * np.cos(y), (np.sin(z) * np.sin(y) * np.sin(x)) + (np.cos(z) * np.cos(x)),
|
525 |
-
(np.sin(z) * np.sin(y) * np.cos(x)) - (np.cos(z) * np.sin(x)), disp_y],
|
526 |
-
[-np.sin(y), np.cos(y) * np.sin(x), np.cos(y) * np.cos(x), disp_z],
|
527 |
-
[0, 0, 0, 1]], dtype=np.double)
|
528 |
-
return transMat
|
529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UltraFlow/commons/get_free_gpu.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from gpustat import GPUStatCollection
|
3 |
-
import time
|
4 |
-
def get_free_gpu(mode="memory", memory_need=10000) -> list:
|
5 |
-
r"""Get free gpu according to mode (process-free or memory-free).
|
6 |
-
Args:
|
7 |
-
mode (str, optional): memory-free or process-free. Defaults to "memory".
|
8 |
-
memory_need (int): The memory you need, used if mode=='memory'. Defaults to 10000.
|
9 |
-
Returns:
|
10 |
-
list: free gpu ids sorting by free memory
|
11 |
-
"""
|
12 |
-
assert mode in ["memory", "process"], "mode must be 'memory' or 'process'"
|
13 |
-
if mode == "memory":
|
14 |
-
assert memory_need is not None, \
|
15 |
-
"'memory_need' if None, 'memory' mode must give the free memory you want to apply for"
|
16 |
-
memory_need = int(memory_need)
|
17 |
-
assert memory_need > 0, "'memory_need' you want must be positive"
|
18 |
-
gpu_stats = GPUStatCollection.new_query()
|
19 |
-
gpu_free_id_list = []
|
20 |
-
|
21 |
-
for idx, gpu_stat in enumerate(gpu_stats):
|
22 |
-
if gpu_check_condition(gpu_stat, mode, memory_need):
|
23 |
-
gpu_free_id_list.append([idx, gpu_stat.memory_free])
|
24 |
-
print("gpu[{}]: {}MB".format(idx, gpu_stat.memory_free))
|
25 |
-
|
26 |
-
if gpu_free_id_list:
|
27 |
-
gpu_free_id_list = sorted(gpu_free_id_list,
|
28 |
-
key=lambda x: x[1],
|
29 |
-
reverse=True)
|
30 |
-
gpu_free_id_list = [i[0] for i in gpu_free_id_list]
|
31 |
-
return gpu_free_id_list
|
32 |
-
|
33 |
-
|
34 |
-
def gpu_check_condition(gpu_stat, mode, memory_need) -> bool:
|
35 |
-
r"""Check gpu is free or not.
|
36 |
-
Args:
|
37 |
-
gpu_stat (gpustat.core): gpustat to check
|
38 |
-
mode (str): memory-free or process-free.
|
39 |
-
memory_need (int): The memory you need, used if mode=='memory'
|
40 |
-
Returns:
|
41 |
-
bool: gpu is free or not
|
42 |
-
"""
|
43 |
-
if mode == "memory":
|
44 |
-
return gpu_stat.memory_free > memory_need
|
45 |
-
elif mode == "process":
|
46 |
-
for process in gpu_stat.processes:
|
47 |
-
if process["command"] == "python":
|
48 |
-
return False
|
49 |
-
return True
|
50 |
-
else:
|
51 |
-
return False
|
52 |
-
|
53 |
-
def get_device(target_gpu_idx, memory_need=10000):
|
54 |
-
# check device
|
55 |
-
# assert torch.cuda.device_count() >= len(target_gpus), 'do you set the gpus in config correctly?'
|
56 |
-
flag = None
|
57 |
-
|
58 |
-
while True:
|
59 |
-
# Get the gpu ids which have more than 10000MB memory
|
60 |
-
free_gpu_ids = get_free_gpu('memory', memory_need)
|
61 |
-
if len(free_gpu_ids) < 1:
|
62 |
-
if flag is None:
|
63 |
-
print("No GPU available now. sleeping 60s ....")
|
64 |
-
time.sleep(6)
|
65 |
-
else:
|
66 |
-
|
67 |
-
gpuid = list(set(free_gpu_ids) & set(target_gpu_idx))[0]
|
68 |
-
|
69 |
-
device = torch.device('cuda:'+str(gpuid))
|
70 |
-
print("Using device %s as main device" % device)
|
71 |
-
break
|
72 |
-
|
73 |
-
return device
|
74 |
-
|
75 |
-
if __name__ == '__main__':
|
76 |
-
target_gpu_idx = [0,1,2,3,4,5,6,7,8]
|
77 |
-
device = get_device(target_gpu_idx)
|
78 |
-
print(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UltraFlow/commons/loss_weight.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:42d7d91c0447c79418d9d547d45203612a9dcbf21355e047923237ba36d8765e
|
3 |
-
size 748
|
|
|
|
|
|
|
|
UltraFlow/commons/metrics.py
DELETED
@@ -1,315 +0,0 @@
|
|
1 |
-
from scipy import stats
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import numpy as np
|
5 |
-
from math import sqrt, ceil
|
6 |
-
from sklearn.linear_model import LinearRegression
|
7 |
-
from sklearn.metrics import ndcg_score, recall_score
|
8 |
-
import os
|
9 |
-
import pickle
|
10 |
-
import dgl
|
11 |
-
from typing import Union, List
|
12 |
-
from torch import Tensor
|
13 |
-
from statistics import stdev
|
14 |
-
|
15 |
-
def affinity_loss(affinity_pred,labels,sec_pred,bg_prot,config):
|
16 |
-
loss = nn.MSELoss(affinity_pred,labels)
|
17 |
-
if config.model.aux_w != 0:
|
18 |
-
loss += config.train.aux_w * nn.CrossEntropyLoss(sec_pred,bg_prot.ndata['s'])
|
19 |
-
return loss
|
20 |
-
|
21 |
-
def Accurate_num(outputs,y):
|
22 |
-
_, y_pred_label = torch.max(outputs, dim=1)
|
23 |
-
return torch.sum(y_pred_label == y.data).item()
|
24 |
-
|
25 |
-
def RMSE(y,f):
|
26 |
-
rmse = sqrt(((y - f)**2).mean(axis=0))
|
27 |
-
return rmse
|
28 |
-
|
29 |
-
def MAE(y,f):
|
30 |
-
mae = (np.abs(y-f)).mean()
|
31 |
-
return mae
|
32 |
-
|
33 |
-
def SD(y,f):
|
34 |
-
f,y = f.reshape(-1,1),y.reshape(-1,1)
|
35 |
-
lr = LinearRegression()
|
36 |
-
lr.fit(f,y)
|
37 |
-
y_ = lr.predict(f)
|
38 |
-
sd = (((y - y_) ** 2).sum() / (len(y) - 1)) ** 0.5
|
39 |
-
return sd
|
40 |
-
|
41 |
-
def Pearson(y,f):
|
42 |
-
y,f = y.flatten(),f.flatten()
|
43 |
-
rp = np.corrcoef(y, f)[0,1]
|
44 |
-
return rp
|
45 |
-
|
46 |
-
def Spearman(y,f):
|
47 |
-
y, f = y.flatten(), f.flatten()
|
48 |
-
rp = stats.spearmanr(y, f)
|
49 |
-
return rp[0]
|
50 |
-
|
51 |
-
def NDCG(y,f,k=None):
|
52 |
-
y, f = y.flatten(), f.flatten()
|
53 |
-
return ndcg_score(np.expand_dims(y, axis=0), np.expand_dims(f,axis=0),k=k)
|
54 |
-
|
55 |
-
def Recall(y, f, postive_threshold = 7.5):
|
56 |
-
y, f = y.flatten(), f.flatten()
|
57 |
-
y_class = y > postive_threshold
|
58 |
-
f_class = f > postive_threshold
|
59 |
-
|
60 |
-
return recall_score(y_class, f_class)
|
61 |
-
|
62 |
-
def Enrichment_Factor(y, f, postive_threshold = 7.5, top_percentage = 0.001):
|
63 |
-
y, f = y.flatten(), f.flatten()
|
64 |
-
y_class = y > postive_threshold
|
65 |
-
f_class = f > postive_threshold
|
66 |
-
|
67 |
-
data = list(zip(y_class.tolist(), f_class.tolist()))
|
68 |
-
data.sort(key=lambda x:x[1], reverse=True)
|
69 |
-
|
70 |
-
y_class, f_class = map(list, zip(*data))
|
71 |
-
|
72 |
-
total_active_rate = sum(y_class) / len(y_class)
|
73 |
-
top_num = ceil(len(y_class) * top_percentage)
|
74 |
-
top_active_rate = sum(y_class[:top_num]) / top_num
|
75 |
-
|
76 |
-
er = top_active_rate / total_active_rate
|
77 |
-
|
78 |
-
return er
|
79 |
-
|
80 |
-
def Auxiliary_Weight_Balance(aux_type='Q8'):
|
81 |
-
if os.path.exists('loss_weight.pkl'):
|
82 |
-
with open('loss_weight.pkl','rb') as f:
|
83 |
-
w = pickle.load(f)
|
84 |
-
return w[aux_type]
|
85 |
-
|
86 |
-
def RMSD(ligs_coords_pred, ligs_coords):
|
87 |
-
rmsds = []
|
88 |
-
for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
|
89 |
-
rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1))).item())
|
90 |
-
return rmsds
|
91 |
-
|
92 |
-
def KabschRMSD(ligs_coords_pred, ligs_coords):
|
93 |
-
rmsds = []
|
94 |
-
for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
|
95 |
-
lig_coords_pred_mean = lig_coords_pred.mean(dim=0, keepdim=True) # (1,3)
|
96 |
-
lig_coords_mean = lig_coords.mean(dim=0, keepdim=True) # (1,3)
|
97 |
-
|
98 |
-
A = (lig_coords_pred - lig_coords_pred_mean).transpose(0, 1) @ (lig_coords - lig_coords_mean)
|
99 |
-
|
100 |
-
U, S, Vt = torch.linalg.svd(A)
|
101 |
-
|
102 |
-
corr_mat = torch.diag(torch.tensor([1, 1, torch.sign(torch.det(A))], device=lig_coords_pred.device))
|
103 |
-
rotation = (U @ corr_mat) @ Vt
|
104 |
-
translation = lig_coords_pred_mean - torch.t(rotation @ lig_coords_mean.t()) # (1,3)
|
105 |
-
|
106 |
-
lig_coords = (rotation @ lig_coords.t()).t() + translation
|
107 |
-
rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1))))
|
108 |
-
return torch.tensor(rmsds).mean()
|
109 |
-
|
110 |
-
|
111 |
-
class RMSDmedian(nn.Module):
|
112 |
-
def __init__(self) -> None:
|
113 |
-
super(RMSDmedian, self).__init__()
|
114 |
-
|
115 |
-
def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor:
|
116 |
-
rmsds = []
|
117 |
-
for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
|
118 |
-
rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1))))
|
119 |
-
return torch.median(torch.tensor(rmsds))
|
120 |
-
|
121 |
-
|
122 |
-
class RMSDfraction(nn.Module):
|
123 |
-
def __init__(self, distance) -> None:
|
124 |
-
super(RMSDfraction, self).__init__()
|
125 |
-
self.distance = distance
|
126 |
-
|
127 |
-
def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor:
|
128 |
-
rmsds = []
|
129 |
-
for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
|
130 |
-
rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1))))
|
131 |
-
count = torch.tensor(rmsds) < self.distance
|
132 |
-
return 100 * count.sum() / len(count)
|
133 |
-
|
134 |
-
|
135 |
-
def CentroidDist(ligs_coords_pred, ligs_coords):
|
136 |
-
distances = []
|
137 |
-
for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
|
138 |
-
distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0)).item())
|
139 |
-
return distances
|
140 |
-
|
141 |
-
|
142 |
-
class CentroidDistMedian(nn.Module):
|
143 |
-
def __init__(self) -> None:
|
144 |
-
super(CentroidDistMedian, self).__init__()
|
145 |
-
|
146 |
-
def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor:
|
147 |
-
distances = []
|
148 |
-
for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
|
149 |
-
distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0)))
|
150 |
-
return torch.median(torch.tensor(distances))
|
151 |
-
|
152 |
-
|
153 |
-
class CentroidDistFraction(nn.Module):
|
154 |
-
def __init__(self, distance) -> None:
|
155 |
-
super(CentroidDistFraction, self).__init__()
|
156 |
-
self.distance = distance
|
157 |
-
|
158 |
-
def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor:
|
159 |
-
distances = []
|
160 |
-
for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords):
|
161 |
-
distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0)))
|
162 |
-
count = torch.tensor(distances) < self.distance
|
163 |
-
return 100 * count.sum() / len(count)
|
164 |
-
|
165 |
-
class MeanPredictorLoss(nn.Module):
|
166 |
-
|
167 |
-
def __init__(self, loss_func) -> None:
|
168 |
-
super(MeanPredictorLoss, self).__init__()
|
169 |
-
self.loss_func = loss_func
|
170 |
-
|
171 |
-
def forward(self, x1: Tensor, targets: Tensor) -> Tensor:
|
172 |
-
return self.loss_func(torch.full_like(targets, targets.mean()), targets)
|
173 |
-
|
174 |
-
|
175 |
-
def compute_mmd(source, target, batch_size=1000, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
176 |
-
"""
|
177 |
-
Calculate the `maximum mean discrepancy distance <https://jmlr.csail.mit.edu/papers/v13/gretton12a.html>`_ between two sample set.
|
178 |
-
This implementation is based on `this open source code <https://github.com/ZongxianLee/MMD_Loss.Pytorch>`_.
|
179 |
-
Args:
|
180 |
-
source (pytorch tensor): the pytorch tensor containing data samples of the source distribution.
|
181 |
-
target (pytorch tensor): the pytorch tensor containing data samples of the target distribution.
|
182 |
-
:rtype:
|
183 |
-
:class:`float`
|
184 |
-
"""
|
185 |
-
n_source = int(source.size()[0])
|
186 |
-
n_target = int(target.size()[0])
|
187 |
-
n_samples = n_source + n_target
|
188 |
-
|
189 |
-
total = torch.cat([source, target], dim=0)
|
190 |
-
total0 = total.unsqueeze(0)
|
191 |
-
total1 = total.unsqueeze(1)
|
192 |
-
|
193 |
-
if fix_sigma:
|
194 |
-
bandwidth = fix_sigma
|
195 |
-
else:
|
196 |
-
bandwidth, id = 0.0, 0
|
197 |
-
while id < n_samples:
|
198 |
-
bandwidth += torch.sum((total0 - total1[id:id + batch_size]) ** 2)
|
199 |
-
id += batch_size
|
200 |
-
bandwidth /= n_samples ** 2 - n_samples
|
201 |
-
|
202 |
-
bandwidth /= kernel_mul ** (kernel_num // 2)
|
203 |
-
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
|
204 |
-
|
205 |
-
XX_kernel_val = [0 for _ in range(kernel_num)]
|
206 |
-
for i in range(kernel_num):
|
207 |
-
XX_kernel_val[i] += torch.sum(
|
208 |
-
torch.exp(-((total0[:, :n_source] - total1[:n_source, :]) ** 2) / bandwidth_list[i]))
|
209 |
-
XX = sum(XX_kernel_val) / (n_source * n_source)
|
210 |
-
|
211 |
-
YY_kernel_val = [0 for _ in range(kernel_num)]
|
212 |
-
id = n_source
|
213 |
-
while id < n_samples:
|
214 |
-
for i in range(kernel_num):
|
215 |
-
YY_kernel_val[i] += torch.sum(
|
216 |
-
torch.exp(-((total0[:, n_source:] - total1[id:id + batch_size, :]) ** 2) / bandwidth_list[i]))
|
217 |
-
id += batch_size
|
218 |
-
YY = sum(YY_kernel_val) / (n_target * n_target)
|
219 |
-
|
220 |
-
XY_kernel_val = [0 for _ in range(kernel_num)]
|
221 |
-
id = n_source
|
222 |
-
while id < n_samples:
|
223 |
-
for i in range(kernel_num):
|
224 |
-
XY_kernel_val[i] += torch.sum(
|
225 |
-
torch.exp(-((total0[:, id:id + batch_size] - total1[:n_source, :]) ** 2) / bandwidth_list[i]))
|
226 |
-
id += batch_size
|
227 |
-
XY = sum(XY_kernel_val) / (n_source * n_target)
|
228 |
-
|
229 |
-
return XX.item() + YY.item() - 2 * XY.item()
|
230 |
-
|
231 |
-
|
232 |
-
def get_matric_dict(rmsds, centroids, kabsch_rmsds=None):
|
233 |
-
rmsd_mean = sum(rmsds)/len(rmsds)
|
234 |
-
centroid_mean = sum(centroids) / len(centroids)
|
235 |
-
rmsd_std = stdev(rmsds)
|
236 |
-
centroid_std = stdev(centroids)
|
237 |
-
|
238 |
-
# rmsd < 2
|
239 |
-
count = torch.tensor(rmsds) < 2.0
|
240 |
-
rmsd_less_than_2 = 100 * count.sum().item() / len(count)
|
241 |
-
|
242 |
-
# rmsd < 2
|
243 |
-
count = torch.tensor(rmsds) < 5.0
|
244 |
-
rmsd_less_than_5 = 100 * count.sum().item() / len(count)
|
245 |
-
|
246 |
-
# centorid < 2
|
247 |
-
count = torch.tensor(centroids) < 2.0
|
248 |
-
centroid_less_than_2 = 100 * count.sum().item() / len(count)
|
249 |
-
|
250 |
-
# centorid < 5
|
251 |
-
count = torch.tensor(centroids) < 5.0
|
252 |
-
centroid_less_than_5 = 100 * count.sum().item() / len(count)
|
253 |
-
|
254 |
-
rmsd_precentiles = np.percentile(np.array(rmsds), [25, 50, 75]).round(4)
|
255 |
-
centroid_prcentiles = np.percentile(np.array(centroids), [25, 50, 75]).round(4)
|
256 |
-
|
257 |
-
metrics_dict = {'rmsd mean': rmsd_mean, 'rmsd std': rmsd_std,
|
258 |
-
'rmsd 25%': rmsd_precentiles[0], 'rmsd 50%': rmsd_precentiles[1], 'rmsd 75%': rmsd_precentiles[2],
|
259 |
-
'centroid mean': centroid_mean, 'centroid std': centroid_std,
|
260 |
-
'centroid 25%': centroid_prcentiles[0], 'centroid 50%': centroid_prcentiles[1], 'centroid 75%': centroid_prcentiles[2],
|
261 |
-
'rmsd less than 2': rmsd_less_than_2, 'rmsd less than 5':rmsd_less_than_5,
|
262 |
-
'centroid less than 2': centroid_less_than_2, 'centroid less than 5': centroid_less_than_5,
|
263 |
-
}
|
264 |
-
|
265 |
-
if kabsch_rmsds is not None:
|
266 |
-
kabsch_rmsd_mean = sum(kabsch_rmsds) / len(kabsch_rmsds)
|
267 |
-
kabsch_rmsd_std = stdev(kabsch_rmsd_mean)
|
268 |
-
metrics_dict['kabsch rmsd mean'] = kabsch_rmsd_mean
|
269 |
-
metrics_dict['kabsch rmsd std'] = kabsch_rmsd_std
|
270 |
-
|
271 |
-
return metrics_dict
|
272 |
-
|
273 |
-
def get_sbap_regression_metric_dict(np_y, np_f):
|
274 |
-
rmse, mae, pearson, spearman, sd_ = RMSE(np_y, np_f), \
|
275 |
-
MAE(np_y, np_f),\
|
276 |
-
Pearson(np_y,np_f), \
|
277 |
-
Spearman(np_y, np_f),\
|
278 |
-
SD(np_y, np_f)
|
279 |
-
|
280 |
-
metrics_dict = {'RMSE': rmse, 'MAE': mae, 'Pearson': pearson, 'Spearman': spearman, 'SD':sd_}
|
281 |
-
return metrics_dict
|
282 |
-
|
283 |
-
def get_sbap_matric_dict(np_y, np_f):
|
284 |
-
rmse, mae, pearson, spearman, sd_ = RMSE(np_y, np_f), \
|
285 |
-
MAE(np_y, np_f),\
|
286 |
-
Pearson(np_y,np_f), \
|
287 |
-
Spearman(np_y, np_f),\
|
288 |
-
SD(np_y, np_f)
|
289 |
-
|
290 |
-
recall, ndcg = Recall(np_y, np_f), NDCG(np_y, np_f)
|
291 |
-
enrichment_factor = Enrichment_Factor(np_y, np_f)
|
292 |
-
|
293 |
-
metrics_dict = {'RMSE': rmse, 'MAE': mae, 'Pearson': pearson, 'Spearman': spearman, 'SD':sd_,
|
294 |
-
'Recall': recall, 'NDCG': ndcg, 'EF1%':enrichment_factor
|
295 |
-
}
|
296 |
-
return metrics_dict
|
297 |
-
|
298 |
-
def get_matric_output_str(matric_dict):
|
299 |
-
matric_str = ''
|
300 |
-
for key in matric_dict.keys():
|
301 |
-
if not 'less than' in key:
|
302 |
-
matric_str += '| {}: {:.4f} '.format(key, matric_dict[key])
|
303 |
-
else:
|
304 |
-
matric_str += '| {}: {:.4f}% '.format(key, matric_dict[key])
|
305 |
-
return matric_str
|
306 |
-
|
307 |
-
def get_unseen_matric(rmsds, centroids, names, unseen_file_path):
|
308 |
-
with open(unseen_file_path, 'r') as f:
|
309 |
-
unseen_names = f.read().strip().split('\n')
|
310 |
-
unseen_rmsds, unseen_centroids = [], []
|
311 |
-
for name, rmsd, centroid in zip(names, rmsds, centroids):
|
312 |
-
if name in unseen_names:
|
313 |
-
unseen_rmsds.append(rmsd)
|
314 |
-
unseen_centroids.append(centroid)
|
315 |
-
return get_matric_dict(unseen_rmsds, unseen_centroids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UltraFlow/commons/torch_prepare.py
DELETED
@@ -1,156 +0,0 @@
|
|
1 |
-
import copy
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import warnings
|
5 |
-
import dgl
|
6 |
-
import os
|
7 |
-
from UltraFlow import runner, dataset
|
8 |
-
from .utils import get_run_dir
|
9 |
-
|
10 |
-
# customize exp lr scheduler with min lr
|
11 |
-
class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR):
|
12 |
-
def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False):
|
13 |
-
self.gamma = gamma
|
14 |
-
self.min_lr = min_lr
|
15 |
-
super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose)
|
16 |
-
|
17 |
-
def get_lr(self):
|
18 |
-
if not self._get_lr_called_within_step:
|
19 |
-
warnings.warn("To get the last learning rate computed by the scheduler, "
|
20 |
-
"please use `get_last_lr()`.", UserWarning)
|
21 |
-
|
22 |
-
if self.last_epoch == 0:
|
23 |
-
return self.base_lrs
|
24 |
-
return [max(group['lr'] * self.gamma, self.min_lr)
|
25 |
-
for group in self.optimizer.param_groups]
|
26 |
-
|
27 |
-
def _get_closed_form_lr(self):
|
28 |
-
return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr)
|
29 |
-
for base_lr in self.base_lrs]
|
30 |
-
|
31 |
-
|
32 |
-
def get_scheduler(config, optimizer):
|
33 |
-
if config.type == 'plateau':
|
34 |
-
return torch.optim.lr_scheduler.ReduceLROnPlateau(
|
35 |
-
optimizer,
|
36 |
-
factor=config.factor,
|
37 |
-
patience=config.patience,
|
38 |
-
)
|
39 |
-
elif config.train.scheduler == 'expmin':
|
40 |
-
return ExponentialLR_with_minLr(
|
41 |
-
optimizer,
|
42 |
-
gamma=config.factor,
|
43 |
-
min_lr=config.min_lr,
|
44 |
-
)
|
45 |
-
else:
|
46 |
-
raise NotImplementedError('Scheduler not supported: %s' % config.type)
|
47 |
-
|
48 |
-
def get_optimizer(config, model):
|
49 |
-
if config.type == "Adam":
|
50 |
-
return torch.optim.Adam(
|
51 |
-
filter(lambda p: p.requires_grad, model.parameters()),
|
52 |
-
lr=config.lr,
|
53 |
-
weight_decay=config.weight_decay)
|
54 |
-
else:
|
55 |
-
raise NotImplementedError('Optimizer not supported: %s' % config.type)
|
56 |
-
|
57 |
-
def get_optimizer_ablation(config, model, interact_ablation):
|
58 |
-
if config.type == "Adam":
|
59 |
-
return torch.optim.Adam(
|
60 |
-
filter(lambda p: p.requires_grad, list(model.parameters()) + list(interact_ablation.parameters()) ) ,
|
61 |
-
lr=config.lr,
|
62 |
-
weight_decay=config.weight_decay)
|
63 |
-
else:
|
64 |
-
raise NotImplementedError('Optimizer not supported: %s' % config.type)
|
65 |
-
|
66 |
-
def get_dataset(config, ddp=False):
|
67 |
-
if config.data.dataset_name == 'chembl_in_pdbbind_smina':
|
68 |
-
if config.data.split_type == 'assay_specific':
|
69 |
-
if ddp and config.train.use_memory_efficient_dataset == 'v1':
|
70 |
-
train_data, val_data = dataset.load_memoryefficient_ChEMBL_Dock(config)
|
71 |
-
test_data = None
|
72 |
-
elif config.train.use_memory_efficient_dataset == 'v2':
|
73 |
-
train_data, val_data = dataset.load_ChEMBL_Dock_v2(config)
|
74 |
-
test_data = None
|
75 |
-
else:
|
76 |
-
train_data, val_data = dataset.load_ChEMBL_Dock(config)
|
77 |
-
test_data = None
|
78 |
-
|
79 |
-
names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels, IC50_flag, Kd_flag, Ki_flag, K_flag, assay_d\
|
80 |
-
= dataset.load_complete_dataset(config.data.finetune_total_names, config.data.finetune_dataset_name, config.data.labels_path, config)
|
81 |
-
|
82 |
-
train_names, valid_names, test_names = dataset.split_names(names, config)
|
83 |
-
finetune_val_data = dataset.select_according_names(valid_names,
|
84 |
-
names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels,
|
85 |
-
IC50_flag, Kd_flag, Ki_flag, K_flag,
|
86 |
-
assay_d, config)
|
87 |
-
|
88 |
-
return train_data, val_data, test_data, finetune_val_data
|
89 |
-
|
90 |
-
def get_finetune_dataset(config):
|
91 |
-
names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels, IC50_flag, Kd_flag, Ki_flag, K_flag, assay_d\
|
92 |
-
= dataset.load_complete_dataset(config.data.finetune_total_names, config.data.finetune_dataset_name, config.data.labels_path, config)
|
93 |
-
|
94 |
-
train_names, valid_names, test_names = dataset.split_names(names, config)
|
95 |
-
|
96 |
-
train_data = dataset.select_according_names(train_names,
|
97 |
-
names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels,
|
98 |
-
IC50_flag, Kd_flag, Ki_flag, K_flag,
|
99 |
-
assay_d, config)
|
100 |
-
|
101 |
-
val_data = dataset.select_according_names(valid_names,
|
102 |
-
names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels,
|
103 |
-
IC50_flag, Kd_flag, Ki_flag, K_flag,
|
104 |
-
assay_d, config)
|
105 |
-
|
106 |
-
test_data = dataset.select_according_names(test_names,
|
107 |
-
names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels,
|
108 |
-
IC50_flag, Kd_flag, Ki_flag, K_flag,
|
109 |
-
assay_d, config)
|
110 |
-
|
111 |
-
# train_data = dataset.pdbbind_finetune(config.data.finetune_train_names, config.data.finetune_dataset_name,
|
112 |
-
# config.data.labels_path, config)
|
113 |
-
# val_data = dataset.pdbbind_finetune(config.data.finetune_valid_names, config.data.finetune_dataset_name,
|
114 |
-
# config.data.labels_path, config)
|
115 |
-
# test_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name,
|
116 |
-
# config.data.labels_path, config)
|
117 |
-
|
118 |
-
# train_data = dataset.pdbbind_finetune(config.data.finetune_train_names, config.data.finetune_dataset_name,
|
119 |
-
# config.data.labels_path, config)
|
120 |
-
# val_data = dataset.pdbbind_finetune(config.data.finetune_valid_names, config.data.finetune_dataset_name,
|
121 |
-
# config.data.labels_path, config)
|
122 |
-
# test_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name,
|
123 |
-
# config.data.labels_path, config)
|
124 |
-
|
125 |
-
generalize_csar_data = dataset.pdbbind_finetune(config.data.generalize_csar_test, config.data.generalize_dataset_name,
|
126 |
-
config.data.generalize_labels_path, config)
|
127 |
-
|
128 |
-
return train_data, val_data, test_data, generalize_csar_data
|
129 |
-
|
130 |
-
def get_test_dataset(config):
|
131 |
-
test_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name,
|
132 |
-
config.data.labels_path, config)
|
133 |
-
|
134 |
-
generalize_csar_data = dataset.pdbbind_finetune(config.data.generalize_csar_test, config.data.generalize_dataset_name,
|
135 |
-
config.data.generalize_labels_path, config)
|
136 |
-
|
137 |
-
return test_data, generalize_csar_data
|
138 |
-
|
139 |
-
def get_dataset_example(config):
|
140 |
-
example_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name,
|
141 |
-
config.data.labels_path, config)
|
142 |
-
|
143 |
-
return example_data
|
144 |
-
|
145 |
-
def get_model(config):
|
146 |
-
return globals()[config.model.model_type](config).to(config.train.device)
|
147 |
-
|
148 |
-
def repeat_data(data, num_repeat):
|
149 |
-
datas = [copy.deepcopy(data) for i in range(num_repeat)]
|
150 |
-
g_ligs, g_prots, g_inters = list(zip(*datas))
|
151 |
-
return dgl.batch(g_ligs), dgl.batch(g_prots), dgl.batch(g_inters)
|
152 |
-
|
153 |
-
def clip_norm(vec, limit, p=2):
|
154 |
-
norm = torch.norm(vec, dim=-1, p=2, keepdim=True)
|
155 |
-
denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))
|
156 |
-
return vec * denom
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UltraFlow/commons/visualize.py
DELETED
@@ -1,364 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import pandas as pd
|
3 |
-
import torch
|
4 |
-
from prody import writePDB
|
5 |
-
from rdkit import Chem as Chem
|
6 |
-
from rdkit.Chem.rdchem import BondType as BT
|
7 |
-
from openbabel import openbabel, pybel
|
8 |
-
from io import BytesIO
|
9 |
-
from .process_mols import read_molecules_crossdock, read_molecules, read_rdmol
|
10 |
-
from .geomop import canonical_protein_ligand_orientation
|
11 |
-
from collections import defaultdict
|
12 |
-
import numpy as np
|
13 |
-
|
14 |
-
BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
|
15 |
-
BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())}
|
16 |
-
|
17 |
-
def simply_modify_coords(pred_coords,file_path,file_type='mol2',pos_no=None,file_label='pred'):
|
18 |
-
with open(file_path,'r') as f:
|
19 |
-
lines = f.read().strip().split('\n')
|
20 |
-
index = 0
|
21 |
-
while index < len(lines):
|
22 |
-
if '@<TRIPOS>ATOM' in lines[index]:
|
23 |
-
break
|
24 |
-
index += 1
|
25 |
-
for coord in pred_coords:
|
26 |
-
index += 1
|
27 |
-
new_x = '{:.4f}'.format(coord[0]).rjust(10, ' ')
|
28 |
-
new_y = '{:.4f}'.format(coord[1]).rjust(10, ' ')
|
29 |
-
new_z = '{:.4f}'.format(coord[2]).rjust(10, ' ')
|
30 |
-
new_coord_str = new_x + new_y + new_z
|
31 |
-
lines[index] = lines[index][:16] + new_coord_str + lines[index][46:]
|
32 |
-
|
33 |
-
if pos_no is not None:
|
34 |
-
with open('{}_{}_{}.{}'.format(os.path.join(os.path.dirname(file_path),os.path.basename(file_path).split('.')[0]), file_label, pos_no, file_type),'w') as f:
|
35 |
-
f.write('\n'.join(lines))
|
36 |
-
else:
|
37 |
-
with open('{}_{}.{}'.format(os.path.join(os.path.dirname(file_path),os.path.basename(file_path).split('.')[0]), file_label, file_type),'w') as f:
|
38 |
-
f.write('\n'.join(lines))
|
39 |
-
|
40 |
-
def set_new_coords_for_protein_atom(m_prot, new_coords):
|
41 |
-
for index,atom in enumerate(m_prot):
|
42 |
-
atom.setCoords(new_coords[index])
|
43 |
-
return m_prot
|
44 |
-
|
45 |
-
def save_ligand_file(m_lig, output_path, file_type='mol2'):
|
46 |
-
|
47 |
-
return
|
48 |
-
|
49 |
-
def save_protein_file(m_prot, output_path, file_type='pdb'):
|
50 |
-
if file_type=='pdb':
|
51 |
-
writePDB(output_path, m_prot)
|
52 |
-
return
|
53 |
-
|
54 |
-
def generated_to_xyz(data):
|
55 |
-
ptable = Chem.GetPeriodicTable()
|
56 |
-
num_atoms, atom_type, atom_coords = data
|
57 |
-
xyz = "%d\n\n" % (num_atoms, )
|
58 |
-
for i in range(num_atoms):
|
59 |
-
symb = ptable.GetElementSymbol(int(atom_type[i]))
|
60 |
-
x, y, z = atom_coords[i].clone().cpu().tolist()
|
61 |
-
xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z)
|
62 |
-
return xyz
|
63 |
-
|
64 |
-
def generated_to_sdf(data):
|
65 |
-
xyz = generated_to_xyz(data)
|
66 |
-
obConversion = openbabel.OBConversion()
|
67 |
-
obConversion.SetInAndOutFormats("xyz", "sdf")
|
68 |
-
|
69 |
-
mol = openbabel.OBMol()
|
70 |
-
obConversion.ReadString(mol, xyz)
|
71 |
-
sdf = obConversion.WriteString(mol)
|
72 |
-
return sdf
|
73 |
-
|
74 |
-
def sdf_to_rdmol(sdf):
|
75 |
-
stream = BytesIO(sdf.encode())
|
76 |
-
suppl = Chem.ForwardSDMolSupplier(stream)
|
77 |
-
for mol in suppl:
|
78 |
-
return mol
|
79 |
-
return None
|
80 |
-
|
81 |
-
def generated_to_rdmol(data):
|
82 |
-
sdf = generated_to_sdf(data)
|
83 |
-
return sdf_to_rdmol(sdf)
|
84 |
-
|
85 |
-
def generated_to_rdmol_trajectory(trajectory):
|
86 |
-
sdf_trajectory = ''
|
87 |
-
for data in trajectory:
|
88 |
-
sdf_trajectory += generated_to_sdf(data)
|
89 |
-
return sdf_trajectory
|
90 |
-
|
91 |
-
def filter_rd_mol(rdmol):
|
92 |
-
ring_info = rdmol.GetRingInfo()
|
93 |
-
ring_info.AtomRings()
|
94 |
-
rings = [set(r) for r in ring_info.AtomRings()]
|
95 |
-
|
96 |
-
# 3-3 ring intersection
|
97 |
-
for i, ring_a in enumerate(rings):
|
98 |
-
if len(ring_a) != 3:continue
|
99 |
-
for j, ring_b in enumerate(rings):
|
100 |
-
if i <= j: continue
|
101 |
-
inter = ring_a.intersection(ring_b)
|
102 |
-
if (len(ring_b) == 3) and (len(inter) > 0):
|
103 |
-
return False
|
104 |
-
return True
|
105 |
-
|
106 |
-
|
107 |
-
def save_sdf_mol(rdmol, save_path, suffix='test'):
|
108 |
-
writer = Chem.SDWriter(os.path.join(save_path, 'visualize_dir', f'{suffix}.sdf'))
|
109 |
-
writer.SetKekulize(False)
|
110 |
-
try:
|
111 |
-
writer.write(rdmol, confId=0)
|
112 |
-
except:
|
113 |
-
writer.close()
|
114 |
-
return False
|
115 |
-
writer.close()
|
116 |
-
return True
|
117 |
-
|
118 |
-
def sdf_string_save_sdf_file(sdf_string, save_path, suffix='test'):
|
119 |
-
with open(os.path.join(save_path, 'visualize_dir', f'{suffix}.sdf'), 'w') as f:
|
120 |
-
f.write(sdf_string)
|
121 |
-
return
|
122 |
-
|
123 |
-
def visualize_generate_full_trajectory(trajectory, index, dataset, save_path, move_truth=True,
|
124 |
-
name_suffix='pred_trajectory', canonical_oritentaion=True):
|
125 |
-
if dataset.dataset_name in ['crossdock2020', 'crossdock2020_test']:
|
126 |
-
lig_path = index
|
127 |
-
lig_path_split = lig_path.split('/')
|
128 |
-
lig_dir, lig_base = lig_path_split[0], lig_path_split[1]
|
129 |
-
prot_path = os.path.join(lig_dir, lig_base[:10] + '.pdb')
|
130 |
-
|
131 |
-
if not os.path.exists(os.path.join(save_path, 'visualize_dir', lig_dir)):
|
132 |
-
os.makedirs(os.path.join(save_path, 'visualize_dir', lig_dir))
|
133 |
-
|
134 |
-
name = index[:-4]
|
135 |
-
|
136 |
-
assert prot_path.endswith('_rec.pdb')
|
137 |
-
molecular_representation = read_molecules_crossdock(lig_path, prot_path, dataset.ligcut, dataset.protcut,
|
138 |
-
dataset.lig_type, dataset.prot_graph_type,
|
139 |
-
dataset.dataset_path, dataset.chaincut)
|
140 |
-
|
141 |
-
lig_path_direct = os.path.join(dataset.dataset_path, lig_path)
|
142 |
-
prot_path_direct = os.path.join(dataset.dataset_path, prot_path)
|
143 |
-
|
144 |
-
|
145 |
-
elif dataset.dataset_name in ['pdbbind2020', 'pdbbind2016']:
|
146 |
-
name = index
|
147 |
-
molecular_representation = read_molecules(index, dataset.dataset_path, dataset.prot_graph_type,
|
148 |
-
dataset.ligcut, dataset.protcut, dataset.lig_type,
|
149 |
-
init_type=None, chain_cut=dataset.chaincut)
|
150 |
-
|
151 |
-
lig_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2')
|
152 |
-
if os.path.exists(os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')):
|
153 |
-
prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')
|
154 |
-
else:
|
155 |
-
prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein.pdb')
|
156 |
-
|
157 |
-
lig_coords, _, _, lig_node_type, _, prot_coords, _, _, _, _, _, _, _, _, _ = molecular_representation
|
158 |
-
|
159 |
-
if dataset.canonical_oritentaion and canonical_oritentaion:
|
160 |
-
new_trajectory = []
|
161 |
-
_, _, _, _, _, rotation, translation = canonical_protein_ligand_orientation(lig_coords, prot_coords)
|
162 |
-
for coords in trajectory:
|
163 |
-
new_trajectory.append((coords @ rotation.T) - translation)
|
164 |
-
trajectory = new_trajectory
|
165 |
-
|
166 |
-
trajectory_data = []
|
167 |
-
num_atoms = len(coords)
|
168 |
-
for coords in trajectory:
|
169 |
-
data = (num_atoms, lig_node_type, coords)
|
170 |
-
trajectory_data.append(data)
|
171 |
-
sdf_file_string = generated_to_rdmol_trajectory(trajectory_data)
|
172 |
-
|
173 |
-
if name_suffix is None:
|
174 |
-
sdf_string_save_sdf_file(sdf_file_string, save_path, suffix=name)
|
175 |
-
else:
|
176 |
-
sdf_string_save_sdf_file(sdf_file_string, save_path, suffix=f'{name}_{name_suffix}')
|
177 |
-
|
178 |
-
if move_truth:
|
179 |
-
output_path = os.path.join(save_path, 'visualize_dir')
|
180 |
-
cmd = f'cp {prot_path_direct} {output_path}'
|
181 |
-
cmd += f'&& cp {lig_path_direct} {output_path}'
|
182 |
-
os.system(cmd)
|
183 |
-
|
184 |
-
return
|
185 |
-
|
186 |
-
def visualize_generated_coordinates(coords, index, dataset, save_path, move_truth=True, name_suffix=None, canonical_oritentaion=True):
|
187 |
-
|
188 |
-
if dataset.dataset_name in ['crossdock2020', 'crossdock2020_test']:
|
189 |
-
lig_path = index
|
190 |
-
lig_path_split = lig_path.split('/')
|
191 |
-
lig_dir, lig_base = lig_path_split[0], lig_path_split[1]
|
192 |
-
prot_path = os.path.join(lig_dir, lig_base[:10]+'.pdb')
|
193 |
-
|
194 |
-
if not os.path.exists(os.path.join(save_path, 'visualize_dir', lig_dir)):
|
195 |
-
os.makedirs(os.path.join(save_path, 'visualize_dir', lig_dir))
|
196 |
-
|
197 |
-
name = index[:-4]
|
198 |
-
|
199 |
-
assert prot_path.endswith('_rec.pdb')
|
200 |
-
molecular_representation = read_molecules_crossdock(lig_path, prot_path, dataset.ligcut, dataset.protcut,
|
201 |
-
dataset.lig_type, dataset.prot_graph_type, dataset.dataset_path, dataset.chaincut)
|
202 |
-
|
203 |
-
lig_path_direct = os.path.join(dataset.dataset_path, lig_path)
|
204 |
-
prot_path_direct = os.path.join(dataset.dataset_path, prot_path)
|
205 |
-
|
206 |
-
|
207 |
-
elif dataset.dataset_name in ['pdbbind2020','pdbbind2016']:
|
208 |
-
name = index
|
209 |
-
molecular_representation = read_molecules(index, dataset.dataset_path, dataset.prot_graph_type,
|
210 |
-
dataset.ligcut, dataset.protcut, dataset.lig_type,
|
211 |
-
init_type=None, chain_cut=dataset.chaincut)
|
212 |
-
|
213 |
-
lig_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2')
|
214 |
-
if os.path.exists(os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')):
|
215 |
-
prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')
|
216 |
-
else:
|
217 |
-
prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein.pdb')
|
218 |
-
|
219 |
-
lig_coords, _, _, lig_node_type, _, prot_coords, _, _, _, _, _, _, _, _, _ = molecular_representation
|
220 |
-
|
221 |
-
if dataset.canonical_oritentaion and canonical_oritentaion:
|
222 |
-
_, _ , _, _, _, rotation, translation = canonical_protein_ligand_orientation(lig_coords, prot_coords)
|
223 |
-
coords = (coords @ rotation.T) - translation
|
224 |
-
|
225 |
-
num_atoms = len(coords)
|
226 |
-
|
227 |
-
data = (num_atoms, lig_node_type, coords)
|
228 |
-
sdf_string = generated_to_sdf(data)
|
229 |
-
|
230 |
-
sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_{name_suffix}.sdf')
|
231 |
-
with open(sdf_path, 'w') as f:
|
232 |
-
f.write(sdf_string)
|
233 |
-
|
234 |
-
if move_truth:
|
235 |
-
lig_path_direct_sdf = os.path.join(dataset.dataset_path, name, f'{name}_ligand.sdf')
|
236 |
-
output_path = os.path.join(save_path, 'visualize_dir')
|
237 |
-
cmd = f'cp {prot_path_direct} {output_path}'
|
238 |
-
cmd += f' && cp {lig_path_direct} {output_path}'
|
239 |
-
cmd += f' && cp {lig_path_direct_sdf} {output_path}'
|
240 |
-
os.system(cmd)
|
241 |
-
|
242 |
-
def visualize_predicted_pocket(binding_site_flag, index, dataset, save_path, move_truth=True, name_suffix=None, canonical_oritentaion=True):
|
243 |
-
if not os.path.exists(os.path.join(save_path, 'visualize_dir')):
|
244 |
-
os.makedirs(os.path.join(save_path, 'visualize_dir'))
|
245 |
-
|
246 |
-
if dataset.dataset_name in ['crossdock2020', 'crossdock2020_test']:
|
247 |
-
lig_path = index
|
248 |
-
lig_path_split = lig_path.split('/')
|
249 |
-
lig_dir, lig_base = lig_path_split[0], lig_path_split[1]
|
250 |
-
prot_path = os.path.join(lig_dir, lig_base[:10]+'.pdb')
|
251 |
-
|
252 |
-
if not os.path.exists(os.path.join(save_path, 'visualize_dir', lig_dir)):
|
253 |
-
os.makedirs(os.path.join(save_path, 'visualize_dir', lig_dir))
|
254 |
-
|
255 |
-
name = index[:-4]
|
256 |
-
|
257 |
-
assert prot_path.endswith('_rec.pdb')
|
258 |
-
molecular_representation = read_molecules_crossdock(lig_path, prot_path, dataset.ligcut, dataset.protcut,
|
259 |
-
dataset.lig_type, dataset.prot_graph_type, dataset.dataset_path, dataset.chaincut)
|
260 |
-
|
261 |
-
lig_path_direct = os.path.join(dataset.dataset_path, lig_path)
|
262 |
-
prot_path_direct = os.path.join(dataset.dataset_path, prot_path)
|
263 |
-
|
264 |
-
elif dataset.dataset_name in ['pdbbind2020','pdbbind2016']:
|
265 |
-
name = index
|
266 |
-
molecular_representation = read_molecules(index, dataset.dataset_path, dataset.prot_graph_type,
|
267 |
-
dataset.ligcut, dataset.protcut, dataset.lig_type,
|
268 |
-
init_type=None, chain_cut=dataset.chaincut)
|
269 |
-
|
270 |
-
lig_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2')
|
271 |
-
if os.path.exists(os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')):
|
272 |
-
prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')
|
273 |
-
else:
|
274 |
-
prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein.pdb')
|
275 |
-
|
276 |
-
lig_coords, _, _, lig_node_type, _, prot_coords, _, _, _, _, _, _, _, _, _ = molecular_representation
|
277 |
-
|
278 |
-
coords = torch.from_numpy(prot_coords[binding_site_flag.cpu()])
|
279 |
-
|
280 |
-
num_atoms = len(coords)
|
281 |
-
|
282 |
-
data = (num_atoms, [6] * num_atoms, coords)
|
283 |
-
sdf_string = generated_to_sdf(data)
|
284 |
-
|
285 |
-
sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_{name_suffix}.sdf')
|
286 |
-
with open(sdf_path, 'w') as f:
|
287 |
-
f.write(sdf_string)
|
288 |
-
|
289 |
-
if move_truth:
|
290 |
-
output_path = os.path.join(save_path, 'visualize_dir')
|
291 |
-
cmd = f'cp {prot_path_direct} {output_path}'
|
292 |
-
cmd += f'&& cp {lig_path_direct} {output_path}'
|
293 |
-
os.system(cmd)
|
294 |
-
|
295 |
-
def visualize_predicted_link_map(pred_prob, true_prob, pdb_name, dataset, save_path):
|
296 |
-
"""
|
297 |
-
:param pred_prob: [N,M], torch.tensor
|
298 |
-
:param true_prob: [N,M], torch.tensor
|
299 |
-
:param pdb_name: string
|
300 |
-
:param dataset:
|
301 |
-
:param save_path:
|
302 |
-
:return:
|
303 |
-
"""
|
304 |
-
if not os.path.exists(os.path.join(save_path, 'visualize_dir')):
|
305 |
-
os.makedirs(os.path.join(save_path, 'visualize_dir'))
|
306 |
-
|
307 |
-
pd.DataFrame(pred_prob.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_link_map_pred.csv'))
|
308 |
-
pd.DataFrame(true_prob.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_link_map_true.csv'))
|
309 |
-
|
310 |
-
def visualize_edge_coef_map(feats_coef, coords_coef, pdb_name, dataset, save_path, layer_index):
|
311 |
-
if not os.path.exists(os.path.join(save_path, 'visualize_dir')):
|
312 |
-
os.makedirs(os.path.join(save_path, 'visualize_dir'))
|
313 |
-
|
314 |
-
pd.DataFrame(feats_coef.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_feats_coef_layer_{layer_index}.csv'))
|
315 |
-
pd.DataFrame(coords_coef.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_coords_coef_layer_{layer_index}.csv'))
|
316 |
-
|
317 |
-
def collect_bond_dists(index, dataset, save_path, name_suffix='pred'):
|
318 |
-
"""
|
319 |
-
Collect the lengths for each type of chemical bond in given valid molecular geometries.
|
320 |
-
Args:
|
321 |
-
mol_dicts (dict): A python dict where the key is the number of atoms, and the value indexed by that key is another python dict storing the atomic
|
322 |
-
number matrix (indexed by the key '_atomic_numbers') and the coordinate tensor (indexed by the key '_positions') of all generated molecular geometries with that atom number.
|
323 |
-
valid_list (list): the list of bool values indicating whether each molecular geometry is chemically valid. Note that only the bond lengths of
|
324 |
-
valid molecular geometries will be collected.
|
325 |
-
con_mat_list (list): the list of bond order matrices.
|
326 |
-
|
327 |
-
:rtype: :class:`dict` a python dict where the key is the bond type, and the value indexed by that key is the list of all bond lengths of that bond.
|
328 |
-
"""
|
329 |
-
name = index
|
330 |
-
bonds_dist = []
|
331 |
-
|
332 |
-
lig_path_mol2 = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2')
|
333 |
-
lig_path_sdf = os.path.join(dataset.dataset_path, name, f'{name}_ligand.sdf')
|
334 |
-
rdmol = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True)
|
335 |
-
if rdmol == None: # read mol2 file if sdf file cannot be sanitized
|
336 |
-
rdmol = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True)
|
337 |
-
gd_atom_coords = rdmol.GetConformer().GetPositions()
|
338 |
-
|
339 |
-
pred_sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_{name_suffix}.sdf')
|
340 |
-
pred_m_lig = next(pybel.readfile('sdf', pred_sdf_path))
|
341 |
-
pred_atom_coords = np.array([atom.coords for atom in pred_m_lig], dtype=np.float32)
|
342 |
-
assert len(pred_atom_coords) == len(gd_atom_coords)
|
343 |
-
|
344 |
-
init_sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf')
|
345 |
-
inti_m_lig = next(pybel.readfile('sdf', init_sdf_path))
|
346 |
-
init_atom_coords = np.array([atom.coords for atom in inti_m_lig], dtype=np.float32)
|
347 |
-
assert len(init_atom_coords) == len(gd_atom_coords)
|
348 |
-
|
349 |
-
for bond in rdmol.GetBonds():
|
350 |
-
start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom()
|
351 |
-
start_idx, end_idx = start_atom.GetIdx(), end_atom.GetIdx()
|
352 |
-
if start_idx < end_idx:
|
353 |
-
continue
|
354 |
-
start_atom_type, end_atom_type = start_atom.GetAtomicNum(), end_atom.GetAtomicNum()
|
355 |
-
bond_type = BOND_TYPES[bond.GetBondType()]
|
356 |
-
|
357 |
-
gd_bond_dist = np.linalg.norm(gd_atom_coords[start_idx] - gd_atom_coords[end_idx])
|
358 |
-
pred_bond_dist = np.linalg.norm(pred_atom_coords[start_idx] - pred_atom_coords[end_idx])
|
359 |
-
init_bond_dist = np.linalg.norm(init_atom_coords[start_idx] - init_atom_coords[end_idx])
|
360 |
-
|
361 |
-
z1, z2 = min(start_atom_type, end_atom_type), max(start_atom_type, end_atom_type)
|
362 |
-
bonds_dist.append((z1, z2, bond_type, gd_bond_dist, pred_bond_dist, init_bond_dist))
|
363 |
-
|
364 |
-
return bonds_dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UltraFlow/data/INDEX_general_PL_data.2016
DELETED
File without changes
|
UltraFlow/data/INDEX_general_PL_data.2020
DELETED
File without changes
|
UltraFlow/data/INDEX_refined_data.2020
DELETED
File without changes
|
UltraFlow/data/chembl/P49841/P49841_valid_chains.pdb
DELETED
File without changes
|
UltraFlow/data/chembl/P49841/P49841_valid_pvalue.smi
DELETED
File without changes
|
UltraFlow/data/chembl/P49841/P49841_valid_smiles.smi
DELETED
File without changes
|
UltraFlow/data/chembl/P49841/visualize_dir/total_vs.sdf
DELETED
File without changes
|
UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_chains.pdb
DELETED
File without changes
|
UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_pvalue.smi
DELETED
File without changes
|
UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_smiles.smi
DELETED
File without changes
|
UltraFlow/data/chembl/Q9Y233/visualize_dir/total_vs.sdf
DELETED
File without changes
|
UltraFlow/data/core_set
DELETED
File without changes
|
UltraFlow/data/csar_2016
DELETED
File without changes
|
UltraFlow/data/csar_2020
DELETED
File without changes
|
UltraFlow/data/csar_new_2016
DELETED
File without changes
|
UltraFlow/data/horizontal_test.pkl
DELETED
File without changes
|
UltraFlow/data/horizontal_train.pkl
DELETED
File without changes
|
UltraFlow/data/horizontal_valid.pkl
DELETED
File without changes
|
UltraFlow/data/pdb2016_total
DELETED
File without changes
|
UltraFlow/data/pdb_after_2016
DELETED
File without changes
|
UltraFlow/data/pdbbind2016_general_gign_train
DELETED
File without changes
|
UltraFlow/data/pdbbind2016_general_gign_valid
DELETED
File without changes
|
UltraFlow/data/pdbbind2016_general_train
DELETED
File without changes
|
UltraFlow/data/pdbbind2016_general_valid
DELETED
File without changes
|
UltraFlow/data/pdbbind2016_test
DELETED
File without changes
|
UltraFlow/data/pdbbind2016_train
DELETED
File without changes
|
UltraFlow/data/pdbbind2016_train_M
DELETED
File without changes
|
UltraFlow/data/pdbbind2016_valid
DELETED
File without changes
|
UltraFlow/data/pdbbind2016_valid_M
DELETED
File without changes
|
UltraFlow/data/pdbbind2020_finetune_test
DELETED
File without changes
|
UltraFlow/data/pdbbind2020_finetune_train
DELETED
File without changes
|
UltraFlow/data/pdbbind2020_finetune_valid
DELETED
File without changes
|
UltraFlow/data/pdbbind2020_vstrain1
DELETED
File without changes
|
UltraFlow/data/pdbbind2020_vstrain2
DELETED
File without changes
|
UltraFlow/data/pdbbind2020_vstrain3
DELETED
File without changes
|
UltraFlow/data/pdbbind2020_vsvalid1
DELETED
File without changes
|
UltraFlow/data/pdbbind2020_vsvalid2
DELETED
File without changes
|
UltraFlow/data/pdbbind2020_vsvalid3
DELETED
File without changes
|
UltraFlow/data/pdbbind_2020_casf_test
DELETED
File without changes
|
UltraFlow/data/pdbbind_2020_casf_train
DELETED
File without changes
|
UltraFlow/data/pdbbind_2020_casf_valid
DELETED
File without changes
|
UltraFlow/data/tankbind_vtrain
DELETED
File without changes
|