Spaces:
Running
Running
import os | |
import math | |
from openbabel import pybel | |
from openbabel import openbabel | |
import dgl | |
import pickle | |
import numpy as np | |
import torch | |
import scipy.spatial as spatial | |
from functools import partial | |
from prody import * | |
from rdkit import Chem as Chem | |
from rdkit.Chem.rdPartialCharges import ComputeGasteigerCharges | |
from rdkit.Chem.rdchem import BondType as BT | |
from rdkit.Chem import AllChem | |
from Bio.PDB import get_surface, PDBParser | |
from Bio.PDB.PDBExceptions import PDBConstructionWarning | |
from scipy.special import softmax | |
from scipy.spatial.transform import Rotation | |
import pandas as pd | |
ob_log_handler = pybel.ob.OBMessageHandler() | |
ob_log_handler.SetOutputLevel(0) | |
pybel.ob.obErrorLog.StopLogging() | |
BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} | |
BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())} | |
graph_type_filename = {'atom_pocket':'valid_pocket.pdb', | |
'atom_complete':'valid_chains.pdb'} | |
ResDict = {'ALA':0,'ARG':1,'ASN':2,'ASP':3,'CYS':4, | |
'GLN':5,'GLU':6,'GLY':7,'HIS':8,'ILE':9, | |
'LEU':10,'LYS':11,'MET':12,'PHE':13,'PRO':14, | |
'SER':15,'THR':16,'TRP':17,'TYR':18,'VAL':19} | |
SSEDict = {'H':0,'B':1,'E':2,'G':3,'I':4,'T':5,'S':6,' ':7} | |
SSEType,UNKOWN_RES = 8,20 | |
allowable_features = { | |
'possible_atomic_num_list': list(range(1, 119)) + ['misc'], | |
'possible_chirality_list': [ | |
'CHI_UNSPECIFIED', | |
'CHI_TETRAHEDRAL_CW', | |
'CHI_TETRAHEDRAL_CCW', | |
'CHI_OTHER' | |
], | |
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], | |
'possible_numring_list': [0, 1, 2, 3, 4, 5, 6, 'misc'], | |
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6, 'misc'], | |
'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], | |
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], | |
'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], | |
'possible_hybridization_list': [ | |
'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc' | |
], | |
'possible_is_aromatic_list': [False, True], | |
'possible_is_in_ring3_list': [False, True], | |
'possible_is_in_ring4_list': [False, True], | |
'possible_is_in_ring5_list': [False, True], | |
'possible_is_in_ring6_list': [False, True], | |
'possible_is_in_ring7_list': [False, True], | |
'possible_is_in_ring8_list': [False, True], | |
'possible_amino_acids': ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', | |
'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'HIP', 'HIE', 'TPO', 'HID', 'LEV', 'MEU', | |
'PTR', 'GLV', 'CYT', 'SEP', 'HIZ', 'CYM', 'GLM', 'ASQ', 'TYS', 'CYX', 'GLZ', 'misc'], | |
'possible_atom_type_2': ['C*', 'CA', 'CB', 'CD', 'CE', 'CG', 'CH', 'CZ', 'N*', 'ND', 'NE', 'NH', 'NZ', 'O*', 'OD', | |
'OE', 'OG', 'OH', 'OX', 'S*', 'SD', 'SG', 'misc'], | |
'possible_atom_type_3': ['C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2', | |
'CZ', 'CZ2', 'CZ3', 'N', 'ND1', 'ND2', 'NE', 'NE1', 'NE2', 'NH1', 'NH2', 'NZ', 'O', 'OD1', | |
'OD2', 'OE1', 'OE2', 'OG', 'OG1', 'OH', 'OXT', 'SD', 'SG', 'misc'], | |
} | |
lig_feature_dims = (list(map(len, [ | |
allowable_features['possible_atomic_num_list'], | |
allowable_features['possible_chirality_list'], | |
allowable_features['possible_degree_list'], | |
allowable_features['possible_formal_charge_list'], | |
allowable_features['possible_implicit_valence_list'], | |
allowable_features['possible_numH_list'], | |
allowable_features['possible_number_radical_e_list'], | |
allowable_features['possible_hybridization_list'], | |
allowable_features['possible_is_aromatic_list'], | |
allowable_features['possible_numring_list'], | |
allowable_features['possible_is_in_ring3_list'], | |
allowable_features['possible_is_in_ring4_list'], | |
allowable_features['possible_is_in_ring5_list'], | |
allowable_features['possible_is_in_ring6_list'], | |
allowable_features['possible_is_in_ring7_list'], | |
allowable_features['possible_is_in_ring8_list'], | |
])), 1) # number of scalar features | |
rec_atom_feature_dims = (list(map(len, [ | |
allowable_features['possible_amino_acids'], | |
allowable_features['possible_atomic_num_list'], | |
allowable_features['possible_atom_type_2'], | |
allowable_features['possible_atom_type_3'], | |
])), 2) | |
rec_residue_feature_dims = (list(map(len, [ | |
allowable_features['possible_amino_acids'] | |
])), 2) | |
dbcg_prot_residue_feature_dims = [[21],0] | |
def safe_index(l, e): | |
""" | |
Return index of element e in list l. If e is not present, return the last index | |
""" | |
try: | |
return l.index(e) | |
except: | |
return len(l) - 1 | |
def lig_atom_featurizer_rdmol(mol): | |
ComputeGasteigerCharges(mol) # they are Nan for 93 molecules in all of PDBbind. We put a 0 in that case. | |
ringinfo = mol.GetRingInfo() | |
atom_features_list = [] | |
for idx, atom in enumerate(mol.GetAtoms()): | |
g_charge = atom.GetDoubleProp('_GasteigerCharge') | |
atom_features_list.append([ | |
safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()), | |
allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())), | |
safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()), | |
safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()), | |
safe_index(allowable_features['possible_implicit_valence_list'], atom.GetImplicitValence()), | |
safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()), | |
safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()), | |
safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())), | |
allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()), | |
safe_index(allowable_features['possible_numring_list'], ringinfo.NumAtomRings(idx)), | |
allowable_features['possible_is_in_ring3_list'].index(ringinfo.IsAtomInRingOfSize(idx, 3)), | |
allowable_features['possible_is_in_ring4_list'].index(ringinfo.IsAtomInRingOfSize(idx, 4)), | |
allowable_features['possible_is_in_ring5_list'].index(ringinfo.IsAtomInRingOfSize(idx, 5)), | |
allowable_features['possible_is_in_ring6_list'].index(ringinfo.IsAtomInRingOfSize(idx, 6)), | |
allowable_features['possible_is_in_ring7_list'].index(ringinfo.IsAtomInRingOfSize(idx, 7)), | |
allowable_features['possible_is_in_ring8_list'].index(ringinfo.IsAtomInRingOfSize(idx, 8)), | |
g_charge if not np.isnan(g_charge) and not np.isinf(g_charge) else 0. | |
]) | |
return torch.tensor(atom_features_list) | |
def vina_gaussain_1(d): | |
return torch.exp(- ((d / 0.5) ** 2)) | |
def vina_gaussain_2(d): | |
return torch.exp(- ( ((d - 3) / 2.0) ** 2)) | |
def vina_repulsion(d): | |
if d >= 0: | |
return torch.tensor(0.0) | |
return torch.tensor(d ** 2) | |
def hydrophobic(d): | |
if d < 0.5: | |
return torch.tensor(1.0) | |
if d <= 1.5: | |
return torch.tensor(-d + 1.5) | |
return torch.tensor(0.0) | |
def hydrogen_bonding(d): | |
if d < -0.7: | |
return torch.tensor(1.0) | |
if d <= 0.0: | |
return torch.tensor(-(10/7) * d) | |
return torch.tensor(0.0) | |
def CusBondFeaturizer(bond): | |
return [int(bond.GetBondOrder()), int(bond.IsAromatic()), int(bond.IsInRing())] | |
def CusBondFeaturizer_new(bond): | |
return [int(int(bond.GetBondOrder())==1), int(int(bond.GetBondOrder())==2), int(int(bond.GetBondOrder())==3), int(bond.IsAromatic()), int(bond.IsInRing())] | |
class Featurizer(): | |
"""Calcaulates atomic features for molecules. Features can encode atom type, | |
native pybel properties or any property defined with SMARTS patterns | |
Attributes | |
---------- | |
FEATURE_NAMES: list of strings | |
Labels for features (in the same order as features) | |
NUM_ATOM_CLASSES: int | |
Number of atom codes | |
ATOM_CODES: dict | |
Dictionary mapping atomic numbers to codes | |
NAMED_PROPS: list of string | |
Names of atomic properties to retrieve from pybel.Atom object | |
CALLABLES: list of callables | |
Callables used to calculcate custom atomic properties | |
SMARTS: list of SMARTS strings | |
SMARTS patterns defining additional atomic properties | |
""" | |
def __init__(self, atom_codes=None, atom_labels=None, | |
named_properties=None, save_molecule_codes=True, | |
custom_properties=None, smarts_properties=None, | |
smarts_labels=None): | |
"""Creates Featurizer with specified types of features. Elements of a | |
feature vector will be in a following order: atom type encoding | |
(defined by atom_codes), Pybel atomic properties (defined by | |
named_properties), molecule code (if present), custom atomic properties | |
(defined `custom_properties`), and additional properties defined with | |
SMARTS (defined with `smarts_properties`). | |
Parameters | |
---------- | |
atom_codes: dict, optional | |
Dictionary mapping atomic numbers to codes. It will be used for | |
one-hot encoging therefore if n different types are used, codes | |
shpuld be from 0 to n-1. Multiple atoms can have the same code, | |
e.g. you can use {6: 0, 7: 1, 8: 1} to encode carbons with [1, 0] | |
and nitrogens and oxygens with [0, 1] vectors. If not provided, | |
default encoding is used. | |
atom_labels: list of strings, optional | |
Labels for atoms codes. It should have the same length as the | |
number of used codes, e.g. for `atom_codes={6: 0, 7: 1, 8: 1}` you | |
should provide something like ['C', 'O or N']. If not specified | |
labels 'atom0', 'atom1' etc are used. If `atom_codes` is not | |
specified this argument is ignored. | |
named_properties: list of strings, optional | |
Names of atomic properties to retrieve from pybel.Atom object. If | |
not specified ['hyb', 'heavyvalence', 'heterovalence', | |
'partialcharge'] is used. | |
save_molecule_codes: bool, optional (default True) | |
If set to True, there will be an additional feature to save | |
molecule code. It is usefeul when saving molecular complex in a | |
single array. | |
custom_properties: list of callables, optional | |
Custom functions to calculate atomic properties. Each element of | |
this list should be a callable that takes pybel.Atom object and | |
returns a float. If callable has `__name__` property it is used as | |
feature label. Otherwise labels 'func<i>' etc are used, where i is | |
the index in `custom_properties` list. | |
smarts_properties: list of strings, optional | |
Additional atomic properties defined with SMARTS patterns. These | |
patterns should match a single atom. If not specified, deafult | |
patterns are used. | |
smarts_labels: list of strings, optional | |
Labels for properties defined with SMARTS. Should have the same | |
length as `smarts_properties`. If not specified labels 'smarts0', | |
'smarts1' etc are used. If `smarts_properties` is not specified | |
this argument is ignored. | |
""" | |
# Remember namse of all features in the correct order | |
self.FEATURE_NAMES = [] | |
if atom_codes is not None: | |
if not isinstance(atom_codes, dict): | |
raise TypeError('Atom codes should be dict, got %s instead' | |
% type(atom_codes)) | |
codes = set(atom_codes.values()) | |
for i in range(len(codes)): | |
if i not in codes: | |
raise ValueError('Incorrect atom code %s' % i) | |
self.NUM_ATOM_CLASSES = len(codes) | |
self.ATOM_CODES = atom_codes | |
if atom_labels is not None: | |
if len(atom_labels) != self.NUM_ATOM_CLASSES: | |
raise ValueError('Incorrect number of atom labels: ' | |
'%s instead of %s' | |
% (len(atom_labels), self.NUM_ATOM_CLASSES)) | |
else: | |
atom_labels = ['atom%s' % i for i in range(self.NUM_ATOM_CLASSES)] | |
self.FEATURE_NAMES += atom_labels | |
else: | |
self.ATOM_CODES = {} | |
metals = ([3, 4, 11, 12, 13] + list(range(19, 32)) | |
+ list(range(37, 51)) + list(range(55, 84)) | |
+ list(range(87, 104))) | |
# List of tuples (atomic_num, class_name) with atom types to encode. | |
atom_classes = [ | |
(5, 'B'), | |
(6, 'C'), | |
(7, 'N'), | |
(8, 'O'), | |
(15, 'P'), | |
(16, 'S'), | |
(34, 'Se'), | |
([9, 17, 35, 53], 'halogen'), | |
(metals, 'metal') | |
] | |
for code, (atom, name) in enumerate(atom_classes): | |
if type(atom) is list: | |
# | |
for a in atom: | |
self.ATOM_CODES[a] = code | |
else: | |
self.ATOM_CODES[atom] = code | |
self.FEATURE_NAMES.append(name) | |
self.NUM_ATOM_CLASSES = len(atom_classes) | |
if named_properties is not None: | |
if not isinstance(named_properties, (list, tuple, np.ndarray)): | |
raise TypeError('named_properties must be a list') | |
allowed_props = [prop for prop in dir(pybel.Atom) | |
if not prop.startswith('__')] | |
for prop_id, prop in enumerate(named_properties): | |
if prop not in allowed_props: | |
raise ValueError( | |
'named_properties must be in pybel.Atom attributes,' | |
' %s was given at position %s' % (prop_id, prop) | |
) | |
self.NAMED_PROPS = named_properties | |
else: | |
# pybel.Atom properties to save | |
self.NAMED_PROPS = ['hyb', 'heavydegree', 'heterodegree', | |
'partialcharge'] | |
self.FEATURE_NAMES += self.NAMED_PROPS | |
if not isinstance(save_molecule_codes, bool): | |
raise TypeError('save_molecule_codes should be bool, got %s ' | |
'instead' % type(save_molecule_codes)) | |
self.save_molecule_codes = save_molecule_codes | |
if save_molecule_codes: | |
# Remember if an atom belongs to the ligand or to the protein | |
self.FEATURE_NAMES.append('molcode') | |
self.CALLABLES = [] | |
if custom_properties is not None: | |
for i, func in enumerate(custom_properties): | |
if not callable(func): | |
raise TypeError('custom_properties should be list of' | |
' callables, got %s instead' % type(func)) | |
name = getattr(func, '__name__', '') | |
if name == '': | |
name = 'func%s' % i | |
self.CALLABLES.append(func) | |
self.FEATURE_NAMES.append(name) | |
if smarts_properties is None: | |
# SMARTS definition for other properties | |
self.SMARTS = [ | |
'[#6+0!$(*~[#7,#8,F]),SH0+0v2,s+0,S^3,Cl+0,Br+0,I+0]', | |
'[a]', | |
'[!$([#1,#6,F,Cl,Br,I,o,s,nX3,#7v5,#15v5,#16v4,#16v6,*+1,*+2,*+3])]', | |
'[!$([#6,H0,-,-2,-3]),$([!H0;#7,#8,#9])]', | |
'[r]' | |
] | |
smarts_labels = ['hydrophobic', 'aromatic', 'acceptor', 'donor', | |
'ring'] | |
elif not isinstance(smarts_properties, (list, tuple, np.ndarray)): | |
raise TypeError('smarts_properties must be a list') | |
else: | |
self.SMARTS = smarts_properties | |
if smarts_labels is not None: | |
if len(smarts_labels) != len(self.SMARTS): | |
raise ValueError('Incorrect number of SMARTS labels: %s' | |
' instead of %s' | |
% (len(smarts_labels), len(self.SMARTS))) | |
else: | |
smarts_labels = ['smarts%s' % i for i in range(len(self.SMARTS))] | |
# Compile patterns | |
self.compile_smarts() | |
self.FEATURE_NAMES += smarts_labels | |
def compile_smarts(self): | |
self.__PATTERNS = [] | |
for smarts in self.SMARTS: | |
self.__PATTERNS.append(pybel.Smarts(smarts)) | |
def encode_num(self, atomic_num): | |
"""Encode atom type with a binary vector. If atom type is not included in | |
the `atom_classes`, its encoding is an all-zeros vector. | |
Parameters | |
---------- | |
atomic_num: int | |
Atomic number | |
Returns | |
------- | |
encoding: np.ndarray | |
Binary vector encoding atom type (one-hot or null). | |
""" | |
if not isinstance(atomic_num, int): | |
raise TypeError('Atomic number must be int, %s was given' | |
% type(atomic_num)) | |
encoding = np.zeros(self.NUM_ATOM_CLASSES) | |
try: | |
encoding[self.ATOM_CODES[atomic_num]] = 1.0 | |
except: | |
pass | |
return encoding | |
def find_smarts(self, molecule): | |
"""Find atoms that match SMARTS patterns. | |
Parameters | |
---------- | |
molecule: pybel.Molecule | |
Returns | |
------- | |
features: np.ndarray | |
NxM binary array, where N is the number of atoms in the `molecule` | |
and M is the number of patterns. `features[i, j]` == 1.0 if i'th | |
atom has j'th property | |
""" | |
if not isinstance(molecule, pybel.Molecule): | |
raise TypeError('molecule must be pybel.Molecule object, %s was given' | |
% type(molecule)) | |
features = np.zeros((len(molecule.atoms), len(self.__PATTERNS))) | |
for (pattern_id, pattern) in enumerate(self.__PATTERNS): | |
atoms_with_prop = np.array(list(*zip(*pattern.findall(molecule))), | |
dtype=int) - 1 | |
features[atoms_with_prop, pattern_id] = 1.0 | |
return features | |
def get_features(self, molecule, molcode=None): | |
"""Get coordinates and features for all heavy atoms in the molecule. | |
Parameters | |
---------- | |
molecule: pybel.Molecule | |
molcode: float, optional | |
Molecule type. You can use it to encode whether an atom belongs to | |
the ligand (1.0) or to the protein (-1.0) etc. | |
Returns | |
------- | |
coords: np.ndarray, shape = (N, 3) | |
Coordinates of all heavy atoms in the `molecule`. | |
features: np.ndarray, shape = (N, F) | |
Features of all heavy atoms in the `molecule`: atom type | |
(one-hot encoding), pybel.Atom attributes, type of a molecule | |
(e.g protein/ligand distinction), and other properties defined with | |
SMARTS patterns | |
""" | |
if not isinstance(molecule, pybel.Molecule): | |
raise TypeError('molecule must be pybel.Molecule object,' | |
' %s was given' % type(molecule)) | |
if molcode is None: | |
if self.save_molecule_codes is True: | |
raise ValueError('save_molecule_codes is set to True,' | |
' you must specify code for the molecule') | |
elif not isinstance(molcode, (float, int)): | |
raise TypeError('motlype must be float, %s was given' | |
% type(molcode)) | |
coords = [] | |
features = [] | |
heavy_atoms = [] | |
for i, atom in enumerate(molecule): | |
# ignore hydrogens and dummy atoms (they have atomicnum set to 0) | |
if atom.atomicnum > 1: | |
heavy_atoms.append(i) | |
coords.append(atom.coords) | |
features.append(np.concatenate(( | |
self.encode_num(atom.atomicnum), | |
[atom.__getattribute__(prop) for prop in self.NAMED_PROPS], | |
[func(atom) for func in self.CALLABLES], | |
))) | |
coords = np.array(coords, dtype=np.float32) | |
features = np.array(features, dtype=np.float32) | |
if self.save_molecule_codes: | |
features = np.hstack((features, | |
molcode * np.ones((len(features), 1)))) | |
features = np.hstack([features, | |
self.find_smarts(molecule)[heavy_atoms]]) | |
if np.isnan(features).any(): | |
raise RuntimeError('Got NaN when calculating features') | |
return coords, features | |
def get_features_CSAR(self, molecule, protein_idxs, ligand_idxs, molcode=None): | |
"""Get coordinates and features for all heavy atoms in the molecule. | |
Parameters | |
---------- | |
molecule: pybel.Molecule | |
molcode: float, optional | |
Molecule type. You can use it to encode whether an atom belongs to | |
the ligand (1.0) or to the protein (-1.0) etc. | |
Returns | |
------- | |
coords: np.ndarray, shape = (N, 3) | |
Coordinates of all heavy atoms in the `molecule`. | |
features: np.ndarray, shape = (N, F) | |
Features of all heavy atoms in the `molecule`: atom type | |
(one-hot encoding), pybel.Atom attributes, type of a molecule | |
(e.g protein/ligand distinction), and other properties defined with | |
SMARTS patterns | |
""" | |
if not isinstance(molecule, pybel.Molecule): | |
raise TypeError('molecule must be pybel.Molecule object,' | |
' %s was given' % type(molecule)) | |
if molcode is None: | |
if self.save_molecule_codes is True: | |
raise ValueError('save_molecule_codes is set to True,' | |
' you must specify code for the molecule') | |
elif not isinstance(molcode, (float, int)): | |
raise TypeError('motlype must be float, %s was given' | |
% type(molcode)) | |
coords,protein_coords,ligand_coords = [],[],[] | |
features,protein_features,ligand_features = [],[],[] | |
heavy_atoms,protein_heavy_atoms,ligand_heavy_atoms = [],[],[] | |
for i, atom in enumerate(molecule): | |
# ignore hydrogens and dummy atoms (they have atomicnum set to 0) | |
index = i | |
if atom.atomicnum > 1: | |
heavy_atoms.append(i) | |
coords.append(atom.coords) | |
features.append(np.concatenate(( | |
self.encode_num(atom.atomicnum), | |
[atom.__getattribute__(prop) for prop in self.NAMED_PROPS], | |
[func(atom) for func in self.CALLABLES], | |
))) | |
if index in protein_idxs: | |
protein_heavy_atoms.append(i) | |
protein_coords.append(atom.coords) | |
protein_features.append(np.concatenate(( | |
self.encode_num(atom.atomicnum), | |
[atom.__getattribute__(prop) for prop in self.NAMED_PROPS], | |
[func(atom) for func in self.CALLABLES], | |
))) | |
elif index in ligand_idxs: | |
ligand_heavy_atoms.append(i) | |
ligand_coords.append(atom.coords) | |
ligand_features.append(np.concatenate(( | |
self.encode_num(atom.atomicnum), | |
[atom.__getattribute__(prop) for prop in self.NAMED_PROPS], | |
[func(atom) for func in self.CALLABLES], | |
))) | |
coords,protein_coords,ligand_coords = np.array(coords, dtype=np.float32),\ | |
np.array(protein_coords, dtype=np.float32),\ | |
np.array(ligand_coords, dtype=np.float32) | |
features = np.array(features, dtype=np.float32) | |
if self.save_molecule_codes: | |
features = np.hstack((features, | |
molcode * np.ones((len(features), 1)))) | |
features = np.hstack([features, | |
self.find_smarts(molecule)[heavy_atoms]]) | |
protein_features = np.hstack([protein_features, | |
self.find_smarts(molecule)[protein_heavy_atoms]]) | |
ligand_features = np.hstack([ligand_features, | |
self.find_smarts(molecule)[ligand_heavy_atoms]]) | |
if np.isnan(features).any(): | |
raise RuntimeError('Got NaN when calculating features') | |
return coords, features, protein_coords, protein_features, ligand_coords, ligand_features | |
def to_pickle(self, fname='featurizer.pkl'): | |
"""Save featurizer in a given file. Featurizer can be restored with | |
`from_pickle` method. | |
Parameters | |
---------- | |
fname: str, optional | |
Path to file in which featurizer will be saved | |
""" | |
# patterns can't be pickled, we need to temporarily remove them | |
patterns = self.__PATTERNS[:] | |
del self.__PATTERNS | |
try: | |
with open(fname, 'wb') as f: | |
pickle.dump(self, f) | |
finally: | |
self.__PATTERNS = patterns[:] | |
def from_pickle(fname): | |
"""Load pickled featurizer from a given file | |
Parameters | |
---------- | |
fname: str, optional | |
Path to file with saved featurizer | |
Returns | |
------- | |
featurizer: Featurizer object | |
Loaded featurizer | |
""" | |
with open(fname, 'rb') as f: | |
featurizer = pickle.load(f) | |
featurizer.compile_smarts() | |
return featurizer | |
featurizer = Featurizer(save_molecule_codes=False) | |
def get_labels_from_names(lables_path,names): | |
with open(lables_path, 'rb') as f: | |
lines = f.read().decode().strip().split('\n')[6:] | |
res = {} | |
for line in lines: | |
temp = line.split() | |
name, score = temp[0], float(temp[3]) | |
res[name] = score | |
labels = [] | |
for name in names: | |
labels.append(res[name]) | |
return labels | |
def get_labels_from_names_csar(lables_path,names): | |
with open(lables_path, 'rb') as f: | |
lines = f.read().decode().strip().split('\n')[1:] | |
res = {} | |
for line in lines: | |
temp = [x.strip() for x in line.split(',')] | |
name, score = temp[1], float(temp[2]) | |
res[name] = score | |
labels = [] | |
for name in names: | |
labels.append(res[name]) | |
return labels | |
def get_lig_coords_ground_truth_from_names(lables_path,names): | |
return | |
def lig_atom_type_obmol(obmol): | |
AtomIndex = [atom.atomicnum for atom in obmol if atom.atomicnum > 1] | |
return torch.tensor(AtomIndex,dtype=torch.int64) | |
def lig_atom_type_rdmol(rdmol): | |
AtomIndex = [atom.GetAtomicNum() for atom in rdmol.GetAtoms()] | |
return torch.tensor(AtomIndex,dtype=torch.int64) | |
def get_bonded_edges_obmol(pocket): | |
edge_l = [] | |
idx_map = [-1]*(len(pocket.atoms)+1) | |
idx_new = 0 | |
for atom in pocket: | |
edges = [] | |
a1_sym = atom.atomicnum | |
a1 = atom.idx | |
if a1_sym == 1: | |
continue | |
idx_map[a1] = idx_new | |
idx_new += 1 | |
for natom in openbabel.OBAtomAtomIter(atom.OBAtom): | |
if natom.GetAtomicNum() == 1: | |
continue | |
a2 = natom.GetIdx() | |
bond = openbabel.OBAtom.GetBond(natom,atom.OBAtom) | |
bond_type = CusBondFeaturizer_new(bond) | |
edges.append((a1,a2,bond_type)) | |
edge_l += edges | |
edge_l_new = [] | |
for a1,a2,t in edge_l: | |
a1_, a2_ = idx_map[a1], idx_map[a2] | |
assert((a1_!=-1)&(a2_!=-1)) | |
edge_l_new.append((a1_,a2_,t)) | |
return edge_l_new | |
def get_bonded_edges_rdmol(rdmol): | |
row, col, edge_type = [], [], [] | |
for bond in rdmol.GetBonds(): | |
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() | |
row += [start, end] | |
col += [end, start] | |
edge_type += 2 * [BOND_TYPES[bond.GetBondType()]] | |
return zip(row,col,edge_type) | |
def D3_info(a, b, c): | |
# 空间夹角 | |
ab = b - a # 向量ab | |
ac = c - a # 向量ac | |
cosine_angle = np.dot(ab, ac) / (np.linalg.norm(ab) * np.linalg.norm(ac)) | |
cosine_angle = cosine_angle if cosine_angle >= -1.0 else -1.0 | |
angle = np.arccos(cosine_angle) | |
# 三角形面积 | |
ab_ = np.sqrt(np.sum(ab ** 2)) | |
ac_ = np.sqrt(np.sum(ac ** 2)) # 欧式距离 | |
area = 0.5 * ab_ * ac_ * np.sin(angle) | |
return np.degrees(angle), area, ac_ | |
def D3_info_cal(nodes_ls, g): | |
if len(nodes_ls) > 2: | |
Angles = [] | |
Areas = [] | |
Distances = [] | |
for node_id in nodes_ls[2:]: | |
angle, area, distance = D3_info(g.ndata['pos'][nodes_ls[0]].numpy(), g.ndata['pos'][nodes_ls[1]].numpy(), | |
g.ndata['pos'][node_id].numpy()) | |
Angles.append(angle) | |
Areas.append(area) | |
Distances.append(distance) | |
return [np.max(Angles) * 0.01, np.sum(Angles) * 0.01, np.mean(Angles) * 0.01, np.max(Areas), np.sum(Areas), | |
np.mean(Areas), | |
np.max(Distances) * 0.1, np.sum(Distances) * 0.1, np.mean(Distances) * 0.1] | |
else: | |
return [0, 0, 0, 0, 0, 0, 0, 0, 0] | |
def bond_feature(g): | |
src_nodes, dst_nodes = g.find_edges(range(g.number_of_edges())) | |
src_nodes, dst_nodes = src_nodes.tolist(), dst_nodes.tolist() | |
neighbors_ls = [] | |
for i, src_node in enumerate(src_nodes): | |
tmp = [src_node, dst_nodes[i]] # the source node id and destination id of an edge | |
neighbors = g.predecessors(src_node).tolist() | |
neighbors.remove(dst_nodes[i]) | |
tmp.extend(neighbors) | |
neighbors_ls.append(tmp) | |
D3_info_ls = list(map(partial(D3_info_cal, g=g), neighbors_ls)) | |
D3_info_th = torch.tensor(D3_info_ls, dtype=torch.float) | |
# D3_info_th = torch.cat([g.edata['e'], D3_info_th], dim=-1) | |
return D3_info_th | |
def read_molecules_crossdock(lig_path, prot_path, ligcut, protcut, lig_type, prot_graph_type, dataset_path, chain_cut=5.0): | |
lig_path = os.path.join(dataset_path, lig_path) | |
prot_path = os.path.join(dataset_path, prot_path) | |
if lig_type=='openbabel': | |
m_lig = next(pybel.readfile('sdf', lig_path)) | |
lig_coords, lig_features = featurizer.get_features(m_lig) | |
lig_edges = get_bonded_edges_obmol(m_lig) if ligcut is None else None | |
lig_node_type = lig_atom_type_obmol(m_lig) | |
elif lig_type=='rdkit': | |
m_lig = read_rdmol(lig_path, sanitize=True, remove_hs=True) | |
try: | |
assert m_lig is not None | |
except: | |
raise ValueError(f'sanitize error : {lig_path}') | |
conf = m_lig.GetConformer() | |
lig_coords, lig_features = conf.GetPositions(), lig_atom_featurizer_rdmol(m_lig) | |
lig_edges = get_bonded_edges_rdmol(m_lig) if ligcut is None else None | |
lig_node_type = lig_atom_type_rdmol(m_lig) | |
prot_complex = parsePDB(prot_path) | |
prot_structure_no_water = prot_complex.select('protein') | |
if chain_cut is not None: | |
prot_valid_chains = prot_structure_no_water.select(f'same chain as within {chain_cut} of ligand', ligand=lig_coords) | |
else: | |
prot_valid_chains = prot_structure_no_water | |
prot_valid_pocket = prot_structure_no_water.select('same residue as within 12 of ligand', ligand=lig_coords) | |
prot_alpha_c = prot_valid_chains.select('calpha') | |
prot_pocket_alpha_c = prot_valid_pocket.select('calpha') | |
alpha_c_sec_features = None | |
prot_pocket_alpha_c_sec_features = None | |
alpha_c_coords, c_coords, n_coords, complete_residues = [], [], [], [] # complete_residue means a residue has alpha_c,beta_c,and ,N | |
if prot_graph_type.startswith('atom'): | |
m_prot = prot_valid_pocket if prot_graph_type.endswith('pocket') else prot_valid_chains | |
sec_features = None | |
prot_coords, prot_features = featurizer.get_features(m_prot) | |
prot_edges = get_bonded_edges_obmol(m_prot) if protcut is None else None | |
prot_node_type = lig_atom_type_obmol(m_prot) | |
elif prot_graph_type.startswith('residue'): | |
alpha_c_sec_features = None | |
prot_pocket_alpha_c_sec_features = None | |
m_prot = prot_pocket_alpha_c if prot_graph_type.endswith('pocket') else prot_alpha_c | |
m_prot_complete = prot_valid_pocket if prot_graph_type.endswith('pocket') else prot_valid_chains | |
sec_features = prot_pocket_alpha_c_sec_features if prot_graph_type.endswith('pocket') else alpha_c_sec_features | |
prot_coords, prot_features = prot_alpha_c_featurizer(m_prot) | |
prot_node_type = prot_residue_type(m_prot) | |
prot_edges = None | |
hv = m_prot_complete.getHierView() | |
for chain in hv: | |
for i, residue in enumerate(chain): | |
alpha_c_coord, c_coord, n_coord = None, None, None | |
for atom in residue: | |
if atom.getName() == 'CA': | |
alpha_c_coord = atom.getCoords() | |
if atom.getName() == 'C': | |
c_coord = atom.getCoords() | |
if atom.getName() == 'N': | |
n_coord = atom.getCoords() | |
if alpha_c_coord is not None and c_coord is not None and n_coord is not None: | |
alpha_c_coords.append(alpha_c_coord) | |
c_coords.append(c_coord) | |
n_coords.append(n_coord) | |
complete_residues.append(True) | |
else: | |
complete_residues.append(False) | |
assert len(complete_residues) == len(prot_coords) | |
prot_coords = prot_coords[complete_residues] | |
prot_features = prot_features[complete_residues] | |
prot_node_type = prot_node_type[complete_residues] | |
if sec_features is not None: | |
sec_features = sec_features[complete_residues] | |
assert len(sec_features) == len(prot_coords) | |
assert len(alpha_c_coords) == len(prot_coords) | |
assert len(c_coords) == len(prot_coords) | |
assert len(n_coords) == len(prot_coords) | |
else: | |
raise ValueError("error prot_graph_type") | |
return lig_coords, lig_features, lig_edges, lig_node_type, \ | |
prot_coords, prot_features, prot_edges, prot_node_type, sec_features,\ | |
np.array(alpha_c_coords), np.array(c_coords), np.array(n_coords) | |
def read_ligands_chembl_smina_multi_pose(name, valid_ligand_index, dataset_path, ligcut, lig_type='openbabel', top_N=2, | |
docking_type='site_specific'): | |
valid_lig_multi_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list = [], [], [], [], [] | |
for index, valid in enumerate(valid_ligand_index): | |
if docking_type == 'site_specific': | |
lig_paths_mol2 = [os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}.mol2')] | |
elif docking_type == 'blind': | |
lig_paths_mol2 = [os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}_blind.mol2')] | |
elif docking_type == 'all': | |
lig_paths_mol2 = [os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}.mol2')] +\ | |
[os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}_blind.mol2')] | |
if valid: | |
if lig_type == 'openbabel': | |
lig_multi_coords = [] | |
previou_atom_num = -1 | |
for lig_path_mol2 in lig_paths_mol2: | |
m_lig_iter = pybel.readfile('mol2', lig_path_mol2) | |
c_m_lig = 0 | |
while c_m_lig < top_N: | |
try: | |
m_lig = next(m_lig_iter) | |
lig_coords, lig_features = featurizer.get_features(m_lig) | |
if previou_atom_num != -1: | |
assert len(lig_coords) == previou_atom_num | |
else: | |
previou_atom_num == len(lig_coords) | |
lig_edges = get_bonded_edges_obmol(m_lig) | |
lig_node_type = lig_atom_type_obmol(m_lig) | |
lig_multi_coords.append(lig_coords) | |
c_m_lig += 1 | |
except: | |
print(f'{lig_path_mol2} only has {c_m_lig} poses') | |
break | |
valid_lig_multi_coords_list.append(lig_multi_coords) | |
valid_lig_features_list.append(lig_features) | |
valid_lig_edges_list.append(lig_edges) | |
valid_lig_node_type_list.append(lig_node_type) | |
valid_index_list.append(index) | |
return valid_lig_multi_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list | |
def read_ligands_chembl_smina(name, valid_ligand_index, dataset_path, ligcut, lig_type='openbabel',docking_type='site_specific'): | |
valid_lig_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list = [], [], [], [], [] | |
for index, valid in enumerate(valid_ligand_index): | |
lig_path_mol2 = os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}.mol2') | |
if docking_type == 'blind': | |
lig_path_mol2 = os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}_blind.mol2') | |
if valid: | |
if lig_type == 'openbabel': | |
try: | |
m_lig = next(pybel.readfile('mol2', lig_path_mol2)) | |
except: | |
print(lig_path_mol2) | |
lig_coords, lig_features = featurizer.get_features(m_lig) | |
lig_edges = get_bonded_edges_obmol(m_lig) | |
lig_node_type = lig_atom_type_obmol(m_lig) | |
valid_lig_coords_list.append(lig_coords) | |
valid_lig_features_list.append(lig_features) | |
valid_lig_edges_list.append(lig_edges) | |
valid_lig_node_type_list.append(lig_node_type) | |
valid_index_list.append(index) | |
elif lig_type == 'rdkit': | |
m_lig = read_rdmol(lig_path_mol2) | |
conf = m_lig.GetConformer() | |
lig_coords, lig_features = conf.GetPositions(), lig_atom_featurizer_rdmol(m_lig) | |
lig_edges = get_bonded_edges_rdmol(m_lig) | |
lig_node_type = lig_atom_type_rdmol(m_lig) | |
valid_lig_coords_list.append(lig_coords) | |
valid_lig_features_list.append(lig_features) | |
valid_lig_edges_list.append(lig_edges) | |
valid_lig_node_type_list.append(lig_node_type) | |
valid_index_list.append(index) | |
return valid_lig_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list | |
def read_ligands(name, dataset_path, ligcut, lig_type='openbabel'): | |
#########################Read Ligand######################################################## | |
lig_path_sdf = os.path.join(dataset_path, name, 'visualize_dir', 'total_vs.sdf') | |
valid_lig_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list = [], [], [], [], [] | |
if lig_type == 'openbabel': | |
m_ligs = pybel.readfile('sdf', lig_path_sdf) | |
for index, m_lig in enumerate(m_ligs): | |
try: | |
lig_coords, lig_features = featurizer.get_features(m_lig) | |
lig_edges = get_bonded_edges_obmol(m_lig) | |
lig_node_type = lig_atom_type_obmol(m_lig) | |
valid_lig_coords_list.append(lig_coords) | |
valid_lig_features_list.append(lig_features) | |
valid_lig_edges_list.append(lig_edges) | |
valid_lig_node_type_list.append(lig_node_type) | |
valid_index_list.append(index) | |
except: | |
print(f'{index} error') | |
elif lig_type == 'rdkit': | |
supplier = Chem.SDMolSupplier(lig_path_sdf, sanitize=True, removeHs=False) | |
for index, m_lig in enumerate(supplier): | |
try: | |
conf = m_lig.GetConformer() | |
lig_coords, lig_features = conf.GetPositions(), lig_atom_featurizer_rdmol(m_lig) | |
lig_edges = get_bonded_edges_rdmol(m_lig) | |
lig_node_type = lig_atom_type_rdmol(m_lig) | |
valid_lig_coords_list.append(lig_coords) | |
valid_lig_features_list.append(lig_features) | |
valid_lig_edges_list.append(lig_edges) | |
valid_lig_node_type_list.append(lig_node_type) | |
valid_index_list.append(index) | |
except: | |
print(f'{index} error') | |
return valid_lig_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list | |
def read_casf_ligands(name, dataset_path, ligcut, lig_type='openbabel'): | |
lig_files = os.listdir(os.path.join(dataset_path, name)) | |
assert lig_type == 'openbabel' | |
lig_multi_name_list, lig_multi_coords_list, lig_features_list, lig_edges_list, lig_node_type_list = [], [], [], [], [] | |
for lig_file in lig_files: | |
lig_name = lig_file.split('_')[-1][:4] | |
file_type = lig_file.split('.')[-1] | |
lig_path = os.path.join(dataset_path, name, lig_file) | |
m_ligs = pybel.readfile(file_type, lig_path) | |
multi_coords, multi_names = [], [] | |
for index, m_lig in enumerate(m_ligs): | |
lig_coords, lig_features = featurizer.get_features(m_lig) | |
if index == 0: | |
lig_edges = get_bonded_edges_obmol(m_lig) | |
lig_node_type = lig_atom_type_obmol(m_lig) | |
multi_coords.append(lig_coords) | |
multi_names.append(f'{lig_name}_ligand_{index+1}') | |
lig_multi_name_list.append(multi_names) | |
lig_multi_coords_list.append(multi_coords) | |
lig_features_list.append(lig_features) | |
lig_edges_list.append(lig_edges) | |
lig_node_type_list.append(lig_node_type) | |
return lig_multi_name_list, lig_multi_coords_list, lig_features_list, lig_edges_list, lig_node_type_list | |
def read_proteins(name, dataset_path, prot_graph_type, protcut): | |
#########################Read Protein######################################################## | |
try: | |
prot_valid_chains = parsePDB(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb')) | |
except: | |
raise ValueError(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb')) | |
prot_alpha_c = prot_valid_chains.select('calpha') | |
alpha_c_coords, c_coords, n_coords = [], [], [] | |
# writePDB(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb'), prot_valid_chains) | |
if prot_graph_type.startswith('atom'): | |
prot_path = os.path.join(dataset_path, name, f'{name}_{graph_type_filename[prot_graph_type]}') | |
m_prot = next(pybel.readfile('pdb', prot_path)) | |
sec_features = None | |
prot_coords_valid, prot_features_valid = featurizer.get_features(m_prot) | |
prot_edges = get_bonded_edges_obmol(m_prot) if protcut is None else None | |
prot_node_type = lig_atom_type_obmol(m_prot) | |
elif prot_graph_type.startswith('residue'): | |
alpha_c_sec_features = None | |
m_prot = prot_alpha_c | |
m_prot_complete = prot_valid_chains | |
sec_features = alpha_c_sec_features | |
prot_coords, prot_features = prot_alpha_c_featurizer(m_prot) | |
prot_node_type = prot_residue_type(m_prot) | |
prot_edges = None | |
hv = m_prot_complete.getHierView() | |
index = 0 | |
valid_index, prot_coords_valid, prot_features_valid = [], [], [] | |
for chain in hv: | |
for i, residue in enumerate(chain): | |
alpha_c_coord, c_coord, n_coord = None, None, None | |
for atom in residue: | |
if atom.getName() == 'CA': | |
alpha_c_coord = atom.getCoords() | |
if atom.getName() == 'C': | |
c_coord = atom.getCoords() | |
if atom.getName() == 'N': | |
n_coord = atom.getCoords() | |
if alpha_c_coord is not None and c_coord is not None and n_coord is not None: | |
alpha_c_coords.append(alpha_c_coord) | |
c_coords.append(c_coord) | |
n_coords.append(n_coord) | |
valid_index.append(index) | |
index += 1 | |
prot_coords_valid = prot_coords[valid_index] | |
prot_features_valid = prot_features[valid_index] | |
else: | |
raise ValueError("error prot_graph_type") | |
return prot_coords_valid, prot_features_valid, prot_edges, prot_node_type, sec_features,\ | |
np.array(alpha_c_coords), np.array(c_coords), np.array(n_coords),\ | |
def read_molecules(name, dataset_path, prot_graph_type, ligcut, protcut, lig_type='openbabel',init_type='redock_init', | |
chain_cut=5.0, p2rank_base=None, binding_site_type='ligand_center', LAS_mask=True, | |
keep_hs_before_rdkit_generate=False, rd_gen_maxIters=200): | |
#########################Read Ligand######################################################## | |
lig_path_mol2 = os.path.join(dataset_path, name, f'{name}_ligand.mol2') | |
lig_path_sdf = os.path.join(dataset_path, name, f'{name}_ligand.sdf') | |
if lig_type == 'openbabel': | |
m_lig = next(pybel.readfile('mol2', lig_path_mol2)) | |
lig_coords, lig_features = featurizer.get_features(m_lig) | |
lig_edges = get_bonded_edges_obmol(m_lig) | |
lig_node_type = lig_atom_type_obmol(m_lig) | |
elif lig_type == 'rdkit': | |
m_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True) | |
if m_lig == None: # read mol2 file if sdf file cannot be sanitized | |
m_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True) | |
conf = m_lig.GetConformer() | |
lig_coords, lig_features = conf.GetPositions(), lig_atom_featurizer_rdmol(m_lig) | |
lig_edges = get_bonded_edges_rdmol(m_lig) | |
lig_node_type = lig_atom_type_rdmol(m_lig) | |
#########################Get Ligand Rdkit Init Coordinates################################### | |
if init_type == 'rdkit_init': | |
rd_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=not keep_hs_before_rdkit_generate) | |
if rd_lig == None: # read mol2 file if sdf file cannot be sanitized | |
rd_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=not keep_hs_before_rdkit_generate) | |
try: | |
lig_init_coords = get_rdkit_coords(rd_lig, sanitize=True, remove_hs=True, maxIters=rd_gen_maxIters) | |
except Exception as e: | |
lig_init_coords = lig_coords | |
with open(f'temp_create_dataset_rdkit_timesplit_no_lig_or_rec_overlap_train_remove_hs_before_generate_{not keep_hs_before_rdkit_generate}.log', 'a') as f: | |
f.write('Generating RDKit conformer failed for \n') | |
f.write(name) | |
f.write('\n') | |
f.write(str(e)) | |
f.write('\n') | |
f.flush() | |
assert len(lig_init_coords) == len(lig_coords) | |
rdlig_node_type = lig_atom_type_rdmol(rd_lig) | |
# remove all h | |
if lig_type == 'openbabel': | |
lig_init_coords = lig_init_coords[rdlig_node_type != 1] | |
try: | |
if len(lig_init_coords)!=len(lig_coords): | |
raise ValueError('{} {}!={}'.format(name, len(lig_init_coords), len(lig_coords))) | |
except ValueError as e: | |
print("error raise:", repr(e)) | |
raise | |
elif init_type == 'redock_init': | |
lig_init_coords = lig_coords | |
elif init_type == 'random_init': | |
lig_init_coords = np.random.randn(len(lig_coords),3) | |
else: | |
lig_init_coords = None | |
# random location and orientation | |
if lig_init_coords is not None: | |
rot_T, rot_b = random_rotation_translation() | |
mean_to_remove = lig_init_coords.mean(axis=0, keepdims=True) | |
lig_init_coords = (rot_T @ (lig_init_coords - mean_to_remove).T).T + rot_b | |
#########################Read Protein######################################################## | |
if os.path.exists(os.path.join(dataset_path, name, f'{name}_protein_processed.pdb')): | |
prot_complex = parsePDB(os.path.join(dataset_path, name, f'{name}_protein_processed.pdb')) | |
else: | |
prot_complex = parsePDB(os.path.join(dataset_path, name, f'{name}_protein.pdb')) | |
prot_structure_no_water = prot_complex.select('protein') | |
if chain_cut is not None: | |
prot_valid_chains = prot_structure_no_water.select(f'same chain as within {chain_cut} of ligand', ligand=lig_coords) | |
else: | |
prot_valid_chains = prot_structure_no_water | |
prot_valid_pocket = prot_structure_no_water.select('same residue as within 12 of ligand', ligand=lig_coords) | |
try: | |
prot_alpha_c = prot_valid_chains.select('calpha') | |
prot_pocket_alpha_c = prot_valid_pocket.select('calpha') | |
except: | |
raise ValueError(f'{name} process pdb error') | |
alpha_c_sec_features = None | |
prot_pocket_alpha_c_sec_features = None | |
alpha_c_coords, c_coords, n_coords = [], [], [] | |
writePDB(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb'), prot_valid_chains) | |
writePDB(os.path.join(dataset_path, name, f'{name}_valid_pocket.pdb'), prot_valid_pocket) | |
if prot_graph_type.startswith('atom'): | |
prot_path = os.path.join(dataset_path, name, f'{name}_{graph_type_filename[prot_graph_type]}') | |
m_prot = next(pybel.readfile('pdb', prot_path)) | |
sec_features = None | |
prot_coords_valid, prot_features_valid = featurizer.get_features(m_prot) | |
prot_edges = get_bonded_edges_obmol(m_prot) if protcut is None else None | |
prot_node_type = lig_atom_type_obmol(m_prot) | |
elif prot_graph_type.startswith('residue'): | |
alpha_c_sec_features = None | |
prot_pocket_alpha_c_sec_features = None | |
m_prot = prot_pocket_alpha_c if prot_graph_type.endswith('pocket') else prot_alpha_c | |
m_prot_complete = prot_valid_pocket if prot_graph_type.endswith('pocket') else prot_valid_chains | |
sec_features = prot_pocket_alpha_c_sec_features if prot_graph_type.endswith('pocket') else alpha_c_sec_features | |
prot_coords, prot_features = prot_alpha_c_featurizer(m_prot) | |
prot_node_type = prot_residue_type(m_prot) | |
prot_edges = None | |
hv = m_prot_complete.getHierView() | |
index = 0 | |
valid_index, prot_coords_valid, prot_features_valid = [], [], [] | |
for chain in hv: | |
for i, residue in enumerate(chain): | |
alpha_c_coord, c_coord, n_coord = None, None, None | |
for atom in residue: | |
if atom.getName() == 'CA': | |
alpha_c_coord = atom.getCoords() | |
if atom.getName() == 'C': | |
c_coord = atom.getCoords() | |
if atom.getName() == 'N': | |
n_coord = atom.getCoords() | |
if alpha_c_coord is not None and c_coord is not None and n_coord is not None: | |
alpha_c_coords.append(alpha_c_coord) | |
c_coords.append(c_coord) | |
n_coords.append(n_coord) | |
valid_index.append(index) | |
index += 1 | |
prot_coords_valid = prot_coords[valid_index] | |
prot_features_valid = prot_features[valid_index] | |
else: | |
raise ValueError("error prot_graph_type") | |
############################### Read Binding Site ########################################## | |
if binding_site_type == 'p2rank': | |
p2rank_result_path = os.path.join(p2rank_base, f'{name}_valid_chains.pdb_predictions.csv') | |
df = pd.read_csv(p2rank_result_path, usecols= [' center_x',' center_y',' center_z']) | |
possible_binding_sites = df.values | |
ligand_center = lig_coords.mean(axis=0) | |
if len(possible_binding_sites) == 0: | |
binding_site = ligand_center | |
else: | |
binding_site_index = ((possible_binding_sites - ligand_center) ** 2).sum(axis=1).argmin() | |
binding_site = possible_binding_sites[binding_site_index] | |
elif binding_site_type == 'ligand_center': | |
binding_site = lig_coords.mean(axis=0) | |
############################### Get LAS Mask ########################################## | |
if LAS_mask: | |
assert lig_type == 'rdkit' | |
lig_LAS_mask = get_LAS_distance_constraint_mask(m_lig) | |
else: | |
lig_LAS_mask = None | |
return lig_coords, lig_features, lig_edges, lig_node_type, lig_init_coords, \ | |
prot_coords_valid, prot_features_valid, prot_edges, prot_node_type, sec_features,\ | |
np.array(alpha_c_coords), np.array(c_coords), np.array(n_coords),\ | |
binding_site.reshape(1,-1), lig_LAS_mask | |
def read_molecules_inference(lig_path, protein_path, prot_graph_type, chain_cut=5.0): | |
#########################Read Ligand######################################################## | |
m_lig = next(pybel.readfile(lig_path.split('.')[-1], lig_path)) | |
lig_coords, lig_features = featurizer.get_features(m_lig) | |
lig_edges = get_bonded_edges_obmol(m_lig) | |
lig_node_type = lig_atom_type_obmol(m_lig) | |
#########################Read Protein######################################################## | |
prot_complex = parsePDB(protein_path) | |
prot_structure_no_water = prot_complex.select('protein') | |
if chain_cut is not None: | |
prot_valid_chains = prot_structure_no_water.select(f'same chain as within {chain_cut} of ligand', | |
ligand=lig_coords) | |
else: | |
prot_valid_chains = prot_structure_no_water | |
prot_valid_pocket = prot_structure_no_water.select('same residue as within 12 of ligand', ligand=lig_coords) | |
prot_alpha_c = prot_valid_chains.select('calpha') | |
prot_pocket_alpha_c = prot_valid_pocket.select('calpha') | |
alpha_c_coords, c_coords, n_coords = [], [], [] | |
alpha_c_sec_features,prot_pocket_alpha_c_sec_features = None, None | |
m_prot = prot_pocket_alpha_c if prot_graph_type.endswith('pocket') else prot_alpha_c | |
m_prot_complete = prot_valid_pocket if prot_graph_type.endswith('pocket') else prot_valid_chains | |
sec_features = prot_pocket_alpha_c_sec_features if prot_graph_type.endswith('pocket') else alpha_c_sec_features | |
prot_coords, prot_features = prot_alpha_c_featurizer(m_prot) | |
prot_node_type = prot_residue_type(m_prot) | |
prot_edges = None | |
hv = m_prot_complete.getHierView() | |
index = 0 | |
valid_index, prot_coords_valid, prot_features_valid, ca_res_number_valid, residue_name_valid, chain_index_valid = [], [], [], [], [], [] | |
for chain in hv: | |
for i, residue in enumerate(chain): | |
alpha_c_coord, c_coord, n_coord = None, None, None | |
ca_res_number = residue.getResnums()[0] | |
residue_name = residue.getResname() | |
chain_index = residue.getChid() | |
# input(ca_res_number) | |
# input(residue_name) | |
for atom in residue: | |
if atom.getName() == 'CA': | |
alpha_c_coord = atom.getCoords() | |
if atom.getName() == 'C': | |
c_coord = atom.getCoords() | |
if atom.getName() == 'N': | |
n_coord = atom.getCoords() | |
if alpha_c_coord is not None and c_coord is not None and n_coord is not None: | |
alpha_c_coords.append(alpha_c_coord) | |
c_coords.append(c_coord) | |
n_coords.append(n_coord) | |
valid_index.append(index) | |
ca_res_number_valid.append(ca_res_number) | |
residue_name_valid.append(residue_name) | |
chain_index_valid.append(chain_index) | |
index += 1 | |
prot_coords_valid = alpha_c_coords | |
ResIndex_valid = [ResDict.get(ResName,UNKOWN_RES) for ResName in residue_name_valid] | |
prot_node_type = torch.tensor(ResIndex_valid,dtype=torch.int64) | |
prot_features_valid = torch.tensor(np.eye(UNKOWN_RES + 1)[ResIndex_valid]) | |
return lig_coords, lig_features, lig_edges, lig_node_type, \ | |
prot_coords_valid, prot_features_valid, prot_edges, prot_node_type, sec_features, \ | |
np.array(alpha_c_coords), np.array(c_coords), np.array(n_coords), ca_res_number_valid, chain_index_valid | |
def get_ligand_smiles(name, dataset_path,): | |
lig_path_mol2 = os.path.join(dataset_path, name, f'{name}_ligand.mol2') | |
lig_path_sdf = os.path.join(dataset_path, name, f'{name}_ligand.sdf') | |
m_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True) | |
if m_lig == None: # read mol2 file if sdf file cannot be sanitized | |
m_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True) | |
sm = Chem.MolToSmiles(m_lig) | |
m_sm_order = list(m_lig.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder']) | |
sm2m_order = [0] * len(m_sm_order) | |
for index, order in enumerate(m_sm_order): | |
sm2m_order[order] = index | |
return sm, sm2m_order | |
def get_protein_fasta(name, dataset_path,): | |
try: | |
prot_valid_chains = parsePDB(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb')) | |
except: | |
raise ValueError(f'{name} error!') | |
hv = prot_valid_chains.getHierView() | |
index = 0 | |
valid_index, prot_coords_valid, prot_features_valid = [], [], [] | |
alpha_c_coords, c_coords, n_coords = [], [], [] | |
fasta_list = [] | |
for chain in hv: | |
fasta_list.append(chain.getSequence()) | |
for i, residue in enumerate(chain): | |
alpha_c_coord, c_coord, n_coord = None, None, None | |
for atom in residue: | |
if atom.getName() == 'CA': | |
alpha_c_coord = atom.getCoords() | |
if atom.getName() == 'C': | |
c_coord = atom.getCoords() | |
if atom.getName() == 'N': | |
n_coord = atom.getCoords() | |
if alpha_c_coord is not None and c_coord is not None and n_coord is not None: | |
alpha_c_coords.append(alpha_c_coord) | |
c_coords.append(c_coord) | |
n_coords.append(n_coord) | |
valid_index.append(index) | |
index += 1 | |
return fasta_list, valid_index | |
def prot_p2rank_feats(p2rank_result_path, p2rank_feats_tpye='zscore', pocket_cut=10): | |
df = pd.read_csv(p2rank_result_path) | |
df.columns = df.columns.str.strip() | |
residue_zscores, residue_pocket_idxs = df[p2rank_feats_tpye].values, df['pocket'].values | |
feat_len = pocket_cut + 2 | |
p2rank_feats = torch.zeros(len(residue_zscores), feat_len) | |
for index, residue_zscore in enumerate(residue_zscores): | |
pocket_idx = residue_pocket_idxs[index] | |
if pocket_idx == 0 : | |
p2rank_feats[index, feat_len - 1] = residue_zscore | |
elif pocket_idx > pocket_cut : | |
p2rank_feats[index, feat_len - 2] = residue_zscore | |
else: | |
p2rank_feats[index, pocket_idx - 1] = residue_zscore | |
return p2rank_feats | |
def get_p2rank_feats(name, dataset_path, p2rank_base=None, p2rank_feats_tpye='zscore', pocket_cut=10): | |
try: | |
prot_valid_chains = parsePDB(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb')) | |
except: | |
raise ValueError(f'{name} error!') | |
prot_alpha_c = prot_valid_chains.select('calpha') | |
alpha_c_coords, c_coords, n_coords = [], [], [] | |
p2rank_result_path = os.path.join(p2rank_base, f'{name}_valid_chains.pdb_residues.csv') | |
p2rank_features = prot_p2rank_feats(p2rank_result_path, pocket_cut=pocket_cut, p2rank_feats_tpye=p2rank_feats_tpye) | |
prot_coords, prot_features = prot_alpha_c_featurizer(prot_alpha_c) | |
try: | |
assert len(p2rank_features) == len(prot_features) | |
except: | |
# print(f'p2rank protein number, {len(p2rank_features)}') | |
# print(f'prot_features protein number, {len(prot_features)}') | |
# raise ValueError(f'p2rank length error, {name}') | |
with open('p2rank_feats_error.txt','a') as f: | |
f.write(f'{name}\n') | |
return torch.zeros(len(prot_features), pocket_cut + 2) | |
hv = prot_valid_chains.getHierView() | |
index = 0 | |
valid_index, prot_coords_valid, prot_features_valid = [], [], [] | |
for chain in hv: | |
for i, residue in enumerate(chain): | |
alpha_c_coord, c_coord, n_coord = None, None, None | |
for atom in residue: | |
if atom.getName() == 'CA': | |
alpha_c_coord = atom.getCoords() | |
if atom.getName() == 'C': | |
c_coord = atom.getCoords() | |
if atom.getName() == 'N': | |
n_coord = atom.getCoords() | |
if alpha_c_coord is not None and c_coord is not None and n_coord is not None: | |
alpha_c_coords.append(alpha_c_coord) | |
c_coords.append(c_coord) | |
n_coords.append(n_coord) | |
valid_index.append(index) | |
index += 1 | |
p2rank_features_valid = p2rank_features[valid_index] | |
return p2rank_features_valid | |
def binarize(x): | |
return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) | |
#adj - > n_hops connections adj | |
def n_hops_adj(adj, n_hops): | |
adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))] | |
for i in range(2, n_hops+1): | |
adj_mats.append(binarize(adj_mats[i-1] @ adj_mats[1])) | |
extend_mat = torch.zeros_like(adj) | |
for i in range(1, n_hops+1): | |
extend_mat += (adj_mats[i] - adj_mats[i-1]) * i | |
return extend_mat | |
def get_LAS_distance_constraint_mask(mol): | |
# Get the adj | |
adj = Chem.GetAdjacencyMatrix(mol) | |
adj = torch.from_numpy(adj) | |
extend_adj = n_hops_adj(adj,2) | |
# add ring | |
ssr = Chem.GetSymmSSSR(mol) | |
for ring in ssr: | |
# print(ring) | |
for i in ring: | |
for j in ring: | |
if i==j: | |
continue | |
else: | |
extend_adj[i][j]+=1 | |
# turn to mask | |
mol_mask = binarize(extend_adj) | |
return mol_mask | |
def get_lig_graph_geodiff(lig_coords, lig_features, lig_node_type, lig_edges): | |
g_lig = dgl.DGLGraph() | |
num_atoms_lig = len(lig_coords) # number of ligand atom_level | |
g_lig.add_nodes(num_atoms_lig) | |
g_lig.ndata['h'] = torch.from_numpy(lig_features) if isinstance(lig_features, np.ndarray) else lig_features | |
g_lig.ndata['node_type'] = lig_node_type # schnet\mgcn features | |
edges = lig_edges | |
src_ls, dst_ls, bond_type = list(zip(*edges)) | |
src_ls, dst_ls = np.array(src_ls), np.array(dst_ls) | |
g_lig.add_edges(src_ls, dst_ls) | |
g_lig.ndata['pos'] = torch.tensor(lig_coords, dtype=torch.float) | |
g_lig.edata['bond_type'] = torch.tensor(bond_type, dtype=torch.int64) | |
return g_lig | |
def get_lig_multi_pose_graph_equibind(lig_multi_coords, lig_features, lig_node_type, max_neighbors=None, cutoff=5.0): | |
multi_graphs = [] | |
for lig_coords in lig_multi_coords: | |
multi_graphs.append(get_lig_graph_equibind(lig_coords, lig_features, lig_node_type, max_neighbors, cutoff)) | |
return multi_graphs | |
def get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type, max_neighbors=None, cutoff=5.0): | |
num_nodes = lig_coords.shape[0] | |
assert lig_coords.shape[1] == 3 | |
distance = spatial.distance_matrix(lig_coords, lig_coords) | |
src_list = [] | |
dst_list = [] | |
dist_list = [] | |
mean_norm_list = [] | |
for i in range(num_nodes): | |
dst = list(np.where(distance[i, :] < cutoff)[0]) | |
dst.remove(i) | |
if max_neighbors != None and len(dst) > max_neighbors: | |
dst = list(np.argsort(distance[i, :]))[1: max_neighbors + 1] # closest would be self loop | |
if len(dst) == 0: | |
dst = list(np.argsort(distance[i, :]))[1:2] # closest would be the index i itself > self loop | |
print( | |
f'The lig_radius {cutoff} was too small for one lig atom such that it had no neighbors. So we connected {i} to the closest other lig atom {dst}') | |
assert i not in dst | |
src = [i] * len(dst) | |
src_list.extend(src) | |
dst_list.extend(dst) | |
valid_dist = list(distance[i, dst]) | |
dist_list.extend(valid_dist) | |
valid_dist_np = distance[i, dst] | |
sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1)) | |
weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num) | |
assert weights[0].sum() > 1 - 1e-2 and weights[0].sum() < 1.01 | |
diff_vecs = lig_coords[src, :] - lig_coords[dst, :] # (neigh_num, 3) | |
mean_vec = weights.dot(diff_vecs) # (sigma_num, 3) | |
denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,) | |
mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,) | |
mean_norm_list.append(mean_vec_ratio_norm) | |
assert len(src_list) == len(dst_list) | |
assert len(dist_list) == len(dst_list) | |
graph = dgl.graph((torch.tensor(src_list), torch.tensor(dst_list)), num_nodes=num_nodes, idtype=torch.int32) | |
graph.ndata['h'] = torch.from_numpy(lig_features) if isinstance(lig_features, np.ndarray) else lig_features | |
graph.ndata['node_type'] = lig_node_type # schnet\mgcn features | |
graph.edata['e'] = distance_featurizer(dist_list, 0.75) # avg distance = 1.3 So divisor = (4/7)*1.3 = ~0.75 | |
graph.ndata['pos'] = torch.from_numpy(np.array(lig_coords).astype(np.float32)) | |
graph.ndata['mu_r_norm'] = torch.from_numpy(np.array(mean_norm_list).astype(np.float32)) | |
if lig_edges is not None: | |
edge_src_dst_2_edge_index = {} | |
for idx, (s, d) in enumerate(zip(src_list, dst_list)): | |
edge_src_dst_2_edge_index[(s, d)] = idx | |
bond_src_ls, bond_dst_ls, bond_type = list(zip(*lig_edges)) | |
bond_edge_idx = [] | |
for bs, bd in zip(bond_src_ls, bond_dst_ls): | |
bond_edge_idx.append(edge_src_dst_2_edge_index[(bs, bd)]) | |
graph.edata['bond_type'] = torch.zeros(len(src_list), len(bond_type[0])) | |
graph.edata['bond_type'][bond_edge_idx] = torch.tensor(bond_type).to(torch.float32) | |
return graph | |
def get_lig_graph(lig_coords,lig_features, lig_edges, lig_node_type, cutoff=None): | |
g_lig = dgl.DGLGraph() | |
num_atoms_lig = len(lig_coords) # number of ligand atom_level | |
g_lig.add_nodes(num_atoms_lig) | |
g_lig.ndata['h'] = torch.from_numpy(lig_features) if isinstance(lig_features, np.ndarray) else lig_features | |
g_lig.ndata['node_type'] = lig_node_type # schnet\mgcn features | |
dis_matrix_lig = spatial.distance_matrix(lig_coords, lig_coords) | |
if cutoff is None: | |
edges = lig_edges | |
src_ls, dst_ls, bond_type = list(zip(*edges)) | |
src_ls, dst_ls = np.array(src_ls), np.array(dst_ls) | |
else: | |
node_idx = np.where( (dis_matrix_lig < cutoff) & (dis_matrix_lig!=0) ) # no self-loop | |
src_ls = node_idx[0] | |
dst_ls = node_idx[1] | |
g_lig.add_edges(src_ls, dst_ls) | |
lig_d = torch.tensor(dis_matrix_lig[src_ls, dst_ls], dtype=torch.float).view(-1, 1) | |
g_lig.edata['distance'] = lig_d | |
g_lig.edata['e'] = lig_d * 0.1 # g.edata['e'] ~ (n_bond1+n_bond2) * k | |
g_lig.ndata['pos'] = torch.tensor(lig_coords, dtype=torch.float) | |
D3_info = bond_feature(g_lig) | |
g_lig.edata['e'] = torch.cat([g_lig.edata['e'], D3_info], dim=-1) | |
g_lig.edata['bond_type'] = torch.tensor(bond_type,dtype=torch.int64) | |
# g_lig.ndata.pop('pos') | |
assert not torch.any(torch.isnan(D3_info)) | |
return g_lig | |
def get_prot_atom_graph(prot_coords, prot_features, prot_edges, prot_node_type, cutoff=None): | |
g_prot = dgl.DGLGraph() | |
num_atoms_prot = len(prot_coords) | |
g_prot.add_nodes(num_atoms_prot) | |
g_prot.ndata['h'] = torch.from_numpy(prot_features) if isinstance(prot_features, np.ndarray) else prot_features | |
g_prot.ndata['node_type'] = prot_node_type # schnet\mgcn features | |
dis_matrix_lig = spatial.distance_matrix(prot_coords, prot_coords) | |
if cutoff is None: | |
edges = prot_edges | |
src_ls, dst_ls, bond_type = list(zip(*edges)) | |
src_ls, dst_ls = np.array(src_ls), np.array(dst_ls) | |
else: | |
node_idx = np.where( (dis_matrix_lig < cutoff) & (dis_matrix_lig!=0) ) # no self-loop | |
src_ls = node_idx[0] | |
dst_ls = node_idx[1] | |
g_prot.add_edges(src_ls, dst_ls) | |
prot_d = torch.tensor(dis_matrix_lig[src_ls, dst_ls], dtype=torch.float).view(-1, 1) | |
g_prot.edata['distance'] = prot_d | |
g_prot.edata['e'] = prot_d * 0.1 # g.edata['e'] ~ (n_bond1+n_bond2) * k | |
g_prot.ndata['pos'] = torch.tensor(prot_coords, dtype=torch.float) | |
D3_info = bond_feature(g_prot) | |
g_prot.edata['e'] = torch.cat([g_prot.edata['e'], D3_info], dim=-1) | |
g_prot.edata['bond_type'] = torch.tensor(bond_type, dtype=torch.int64) | |
# g_prot.ndata.pop('pos') | |
assert not torch.any(torch.isnan(D3_info)) | |
return g_prot | |
def distance_featurizer(dist_list, divisor) -> torch.Tensor: | |
# you want to use a divisor that is close to 4/7 times the average distance that you want to encode | |
length_scale_list = [1.5 ** x for x in range(15)] | |
center_list = [0. for _ in range(15)] | |
num_edge = len(dist_list) | |
dist_list = np.array(dist_list) | |
transformed_dist = [np.exp(- ((dist_list / divisor) ** 2) / float(length_scale)) | |
for length_scale, center in zip(length_scale_list, center_list)] | |
transformed_dist = np.array(transformed_dist).T | |
transformed_dist = transformed_dist.reshape((num_edge, -1)) | |
return torch.from_numpy(transformed_dist.astype(np.float32)) | |
def local_coordinate_system_feature(prot_coords, c_alpha_coords, c_coords, n_coords, prot_d, src_ls, dst_ls): | |
n_i_list, u_i_list, v_i_list = [], [], [] | |
for i in range(len(prot_coords)): | |
nitrogen = n_coords[i] | |
c_alpha = c_alpha_coords[i] | |
carbon = c_coords[i] | |
u_i = (nitrogen - c_alpha) / np.linalg.norm(nitrogen - c_alpha) | |
t_i = (carbon - c_alpha) / np.linalg.norm(carbon - c_alpha) | |
n_i = np.cross(u_i, t_i) / np.linalg.norm(np.cross(u_i, t_i)) | |
v_i = np.cross(n_i, u_i) | |
assert (math.fabs( | |
np.linalg.norm(v_i) - 1.) < 1e-5), "protein utils protein_to_graph_dips, v_i norm larger than 1" | |
n_i_list.append(n_i) | |
u_i_list.append(u_i) | |
v_i_list.append(v_i) | |
n_i_feat, u_i_feat, v_i_feat = np.stack(n_i_list), np.stack(u_i_list), np.stack(v_i_list) | |
edge_feat_ori_list = [] | |
for i in range(len(prot_d)): | |
src = src_ls[i] | |
dst = dst_ls[i] | |
# place n_i, u_i, v_i as lines in a 3x3 basis matrix | |
basis_matrix = np.stack((n_i_feat[dst, :], u_i_feat[dst, :], v_i_feat[dst, :]), axis=0) | |
p_ij = np.matmul(basis_matrix, c_alpha_coords[src, :] - c_alpha_coords[dst, :]) | |
q_ij = np.matmul(basis_matrix, n_i_feat[src, :]) # shape (3,) | |
k_ij = np.matmul(basis_matrix, u_i_feat[src, :]) | |
t_ij = np.matmul(basis_matrix, v_i_feat[src, :]) | |
s_ij = np.concatenate((p_ij, q_ij, k_ij, t_ij), axis=0) # shape (12,) | |
edge_feat_ori_list.append(s_ij) | |
edge_feat_ori_feat = np.stack(edge_feat_ori_list, axis=0) # shape (num_edges, 12) | |
edge_feat_ori_feat = torch.from_numpy(edge_feat_ori_feat.astype(np.float32)) | |
c_alpha_edge_feat = torch.cat([distance_featurizer(prot_d, divisor=4), edge_feat_ori_feat],axis=1) # shape (num_edges, 17) | |
return c_alpha_edge_feat | |
def get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type, sec_features, | |
alpha_c_coords, c_coords, n_coords, | |
max_neighbor=None, cutoff=None): | |
try: | |
assert len(alpha_c_coords) == len(prot_coords) | |
assert len(c_coords) == len(prot_coords) | |
assert len(n_coords) == len(prot_coords) | |
except: | |
raise ValueError(f'{len(alpha_c_coords)} == {len(prot_coords)}, {len(c_coords)} == {len(prot_coords)}, {len(n_coords)} == {len(prot_coords)}') | |
g_prot = dgl.DGLGraph() | |
num_atoms_prot = len(prot_coords) # number of pocket atom_level | |
g_prot.add_nodes(num_atoms_prot) | |
g_prot.ndata['h'] = prot_features | |
g_prot.ndata['node_type'] = prot_node_type[:num_atoms_prot] | |
distances = spatial.distance_matrix(prot_coords, prot_coords) | |
src_list = [] | |
dst_list = [] | |
dist_list = [] | |
mean_norm_list = [] | |
for i in range(num_atoms_prot): | |
dst = list(np.where(distances[i, :] < cutoff)[0]) | |
dst.remove(i) | |
if max_neighbor != None and len(dst) > max_neighbor: | |
dst = list(np.argsort(distances[i, :]))[1: max_neighbor + 1] | |
if len(dst) == 0: | |
dst = list(np.argsort(distances[i, :]))[1:2] # choose second because first is i itself | |
print( | |
f'The c_alpha_cutoff {cutoff} was too small for one c_alpha such that it had no neighbors. So we connected it to the closest other c_alpha') | |
assert i not in dst | |
src = [i] * len(dst) | |
src_list.extend(src) | |
dst_list.extend(dst) | |
valid_dist = list(distances[i, dst]) | |
dist_list.extend(valid_dist) | |
valid_dist_np = distances[i, dst] | |
sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1)) | |
weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num) | |
assert weights[0].sum() > 1 - 1e-2 and weights[0].sum() < 1.01 | |
diff_vecs = alpha_c_coords[src, :] - alpha_c_coords[dst, :] # (neigh_num, 3) | |
mean_vec = weights.dot(diff_vecs) # (sigma_num, 3) | |
denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,) | |
mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,) | |
mean_norm_list.append(mean_vec_ratio_norm) | |
assert len(src_list) == len(dst_list) | |
assert len(dist_list) == len(dst_list) | |
g_prot.add_edges(src_list, dst_list) | |
g_prot.edata['e'] = local_coordinate_system_feature(prot_coords, alpha_c_coords, c_coords, n_coords, | |
dist_list, src_list, dst_list) | |
residue_representatives_loc_feat = torch.from_numpy(alpha_c_coords.astype(np.float32)) | |
g_prot.ndata['pos'] = residue_representatives_loc_feat | |
g_prot.ndata['mu_r_norm'] = torch.from_numpy(np.array(mean_norm_list).astype(np.float32)) | |
return g_prot | |
def get_prot_alpha_c_graph_ign(prot_coords, prot_features, prot_node_type, sec_features, cutoff=None): | |
g_prot = dgl.DGLGraph() | |
num_atoms_prot = len(prot_coords) # number of pocket atom_level | |
g_prot.add_nodes(num_atoms_prot) | |
g_prot.ndata['h'] = torch.from_numpy(prot_features) if isinstance(prot_features, np.ndarray) else prot_features | |
g_prot.ndata['node_type'] = prot_node_type[:num_atoms_prot] | |
dis_matrix = spatial.distance_matrix(prot_coords, prot_coords) | |
node_idx = np.where((dis_matrix < cutoff) & (dis_matrix != 0)) # no self-loop | |
src_ls = node_idx[0] | |
dst_ls = node_idx[1] | |
g_prot.add_edges(src_ls, dst_ls) | |
g_prot.ndata['pos'] = torch.tensor(prot_coords, dtype=torch.float) | |
prot_d = torch.tensor(dis_matrix[src_ls, dst_ls], dtype=torch.float).view(-1, 1) | |
g_prot.edata['distance'] = prot_d | |
# g_prot.edata['e'] = prot_d * 0.1 | |
# calculate the 3D info for g | |
D3_info_th = bond_feature(g_prot) | |
g_prot.edata['e'] = torch.cat([D3_info_th, prot_d * 0.1], dim=-1) | |
# g_prot.ndata.pop('pos') | |
assert not torch.any(torch.isnan(D3_info_th)) | |
return g_prot | |
def get_interact_graph_fc(lig_coords,prot_coords,cutoff=None): | |
# get fully connected graph | |
g_inter = dgl.DGLGraph() | |
num_atoms_lig = len(lig_coords) | |
num_atoms_prot = len(prot_coords) | |
g_inter.add_nodes(num_atoms_lig + num_atoms_prot) | |
dis_matrix = spatial.distance_matrix(lig_coords, prot_coords) | |
node_idx = np.where(dis_matrix > 0) | |
src_ls = np.concatenate([node_idx[0], node_idx[1] + num_atoms_lig]) | |
dst_ls = np.concatenate([node_idx[1] + num_atoms_lig, node_idx[0]]) | |
g_inter.add_edges(src_ls, dst_ls) | |
# 'd', distance between ligand atom_level and pocket atom_level | |
inter_dis = np.concatenate([dis_matrix[node_idx[0], node_idx[1]], dis_matrix[node_idx[0], node_idx[1]]]) | |
inter_d = torch.tensor(inter_dis, dtype=torch.float).view(-1, 1) | |
g_inter.edata['e'] = inter_d # if add_self_loop=ture, need to modify here6+' | |
g_inter.ndata['pos'] = torch.cat([torch.tensor(lig_coords, dtype=torch.float),torch.tensor(prot_coords, dtype=torch.float)],dim=0) | |
return g_inter | |
def get_interact_multi_pose_graph_knn(lig_multi_coords, prot_coords, max_neighbor=None, min_neighbor=None, cutoff=None): | |
multi_graphs = [] | |
for lig_coords in lig_multi_coords: | |
multi_graphs.append(get_interact_graph_knn(lig_coords,prot_coords,max_neighbor,min_neighbor,cutoff)) | |
return multi_graphs | |
def get_interact_graph_knn(lig_coords,prot_coords,max_neighbor=None,min_neighbor=None,cutoff=None): | |
g_inter = dgl.DGLGraph() | |
num_atoms_lig = len(lig_coords) | |
num_atoms_prot = len(prot_coords) | |
g_inter.add_nodes(num_atoms_lig + num_atoms_prot) | |
dis_matrix = spatial.distance_matrix(lig_coords, prot_coords) | |
src_list, dst_list, dis_list = [], [], [] | |
for i in range(num_atoms_lig): | |
dst = np.where(dis_matrix[i, :] < cutoff)[0] | |
if max_neighbor != None and len(dst) > max_neighbor: | |
dst = list(np.argsort(dis_matrix[i, :]))[:max_neighbor] | |
if min_neighbor != None and len(dst) == 0: | |
dst = list(np.argsort(dis_matrix[i, :]))[:min_neighbor] | |
src = [i] * len(dst) | |
src_list.extend(src) | |
dst_list.extend([x + num_atoms_lig for x in dst]) | |
dis_list.extend(list(dis_matrix[i,dst])) | |
for i in range(num_atoms_prot): | |
dst = list(np.where(dis_matrix[:, i] < cutoff)[0]) | |
if max_neighbor != None and len(dst) > max_neighbor: | |
dst = list(np.argsort(dis_matrix[:, i]))[:max_neighbor] | |
if min_neighbor != None and len(dst) == 0: | |
dst = list(np.argsort(dis_matrix[:, i]))[:min_neighbor] # choose second because first is i itself | |
src = [i] * len(dst) | |
src_list.extend([x + num_atoms_lig for x in src]) | |
dst_list.extend(dst) | |
dis_list.extend(list(dis_matrix[dst, i])) | |
src_ls = np.array(src_list) | |
dst_ls = np.array(dst_list) | |
g_inter.add_edges(src_ls, dst_ls) | |
# 'd', distance between ligand atom_level and pocket atom_level | |
inter_dis = np.array(dis_list) | |
inter_d = torch.tensor(inter_dis, dtype=torch.float).view(-1, 1) | |
# squared_distance = inter_d ** 2 | |
# all_sigmas_dist = [1.5 ** x for x in range(15)] | |
# prot_square_distance_scale = 10.0 | |
# x_rel_mag = torch.cat([torch.exp(-(squared_distance / prot_square_distance_scale) / sigma) for sigma in | |
# all_sigmas_dist], dim=-1) | |
# g_inter.edata['e'] = x_rel_mag | |
g_inter.edata['d'] = inter_d | |
return g_inter | |
def get_interact_graph_knn_v2(lig_coords,prot_coords,max_neighbor=None,min_neighbor=None,cutoff=None,): | |
g_inter = dgl.DGLGraph() | |
num_atoms_lig = len(lig_coords) | |
num_atoms_prot = len(prot_coords) | |
g_inter.add_nodes(num_atoms_lig + num_atoms_prot) | |
dis_matrix = spatial.distance_matrix(lig_coords, prot_coords) | |
src_list, dst_list, dis_list = [], [], [] | |
for i in range(num_atoms_lig): | |
dst = np.where(dis_matrix[i, :] < cutoff)[0] | |
if max_neighbor != None and len(dst) > max_neighbor: | |
dst = list(np.argsort(dis_matrix[i, :]))[:max_neighbor] | |
if min_neighbor != None and len(dst) == 0: | |
dst = list(np.argsort(dis_matrix[i, :]))[:min_neighbor] | |
src = [i] * len(dst) | |
src_list.extend(src) | |
dst_list.extend([x + num_atoms_lig for x in dst]) | |
dis_list.extend(list(dis_matrix[i,dst])) | |
for i in range(num_atoms_prot): | |
dst = list(np.where(dis_matrix[:, i] < cutoff)[0]) | |
if max_neighbor != None and len(dst) > max_neighbor: | |
dst = list(np.argsort(dis_matrix[:, i]))[:max_neighbor] | |
if min_neighbor != None and len(dst) == 0: | |
dst = list(np.argsort(dis_matrix[:, i]))[:min_neighbor] # choose second because first is i itself | |
src = [i] * len(dst) | |
src_list.extend([x + num_atoms_lig for x in src]) | |
dst_list.extend(dst) | |
dis_list.extend(list(dis_matrix[dst, i])) | |
src_ls = np.array(src_list) | |
dst_ls = np.array(dst_list) | |
g_inter.add_edges(src_ls, dst_ls) | |
# 'd', distance between ligand atom_level and pocket atom_level | |
inter_dis = np.array(dis_list) | |
inter_d = torch.tensor(inter_dis, dtype=torch.float).view(-1, 1) | |
squared_distance = inter_d ** 2 | |
all_sigmas_dist = [1.5 ** x for x in range(15)] | |
prot_square_distance_scale = 10.0 | |
x_rel_mag = torch.cat([torch.exp(-(squared_distance / prot_square_distance_scale) / sigma) for sigma in | |
all_sigmas_dist], dim=-1) | |
g_inter.edata['e'] = x_rel_mag | |
g_inter.edata['d'] = inter_d | |
return g_inter | |
def pack_graph_and_labels(lig_graphs, prot_graphs, inter_graphs, labels): | |
return lig_graphs, prot_graphs, inter_graphs, labels | |
SSE_color = {'H':'R','G':'R','I':'R','E':'B','T':'G','S':'W','B':'W',' ':'W'} | |
def ExcuteDSSP(dataset_path,name,dssp = '/data/jiaxianyan/anaconda3/bin/mkdssp'): | |
InPath = os.path.join(dataset_path, name, f'{name}_protein_chains.pdb') | |
OutPath = os.path.join(dataset_path, name, f'{name}_protein_chains.pdb.dssp') | |
SSPath = os.path.join(dataset_path, name, f'{name}_protein_chains.pdb.SS') | |
cmd = '{} -i {} -o {}'.format(dssp,InPath,OutPath) | |
os.system(cmd) | |
def ListToStr(res): | |
# res is list | |
Res = '' | |
for r in res: | |
Res += r[0]*r[1] | |
return Res | |
def StrToList(res): | |
# res is string | |
before, length, Res = 'W',0,[] | |
for s in res: | |
if s != before: | |
Res.append((before, length)) | |
length = 1 | |
before = s | |
else: | |
length += 1 | |
Res.append((before, length)) | |
return Res | |
def AllocateIndex(res,BeginIndex=0): | |
# res is list | |
str_res = ListToStr(res) | |
ClusterIndex,Res = BeginIndex,[] | |
for index in range(len(str_res)): | |
if index!=0 and str_res[index]!=str_res[index-1]: | |
ClusterIndex += 1 | |
Res.append(str(ClusterIndex)+','+str_res[index]) | |
return Res,ClusterIndex+1 | |
def SmoothHiearachical(res,threshod=0): | |
if threshod==0: | |
return res | |
# res is list | |
Res = [0] * len(res) | |
if len(res)==1: | |
return res | |
Res[0] = (res[1][0], res[0][1]) | |
Res[-1] = (res[-2][0], res[-1][1]) | |
for index in range(1,len(res)-1): | |
SSE,length = res[index] | |
if length<=threshod: | |
Res[index] = (Res[index-1][0],length) | |
else: | |
Res[index] = res[index] | |
return Res | |
def ExtractHiearachical(FilePath): | |
with open(FilePath,'r') as f: | |
lines = f.read().strip().split('\n')[28:] | |
borders, start = [], 0 | |
for index,line in enumerate(lines): | |
if '!' in line: | |
borders.append((start,index)) | |
start = index + 1 | |
borders.append((start,len(lines)+1)) | |
SSEIndex,SupplyIndex = 16,18 | |
ResSSE = '' | |
ResSmoothSSE = '' | |
NumClusters = 0 | |
ClusterIndexs = [] | |
for border in borders: | |
SSE,Color = '','' | |
OneChain = lines[border[0]:border[1]] | |
for index,line in enumerate(OneChain): | |
if line[SupplyIndex]=='>' and line[SSEIndex] == ' ': | |
SSE += OneChain[index + 1][SSEIndex] | |
Color += SSE_color[OneChain[index + 1][SSEIndex]] | |
elif line[SupplyIndex]=='<' and line[SSEIndex] ==' ': | |
SSE += OneChain[index - 1][SSEIndex] | |
Color += SSE_color[OneChain[index - 1][SSEIndex]] | |
else: | |
SSE += line[SSEIndex] | |
Color += SSE_color[line[SSEIndex]] | |
ResSSE = ResSSE + SSE | |
res = StrToList(SSE) | |
Smoothres = SmoothHiearachical(res) | |
ClusterIndex,NumCluster = AllocateIndex(Smoothres,NumClusters) | |
ResSmoothSSE = ResSmoothSSE + ListToStr(Smoothres) | |
NumClusters = NumCluster | |
ClusterIndexs.extend(ClusterIndex) | |
def prot_alpha_c_featurizer(Structure): | |
Coords = Structure.getCoords() | |
ResNames = Structure.getResnames() | |
ResIndex = [ResDict.get(ResName,UNKOWN_RES) for ResName in ResNames] | |
ProtFeature = torch.tensor(np.eye(UNKOWN_RES+1)[ResIndex]) | |
return Coords, ProtFeature | |
def prot_residue_type(Structure): | |
ResNames = Structure.getResnames() | |
ResIndex = [ResDict.get(ResName,UNKOWN_RES) for ResName in ResNames] | |
return torch.tensor(ResIndex,dtype=torch.int64) | |
def break_molecules_by_rotatable_bonds(mol): | |
patt = Chem.MolFromSmarts('[!$([NH]!@C(=O))&!D1&!$(*#*)]-&!@[!$([NH]!@C(=O))&!D1&!$(*#*)]') | |
rotatable_bonds = mol.GetSubstructMatches(patt) | |
bs = [mol.GetBondBetweenAtoms(x, y).GetIdx() for x, y in rotatable_bonds] | |
mol_broken = Chem.rdmolops.FragmentOnBonds(mol, bs) | |
frags = Chem.rdmolops.GetMolFrags(mol_broken) | |
return frags,rotatable_bonds | |
def get_molecule_mass_center(mol): | |
from rdkit.Chem import Descriptors | |
numatoms = mol.GetNumAtoms() | |
pos = mol.GetConformer().GetPositions() | |
atoms = [atom for atom in mol.GetAtoms()] | |
# get center of mass | |
mass = Descriptors.MolWt(mol) | |
center_of_mass = np.array(np.sum(atoms[i].GetMass() * pos[i] for i in range(numatoms))) / mass | |
return center_of_mass | |
def get_frag_geo_center(mol, frag): | |
pos = mol.GetConformer().GetPositions() | |
center_of_geo = np.array(np.sum(pos[i] for i in frag))/len(frag) | |
return center_of_geo | |
def get_center_frag(mol, frags): | |
mass_center = get_molecule_mass_center(mol) | |
frags_geo_center = np.array([get_frag_geo_center(mol, frag) for frag in frags]) | |
center_distance = (frags_geo_center - mass_center) ** 2 | |
return np.min(center_distance) | |
def get_frag_neighbors(mol, frags, rotatable_bonds): | |
numatoms = mol.GetNumAtoms() | |
frag2neighbour,frag2atom,neighbour2bond,atom2bond,bond2frag = {},{},{},{},{} | |
for i in range(len(numatoms)): | |
atom2bond[i] = [] | |
for b_index,rb in enumerate(rotatable_bonds): | |
x, y = rb | |
atom2bond[x].append(b_index) | |
atom2bond[y].append(b_index) | |
bond2frag[b_index] = [] | |
for f_index,frag in enumerate(frags): | |
frag2neighbour[f_index] = [] | |
frag2atom[f_index] = [] | |
neighbour2bond[f_index] = [] | |
for atom in frag: | |
if len(atom2bond[atom]) != 0: | |
for bond in atom2bond[atom]: | |
bond2frag[bond].append(frag) | |
x, y = rotatable_bonds[bond] | |
if x == atom: | |
frag2atom[f_index].append(y) | |
else: | |
frag2atom[f_index].append(x) | |
for bond in bond2frag.keys(): | |
x,y = bond2frag[bond] | |
frag2neighbour[x].append(y) | |
neighbour2bond[x].append(bond) | |
frag2neighbour[y].append(x) | |
neighbour2bond[y].append(bond) | |
return frag2neighbour,frag2atom,neighbour2bond | |
def bfs(frag_neighbors,neighbour2bond,center_frag_index): | |
bfs_rank = [center_frag_index] | |
bfs_bond_rank = [] | |
visit = [0 for i in range(len(frag_neighbors))] | |
left,right = 0,1 | |
while left < right: | |
cur_frag_index = bfs_rank[left] | |
for index,neighbor_frag_index in enumerate(frag_neighbors[cur_frag_index]): | |
if not visit[neighbor_frag_index]: | |
visit[neighbor_frag_index] = 1 | |
bfs_rank.append(neighbor_frag_index) | |
bfs_bond_rank.append(neighbour2bond[cur_frag_index][index]) | |
right += 1 | |
left += 1 | |
return bfs_rank | |
def get_bfs_generate_rank(mol, center_frag_index, frags, rotatable_bonds): | |
frag_neighbors,frag2atom,neighbour2bond = get_frag_neighbors(mol, frags, rotatable_bonds) | |
bfs_rank = bfs(frag_neighbors,neighbour2bond,center_frag_index) | |
assert bfs_rank==len(frags) | |
return bfs_rank | |
def get_molecule_tree_by_rotatable_bonds(molecule_path): | |
mol = Chem.MolFromMol2File(molecule_path, sanitize=False, removeHs=False) | |
frags, rotatable_bonds = break_molecules_by_rotatable_bonds(mol) | |
center_frag_index = get_center_frag(mol, frags) | |
moltree = get_bfs_generate_rank(mol, center_frag_index, frags, rotatable_bonds) | |
return moltree | |
def get_rdkit_coords(mol, sanitize=True, remove_hs=True, maxIters=200): | |
ps = AllChem.ETKDGv2() | |
id = AllChem.EmbedMolecule(mol, ps) | |
if id == -1: | |
print('rdkit coords could not be generated without using random coords. using random coords now.') | |
ps.useRandomCoords = True | |
AllChem.EmbedMolecule(mol, ps) | |
AllChem.MMFFOptimizeMolecule(mol, maxIters=maxIters, confId=0) | |
else: | |
AllChem.MMFFOptimizeMolecule(mol, maxIters=maxIters, confId=0) | |
if remove_hs: | |
mol = Chem.RemoveHs(mol, sanitize=sanitize) | |
conf = mol.GetConformer() | |
lig_coords = conf.GetPositions() | |
# return torch.tensor(lig_coords, dtype=torch.float32) | |
return lig_coords | |
def read_rdmol_v2(dataset_path, name): | |
lig_path_mol2 = os.path.join(dataset_path, name, f'{name}_ligand.mol2') | |
lig_path_sdf = os.path.join(dataset_path, name, f'{name}_ligand.sdf') | |
m_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True) | |
if m_lig == None: # read mol2 file if sdf file cannot be sanitized | |
m_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True) | |
return m_lig | |
def read_rdmol(molecule_file, sanitize=False, calc_charges=False, remove_hs=False): | |
"""Load a molecule from a file of format ``.mol2`` or ``.sdf`` or ``.pdbqt`` or ``.pdb``. | |
Parameters | |
---------- | |
molecule_file : str | |
Path to file for storing a molecule, which can be of format ``.mol2`` or ``.sdf`` | |
or ``.pdbqt`` or ``.pdb``. | |
sanitize : bool | |
Whether sanitization is performed in initializing RDKit molecule instances. See | |
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization. | |
Default to False. | |
calc_charges : bool | |
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce | |
``sanitize`` to be True. Default to False. | |
remove_hs : bool | |
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite | |
slow for large molecules. Default to False. | |
use_conformation : bool | |
Whether we need to extract molecular conformation from proteins and ligands. | |
Default to True. | |
Returns | |
------- | |
mol : rdkit.Chem.rdchem.Mol | |
RDKit molecule instance for the loaded molecule. | |
coordinates : np.ndarray of shape (N, 3) or None | |
The 3D coordinates of atoms in the molecule. N for the number of atoms in | |
the molecule. None will be returned if ``use_conformation`` is False or | |
we failed to get conformation information. | |
""" | |
if molecule_file.endswith('.mol2'): | |
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False) | |
elif molecule_file.endswith('.sdf'): | |
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False) | |
mol = supplier[0] | |
elif molecule_file.endswith('.pdbqt'): | |
with open(molecule_file) as file: | |
pdbqt_data = file.readlines() | |
pdb_block = '' | |
for line in pdbqt_data: | |
pdb_block += '{}\n'.format(line[:66]) | |
mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False) | |
elif molecule_file.endswith('.pdb'): | |
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False) | |
else: | |
return ValueError('Expect the format of the molecule_file to be ' | |
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file)) | |
try: | |
if sanitize or calc_charges: | |
Chem.SanitizeMol(mol) | |
if calc_charges: | |
# Compute Gasteiger charges on the molecule. | |
try: | |
AllChem.ComputeGasteigerCharges(mol) | |
except: | |
warnings.warn('Unable to compute charges for the molecule.') | |
if remove_hs: | |
mol = Chem.RemoveHs(mol, sanitize=sanitize) | |
except: | |
return None | |
return mol | |
def random_rotation_translation(translation_distance=5.0): | |
rotation = Rotation.random(num=1) | |
rotation_matrix = rotation.as_matrix().squeeze() | |
t = np.random.randn(1, 3) | |
t = t / np.sqrt( np.sum(t * t)) | |
length = np.random.uniform(low=0, high=translation_distance) | |
t = t * length | |
return torch.from_numpy(rotation_matrix.astype(np.float32)), torch.from_numpy(t.astype(np.float32)) |