diff --git a/UltraFlow/commons/__init__.py b/UltraFlow/commons/__init__.py index 6398b819feabbc3ce79a3addb15e685563b18178..77aecc6ffe9e23f929d3bde3f85fd8c2bd227ec1 100644 --- a/UltraFlow/commons/__init__.py +++ b/UltraFlow/commons/__init__.py @@ -1,7 +1,2 @@ from .utils import * -from .torch_prepare import * from .process_mols import * -from .metrics import * -from .geomop import * -from .visualize import * -from .dock_utils import * \ No newline at end of file diff --git a/UltraFlow/commons/dock_utils.py b/UltraFlow/commons/dock_utils.py deleted file mode 100644 index e162a69f4df0942230e5739d17fa73d91619cb60..0000000000000000000000000000000000000000 --- a/UltraFlow/commons/dock_utils.py +++ /dev/null @@ -1,355 +0,0 @@ -import os -from collections import defaultdict -import numpy as np -import torch -from openbabel import pybel -from statistics import stdev -from time import time -from .utils import pmap_multi -import pandas as pd -from tqdm import tqdm - -MGLTols_PYTHON = '/apdcephfs/private_jiaxianyan/dock/mgltools_x86_64Linux2_1.5.7/bin/python2.7' -Prepare_Ligand = '/apdcephfs/private_jiaxianyan/dock/mgltools_x86_64Linux2_1.5.7/MGLToolsPckgs/AutoDockTools/Utilities24/prepare_ligand4.py' -Prepare_Receptor = '/apdcephfs/private_jiaxianyan/dock/mgltools_x86_64Linux2_1.5.7/MGLToolsPckgs/AutoDockTools/Utilities24/prepare_receptor4.py' -SMINA = '/apdcephfs/private_jiaxianyan/dock/smina' - -def read_matric(matric_file_path): - with open(matric_file_path) as f: - lines = f.read().strip().split('\n') - rmsd, centroid = float(lines[0].split(':')[1]), float(lines[1].split(':')[1]) - return rmsd, centroid - -def mol2_add_atom_index_to_atom_name(mol2_file_path): - MOL_list = [x for x in open(mol2_file_path, 'r')] - idx = [i for i, x in enumerate(MOL_list) if x.startswith('@')] - block = MOL_list[idx[1] + 1:idx[2]] - block = [x.split() for x in block] - - block_new = [] - atom_count = defaultdict(int) - for i in block: - at = i[5].strip().split('.')[0] - if 'H' not in at: - atom_count[at] += 1 - count = atom_count[at] - at_new = at + str(count) - at_new = at_new.rjust(4) - block_new.append([i[0], at_new] + i[2:]) - else: - block_new.append(i) - - block_new = ['\t'.join(x) + '\n' for x in block_new] - MOL_list_new = MOL_list[:idx[1] + 1] + block_new + MOL_list[idx[2]:] - with open(mol2_file_path, 'w') as f: - for line in MOL_list_new: - f.write(line) - return - -def prepare_dock_file(pdb_name, config): - visualize_dir = os.path.join(config.train.save_path, 'visualize_dir') - post_align_sdf = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.sdf') - post_align_mol2 = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.mol2') - post_align_pdbqt = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.pdbqt') - - # mgltools preprocess - cmd = f'cd {visualize_dir}' - cmd += f' && obabel -i sdf {post_align_sdf} -o mol2 -O {post_align_mol2}' - - if not os.path.exists(post_align_mol2): - os.system(cmd) - mol2_add_atom_index_to_atom_name(post_align_mol2) - - cmd = f'cd {visualize_dir}' - cmd += f' && {MGLTols_PYTHON} {Prepare_Ligand} -l {post_align_mol2}' - - if not os.path.exists(post_align_pdbqt): - os.system(cmd) - # cmd = f'obabel -i mol2 {post_align_mol2} -o pdbqt -O {post_align_pdbqt}' - # os.system(cmd) - - return - -def get_mol2_atom_name(mol2_file_path): - MOL_list = [x for x in open(mol2_file_path, 'r')] - idx = [i for i, x in enumerate(MOL_list) if x.startswith('@')] - block = MOL_list[idx[1] + 1:idx[2]] - block = [x.split() for x in block] - - atom_names = [] - - for i in block: - at = i[1].strip() - atom_names.append(at) - return atom_names - -def align_dock_name_and_target_name(dock_lig_atom_names, target_lig_atom_names): - dock_lig_atom_index_in_target_lig = [] - target_atom_name_dict = {} - for index, atom_name in enumerate(target_lig_atom_names): - try: - assert atom_name not in target_atom_name_dict.keys() - except: - raise ValueError(atom_name,'has appeared before') - target_atom_name_dict[atom_name] = index - - dock_lig_atom_name_appears_dict = defaultdict(int) - for atom_name in dock_lig_atom_names: - try: - assert atom_name not in dock_lig_atom_name_appears_dict.keys() - except: - raise ValueError(atom_name,'has appeared before') - dock_lig_atom_name_appears_dict[atom_name] += 1 - try: - dock_lig_atom_index_in_target_lig.append(target_atom_name_dict[atom_name]) - except: - dock_lig_atom_index_in_target_lig.append(target_atom_name_dict[atom_name+'1']) - - return dock_lig_atom_index_in_target_lig - -def smina_dock_result_rmsd(pdb_name, config): - # target path - target_lig_mol2 = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_ligand.mol2') - - # get target coords - target_m_lig = next(pybel.readfile('mol2', target_lig_mol2)) - target_lig_coords = [atom.coords for atom in target_m_lig if atom.atomicnum > 1] - target_lig_coords = np.array(target_lig_coords, dtype=np.float32) # np.array, [n, 3] - target_lig_center = target_lig_coords.mean(axis=0) # np.array, [3] - - # get target atom names - visualize_dir = os.path.join(config.train.save_path, 'visualize_dir') - lig_init_mol2 = os.path.join(visualize_dir, f'{pdb_name}_post_align_{config.train.align_method}.mol2') - target_atom_name_reference_lig = next(pybel.readfile('mol2', lig_init_mol2)) - target_lig_atom_names = get_mol2_atom_name(lig_init_mol2) - target_lig_atom_names_no_h = [atom_name for atom, atom_name in zip(target_atom_name_reference_lig, target_lig_atom_names) if atom.atomicnum > 1] - - # get init coords - coords_before_minimized = [atom.coords for atom in target_atom_name_reference_lig if atom.atomicnum > 1] - coords_before_minimized = np.array(coords_before_minimized, dtype=np.float32) # np.array, [n, 3] - - # get smina minimized coords - dock_lig_mol2_path = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.mol2') - dock_m_lig = next(pybel.readfile('mol2', dock_lig_mol2_path)) - dock_lig_coords = [atom.coords for atom in dock_m_lig if atom.atomicnum > 1] - dock_lig_coords = np.array(dock_lig_coords, dtype=np.float32) # np.array, [n, 3] - dock_lig_center = dock_lig_coords.mean(axis=0) # np.array, [3] - - # get atom names - dock_lig_atom_names = get_mol2_atom_name(dock_lig_mol2_path) - dock_lig_atom_names_no_h = [atom_name for atom, atom_name in zip(dock_m_lig, dock_lig_atom_names) if atom.atomicnum > 1] - dock_lig_atom_index_in_target_lig = align_dock_name_and_target_name(dock_lig_atom_names_no_h, target_lig_atom_names_no_h) - - dock_lig_coords_target_align = np.zeros([len(dock_lig_atom_index_in_target_lig),3], dtype=np.float32) - for atom_coords, atom_index_in_target_lig in zip(dock_lig_coords, dock_lig_atom_index_in_target_lig): - dock_lig_coords_target_align[atom_index_in_target_lig] = atom_coords - - # rmsd - error_lig_coords = dock_lig_coords_target_align - target_lig_coords - rmsd = np.sqrt((error_lig_coords ** 2).sum(axis=1, keepdims=True).mean(axis=0)) - - # centroid - error_center_coords = dock_lig_center - target_lig_center - centorid_d = np.sqrt( (error_center_coords ** 2).sum() ) - - # get rmsd after minimized - error_lig_coords_after_minimized = dock_lig_coords_target_align - coords_before_minimized - rmsd_after_minimized = np.sqrt((error_lig_coords_after_minimized ** 2).sum(axis=1, keepdims=True).mean(axis=0)) - - return float(rmsd), float(centorid_d), float(rmsd_after_minimized) - -def get_matric_dict(rmsds, centroids): - rmsd_mean = sum(rmsds)/len(rmsds) - centroid_mean = sum(centroids) / len(centroids) - rmsd_std = stdev(rmsds) - centroid_std = stdev(centroids) - - # rmsd < 2 - count = torch.tensor(rmsds) < 2.0 - rmsd_less_than_2 = 100 * count.sum().item() / len(count) - - # rmsd < 2 - count = torch.tensor(rmsds) < 5.0 - rmsd_less_than_5 = 100 * count.sum().item() / len(count) - - # centorid < 2 - count = torch.tensor(centroids) < 2.0 - centroid_less_than_2 = 100 * count.sum().item() / len(count) - - # centorid < 5 - count = torch.tensor(centroids) < 5.0 - centroid_less_than_5 = 100 * count.sum().item() / len(count) - - metrics_dict = {'rmsd mean': rmsd_mean, 'rmsd std': rmsd_std, 'centroid mean': centroid_mean, 'centroid std': centroid_std, - 'rmsd less than 2': rmsd_less_than_2, 'rmsd less than 5':rmsd_less_than_5, - 'centroid less than 2': centroid_less_than_2, 'centroid less than 5': centroid_less_than_5} - return metrics_dict - -def run_smina_dock(pdb_name ,config): - - r_pdbqt = os.path.join(config.test_set.dataset_path, pdb_name, f'{pdb_name}_protein_processed.pdbqt') - l_pdbqt = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}.pdbqt') - out_mol2 = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.mol2') - log_file = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.log') - cmd = f'{SMINA}' \ - f' --receptor {r_pdbqt}' \ - f' --ligand {l_pdbqt}' \ - f' --out {out_mol2}' \ - f' --log {log_file}' \ - f' --minimize' - os.system(cmd) - - return - -def run_score_only(ligand_file, protein_file, out_log_file): - cmd = f'{SMINA}' \ - f' --receptor {protein_file}' \ - f' --ligand {ligand_file}' \ - f' --score_only' \ - f' > {out_log_file}' - os.system(cmd) - - with open(out_log_file, 'r') as f: - lines = f.read().strip().split('\n') - affinity_score = float(lines[21].split()[1]) - - return affinity_score - -def run_smina_score_after_predict(config): - pdb_name_list = config.test_set.names - smina_score_list = [] - for pdb_name in tqdm(pdb_name_list): - ligand_file = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_pred.sdf') - protein_file = os.path.join(config.test_set.dataset_path, pdb_name, f'{pdb_name}_protein_processed.pdbqt') - out_log_file = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_pred_smina_score.out') - smina_score = run_score_only(ligand_file, protein_file, out_log_file) - smina_score_list.append(smina_score) - - result_d = {'pdb_name':pdb_name_list, 'smina_score':smina_score_list} - pd.DataFrame(result_d).to_csv(os.path.join(config.train.save_path, 'visualize_dir', 'pred_smina_score.csv')) - return - -def run_smina_minimize_after_predict(config): - minimize_time = 0 - - pdb_name_list = config.test_set.names - - # pmap_multi(prepare_dock_file, zip(pdb_name_list, [config] * len(pdb_name_list)), - # n_jobs=8, desc='mgltools preparing ...') - - rmsds_post_dock, centroids_post_dock = [], [] - rmsds_post, centroids_post = [], [] - rmsds, centroids = [], [] - - rmsds_after_minimized = {'pdb_name':[], 'rmsd':[]} - - valid_pdb_name = [] - error_list = [] - # for pdb_name in tqdm(pdb_name_list): - # try: - # minimize_begin_time = time() - # run_smina_dock(pdb_name, config) - # minimize_time += time() - minimize_begin_time - # rmsd_post_dock, centroid_post_dock, rmsd_after_minimized = smina_dock_result_rmsd(pdb_name, config) - # rmsds_post_dock.append(rmsd_post_dock) - # centroids_post_dock.append(centroid_post_dock) - # - # rmsds_after_minimized['pdb_name'].append(pdb_name) - # rmsds_after_minimized['rmsd'].append(rmsd_after_minimized) - # print(f'{pdb_name} smina minimized, rmsd: {rmsd_post_dock}, centroid: {centroid_post_dock}') - # - # text_matics = 'rmsd:{}\ncentroid_d:{}\n'.format(rmsd_post_dock, centroid_post_dock) - # post_dock_matric_path = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_matrics_post_{config.train.align_method}_dock.txt') - # with open(post_dock_matric_path, 'w') as f: - # f.write(text_matics) - # - # # read matrics - # post_matric_path = os.path.join(config.train.save_path, 'visualize_dir', - # f'{pdb_name}_matrics_post_{config.train.align_method}.txt') - # - # matric_path = os.path.join(config.train.save_path, 'visualize_dir', - # f'{pdb_name}_matrics.txt') - # rmsd_post, centroid_post = read_matric(post_matric_path) - # rmsds_post.append(rmsd_post) - # centroids_post.append(centroid_post) - # - # rmsd, centroid = read_matric(matric_path) - # rmsds.append(rmsd) - # centroids.append(centroid) - # valid_pdb_name.append(pdb_name) - # - # except: - # print(f'{pdb_name} dock error!') - # error_list.append(pdb_name) - # - dock_score_analysis(pdb_name_list, config) - - pd.DataFrame(rmsds_after_minimized).to_csv(os.path.join(config.train.save_path, 'visualize_dir', f'rmsd_after_smina_minimzed.csv')) - - matric_dict_post_dock = get_matric_dict(rmsds_post_dock, centroids_post_dock) - matric_dict_post = get_matric_dict(rmsds_post, centroids_post) - matric_dict = get_matric_dict(rmsds, centroids) - - matric_dict_post_dock_d = {'pdb_name': valid_pdb_name, 'rmsd': rmsds_post_dock, 'centroid': centroids_post_dock} - pd.DataFrame(matric_dict_post_dock_d).to_csv( - os.path.join(config.train.save_path, 'visualize_dir', 'matric_distribution_after_minimized.csv')) - - matric_str = '' - for key in matric_dict_post_dock.keys(): - if key.endswith('mean') or key.endswith('std'): - matric_str += '| post dock {}: {:.4f} '.format(key, matric_dict_post_dock[key]) - else: - matric_str += '| post dock {}: {:.4f}% '.format(key, matric_dict_post_dock[key]) - - for key in matric_dict_post.keys(): - if key.endswith('mean') or key.endswith('std'): - matric_str += '| post {}: {:.4f} '.format(key, matric_dict_post[key]) - else: - matric_str += '| post {}: {:.4f}% '.format(key, matric_dict_post[key]) - - for key in matric_dict.keys(): - if key.endswith('mean') or key.endswith('std'): - matric_str += '| {}: {:.4f} '.format(key, matric_dict[key]) - else: - matric_str += '| {}: {:.4f}% '.format(key, matric_dict[key]) - - print(f'smina minimize results ========================') - print(matric_str) - print(f'pdb name error list ==========================') - print('\t'.join(error_list)) - print(f'smina minimize time: {minimize_time}') - - return - -def get_dock_score(log_path): - with open(log_path, 'r') as f: - lines = f.read().strip().split('\n') - - affinity_score = float(lines[20].split()[1]) - - return affinity_score - -def dock_score_analysis(pdb_name_list, config): - dock_score_d = {'name':[], 'score':[]} - error_num = 0 - for pdb_name in tqdm(pdb_name_list): - log_path = os.path.join(config.train.save_path, 'visualize_dir', f'{pdb_name}_post_align_{config.train.align_method}_docked.log') - try: - affinity_score = get_dock_score(log_path) - except: - affinity_score = None - dock_score_d['name'].append(pdb_name) - dock_score_d['score'].append(affinity_score) - print('error num,', error_num) - pd.DataFrame(dock_score_d).to_csv(os.path.join(config.train.save_path, 'visualize_dir', f'post_align_{config.train.align_method}_smina_minimize_score.csv')) - - -def structure2score(score_type): - try: - assert score_type in ['vina', 'smina', 'rfscore', 'ign', 'nnscore'] - except: - raise ValueError(f'{score_type} if not supported scoring function type') - - - - return \ No newline at end of file diff --git a/UltraFlow/commons/geomop.py b/UltraFlow/commons/geomop.py deleted file mode 100644 index 7210ac598a1373c04f901043bb211cb5e6c8fc89..0000000000000000000000000000000000000000 --- a/UltraFlow/commons/geomop.py +++ /dev/null @@ -1,529 +0,0 @@ -import torch -import rdkit.Chem as Chem -import numpy as np -import copy -from rdkit.Chem import AllChem -from rdkit.Chem import rdMolTransforms -from rdkit.Geometry import Point3D -from scipy.optimize import differential_evolution -from .process_mols import read_rdmol -import os -import math -from openbabel import pybel -from tqdm import tqdm - -def get_d_from_pos(pos, edge_index): - return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) # (num_edge) - -def kabsch(coords_A, coords_B, debug=True, device=None): - # rotate and translate coords_A to coords_B pos - coords_A_mean = coords_A.mean(dim=0, keepdim=True) # (1,3) - coords_B_mean = coords_B.mean(dim=0, keepdim=True) # (1,3) - - # A = (coords_A - coords_A_mean).transpose(0, 1) @ (coords_B - coords_B_mean) - A = (coords_A).transpose(0, 1) @ (coords_B ) - if torch.isnan(A).any(): - print('A Nan encountered') - assert not torch.isnan(A).any() - - if torch.isinf(A).any(): - print('inf encountered') - assert not torch.isinf(A).any() - - U, S, Vt = torch.linalg.svd(A) - num_it = 0 - while torch.min(S) < 1e-3 or torch.min( - torch.abs((S ** 2).view(1, 3) - (S ** 2).view(3, 1) + torch.eye(3).to(device))) < 1e-2: - if debug: print('S inside loop ', num_it, ' is ', S, ' and A = ', A) - A = A + torch.rand(3, 3).to(device) * torch.eye(3).to(device) - U, S, Vt = torch.linalg.svd(A) - num_it += 1 - if num_it > 10: raise Exception('SVD was consitantly unstable') - - corr_mat = torch.diag(torch.tensor([1, 1, torch.sign(torch.det(A))], device=device)) - rotation = (U @ corr_mat) @ Vt - - translation = coords_B_mean - torch.t(rotation @ coords_A_mean.t()) # (1,3) - - # new_coords = (rotation @ coords_A.t()).t() + translation - - return rotation, translation - -def rigid_transform_Kabsch_3D(A, B): - assert A.shape[1] == B.shape[1] - num_rows, num_cols = A.shape - if num_rows != 3: - raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") - num_rows, num_cols = B.shape - if num_rows != 3: - raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") - - - # find mean column wise: 3 x 1 - centroid_A = np.mean(A, axis=1, keepdims=True) - centroid_B = np.mean(B, axis=1, keepdims=True) - - # subtract mean - Am = A - centroid_A - Bm = B - centroid_B - - H = Am @ Bm.T - - # find rotation - U, S, Vt = np.linalg.svd(H) - - R = Vt.T @ U.T - - # special reflection case - if np.linalg.det(R) < 0: - # print("det(R) < R, reflection detected!, correcting for it ...") - SS = np.diag([1.,1.,-1.]) - R = (Vt.T @ SS) @ U.T - assert math.fabs(np.linalg.det(R) - 1) < 1e-5 - - t = -R @ centroid_A + centroid_B - return R, t - -def align_molecule_a_according_molecule_b(molecule_a_path, molecule_b_path, device=None, save=False, kabsch_no_h=True): - m_a = Chem.MolFromMol2File(molecule_a_path, sanitize=False, removeHs=False) - m_b = Chem.MolFromMol2File(molecule_b_path, sanitize=False, removeHs=False) - pos_a = torch.tensor(m_a.GetConformer().GetPositions()) - pos_b = torch.tensor(m_b.GetConformer().GetPositions()) - m_a_no_h = Chem.RemoveHs(m_a) - m_b_no_h = Chem.RemoveHs(m_b) - pos_a_no_h = torch.tensor(m_a_no_h.GetConformer().GetPositions()) - pos_b_no_h = torch.tensor(m_b_no_h.GetConformer().GetPositions()) - - if kabsch_no_h: - rotation, translation = kabsch(pos_a_no_h, pos_b_no_h, device=device) - else: - rotation, translation = kabsch(pos_a, pos_b, device=device) - pos_a_new = (rotation @ pos_a.t()).t() + translation - # print(np.sqrt(np.sum((pos_a.numpy() - pos_b.numpy()) ** 2,axis=1).mean())) - # print(np.sqrt(np.sum((pos_a_new.numpy() - pos_b.numpy()) ** 2, axis=1).mean())) - - return pos_a_new, rotation, translation - -def get_principle_axes(xyz,scale_factor=20,pdb_name=None): - #create coordinates array - coord = np.array(xyz, float) - # compute geometric center - center = np.mean(coord, 0) - # print("Coordinates of the geometric center:\n", center) - # center with geometric center - coord = coord - center - # compute principal axis matrix - inertia = np.dot(coord.transpose(), coord) - e_values, e_vectors = np.linalg.eig(inertia) - #-------------------------------------------------------------------------- - # order eigen values (and eigen vectors) - # - # axis1 is the principal axis with the biggest eigen value (eval1) - # axis2 is the principal axis with the second biggest eigen value (eval2) - # axis3 is the principal axis with the smallest eigen value (eval3) - #-------------------------------------------------------------------------- - order = np.argsort(e_values) - eval3, eval2, eval1 = e_values[order] - axis3, axis2, axis1 = e_vectors[:, order].transpose() - - return np.array([axis1, axis2, axis3]), center - -def get_rotation_and_translation(xyz): - protein_principle_axes_system, system_center = get_principle_axes(xyz) - rotation = protein_principle_axes_system.T - translation = -system_center - return rotation, translation - -def canonical_protein_ligand_orientation(lig_coords, prot_coords): - rotation, translation = get_rotation_and_translation(prot_coords) - lig_canoical_truth_coords = (lig_coords + translation) @ rotation - prot_canonical_truth_coords = (prot_coords + translation) @ rotation - rotation_lig, translation_lig = get_rotation_and_translation(lig_coords) - lig_canonical_init_coords = (lig_coords + translation_lig) @ rotation_lig - - return lig_coords, lig_canoical_truth_coords, lig_canonical_init_coords, \ - prot_coords, prot_canonical_truth_coords,\ - rotation, translation - -def canonical_single_molecule_orientation(m_coords): - rotation, translation = get_rotation_and_translation(m_coords) - canonical_init_coords = (m_coords + translation) @ rotation - return canonical_init_coords - -# Clockwise dihedral2 from https://stackoverflow.com/questions/20305272/dihedral-torsion-angle-from-four-points-in-cartesian-coordinates-in-python -def GetDihedralFromPointCloud(Z, atom_idx): - p = Z[list(atom_idx)] - b = p[:-1] - p[1:] - b[0] *= -1 ######################### - v = np.array( [ v - (v.dot(b[1])/b[1].dot(b[1])) * b[1] for v in [b[0], b[2]] ] ) - # Normalize vectors - v /= np.sqrt(np.einsum('...i,...i', v, v)).reshape(-1,1) - b1 = b[1] / np.linalg.norm(b[1]) - x = np.dot(v[0], v[1]) - m = np.cross(v[0], b1) - y = np.dot(m, v[1]) - return np.degrees(np.arctan2( y, x )) - -def A_transpose_matrix(alpha): - return np.array([[np.cos(np.radians(alpha)), np.sin(np.radians(alpha))], - [-np.sin(np.radians(alpha)), np.cos(np.radians(alpha))]], dtype=np.double) - -def S_vec(alpha): - return np.array([[np.cos(np.radians(alpha))], - [np.sin(np.radians(alpha))]], dtype=np.double) - -def get_dihedral_vonMises(mol, conf, atom_idx, Z): - Z = np.array(Z) - v = np.zeros((2,1)) - iAtom = mol.GetAtomWithIdx(atom_idx[1]) - jAtom = mol.GetAtomWithIdx(atom_idx[2]) - k_0 = atom_idx[0] - i = atom_idx[1] - j = atom_idx[2] - l_0 = atom_idx[3] - for b1 in iAtom.GetBonds(): - k = b1.GetOtherAtomIdx(i) - if k == j: - continue - for b2 in jAtom.GetBonds(): - l = b2.GetOtherAtomIdx(j) - if l == i: - continue - assert k != l - s_star = S_vec(GetDihedralFromPointCloud(Z, (k, i, j, l))) - a_mat = A_transpose_matrix(GetDihedral(conf, (k, i, j, k_0)) + GetDihedral(conf, (l_0, i, j, l))) - v = v + np.matmul(a_mat, s_star) - v = v / np.linalg.norm(v) - v = v.reshape(-1) - return np.degrees(np.arctan2(v[1], v[0])) - -def distance_loss_function(epoch, y_pred, x, protein_nodes_xyz, compound_pair_dis_constraint, LAS_distance_constraint_mask=None, mode=0): - dis = torch.cdist(x, protein_nodes_xyz) - dis_clamp = torch.clamp(dis, max=10) - if mode == 0: - interaction_loss = ((dis_clamp - y_pred).abs()).sum() - elif mode == 1: - interaction_loss = ((dis_clamp - y_pred)**2).sum() - elif mode == 2: - # probably not a good choice. x^0.5 has infinite gradient at x=0. added 1e-5 for numerical stability. - interaction_loss = (((dis_clamp - y_pred).abs() + 1e-5)**0.5).sum() - config_dis = torch.cdist(x, x) - if LAS_distance_constraint_mask is not None: - configuration_loss = 1 * (((config_dis-compound_pair_dis_constraint).abs())[LAS_distance_constraint_mask]).sum() - # basic exlcuded-volume. the distance between compound atoms should be at least 1.22Å - configuration_loss += 2 * ((1.22 - config_dis).relu()).sum() - else: - configuration_loss = 1 * ((config_dis-compound_pair_dis_constraint).abs()).sum() - # if epoch < 500: - # loss = interaction_loss - # else: - # loss = 1 * (interaction_loss + 5e-3 * (epoch - 500) * configuration_loss) - loss = 1 * (interaction_loss + 5e-3 * (epoch + 200) * configuration_loss) - return loss, (interaction_loss.item(), configuration_loss.item()) - - -def distance_optimize_compound_coords(coords, y_pred, protein_nodes_xyz, - compound_pair_dis_constraint,total_epoch=1000, loss_function=distance_loss_function, LAS_distance_constraint_mask=None, mode=0, show_progress=False): - # random initialization. center at the protein center. - c_pred = protein_nodes_xyz.mean(axis=0) - x = coords - x.requires_grad = True - optimizer = torch.optim.Adam([x], lr=0.1) - loss_list = [] - # optimizer = torch.optim.LBFGS([x], lr=0.01) - if show_progress: - it = tqdm(range(total_epoch)) - else: - it = range(total_epoch) - for epoch in it: - optimizer.zero_grad() - loss, (interaction_loss, configuration_loss) = loss_function(epoch, y_pred, x, protein_nodes_xyz, - compound_pair_dis_constraint, - LAS_distance_constraint_mask=LAS_distance_constraint_mask, - mode=mode) - loss.backward() - optimizer.step() - loss_list.append(loss.item()) - # break - return x, loss_list - -def tankbind_gen(lig_pred_coords, lig_init_coords, prot_coords, LAS_mask, device='cpu', mode=0): - - pred_prot_lig_inter_distance = torch.cdist(lig_pred_coords, prot_coords) - init_lig_intra_distance = torch.cdist(lig_init_coords, lig_init_coords) - try: - x, loss_list = distance_optimize_compound_coords(lig_pred_coords.to('cpu'), - pred_prot_lig_inter_distance.to('cpu'), - prot_coords.to('cpu'), - init_lig_intra_distance.to('cpu'), - LAS_distance_constraint_mask=LAS_mask.bool(), - mode=mode, show_progress=False) - except: - print('error') - - return x - -def kabsch_align(lig_pred_coords, name, save_path, dataset_path, device='cpu'): - rdkit_init_lig_path_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf') - openbabel_init_m_lig = next(pybel.readfile('sdf', rdkit_init_lig_path_sdf)) - rdkit_init_coords = [atom.coords for atom in openbabel_init_m_lig] - rdkit_init_coords = np.array(rdkit_init_coords, dtype=np.float32) # np.array, [n, 3] - - coords_pred = lig_pred_coords.detach().cpu().numpy() - - R, t = rigid_transform_Kabsch_3D(rdkit_init_coords.T, coords_pred.T) - coords_pred_optimized = (R @ (rdkit_init_coords).T).T + t.squeeze() - - opt_ligCoords = torch.tensor(coords_pred_optimized, device=device) - return opt_ligCoords - -def equibind_align(lig_pred_coords, name, save_path, dataset_path, device='cpu'): - lig_path_mol2 = os.path.join(dataset_path, name, f'{name}_ligand.mol2') - lig_path_sdf = os.path.join(dataset_path, name, f'{name}_ligand.sdf') - m_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True) - if m_lig == None: # read mol2 file if sdf file cannot be sanitized - m_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True) - - # load rdkit mol - lig_path_sdf_error = os.path.join(save_path, 'visualize_dir', f'{name}_init') - pred_lig_path_sdf_error = os.path.join(save_path, 'visualize_dir', f'{name}_pred') - pred_lig_path_sdf_true = os.path.join(save_path, 'visualize_dir', f'{name}_pred.sdf') - - rdkit_init_lig_path_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf') - - if not os.path.exists(rdkit_init_lig_path_sdf): - cmd = f'mv {lig_path_sdf_error} {rdkit_init_lig_path_sdf}' - os.system(cmd) - if not os.path.exists(pred_lig_path_sdf_true): - cmd = f'mv {pred_lig_path_sdf_error} {pred_lig_path_sdf_true}' - os.system(cmd) - - openbabel_init_m_lig = next(pybel.readfile('sdf', rdkit_init_lig_path_sdf)) - rdkit_init_coords = [atom.coords for atom in openbabel_init_m_lig] - rdkit_init_coords = np.array(rdkit_init_coords, dtype=np.float32) # np.array, [n, 3] - # rdkit_init_m_lig = read_rdmol(rdkit_init_lig_path_sdf, sanitize=True, remove_hs=True) - # rdkit_init_coords = rdkit_init_m_lig.GetConformer().GetPositions() - - rdkit_init_lig = copy.deepcopy(m_lig) - conf = rdkit_init_lig.GetConformer() - for i in range(rdkit_init_lig.GetNumAtoms()): - x, y, z = rdkit_init_coords[i] - conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) - - coords_pred = lig_pred_coords.detach().cpu().numpy() - Z_pt_cloud = coords_pred - rotable_bonds = get_torsions([rdkit_init_lig]) - new_dihedrals = np.zeros(len(rotable_bonds)) - - for idx, r in enumerate(rotable_bonds): - new_dihedrals[idx] = get_dihedral_vonMises(rdkit_init_lig, rdkit_init_lig.GetConformer(), r, Z_pt_cloud) - optimized_mol = apply_changes_equibind(rdkit_init_lig, new_dihedrals, rotable_bonds) - - coords_pred_optimized = optimized_mol.GetConformer().GetPositions() - R, t = rigid_transform_Kabsch_3D(coords_pred_optimized.T, coords_pred.T) - coords_pred_optimized = (R @ (coords_pred_optimized).T).T + t.squeeze() - - opt_ligCoords = torch.tensor(coords_pred_optimized, device=device) - return opt_ligCoords - -def dock_compound(lig_pred_coords, prot_coords, name, save_path, - popsize=150, maxiter=500, seed=None, mutation=(0.5, 1), - recombination=0.8, device='cpu', torsion_num_cut=20): - if seed: - np.random.seed(seed) - torch.cuda.manual_seed_all(seed) - torch.manual_seed(seed) - - # load rdkit mol - lig_path_init_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf') - openbabel_m_lig_init = next(pybel.readfile('sdf', lig_path_init_sdf)) - rdkit_init_coords = [atom.coords for atom in openbabel_m_lig_init] - - lig_path_true_sdf = os.path.join(save_path, 'visualize_dir', f'{name}_ligand.sdf') - lig_path_true_mol2 = os.path.join(save_path, 'visualize_dir', f'{name}_ligand.mol2') - m_lig = read_rdmol(lig_path_true_sdf, sanitize=True, remove_hs=True) - if m_lig == None: # read mol2 file if sdf file cannot be sanitized - m_lig = read_rdmol(lig_path_true_mol2, sanitize=True, remove_hs=True) - - atom_num = len(m_lig.GetConformer().GetPositions()) - if len(rdkit_init_coords) != atom_num: - rdkit_init_coords = [atom.coords for atom in openbabel_m_lig_init if atom.atomicnum > 1] - lig_pred_coords_no_h_list = [atom_coords for atom,atom_coords in zip(openbabel_m_lig_init, lig_pred_coords.tolist()) if atom.atomicnum > 1] - lig_pred_coords = torch.tensor(lig_pred_coords_no_h_list, device=device) - - rdkit_init_coords = np.array(rdkit_init_coords, dtype=np.float32) # np.array, [n, 3] - print(f'{name} init coords shape: {rdkit_init_coords.shape}') - print(f'{name} true coords shape: {m_lig.GetConformer().GetPositions().shape}') - - rdkit_init_lig = copy.deepcopy(m_lig) - conf = rdkit_init_lig.GetConformer() - for i in range(rdkit_init_lig.GetNumAtoms()): - x, y, z = rdkit_init_coords[i] - conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z))) - - # move m_lig to pred_coords center - pred_coords_center = lig_pred_coords.cpu().numpy().mean(axis=0) - init_coords_center = rdkit_init_lig.GetConformer().GetPositions().mean(axis=0) - # print(f'{name} pred coords shape: {lig_pred_coords.shape}') - - center_rel_vecs = pred_coords_center - init_coords_center - values = np.concatenate([np.array([0,0,0]),center_rel_vecs]) - rdMolTransforms.TransformConformer(rdkit_init_lig.GetConformer(), GetTransformationMatrix(values)) - - # Set optimization function - opt = optimze_conformation(mol=rdkit_init_lig, target_coords=lig_pred_coords, device=device, - n_particles=1, seed=seed) - if len(opt.rotable_bonds) > torsion_num_cut: - return lig_pred_coords - - # Define bounds for optimization - max_bound = np.concatenate([[np.pi] * 3, prot_coords.cpu().max(0)[0].numpy(), [np.pi] * len(opt.rotable_bonds)], axis=0) - min_bound = np.concatenate([[-np.pi] * 3, prot_coords.cpu().min(0)[0].numpy(), [-np.pi] * len(opt.rotable_bonds)], axis=0) - bounds = (min_bound, max_bound) - - # Optimize conformations - result = differential_evolution(opt.score_conformation, list(zip(bounds[0], bounds[1])), maxiter=maxiter, - popsize=int(np.ceil(popsize / (len(opt.rotable_bonds) + 6))), - mutation=mutation, recombination=recombination, disp=False, seed=seed) - - # Get optimized molecule - starting_mol = opt.mol - opt_mol = apply_changes(starting_mol, result['x'], opt.rotable_bonds) - opt_ligCoords = torch.tensor(opt_mol.GetConformer().GetPositions(), device=device) - - return opt_ligCoords - -class optimze_conformation(): - def __init__(self, mol, target_coords, n_particles, save_molecules=False, device='cpu', - seed=None): - super(optimze_conformation, self).__init__() - if seed: - np.random.seed(seed) - - self.targetCoords = torch.stack([target_coords for _ in range(n_particles)]).double() - self.n_particles = n_particles - self.rotable_bonds = get_torsions([mol]) - self.save_molecules = save_molecules - self.mol = mol - self.device = device - - def score_conformation(self, values): - """ - Parameters - ---------- - values : numpy.ndarray - set of inputs of shape :code:`(n_particles, dimensions)` - Returns - ------- - numpy.ndarray - computed cost of size :code:`(n_particles, )` - """ - if len(values.shape) < 2: values = np.expand_dims(values, axis=0) - mols = [copy.copy(self.mol) for _ in range(self.n_particles)] - - # Apply changes to molecules - # apply rotations - [SetDihedral(mols[m].GetConformer(), self.rotable_bonds[r], values[m, 6 + r]) for r in - range(len(self.rotable_bonds)) for m in range(self.n_particles)] - - # apply transformation matrix - [rdMolTransforms.TransformConformer(mols[m].GetConformer(), GetTransformationMatrix(values[m, :6])) for m in - range(self.n_particles)] - - # Calcualte distances between ligand conformation and pred ligand conformation - ligCoords_list = [torch.tensor(m.GetConformer().GetPositions(), device=self.device) for m in mols] # [n_mols, N, 3] - ligCoords = torch.stack(ligCoords_list).double() # [n_mols, N, 3] - - ligCoords_error = ligCoords - self.targetCoords # [n_mols, N, 3] - ligCoords_rmsd = (ligCoords_error ** 2).sum(dim=-1).mean(dim=-1).sqrt().min().cpu().numpy() - - del ligCoords_error, ligCoords, ligCoords_list, mols - - return ligCoords_rmsd - -def apply_changes(mol, values, rotable_bonds): - opt_mol = copy.copy(mol) - - # apply rotations - [SetDihedral(opt_mol.GetConformer(), rotable_bonds[r], values[6 + r]) for r in range(len(rotable_bonds))] - - # apply transformation matrix - rdMolTransforms.TransformConformer(opt_mol.GetConformer(), GetTransformationMatrix(values[:6])) - - return opt_mol - -def apply_changes_equibind(mol, values, rotable_bonds): - opt_mol = copy.deepcopy(mol) - # opt_mol = add_rdkit_conformer(opt_mol) - - # apply rotations - [SetDihedral(opt_mol.GetConformer(), rotable_bonds[r], values[r]) for r in range(len(rotable_bonds))] - - # # apply transformation matrix - # rdMolTransforms.TransformConformer(opt_mol.GetConformer(), GetTransformationMatrix(values[:6])) - - return opt_mol - -def get_torsions(mol_list): - atom_counter = 0 - torsionList = [] - dihedralList = [] - for m in mol_list: - torsionSmarts = '[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]' - torsionQuery = Chem.MolFromSmarts(torsionSmarts) - matches = m.GetSubstructMatches(torsionQuery) - conf = m.GetConformer() - for match in matches: - idx2 = match[0] - idx3 = match[1] - bond = m.GetBondBetweenAtoms(idx2, idx3) - jAtom = m.GetAtomWithIdx(idx2) - kAtom = m.GetAtomWithIdx(idx3) - for b1 in jAtom.GetBonds(): - if (b1.GetIdx() == bond.GetIdx()): - continue - idx1 = b1.GetOtherAtomIdx(idx2) - for b2 in kAtom.GetBonds(): - if ((b2.GetIdx() == bond.GetIdx()) - or (b2.GetIdx() == b1.GetIdx())): - continue - idx4 = b2.GetOtherAtomIdx(idx3) - # skip 3-membered rings - if (idx4 == idx1): - continue - # skip torsions that include hydrogens - if ((m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) - or (m.GetAtomWithIdx(idx4).GetAtomicNum() == 1)): - continue - if m.GetAtomWithIdx(idx4).IsInRing(): - torsionList.append( - (idx4 + atom_counter, idx3 + atom_counter, idx2 + atom_counter, idx1 + atom_counter)) - break - else: - torsionList.append( - (idx1 + atom_counter, idx2 + atom_counter, idx3 + atom_counter, idx4 + atom_counter)) - break - break - - atom_counter += m.GetNumAtoms() - return torsionList - - -def SetDihedral(conf, atom_idx, new_vale): - rdMolTransforms.SetDihedralRad(conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale) - - -def GetDihedral(conf, atom_idx): - return rdMolTransforms.GetDihedralRad(conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3]) - - -def GetTransformationMatrix(transformations): - x, y, z, disp_x, disp_y, disp_z = transformations - transMat = np.array([[np.cos(z) * np.cos(y), (np.cos(z) * np.sin(y) * np.sin(x)) - (np.sin(z) * np.cos(x)), - (np.cos(z) * np.sin(y) * np.cos(x)) + (np.sin(z) * np.sin(x)), disp_x], - [np.sin(z) * np.cos(y), (np.sin(z) * np.sin(y) * np.sin(x)) + (np.cos(z) * np.cos(x)), - (np.sin(z) * np.sin(y) * np.cos(x)) - (np.cos(z) * np.sin(x)), disp_y], - [-np.sin(y), np.cos(y) * np.sin(x), np.cos(y) * np.cos(x), disp_z], - [0, 0, 0, 1]], dtype=np.double) - return transMat - diff --git a/UltraFlow/commons/get_free_gpu.py b/UltraFlow/commons/get_free_gpu.py deleted file mode 100644 index bb27ab2d5c0dbeeb7f51490ddcc1b139bcce0434..0000000000000000000000000000000000000000 --- a/UltraFlow/commons/get_free_gpu.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch -from gpustat import GPUStatCollection -import time -def get_free_gpu(mode="memory", memory_need=10000) -> list: - r"""Get free gpu according to mode (process-free or memory-free). - Args: - mode (str, optional): memory-free or process-free. Defaults to "memory". - memory_need (int): The memory you need, used if mode=='memory'. Defaults to 10000. - Returns: - list: free gpu ids sorting by free memory - """ - assert mode in ["memory", "process"], "mode must be 'memory' or 'process'" - if mode == "memory": - assert memory_need is not None, \ - "'memory_need' if None, 'memory' mode must give the free memory you want to apply for" - memory_need = int(memory_need) - assert memory_need > 0, "'memory_need' you want must be positive" - gpu_stats = GPUStatCollection.new_query() - gpu_free_id_list = [] - - for idx, gpu_stat in enumerate(gpu_stats): - if gpu_check_condition(gpu_stat, mode, memory_need): - gpu_free_id_list.append([idx, gpu_stat.memory_free]) - print("gpu[{}]: {}MB".format(idx, gpu_stat.memory_free)) - - if gpu_free_id_list: - gpu_free_id_list = sorted(gpu_free_id_list, - key=lambda x: x[1], - reverse=True) - gpu_free_id_list = [i[0] for i in gpu_free_id_list] - return gpu_free_id_list - - -def gpu_check_condition(gpu_stat, mode, memory_need) -> bool: - r"""Check gpu is free or not. - Args: - gpu_stat (gpustat.core): gpustat to check - mode (str): memory-free or process-free. - memory_need (int): The memory you need, used if mode=='memory' - Returns: - bool: gpu is free or not - """ - if mode == "memory": - return gpu_stat.memory_free > memory_need - elif mode == "process": - for process in gpu_stat.processes: - if process["command"] == "python": - return False - return True - else: - return False - -def get_device(target_gpu_idx, memory_need=10000): - # check device - # assert torch.cuda.device_count() >= len(target_gpus), 'do you set the gpus in config correctly?' - flag = None - - while True: - # Get the gpu ids which have more than 10000MB memory - free_gpu_ids = get_free_gpu('memory', memory_need) - if len(free_gpu_ids) < 1: - if flag is None: - print("No GPU available now. sleeping 60s ....") - time.sleep(6) - else: - - gpuid = list(set(free_gpu_ids) & set(target_gpu_idx))[0] - - device = torch.device('cuda:'+str(gpuid)) - print("Using device %s as main device" % device) - break - - return device - -if __name__ == '__main__': - target_gpu_idx = [0,1,2,3,4,5,6,7,8] - device = get_device(target_gpu_idx) - print(device) \ No newline at end of file diff --git a/UltraFlow/commons/loss_weight.pkl b/UltraFlow/commons/loss_weight.pkl deleted file mode 100644 index 7d4c71f659ece70ad6a098c4bc8bd363f2620197..0000000000000000000000000000000000000000 --- a/UltraFlow/commons/loss_weight.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:42d7d91c0447c79418d9d547d45203612a9dcbf21355e047923237ba36d8765e -size 748 diff --git a/UltraFlow/commons/metrics.py b/UltraFlow/commons/metrics.py deleted file mode 100644 index ea0ca78b502a7538aed9346bf98d793de2a82db0..0000000000000000000000000000000000000000 --- a/UltraFlow/commons/metrics.py +++ /dev/null @@ -1,315 +0,0 @@ -from scipy import stats -import torch -import torch.nn as nn -import numpy as np -from math import sqrt, ceil -from sklearn.linear_model import LinearRegression -from sklearn.metrics import ndcg_score, recall_score -import os -import pickle -import dgl -from typing import Union, List -from torch import Tensor -from statistics import stdev - -def affinity_loss(affinity_pred,labels,sec_pred,bg_prot,config): - loss = nn.MSELoss(affinity_pred,labels) - if config.model.aux_w != 0: - loss += config.train.aux_w * nn.CrossEntropyLoss(sec_pred,bg_prot.ndata['s']) - return loss - -def Accurate_num(outputs,y): - _, y_pred_label = torch.max(outputs, dim=1) - return torch.sum(y_pred_label == y.data).item() - -def RMSE(y,f): - rmse = sqrt(((y - f)**2).mean(axis=0)) - return rmse - -def MAE(y,f): - mae = (np.abs(y-f)).mean() - return mae - -def SD(y,f): - f,y = f.reshape(-1,1),y.reshape(-1,1) - lr = LinearRegression() - lr.fit(f,y) - y_ = lr.predict(f) - sd = (((y - y_) ** 2).sum() / (len(y) - 1)) ** 0.5 - return sd - -def Pearson(y,f): - y,f = y.flatten(),f.flatten() - rp = np.corrcoef(y, f)[0,1] - return rp - -def Spearman(y,f): - y, f = y.flatten(), f.flatten() - rp = stats.spearmanr(y, f) - return rp[0] - -def NDCG(y,f,k=None): - y, f = y.flatten(), f.flatten() - return ndcg_score(np.expand_dims(y, axis=0), np.expand_dims(f,axis=0),k=k) - -def Recall(y, f, postive_threshold = 7.5): - y, f = y.flatten(), f.flatten() - y_class = y > postive_threshold - f_class = f > postive_threshold - - return recall_score(y_class, f_class) - -def Enrichment_Factor(y, f, postive_threshold = 7.5, top_percentage = 0.001): - y, f = y.flatten(), f.flatten() - y_class = y > postive_threshold - f_class = f > postive_threshold - - data = list(zip(y_class.tolist(), f_class.tolist())) - data.sort(key=lambda x:x[1], reverse=True) - - y_class, f_class = map(list, zip(*data)) - - total_active_rate = sum(y_class) / len(y_class) - top_num = ceil(len(y_class) * top_percentage) - top_active_rate = sum(y_class[:top_num]) / top_num - - er = top_active_rate / total_active_rate - - return er - -def Auxiliary_Weight_Balance(aux_type='Q8'): - if os.path.exists('loss_weight.pkl'): - with open('loss_weight.pkl','rb') as f: - w = pickle.load(f) - return w[aux_type] - -def RMSD(ligs_coords_pred, ligs_coords): - rmsds = [] - for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): - rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1))).item()) - return rmsds - -def KabschRMSD(ligs_coords_pred, ligs_coords): - rmsds = [] - for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): - lig_coords_pred_mean = lig_coords_pred.mean(dim=0, keepdim=True) # (1,3) - lig_coords_mean = lig_coords.mean(dim=0, keepdim=True) # (1,3) - - A = (lig_coords_pred - lig_coords_pred_mean).transpose(0, 1) @ (lig_coords - lig_coords_mean) - - U, S, Vt = torch.linalg.svd(A) - - corr_mat = torch.diag(torch.tensor([1, 1, torch.sign(torch.det(A))], device=lig_coords_pred.device)) - rotation = (U @ corr_mat) @ Vt - translation = lig_coords_pred_mean - torch.t(rotation @ lig_coords_mean.t()) # (1,3) - - lig_coords = (rotation @ lig_coords.t()).t() + translation - rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1)))) - return torch.tensor(rmsds).mean() - - -class RMSDmedian(nn.Module): - def __init__(self) -> None: - super(RMSDmedian, self).__init__() - - def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: - rmsds = [] - for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): - rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1)))) - return torch.median(torch.tensor(rmsds)) - - -class RMSDfraction(nn.Module): - def __init__(self, distance) -> None: - super(RMSDfraction, self).__init__() - self.distance = distance - - def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: - rmsds = [] - for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): - rmsds.append(torch.sqrt(torch.mean(torch.sum(((lig_coords_pred - lig_coords) ** 2), dim=1)))) - count = torch.tensor(rmsds) < self.distance - return 100 * count.sum() / len(count) - - -def CentroidDist(ligs_coords_pred, ligs_coords): - distances = [] - for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): - distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0)).item()) - return distances - - -class CentroidDistMedian(nn.Module): - def __init__(self) -> None: - super(CentroidDistMedian, self).__init__() - - def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: - distances = [] - for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): - distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0))) - return torch.median(torch.tensor(distances)) - - -class CentroidDistFraction(nn.Module): - def __init__(self, distance) -> None: - super(CentroidDistFraction, self).__init__() - self.distance = distance - - def forward(self, ligs_coords_pred: List[Tensor], ligs_coords: List[Tensor]) -> Tensor: - distances = [] - for lig_coords_pred, lig_coords in zip(ligs_coords_pred, ligs_coords): - distances.append(torch.linalg.norm(lig_coords_pred.mean(dim=0)-lig_coords.mean(dim=0))) - count = torch.tensor(distances) < self.distance - return 100 * count.sum() / len(count) - -class MeanPredictorLoss(nn.Module): - - def __init__(self, loss_func) -> None: - super(MeanPredictorLoss, self).__init__() - self.loss_func = loss_func - - def forward(self, x1: Tensor, targets: Tensor) -> Tensor: - return self.loss_func(torch.full_like(targets, targets.mean()), targets) - - -def compute_mmd(source, target, batch_size=1000, kernel_mul=2.0, kernel_num=5, fix_sigma=None): - """ - Calculate the `maximum mean discrepancy distance `_ between two sample set. - This implementation is based on `this open source code `_. - Args: - source (pytorch tensor): the pytorch tensor containing data samples of the source distribution. - target (pytorch tensor): the pytorch tensor containing data samples of the target distribution. - :rtype: - :class:`float` - """ - n_source = int(source.size()[0]) - n_target = int(target.size()[0]) - n_samples = n_source + n_target - - total = torch.cat([source, target], dim=0) - total0 = total.unsqueeze(0) - total1 = total.unsqueeze(1) - - if fix_sigma: - bandwidth = fix_sigma - else: - bandwidth, id = 0.0, 0 - while id < n_samples: - bandwidth += torch.sum((total0 - total1[id:id + batch_size]) ** 2) - id += batch_size - bandwidth /= n_samples ** 2 - n_samples - - bandwidth /= kernel_mul ** (kernel_num // 2) - bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)] - - XX_kernel_val = [0 for _ in range(kernel_num)] - for i in range(kernel_num): - XX_kernel_val[i] += torch.sum( - torch.exp(-((total0[:, :n_source] - total1[:n_source, :]) ** 2) / bandwidth_list[i])) - XX = sum(XX_kernel_val) / (n_source * n_source) - - YY_kernel_val = [0 for _ in range(kernel_num)] - id = n_source - while id < n_samples: - for i in range(kernel_num): - YY_kernel_val[i] += torch.sum( - torch.exp(-((total0[:, n_source:] - total1[id:id + batch_size, :]) ** 2) / bandwidth_list[i])) - id += batch_size - YY = sum(YY_kernel_val) / (n_target * n_target) - - XY_kernel_val = [0 for _ in range(kernel_num)] - id = n_source - while id < n_samples: - for i in range(kernel_num): - XY_kernel_val[i] += torch.sum( - torch.exp(-((total0[:, id:id + batch_size] - total1[:n_source, :]) ** 2) / bandwidth_list[i])) - id += batch_size - XY = sum(XY_kernel_val) / (n_source * n_target) - - return XX.item() + YY.item() - 2 * XY.item() - - -def get_matric_dict(rmsds, centroids, kabsch_rmsds=None): - rmsd_mean = sum(rmsds)/len(rmsds) - centroid_mean = sum(centroids) / len(centroids) - rmsd_std = stdev(rmsds) - centroid_std = stdev(centroids) - - # rmsd < 2 - count = torch.tensor(rmsds) < 2.0 - rmsd_less_than_2 = 100 * count.sum().item() / len(count) - - # rmsd < 2 - count = torch.tensor(rmsds) < 5.0 - rmsd_less_than_5 = 100 * count.sum().item() / len(count) - - # centorid < 2 - count = torch.tensor(centroids) < 2.0 - centroid_less_than_2 = 100 * count.sum().item() / len(count) - - # centorid < 5 - count = torch.tensor(centroids) < 5.0 - centroid_less_than_5 = 100 * count.sum().item() / len(count) - - rmsd_precentiles = np.percentile(np.array(rmsds), [25, 50, 75]).round(4) - centroid_prcentiles = np.percentile(np.array(centroids), [25, 50, 75]).round(4) - - metrics_dict = {'rmsd mean': rmsd_mean, 'rmsd std': rmsd_std, - 'rmsd 25%': rmsd_precentiles[0], 'rmsd 50%': rmsd_precentiles[1], 'rmsd 75%': rmsd_precentiles[2], - 'centroid mean': centroid_mean, 'centroid std': centroid_std, - 'centroid 25%': centroid_prcentiles[0], 'centroid 50%': centroid_prcentiles[1], 'centroid 75%': centroid_prcentiles[2], - 'rmsd less than 2': rmsd_less_than_2, 'rmsd less than 5':rmsd_less_than_5, - 'centroid less than 2': centroid_less_than_2, 'centroid less than 5': centroid_less_than_5, - } - - if kabsch_rmsds is not None: - kabsch_rmsd_mean = sum(kabsch_rmsds) / len(kabsch_rmsds) - kabsch_rmsd_std = stdev(kabsch_rmsd_mean) - metrics_dict['kabsch rmsd mean'] = kabsch_rmsd_mean - metrics_dict['kabsch rmsd std'] = kabsch_rmsd_std - - return metrics_dict - -def get_sbap_regression_metric_dict(np_y, np_f): - rmse, mae, pearson, spearman, sd_ = RMSE(np_y, np_f), \ - MAE(np_y, np_f),\ - Pearson(np_y,np_f), \ - Spearman(np_y, np_f),\ - SD(np_y, np_f) - - metrics_dict = {'RMSE': rmse, 'MAE': mae, 'Pearson': pearson, 'Spearman': spearman, 'SD':sd_} - return metrics_dict - -def get_sbap_matric_dict(np_y, np_f): - rmse, mae, pearson, spearman, sd_ = RMSE(np_y, np_f), \ - MAE(np_y, np_f),\ - Pearson(np_y,np_f), \ - Spearman(np_y, np_f),\ - SD(np_y, np_f) - - recall, ndcg = Recall(np_y, np_f), NDCG(np_y, np_f) - enrichment_factor = Enrichment_Factor(np_y, np_f) - - metrics_dict = {'RMSE': rmse, 'MAE': mae, 'Pearson': pearson, 'Spearman': spearman, 'SD':sd_, - 'Recall': recall, 'NDCG': ndcg, 'EF1%':enrichment_factor - } - return metrics_dict - -def get_matric_output_str(matric_dict): - matric_str = '' - for key in matric_dict.keys(): - if not 'less than' in key: - matric_str += '| {}: {:.4f} '.format(key, matric_dict[key]) - else: - matric_str += '| {}: {:.4f}% '.format(key, matric_dict[key]) - return matric_str - -def get_unseen_matric(rmsds, centroids, names, unseen_file_path): - with open(unseen_file_path, 'r') as f: - unseen_names = f.read().strip().split('\n') - unseen_rmsds, unseen_centroids = [], [] - for name, rmsd, centroid in zip(names, rmsds, centroids): - if name in unseen_names: - unseen_rmsds.append(rmsd) - unseen_centroids.append(centroid) - return get_matric_dict(unseen_rmsds, unseen_centroids) \ No newline at end of file diff --git a/UltraFlow/commons/torch_prepare.py b/UltraFlow/commons/torch_prepare.py deleted file mode 100644 index 5d8d07fe4aa0b312f812e78988a9a797bf78d5ca..0000000000000000000000000000000000000000 --- a/UltraFlow/commons/torch_prepare.py +++ /dev/null @@ -1,156 +0,0 @@ -import copy -import torch -import torch.nn as nn -import warnings -import dgl -import os -from UltraFlow import runner, dataset -from .utils import get_run_dir - -# customize exp lr scheduler with min lr -class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR): - def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False): - self.gamma = gamma - self.min_lr = min_lr - super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose) - - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == 0: - return self.base_lrs - return [max(group['lr'] * self.gamma, self.min_lr) - for group in self.optimizer.param_groups] - - def _get_closed_form_lr(self): - return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr) - for base_lr in self.base_lrs] - - -def get_scheduler(config, optimizer): - if config.type == 'plateau': - return torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - factor=config.factor, - patience=config.patience, - ) - elif config.train.scheduler == 'expmin': - return ExponentialLR_with_minLr( - optimizer, - gamma=config.factor, - min_lr=config.min_lr, - ) - else: - raise NotImplementedError('Scheduler not supported: %s' % config.type) - -def get_optimizer(config, model): - if config.type == "Adam": - return torch.optim.Adam( - filter(lambda p: p.requires_grad, model.parameters()), - lr=config.lr, - weight_decay=config.weight_decay) - else: - raise NotImplementedError('Optimizer not supported: %s' % config.type) - -def get_optimizer_ablation(config, model, interact_ablation): - if config.type == "Adam": - return torch.optim.Adam( - filter(lambda p: p.requires_grad, list(model.parameters()) + list(interact_ablation.parameters()) ) , - lr=config.lr, - weight_decay=config.weight_decay) - else: - raise NotImplementedError('Optimizer not supported: %s' % config.type) - -def get_dataset(config, ddp=False): - if config.data.dataset_name == 'chembl_in_pdbbind_smina': - if config.data.split_type == 'assay_specific': - if ddp and config.train.use_memory_efficient_dataset == 'v1': - train_data, val_data = dataset.load_memoryefficient_ChEMBL_Dock(config) - test_data = None - elif config.train.use_memory_efficient_dataset == 'v2': - train_data, val_data = dataset.load_ChEMBL_Dock_v2(config) - test_data = None - else: - train_data, val_data = dataset.load_ChEMBL_Dock(config) - test_data = None - - names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels, IC50_flag, Kd_flag, Ki_flag, K_flag, assay_d\ - = dataset.load_complete_dataset(config.data.finetune_total_names, config.data.finetune_dataset_name, config.data.labels_path, config) - - train_names, valid_names, test_names = dataset.split_names(names, config) - finetune_val_data = dataset.select_according_names(valid_names, - names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels, - IC50_flag, Kd_flag, Ki_flag, K_flag, - assay_d, config) - - return train_data, val_data, test_data, finetune_val_data - -def get_finetune_dataset(config): - names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels, IC50_flag, Kd_flag, Ki_flag, K_flag, assay_d\ - = dataset.load_complete_dataset(config.data.finetune_total_names, config.data.finetune_dataset_name, config.data.labels_path, config) - - train_names, valid_names, test_names = dataset.split_names(names, config) - - train_data = dataset.select_according_names(train_names, - names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels, - IC50_flag, Kd_flag, Ki_flag, K_flag, - assay_d, config) - - val_data = dataset.select_according_names(valid_names, - names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels, - IC50_flag, Kd_flag, Ki_flag, K_flag, - assay_d, config) - - test_data = dataset.select_according_names(test_names, - names, lig_graphs, lig_d3_info, prot_graphs, inter_graphs, labels, - IC50_flag, Kd_flag, Ki_flag, K_flag, - assay_d, config) - - # train_data = dataset.pdbbind_finetune(config.data.finetune_train_names, config.data.finetune_dataset_name, - # config.data.labels_path, config) - # val_data = dataset.pdbbind_finetune(config.data.finetune_valid_names, config.data.finetune_dataset_name, - # config.data.labels_path, config) - # test_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name, - # config.data.labels_path, config) - - # train_data = dataset.pdbbind_finetune(config.data.finetune_train_names, config.data.finetune_dataset_name, - # config.data.labels_path, config) - # val_data = dataset.pdbbind_finetune(config.data.finetune_valid_names, config.data.finetune_dataset_name, - # config.data.labels_path, config) - # test_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name, - # config.data.labels_path, config) - - generalize_csar_data = dataset.pdbbind_finetune(config.data.generalize_csar_test, config.data.generalize_dataset_name, - config.data.generalize_labels_path, config) - - return train_data, val_data, test_data, generalize_csar_data - -def get_test_dataset(config): - test_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name, - config.data.labels_path, config) - - generalize_csar_data = dataset.pdbbind_finetune(config.data.generalize_csar_test, config.data.generalize_dataset_name, - config.data.generalize_labels_path, config) - - return test_data, generalize_csar_data - -def get_dataset_example(config): - example_data = dataset.pdbbind_finetune(config.data.finetune_test_names, config.data.finetune_dataset_name, - config.data.labels_path, config) - - return example_data - -def get_model(config): - return globals()[config.model.model_type](config).to(config.train.device) - -def repeat_data(data, num_repeat): - datas = [copy.deepcopy(data) for i in range(num_repeat)] - g_ligs, g_prots, g_inters = list(zip(*datas)) - return dgl.batch(g_ligs), dgl.batch(g_prots), dgl.batch(g_inters) - -def clip_norm(vec, limit, p=2): - norm = torch.norm(vec, dim=-1, p=2, keepdim=True) - denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) - return vec * denom \ No newline at end of file diff --git a/UltraFlow/commons/visualize.py b/UltraFlow/commons/visualize.py deleted file mode 100644 index bdb0de4d2d50993d9ff448f94e262912b6a12a6c..0000000000000000000000000000000000000000 --- a/UltraFlow/commons/visualize.py +++ /dev/null @@ -1,364 +0,0 @@ -import os -import pandas as pd -import torch -from prody import writePDB -from rdkit import Chem as Chem -from rdkit.Chem.rdchem import BondType as BT -from openbabel import openbabel, pybel -from io import BytesIO -from .process_mols import read_molecules_crossdock, read_molecules, read_rdmol -from .geomop import canonical_protein_ligand_orientation -from collections import defaultdict -import numpy as np - -BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} -BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())} - -def simply_modify_coords(pred_coords,file_path,file_type='mol2',pos_no=None,file_label='pred'): - with open(file_path,'r') as f: - lines = f.read().strip().split('\n') - index = 0 - while index < len(lines): - if '@ATOM' in lines[index]: - break - index += 1 - for coord in pred_coords: - index += 1 - new_x = '{:.4f}'.format(coord[0]).rjust(10, ' ') - new_y = '{:.4f}'.format(coord[1]).rjust(10, ' ') - new_z = '{:.4f}'.format(coord[2]).rjust(10, ' ') - new_coord_str = new_x + new_y + new_z - lines[index] = lines[index][:16] + new_coord_str + lines[index][46:] - - if pos_no is not None: - with open('{}_{}_{}.{}'.format(os.path.join(os.path.dirname(file_path),os.path.basename(file_path).split('.')[0]), file_label, pos_no, file_type),'w') as f: - f.write('\n'.join(lines)) - else: - with open('{}_{}.{}'.format(os.path.join(os.path.dirname(file_path),os.path.basename(file_path).split('.')[0]), file_label, file_type),'w') as f: - f.write('\n'.join(lines)) - -def set_new_coords_for_protein_atom(m_prot, new_coords): - for index,atom in enumerate(m_prot): - atom.setCoords(new_coords[index]) - return m_prot - -def save_ligand_file(m_lig, output_path, file_type='mol2'): - - return - -def save_protein_file(m_prot, output_path, file_type='pdb'): - if file_type=='pdb': - writePDB(output_path, m_prot) - return - -def generated_to_xyz(data): - ptable = Chem.GetPeriodicTable() - num_atoms, atom_type, atom_coords = data - xyz = "%d\n\n" % (num_atoms, ) - for i in range(num_atoms): - symb = ptable.GetElementSymbol(int(atom_type[i])) - x, y, z = atom_coords[i].clone().cpu().tolist() - xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z) - return xyz - -def generated_to_sdf(data): - xyz = generated_to_xyz(data) - obConversion = openbabel.OBConversion() - obConversion.SetInAndOutFormats("xyz", "sdf") - - mol = openbabel.OBMol() - obConversion.ReadString(mol, xyz) - sdf = obConversion.WriteString(mol) - return sdf - -def sdf_to_rdmol(sdf): - stream = BytesIO(sdf.encode()) - suppl = Chem.ForwardSDMolSupplier(stream) - for mol in suppl: - return mol - return None - -def generated_to_rdmol(data): - sdf = generated_to_sdf(data) - return sdf_to_rdmol(sdf) - -def generated_to_rdmol_trajectory(trajectory): - sdf_trajectory = '' - for data in trajectory: - sdf_trajectory += generated_to_sdf(data) - return sdf_trajectory - -def filter_rd_mol(rdmol): - ring_info = rdmol.GetRingInfo() - ring_info.AtomRings() - rings = [set(r) for r in ring_info.AtomRings()] - - # 3-3 ring intersection - for i, ring_a in enumerate(rings): - if len(ring_a) != 3:continue - for j, ring_b in enumerate(rings): - if i <= j: continue - inter = ring_a.intersection(ring_b) - if (len(ring_b) == 3) and (len(inter) > 0): - return False - return True - - -def save_sdf_mol(rdmol, save_path, suffix='test'): - writer = Chem.SDWriter(os.path.join(save_path, 'visualize_dir', f'{suffix}.sdf')) - writer.SetKekulize(False) - try: - writer.write(rdmol, confId=0) - except: - writer.close() - return False - writer.close() - return True - -def sdf_string_save_sdf_file(sdf_string, save_path, suffix='test'): - with open(os.path.join(save_path, 'visualize_dir', f'{suffix}.sdf'), 'w') as f: - f.write(sdf_string) - return - -def visualize_generate_full_trajectory(trajectory, index, dataset, save_path, move_truth=True, - name_suffix='pred_trajectory', canonical_oritentaion=True): - if dataset.dataset_name in ['crossdock2020', 'crossdock2020_test']: - lig_path = index - lig_path_split = lig_path.split('/') - lig_dir, lig_base = lig_path_split[0], lig_path_split[1] - prot_path = os.path.join(lig_dir, lig_base[:10] + '.pdb') - - if not os.path.exists(os.path.join(save_path, 'visualize_dir', lig_dir)): - os.makedirs(os.path.join(save_path, 'visualize_dir', lig_dir)) - - name = index[:-4] - - assert prot_path.endswith('_rec.pdb') - molecular_representation = read_molecules_crossdock(lig_path, prot_path, dataset.ligcut, dataset.protcut, - dataset.lig_type, dataset.prot_graph_type, - dataset.dataset_path, dataset.chaincut) - - lig_path_direct = os.path.join(dataset.dataset_path, lig_path) - prot_path_direct = os.path.join(dataset.dataset_path, prot_path) - - - elif dataset.dataset_name in ['pdbbind2020', 'pdbbind2016']: - name = index - molecular_representation = read_molecules(index, dataset.dataset_path, dataset.prot_graph_type, - dataset.ligcut, dataset.protcut, dataset.lig_type, - init_type=None, chain_cut=dataset.chaincut) - - lig_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2') - if os.path.exists(os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')): - prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb') - else: - prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein.pdb') - - lig_coords, _, _, lig_node_type, _, prot_coords, _, _, _, _, _, _, _, _, _ = molecular_representation - - if dataset.canonical_oritentaion and canonical_oritentaion: - new_trajectory = [] - _, _, _, _, _, rotation, translation = canonical_protein_ligand_orientation(lig_coords, prot_coords) - for coords in trajectory: - new_trajectory.append((coords @ rotation.T) - translation) - trajectory = new_trajectory - - trajectory_data = [] - num_atoms = len(coords) - for coords in trajectory: - data = (num_atoms, lig_node_type, coords) - trajectory_data.append(data) - sdf_file_string = generated_to_rdmol_trajectory(trajectory_data) - - if name_suffix is None: - sdf_string_save_sdf_file(sdf_file_string, save_path, suffix=name) - else: - sdf_string_save_sdf_file(sdf_file_string, save_path, suffix=f'{name}_{name_suffix}') - - if move_truth: - output_path = os.path.join(save_path, 'visualize_dir') - cmd = f'cp {prot_path_direct} {output_path}' - cmd += f'&& cp {lig_path_direct} {output_path}' - os.system(cmd) - - return - -def visualize_generated_coordinates(coords, index, dataset, save_path, move_truth=True, name_suffix=None, canonical_oritentaion=True): - - if dataset.dataset_name in ['crossdock2020', 'crossdock2020_test']: - lig_path = index - lig_path_split = lig_path.split('/') - lig_dir, lig_base = lig_path_split[0], lig_path_split[1] - prot_path = os.path.join(lig_dir, lig_base[:10]+'.pdb') - - if not os.path.exists(os.path.join(save_path, 'visualize_dir', lig_dir)): - os.makedirs(os.path.join(save_path, 'visualize_dir', lig_dir)) - - name = index[:-4] - - assert prot_path.endswith('_rec.pdb') - molecular_representation = read_molecules_crossdock(lig_path, prot_path, dataset.ligcut, dataset.protcut, - dataset.lig_type, dataset.prot_graph_type, dataset.dataset_path, dataset.chaincut) - - lig_path_direct = os.path.join(dataset.dataset_path, lig_path) - prot_path_direct = os.path.join(dataset.dataset_path, prot_path) - - - elif dataset.dataset_name in ['pdbbind2020','pdbbind2016']: - name = index - molecular_representation = read_molecules(index, dataset.dataset_path, dataset.prot_graph_type, - dataset.ligcut, dataset.protcut, dataset.lig_type, - init_type=None, chain_cut=dataset.chaincut) - - lig_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2') - if os.path.exists(os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')): - prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb') - else: - prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein.pdb') - - lig_coords, _, _, lig_node_type, _, prot_coords, _, _, _, _, _, _, _, _, _ = molecular_representation - - if dataset.canonical_oritentaion and canonical_oritentaion: - _, _ , _, _, _, rotation, translation = canonical_protein_ligand_orientation(lig_coords, prot_coords) - coords = (coords @ rotation.T) - translation - - num_atoms = len(coords) - - data = (num_atoms, lig_node_type, coords) - sdf_string = generated_to_sdf(data) - - sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_{name_suffix}.sdf') - with open(sdf_path, 'w') as f: - f.write(sdf_string) - - if move_truth: - lig_path_direct_sdf = os.path.join(dataset.dataset_path, name, f'{name}_ligand.sdf') - output_path = os.path.join(save_path, 'visualize_dir') - cmd = f'cp {prot_path_direct} {output_path}' - cmd += f' && cp {lig_path_direct} {output_path}' - cmd += f' && cp {lig_path_direct_sdf} {output_path}' - os.system(cmd) - -def visualize_predicted_pocket(binding_site_flag, index, dataset, save_path, move_truth=True, name_suffix=None, canonical_oritentaion=True): - if not os.path.exists(os.path.join(save_path, 'visualize_dir')): - os.makedirs(os.path.join(save_path, 'visualize_dir')) - - if dataset.dataset_name in ['crossdock2020', 'crossdock2020_test']: - lig_path = index - lig_path_split = lig_path.split('/') - lig_dir, lig_base = lig_path_split[0], lig_path_split[1] - prot_path = os.path.join(lig_dir, lig_base[:10]+'.pdb') - - if not os.path.exists(os.path.join(save_path, 'visualize_dir', lig_dir)): - os.makedirs(os.path.join(save_path, 'visualize_dir', lig_dir)) - - name = index[:-4] - - assert prot_path.endswith('_rec.pdb') - molecular_representation = read_molecules_crossdock(lig_path, prot_path, dataset.ligcut, dataset.protcut, - dataset.lig_type, dataset.prot_graph_type, dataset.dataset_path, dataset.chaincut) - - lig_path_direct = os.path.join(dataset.dataset_path, lig_path) - prot_path_direct = os.path.join(dataset.dataset_path, prot_path) - - elif dataset.dataset_name in ['pdbbind2020','pdbbind2016']: - name = index - molecular_representation = read_molecules(index, dataset.dataset_path, dataset.prot_graph_type, - dataset.ligcut, dataset.protcut, dataset.lig_type, - init_type=None, chain_cut=dataset.chaincut) - - lig_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2') - if os.path.exists(os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb')): - prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein_processed.pdb') - else: - prot_path_direct = os.path.join(dataset.dataset_path, name, f'{name}_protein.pdb') - - lig_coords, _, _, lig_node_type, _, prot_coords, _, _, _, _, _, _, _, _, _ = molecular_representation - - coords = torch.from_numpy(prot_coords[binding_site_flag.cpu()]) - - num_atoms = len(coords) - - data = (num_atoms, [6] * num_atoms, coords) - sdf_string = generated_to_sdf(data) - - sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_{name_suffix}.sdf') - with open(sdf_path, 'w') as f: - f.write(sdf_string) - - if move_truth: - output_path = os.path.join(save_path, 'visualize_dir') - cmd = f'cp {prot_path_direct} {output_path}' - cmd += f'&& cp {lig_path_direct} {output_path}' - os.system(cmd) - -def visualize_predicted_link_map(pred_prob, true_prob, pdb_name, dataset, save_path): - """ - :param pred_prob: [N,M], torch.tensor - :param true_prob: [N,M], torch.tensor - :param pdb_name: string - :param dataset: - :param save_path: - :return: - """ - if not os.path.exists(os.path.join(save_path, 'visualize_dir')): - os.makedirs(os.path.join(save_path, 'visualize_dir')) - - pd.DataFrame(pred_prob.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_link_map_pred.csv')) - pd.DataFrame(true_prob.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_link_map_true.csv')) - -def visualize_edge_coef_map(feats_coef, coords_coef, pdb_name, dataset, save_path, layer_index): - if not os.path.exists(os.path.join(save_path, 'visualize_dir')): - os.makedirs(os.path.join(save_path, 'visualize_dir')) - - pd.DataFrame(feats_coef.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_feats_coef_layer_{layer_index}.csv')) - pd.DataFrame(coords_coef.tolist()).to_csv(os.path.join(save_path, 'visualize_dir', f'{pdb_name}_coords_coef_layer_{layer_index}.csv')) - -def collect_bond_dists(index, dataset, save_path, name_suffix='pred'): - """ - Collect the lengths for each type of chemical bond in given valid molecular geometries. - Args: - mol_dicts (dict): A python dict where the key is the number of atoms, and the value indexed by that key is another python dict storing the atomic - number matrix (indexed by the key '_atomic_numbers') and the coordinate tensor (indexed by the key '_positions') of all generated molecular geometries with that atom number. - valid_list (list): the list of bool values indicating whether each molecular geometry is chemically valid. Note that only the bond lengths of - valid molecular geometries will be collected. - con_mat_list (list): the list of bond order matrices. - - :rtype: :class:`dict` a python dict where the key is the bond type, and the value indexed by that key is the list of all bond lengths of that bond. - """ - name = index - bonds_dist = [] - - lig_path_mol2 = os.path.join(dataset.dataset_path, name, f'{name}_ligand.mol2') - lig_path_sdf = os.path.join(dataset.dataset_path, name, f'{name}_ligand.sdf') - rdmol = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True) - if rdmol == None: # read mol2 file if sdf file cannot be sanitized - rdmol = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True) - gd_atom_coords = rdmol.GetConformer().GetPositions() - - pred_sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_{name_suffix}.sdf') - pred_m_lig = next(pybel.readfile('sdf', pred_sdf_path)) - pred_atom_coords = np.array([atom.coords for atom in pred_m_lig], dtype=np.float32) - assert len(pred_atom_coords) == len(gd_atom_coords) - - init_sdf_path = os.path.join(save_path, 'visualize_dir', f'{name}_init.sdf') - inti_m_lig = next(pybel.readfile('sdf', init_sdf_path)) - init_atom_coords = np.array([atom.coords for atom in inti_m_lig], dtype=np.float32) - assert len(init_atom_coords) == len(gd_atom_coords) - - for bond in rdmol.GetBonds(): - start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() - start_idx, end_idx = start_atom.GetIdx(), end_atom.GetIdx() - if start_idx < end_idx: - continue - start_atom_type, end_atom_type = start_atom.GetAtomicNum(), end_atom.GetAtomicNum() - bond_type = BOND_TYPES[bond.GetBondType()] - - gd_bond_dist = np.linalg.norm(gd_atom_coords[start_idx] - gd_atom_coords[end_idx]) - pred_bond_dist = np.linalg.norm(pred_atom_coords[start_idx] - pred_atom_coords[end_idx]) - init_bond_dist = np.linalg.norm(init_atom_coords[start_idx] - init_atom_coords[end_idx]) - - z1, z2 = min(start_atom_type, end_atom_type), max(start_atom_type, end_atom_type) - bonds_dist.append((z1, z2, bond_type, gd_bond_dist, pred_bond_dist, init_bond_dist)) - - return bonds_dist diff --git a/UltraFlow/data/INDEX_general_PL_data.2016 b/UltraFlow/data/INDEX_general_PL_data.2016 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/INDEX_general_PL_data.2020 b/UltraFlow/data/INDEX_general_PL_data.2020 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/INDEX_refined_data.2020 b/UltraFlow/data/INDEX_refined_data.2020 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/chembl/P49841/P49841_valid_chains.pdb b/UltraFlow/data/chembl/P49841/P49841_valid_chains.pdb deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/chembl/P49841/P49841_valid_pvalue.smi b/UltraFlow/data/chembl/P49841/P49841_valid_pvalue.smi deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/chembl/P49841/P49841_valid_smiles.smi b/UltraFlow/data/chembl/P49841/P49841_valid_smiles.smi deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/chembl/P49841/visualize_dir/total_vs.sdf b/UltraFlow/data/chembl/P49841/visualize_dir/total_vs.sdf deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_chains.pdb b/UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_chains.pdb deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_pvalue.smi b/UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_pvalue.smi deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_smiles.smi b/UltraFlow/data/chembl/Q9Y233/Q9Y233_valid_smiles.smi deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/chembl/Q9Y233/visualize_dir/total_vs.sdf b/UltraFlow/data/chembl/Q9Y233/visualize_dir/total_vs.sdf deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/core_set b/UltraFlow/data/core_set deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/csar_2016 b/UltraFlow/data/csar_2016 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/csar_2020 b/UltraFlow/data/csar_2020 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/csar_new_2016 b/UltraFlow/data/csar_new_2016 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/horizontal_test.pkl b/UltraFlow/data/horizontal_test.pkl deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/horizontal_train.pkl b/UltraFlow/data/horizontal_train.pkl deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/horizontal_valid.pkl b/UltraFlow/data/horizontal_valid.pkl deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdb2016_total b/UltraFlow/data/pdb2016_total deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdb_after_2016 b/UltraFlow/data/pdb_after_2016 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2016_general_gign_train b/UltraFlow/data/pdbbind2016_general_gign_train deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2016_general_gign_valid b/UltraFlow/data/pdbbind2016_general_gign_valid deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2016_general_train b/UltraFlow/data/pdbbind2016_general_train deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2016_general_valid b/UltraFlow/data/pdbbind2016_general_valid deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2016_test b/UltraFlow/data/pdbbind2016_test deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2016_train b/UltraFlow/data/pdbbind2016_train deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2016_train_M b/UltraFlow/data/pdbbind2016_train_M deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2016_valid b/UltraFlow/data/pdbbind2016_valid deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2016_valid_M b/UltraFlow/data/pdbbind2016_valid_M deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2020_finetune_test b/UltraFlow/data/pdbbind2020_finetune_test deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2020_finetune_train b/UltraFlow/data/pdbbind2020_finetune_train deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2020_finetune_valid b/UltraFlow/data/pdbbind2020_finetune_valid deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2020_vstrain1 b/UltraFlow/data/pdbbind2020_vstrain1 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2020_vstrain2 b/UltraFlow/data/pdbbind2020_vstrain2 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2020_vstrain3 b/UltraFlow/data/pdbbind2020_vstrain3 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2020_vsvalid1 b/UltraFlow/data/pdbbind2020_vsvalid1 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2020_vsvalid2 b/UltraFlow/data/pdbbind2020_vsvalid2 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind2020_vsvalid3 b/UltraFlow/data/pdbbind2020_vsvalid3 deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind_2020_casf_test b/UltraFlow/data/pdbbind_2020_casf_test deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind_2020_casf_train b/UltraFlow/data/pdbbind_2020_casf_train deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/pdbbind_2020_casf_valid b/UltraFlow/data/pdbbind_2020_casf_valid deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/tankbind_vtrain b/UltraFlow/data/tankbind_vtrain deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/data/timesplit_test b/UltraFlow/data/timesplit_test deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/dataset/__init__.py b/UltraFlow/dataset/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/dataset/chembl.py b/UltraFlow/dataset/chembl.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/dataset/dataset_test.ipynb b/UltraFlow/dataset/dataset_test.ipynb deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/dataset/inference_dataset.py b/UltraFlow/dataset/inference_dataset.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/dataset/screening.py b/UltraFlow/dataset/screening.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/dataset/utils.py b/UltraFlow/dataset/utils.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/runner/__init__.py b/UltraFlow/runner/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/runner/asrp_runner.py b/UltraFlow/runner/asrp_runner.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/runner/finetune_runner.py b/UltraFlow/runner/finetune_runner.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/runner/reproduce_runner.py b/UltraFlow/runner/reproduce_runner.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/runner/sbap_runner.py b/UltraFlow/runner/sbap_runner.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/UltraFlow/runner/utils.py b/UltraFlow/runner/utils.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000