mbp / UltraFlow /commons /process_mols.py
jiaxianustc's picture
test
3ad8be1
import os
import math
from openbabel import pybel
from openbabel import openbabel
import dgl
import pickle
import numpy as np
import torch
import scipy.spatial as spatial
from functools import partial
from prody import *
from rdkit import Chem as Chem
from rdkit.Chem.rdPartialCharges import ComputeGasteigerCharges
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import AllChem
from Bio.PDB import get_surface, PDBParser
from Bio.PDB.PDBExceptions import PDBConstructionWarning
from scipy.special import softmax
from scipy.spatial.transform import Rotation
import pandas as pd
ob_log_handler = pybel.ob.OBMessageHandler()
ob_log_handler.SetOutputLevel(0)
pybel.ob.obErrorLog.StopLogging()
BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())}
graph_type_filename = {'atom_pocket':'valid_pocket.pdb',
'atom_complete':'valid_chains.pdb'}
ResDict = {'ALA':0,'ARG':1,'ASN':2,'ASP':3,'CYS':4,
'GLN':5,'GLU':6,'GLY':7,'HIS':8,'ILE':9,
'LEU':10,'LYS':11,'MET':12,'PHE':13,'PRO':14,
'SER':15,'THR':16,'TRP':17,'TYR':18,'VAL':19}
SSEDict = {'H':0,'B':1,'E':2,'G':3,'I':4,'T':5,'S':6,' ':7}
SSEType,UNKOWN_RES = 8,20
allowable_features = {
'possible_atomic_num_list': list(range(1, 119)) + ['misc'],
'possible_chirality_list': [
'CHI_UNSPECIFIED',
'CHI_TETRAHEDRAL_CW',
'CHI_TETRAHEDRAL_CCW',
'CHI_OTHER'
],
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
'possible_numring_list': [0, 1, 2, 3, 4, 5, 6, 'misc'],
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6, 'misc'],
'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
'possible_hybridization_list': [
'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc'
],
'possible_is_aromatic_list': [False, True],
'possible_is_in_ring3_list': [False, True],
'possible_is_in_ring4_list': [False, True],
'possible_is_in_ring5_list': [False, True],
'possible_is_in_ring6_list': [False, True],
'possible_is_in_ring7_list': [False, True],
'possible_is_in_ring8_list': [False, True],
'possible_amino_acids': ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET',
'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'HIP', 'HIE', 'TPO', 'HID', 'LEV', 'MEU',
'PTR', 'GLV', 'CYT', 'SEP', 'HIZ', 'CYM', 'GLM', 'ASQ', 'TYS', 'CYX', 'GLZ', 'misc'],
'possible_atom_type_2': ['C*', 'CA', 'CB', 'CD', 'CE', 'CG', 'CH', 'CZ', 'N*', 'ND', 'NE', 'NH', 'NZ', 'O*', 'OD',
'OE', 'OG', 'OH', 'OX', 'S*', 'SD', 'SG', 'misc'],
'possible_atom_type_3': ['C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2',
'CZ', 'CZ2', 'CZ3', 'N', 'ND1', 'ND2', 'NE', 'NE1', 'NE2', 'NH1', 'NH2', 'NZ', 'O', 'OD1',
'OD2', 'OE1', 'OE2', 'OG', 'OG1', 'OH', 'OXT', 'SD', 'SG', 'misc'],
}
lig_feature_dims = (list(map(len, [
allowable_features['possible_atomic_num_list'],
allowable_features['possible_chirality_list'],
allowable_features['possible_degree_list'],
allowable_features['possible_formal_charge_list'],
allowable_features['possible_implicit_valence_list'],
allowable_features['possible_numH_list'],
allowable_features['possible_number_radical_e_list'],
allowable_features['possible_hybridization_list'],
allowable_features['possible_is_aromatic_list'],
allowable_features['possible_numring_list'],
allowable_features['possible_is_in_ring3_list'],
allowable_features['possible_is_in_ring4_list'],
allowable_features['possible_is_in_ring5_list'],
allowable_features['possible_is_in_ring6_list'],
allowable_features['possible_is_in_ring7_list'],
allowable_features['possible_is_in_ring8_list'],
])), 1) # number of scalar features
rec_atom_feature_dims = (list(map(len, [
allowable_features['possible_amino_acids'],
allowable_features['possible_atomic_num_list'],
allowable_features['possible_atom_type_2'],
allowable_features['possible_atom_type_3'],
])), 2)
rec_residue_feature_dims = (list(map(len, [
allowable_features['possible_amino_acids']
])), 2)
dbcg_prot_residue_feature_dims = [[21],0]
def safe_index(l, e):
"""
Return index of element e in list l. If e is not present, return the last index
"""
try:
return l.index(e)
except:
return len(l) - 1
def lig_atom_featurizer_rdmol(mol):
ComputeGasteigerCharges(mol) # they are Nan for 93 molecules in all of PDBbind. We put a 0 in that case.
ringinfo = mol.GetRingInfo()
atom_features_list = []
for idx, atom in enumerate(mol.GetAtoms()):
g_charge = atom.GetDoubleProp('_GasteigerCharge')
atom_features_list.append([
safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()),
allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())),
safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()),
safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()),
safe_index(allowable_features['possible_implicit_valence_list'], atom.GetImplicitValence()),
safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()),
safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()),
safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())),
allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()),
safe_index(allowable_features['possible_numring_list'], ringinfo.NumAtomRings(idx)),
allowable_features['possible_is_in_ring3_list'].index(ringinfo.IsAtomInRingOfSize(idx, 3)),
allowable_features['possible_is_in_ring4_list'].index(ringinfo.IsAtomInRingOfSize(idx, 4)),
allowable_features['possible_is_in_ring5_list'].index(ringinfo.IsAtomInRingOfSize(idx, 5)),
allowable_features['possible_is_in_ring6_list'].index(ringinfo.IsAtomInRingOfSize(idx, 6)),
allowable_features['possible_is_in_ring7_list'].index(ringinfo.IsAtomInRingOfSize(idx, 7)),
allowable_features['possible_is_in_ring8_list'].index(ringinfo.IsAtomInRingOfSize(idx, 8)),
g_charge if not np.isnan(g_charge) and not np.isinf(g_charge) else 0.
])
return torch.tensor(atom_features_list)
def vina_gaussain_1(d):
return torch.exp(- ((d / 0.5) ** 2))
def vina_gaussain_2(d):
return torch.exp(- ( ((d - 3) / 2.0) ** 2))
def vina_repulsion(d):
if d >= 0:
return torch.tensor(0.0)
return torch.tensor(d ** 2)
def hydrophobic(d):
if d < 0.5:
return torch.tensor(1.0)
if d <= 1.5:
return torch.tensor(-d + 1.5)
return torch.tensor(0.0)
def hydrogen_bonding(d):
if d < -0.7:
return torch.tensor(1.0)
if d <= 0.0:
return torch.tensor(-(10/7) * d)
return torch.tensor(0.0)
def CusBondFeaturizer(bond):
return [int(bond.GetBondOrder()), int(bond.IsAromatic()), int(bond.IsInRing())]
def CusBondFeaturizer_new(bond):
return [int(int(bond.GetBondOrder())==1), int(int(bond.GetBondOrder())==2), int(int(bond.GetBondOrder())==3), int(bond.IsAromatic()), int(bond.IsInRing())]
class Featurizer():
"""Calcaulates atomic features for molecules. Features can encode atom type,
native pybel properties or any property defined with SMARTS patterns
Attributes
----------
FEATURE_NAMES: list of strings
Labels for features (in the same order as features)
NUM_ATOM_CLASSES: int
Number of atom codes
ATOM_CODES: dict
Dictionary mapping atomic numbers to codes
NAMED_PROPS: list of string
Names of atomic properties to retrieve from pybel.Atom object
CALLABLES: list of callables
Callables used to calculcate custom atomic properties
SMARTS: list of SMARTS strings
SMARTS patterns defining additional atomic properties
"""
def __init__(self, atom_codes=None, atom_labels=None,
named_properties=None, save_molecule_codes=True,
custom_properties=None, smarts_properties=None,
smarts_labels=None):
"""Creates Featurizer with specified types of features. Elements of a
feature vector will be in a following order: atom type encoding
(defined by atom_codes), Pybel atomic properties (defined by
named_properties), molecule code (if present), custom atomic properties
(defined `custom_properties`), and additional properties defined with
SMARTS (defined with `smarts_properties`).
Parameters
----------
atom_codes: dict, optional
Dictionary mapping atomic numbers to codes. It will be used for
one-hot encoging therefore if n different types are used, codes
shpuld be from 0 to n-1. Multiple atoms can have the same code,
e.g. you can use {6: 0, 7: 1, 8: 1} to encode carbons with [1, 0]
and nitrogens and oxygens with [0, 1] vectors. If not provided,
default encoding is used.
atom_labels: list of strings, optional
Labels for atoms codes. It should have the same length as the
number of used codes, e.g. for `atom_codes={6: 0, 7: 1, 8: 1}` you
should provide something like ['C', 'O or N']. If not specified
labels 'atom0', 'atom1' etc are used. If `atom_codes` is not
specified this argument is ignored.
named_properties: list of strings, optional
Names of atomic properties to retrieve from pybel.Atom object. If
not specified ['hyb', 'heavyvalence', 'heterovalence',
'partialcharge'] is used.
save_molecule_codes: bool, optional (default True)
If set to True, there will be an additional feature to save
molecule code. It is usefeul when saving molecular complex in a
single array.
custom_properties: list of callables, optional
Custom functions to calculate atomic properties. Each element of
this list should be a callable that takes pybel.Atom object and
returns a float. If callable has `__name__` property it is used as
feature label. Otherwise labels 'func<i>' etc are used, where i is
the index in `custom_properties` list.
smarts_properties: list of strings, optional
Additional atomic properties defined with SMARTS patterns. These
patterns should match a single atom. If not specified, deafult
patterns are used.
smarts_labels: list of strings, optional
Labels for properties defined with SMARTS. Should have the same
length as `smarts_properties`. If not specified labels 'smarts0',
'smarts1' etc are used. If `smarts_properties` is not specified
this argument is ignored.
"""
# Remember namse of all features in the correct order
self.FEATURE_NAMES = []
if atom_codes is not None:
if not isinstance(atom_codes, dict):
raise TypeError('Atom codes should be dict, got %s instead'
% type(atom_codes))
codes = set(atom_codes.values())
for i in range(len(codes)):
if i not in codes:
raise ValueError('Incorrect atom code %s' % i)
self.NUM_ATOM_CLASSES = len(codes)
self.ATOM_CODES = atom_codes
if atom_labels is not None:
if len(atom_labels) != self.NUM_ATOM_CLASSES:
raise ValueError('Incorrect number of atom labels: '
'%s instead of %s'
% (len(atom_labels), self.NUM_ATOM_CLASSES))
else:
atom_labels = ['atom%s' % i for i in range(self.NUM_ATOM_CLASSES)]
self.FEATURE_NAMES += atom_labels
else:
self.ATOM_CODES = {}
metals = ([3, 4, 11, 12, 13] + list(range(19, 32))
+ list(range(37, 51)) + list(range(55, 84))
+ list(range(87, 104)))
# List of tuples (atomic_num, class_name) with atom types to encode.
atom_classes = [
(5, 'B'),
(6, 'C'),
(7, 'N'),
(8, 'O'),
(15, 'P'),
(16, 'S'),
(34, 'Se'),
([9, 17, 35, 53], 'halogen'),
(metals, 'metal')
]
for code, (atom, name) in enumerate(atom_classes):
if type(atom) is list:
#
for a in atom:
self.ATOM_CODES[a] = code
else:
self.ATOM_CODES[atom] = code
self.FEATURE_NAMES.append(name)
self.NUM_ATOM_CLASSES = len(atom_classes)
if named_properties is not None:
if not isinstance(named_properties, (list, tuple, np.ndarray)):
raise TypeError('named_properties must be a list')
allowed_props = [prop for prop in dir(pybel.Atom)
if not prop.startswith('__')]
for prop_id, prop in enumerate(named_properties):
if prop not in allowed_props:
raise ValueError(
'named_properties must be in pybel.Atom attributes,'
' %s was given at position %s' % (prop_id, prop)
)
self.NAMED_PROPS = named_properties
else:
# pybel.Atom properties to save
self.NAMED_PROPS = ['hyb', 'heavydegree', 'heterodegree',
'partialcharge']
self.FEATURE_NAMES += self.NAMED_PROPS
if not isinstance(save_molecule_codes, bool):
raise TypeError('save_molecule_codes should be bool, got %s '
'instead' % type(save_molecule_codes))
self.save_molecule_codes = save_molecule_codes
if save_molecule_codes:
# Remember if an atom belongs to the ligand or to the protein
self.FEATURE_NAMES.append('molcode')
self.CALLABLES = []
if custom_properties is not None:
for i, func in enumerate(custom_properties):
if not callable(func):
raise TypeError('custom_properties should be list of'
' callables, got %s instead' % type(func))
name = getattr(func, '__name__', '')
if name == '':
name = 'func%s' % i
self.CALLABLES.append(func)
self.FEATURE_NAMES.append(name)
if smarts_properties is None:
# SMARTS definition for other properties
self.SMARTS = [
'[#6+0!$(*~[#7,#8,F]),SH0+0v2,s+0,S^3,Cl+0,Br+0,I+0]',
'[a]',
'[!$([#1,#6,F,Cl,Br,I,o,s,nX3,#7v5,#15v5,#16v4,#16v6,*+1,*+2,*+3])]',
'[!$([#6,H0,-,-2,-3]),$([!H0;#7,#8,#9])]',
'[r]'
]
smarts_labels = ['hydrophobic', 'aromatic', 'acceptor', 'donor',
'ring']
elif not isinstance(smarts_properties, (list, tuple, np.ndarray)):
raise TypeError('smarts_properties must be a list')
else:
self.SMARTS = smarts_properties
if smarts_labels is not None:
if len(smarts_labels) != len(self.SMARTS):
raise ValueError('Incorrect number of SMARTS labels: %s'
' instead of %s'
% (len(smarts_labels), len(self.SMARTS)))
else:
smarts_labels = ['smarts%s' % i for i in range(len(self.SMARTS))]
# Compile patterns
self.compile_smarts()
self.FEATURE_NAMES += smarts_labels
def compile_smarts(self):
self.__PATTERNS = []
for smarts in self.SMARTS:
self.__PATTERNS.append(pybel.Smarts(smarts))
def encode_num(self, atomic_num):
"""Encode atom type with a binary vector. If atom type is not included in
the `atom_classes`, its encoding is an all-zeros vector.
Parameters
----------
atomic_num: int
Atomic number
Returns
-------
encoding: np.ndarray
Binary vector encoding atom type (one-hot or null).
"""
if not isinstance(atomic_num, int):
raise TypeError('Atomic number must be int, %s was given'
% type(atomic_num))
encoding = np.zeros(self.NUM_ATOM_CLASSES)
try:
encoding[self.ATOM_CODES[atomic_num]] = 1.0
except:
pass
return encoding
def find_smarts(self, molecule):
"""Find atoms that match SMARTS patterns.
Parameters
----------
molecule: pybel.Molecule
Returns
-------
features: np.ndarray
NxM binary array, where N is the number of atoms in the `molecule`
and M is the number of patterns. `features[i, j]` == 1.0 if i'th
atom has j'th property
"""
if not isinstance(molecule, pybel.Molecule):
raise TypeError('molecule must be pybel.Molecule object, %s was given'
% type(molecule))
features = np.zeros((len(molecule.atoms), len(self.__PATTERNS)))
for (pattern_id, pattern) in enumerate(self.__PATTERNS):
atoms_with_prop = np.array(list(*zip(*pattern.findall(molecule))),
dtype=int) - 1
features[atoms_with_prop, pattern_id] = 1.0
return features
def get_features(self, molecule, molcode=None):
"""Get coordinates and features for all heavy atoms in the molecule.
Parameters
----------
molecule: pybel.Molecule
molcode: float, optional
Molecule type. You can use it to encode whether an atom belongs to
the ligand (1.0) or to the protein (-1.0) etc.
Returns
-------
coords: np.ndarray, shape = (N, 3)
Coordinates of all heavy atoms in the `molecule`.
features: np.ndarray, shape = (N, F)
Features of all heavy atoms in the `molecule`: atom type
(one-hot encoding), pybel.Atom attributes, type of a molecule
(e.g protein/ligand distinction), and other properties defined with
SMARTS patterns
"""
if not isinstance(molecule, pybel.Molecule):
raise TypeError('molecule must be pybel.Molecule object,'
' %s was given' % type(molecule))
if molcode is None:
if self.save_molecule_codes is True:
raise ValueError('save_molecule_codes is set to True,'
' you must specify code for the molecule')
elif not isinstance(molcode, (float, int)):
raise TypeError('motlype must be float, %s was given'
% type(molcode))
coords = []
features = []
heavy_atoms = []
for i, atom in enumerate(molecule):
# ignore hydrogens and dummy atoms (they have atomicnum set to 0)
if atom.atomicnum > 1:
heavy_atoms.append(i)
coords.append(atom.coords)
features.append(np.concatenate((
self.encode_num(atom.atomicnum),
[atom.__getattribute__(prop) for prop in self.NAMED_PROPS],
[func(atom) for func in self.CALLABLES],
)))
coords = np.array(coords, dtype=np.float32)
features = np.array(features, dtype=np.float32)
if self.save_molecule_codes:
features = np.hstack((features,
molcode * np.ones((len(features), 1))))
features = np.hstack([features,
self.find_smarts(molecule)[heavy_atoms]])
if np.isnan(features).any():
raise RuntimeError('Got NaN when calculating features')
return coords, features
def get_features_CSAR(self, molecule, protein_idxs, ligand_idxs, molcode=None):
"""Get coordinates and features for all heavy atoms in the molecule.
Parameters
----------
molecule: pybel.Molecule
molcode: float, optional
Molecule type. You can use it to encode whether an atom belongs to
the ligand (1.0) or to the protein (-1.0) etc.
Returns
-------
coords: np.ndarray, shape = (N, 3)
Coordinates of all heavy atoms in the `molecule`.
features: np.ndarray, shape = (N, F)
Features of all heavy atoms in the `molecule`: atom type
(one-hot encoding), pybel.Atom attributes, type of a molecule
(e.g protein/ligand distinction), and other properties defined with
SMARTS patterns
"""
if not isinstance(molecule, pybel.Molecule):
raise TypeError('molecule must be pybel.Molecule object,'
' %s was given' % type(molecule))
if molcode is None:
if self.save_molecule_codes is True:
raise ValueError('save_molecule_codes is set to True,'
' you must specify code for the molecule')
elif not isinstance(molcode, (float, int)):
raise TypeError('motlype must be float, %s was given'
% type(molcode))
coords,protein_coords,ligand_coords = [],[],[]
features,protein_features,ligand_features = [],[],[]
heavy_atoms,protein_heavy_atoms,ligand_heavy_atoms = [],[],[]
for i, atom in enumerate(molecule):
# ignore hydrogens and dummy atoms (they have atomicnum set to 0)
index = i
if atom.atomicnum > 1:
heavy_atoms.append(i)
coords.append(atom.coords)
features.append(np.concatenate((
self.encode_num(atom.atomicnum),
[atom.__getattribute__(prop) for prop in self.NAMED_PROPS],
[func(atom) for func in self.CALLABLES],
)))
if index in protein_idxs:
protein_heavy_atoms.append(i)
protein_coords.append(atom.coords)
protein_features.append(np.concatenate((
self.encode_num(atom.atomicnum),
[atom.__getattribute__(prop) for prop in self.NAMED_PROPS],
[func(atom) for func in self.CALLABLES],
)))
elif index in ligand_idxs:
ligand_heavy_atoms.append(i)
ligand_coords.append(atom.coords)
ligand_features.append(np.concatenate((
self.encode_num(atom.atomicnum),
[atom.__getattribute__(prop) for prop in self.NAMED_PROPS],
[func(atom) for func in self.CALLABLES],
)))
coords,protein_coords,ligand_coords = np.array(coords, dtype=np.float32),\
np.array(protein_coords, dtype=np.float32),\
np.array(ligand_coords, dtype=np.float32)
features = np.array(features, dtype=np.float32)
if self.save_molecule_codes:
features = np.hstack((features,
molcode * np.ones((len(features), 1))))
features = np.hstack([features,
self.find_smarts(molecule)[heavy_atoms]])
protein_features = np.hstack([protein_features,
self.find_smarts(molecule)[protein_heavy_atoms]])
ligand_features = np.hstack([ligand_features,
self.find_smarts(molecule)[ligand_heavy_atoms]])
if np.isnan(features).any():
raise RuntimeError('Got NaN when calculating features')
return coords, features, protein_coords, protein_features, ligand_coords, ligand_features
def to_pickle(self, fname='featurizer.pkl'):
"""Save featurizer in a given file. Featurizer can be restored with
`from_pickle` method.
Parameters
----------
fname: str, optional
Path to file in which featurizer will be saved
"""
# patterns can't be pickled, we need to temporarily remove them
patterns = self.__PATTERNS[:]
del self.__PATTERNS
try:
with open(fname, 'wb') as f:
pickle.dump(self, f)
finally:
self.__PATTERNS = patterns[:]
@staticmethod
def from_pickle(fname):
"""Load pickled featurizer from a given file
Parameters
----------
fname: str, optional
Path to file with saved featurizer
Returns
-------
featurizer: Featurizer object
Loaded featurizer
"""
with open(fname, 'rb') as f:
featurizer = pickle.load(f)
featurizer.compile_smarts()
return featurizer
featurizer = Featurizer(save_molecule_codes=False)
def get_labels_from_names(lables_path,names):
with open(lables_path, 'rb') as f:
lines = f.read().decode().strip().split('\n')[6:]
res = {}
for line in lines:
temp = line.split()
name, score = temp[0], float(temp[3])
res[name] = score
labels = []
for name in names:
labels.append(res[name])
return labels
def get_labels_from_names_csar(lables_path,names):
with open(lables_path, 'rb') as f:
lines = f.read().decode().strip().split('\n')[1:]
res = {}
for line in lines:
temp = [x.strip() for x in line.split(',')]
name, score = temp[1], float(temp[2])
res[name] = score
labels = []
for name in names:
labels.append(res[name])
return labels
def get_lig_coords_ground_truth_from_names(lables_path,names):
return
def lig_atom_type_obmol(obmol):
AtomIndex = [atom.atomicnum for atom in obmol if atom.atomicnum > 1]
return torch.tensor(AtomIndex,dtype=torch.int64)
def lig_atom_type_rdmol(rdmol):
AtomIndex = [atom.GetAtomicNum() for atom in rdmol.GetAtoms()]
return torch.tensor(AtomIndex,dtype=torch.int64)
def get_bonded_edges_obmol(pocket):
edge_l = []
idx_map = [-1]*(len(pocket.atoms)+1)
idx_new = 0
for atom in pocket:
edges = []
a1_sym = atom.atomicnum
a1 = atom.idx
if a1_sym == 1:
continue
idx_map[a1] = idx_new
idx_new += 1
for natom in openbabel.OBAtomAtomIter(atom.OBAtom):
if natom.GetAtomicNum() == 1:
continue
a2 = natom.GetIdx()
bond = openbabel.OBAtom.GetBond(natom,atom.OBAtom)
bond_type = CusBondFeaturizer_new(bond)
edges.append((a1,a2,bond_type))
edge_l += edges
edge_l_new = []
for a1,a2,t in edge_l:
a1_, a2_ = idx_map[a1], idx_map[a2]
assert((a1_!=-1)&(a2_!=-1))
edge_l_new.append((a1_,a2_,t))
return edge_l_new
def get_bonded_edges_rdmol(rdmol):
row, col, edge_type = [], [], []
for bond in rdmol.GetBonds():
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
row += [start, end]
col += [end, start]
edge_type += 2 * [BOND_TYPES[bond.GetBondType()]]
return zip(row,col,edge_type)
def D3_info(a, b, c):
# 空间夹角
ab = b - a # 向量ab
ac = c - a # 向量ac
cosine_angle = np.dot(ab, ac) / (np.linalg.norm(ab) * np.linalg.norm(ac))
cosine_angle = cosine_angle if cosine_angle >= -1.0 else -1.0
angle = np.arccos(cosine_angle)
# 三角形面积
ab_ = np.sqrt(np.sum(ab ** 2))
ac_ = np.sqrt(np.sum(ac ** 2)) # 欧式距离
area = 0.5 * ab_ * ac_ * np.sin(angle)
return np.degrees(angle), area, ac_
def D3_info_cal(nodes_ls, g):
if len(nodes_ls) > 2:
Angles = []
Areas = []
Distances = []
for node_id in nodes_ls[2:]:
angle, area, distance = D3_info(g.ndata['pos'][nodes_ls[0]].numpy(), g.ndata['pos'][nodes_ls[1]].numpy(),
g.ndata['pos'][node_id].numpy())
Angles.append(angle)
Areas.append(area)
Distances.append(distance)
return [np.max(Angles) * 0.01, np.sum(Angles) * 0.01, np.mean(Angles) * 0.01, np.max(Areas), np.sum(Areas),
np.mean(Areas),
np.max(Distances) * 0.1, np.sum(Distances) * 0.1, np.mean(Distances) * 0.1]
else:
return [0, 0, 0, 0, 0, 0, 0, 0, 0]
def bond_feature(g):
src_nodes, dst_nodes = g.find_edges(range(g.number_of_edges()))
src_nodes, dst_nodes = src_nodes.tolist(), dst_nodes.tolist()
neighbors_ls = []
for i, src_node in enumerate(src_nodes):
tmp = [src_node, dst_nodes[i]] # the source node id and destination id of an edge
neighbors = g.predecessors(src_node).tolist()
neighbors.remove(dst_nodes[i])
tmp.extend(neighbors)
neighbors_ls.append(tmp)
D3_info_ls = list(map(partial(D3_info_cal, g=g), neighbors_ls))
D3_info_th = torch.tensor(D3_info_ls, dtype=torch.float)
# D3_info_th = torch.cat([g.edata['e'], D3_info_th], dim=-1)
return D3_info_th
def read_molecules_crossdock(lig_path, prot_path, ligcut, protcut, lig_type, prot_graph_type, dataset_path, chain_cut=5.0):
lig_path = os.path.join(dataset_path, lig_path)
prot_path = os.path.join(dataset_path, prot_path)
if lig_type=='openbabel':
m_lig = next(pybel.readfile('sdf', lig_path))
lig_coords, lig_features = featurizer.get_features(m_lig)
lig_edges = get_bonded_edges_obmol(m_lig) if ligcut is None else None
lig_node_type = lig_atom_type_obmol(m_lig)
elif lig_type=='rdkit':
m_lig = read_rdmol(lig_path, sanitize=True, remove_hs=True)
try:
assert m_lig is not None
except:
raise ValueError(f'sanitize error : {lig_path}')
conf = m_lig.GetConformer()
lig_coords, lig_features = conf.GetPositions(), lig_atom_featurizer_rdmol(m_lig)
lig_edges = get_bonded_edges_rdmol(m_lig) if ligcut is None else None
lig_node_type = lig_atom_type_rdmol(m_lig)
prot_complex = parsePDB(prot_path)
prot_structure_no_water = prot_complex.select('protein')
if chain_cut is not None:
prot_valid_chains = prot_structure_no_water.select(f'same chain as within {chain_cut} of ligand', ligand=lig_coords)
else:
prot_valid_chains = prot_structure_no_water
prot_valid_pocket = prot_structure_no_water.select('same residue as within 12 of ligand', ligand=lig_coords)
prot_alpha_c = prot_valid_chains.select('calpha')
prot_pocket_alpha_c = prot_valid_pocket.select('calpha')
alpha_c_sec_features = None
prot_pocket_alpha_c_sec_features = None
alpha_c_coords, c_coords, n_coords, complete_residues = [], [], [], [] # complete_residue means a residue has alpha_c,beta_c,and ,N
if prot_graph_type.startswith('atom'):
m_prot = prot_valid_pocket if prot_graph_type.endswith('pocket') else prot_valid_chains
sec_features = None
prot_coords, prot_features = featurizer.get_features(m_prot)
prot_edges = get_bonded_edges_obmol(m_prot) if protcut is None else None
prot_node_type = lig_atom_type_obmol(m_prot)
elif prot_graph_type.startswith('residue'):
alpha_c_sec_features = None
prot_pocket_alpha_c_sec_features = None
m_prot = prot_pocket_alpha_c if prot_graph_type.endswith('pocket') else prot_alpha_c
m_prot_complete = prot_valid_pocket if prot_graph_type.endswith('pocket') else prot_valid_chains
sec_features = prot_pocket_alpha_c_sec_features if prot_graph_type.endswith('pocket') else alpha_c_sec_features
prot_coords, prot_features = prot_alpha_c_featurizer(m_prot)
prot_node_type = prot_residue_type(m_prot)
prot_edges = None
hv = m_prot_complete.getHierView()
for chain in hv:
for i, residue in enumerate(chain):
alpha_c_coord, c_coord, n_coord = None, None, None
for atom in residue:
if atom.getName() == 'CA':
alpha_c_coord = atom.getCoords()
if atom.getName() == 'C':
c_coord = atom.getCoords()
if atom.getName() == 'N':
n_coord = atom.getCoords()
if alpha_c_coord is not None and c_coord is not None and n_coord is not None:
alpha_c_coords.append(alpha_c_coord)
c_coords.append(c_coord)
n_coords.append(n_coord)
complete_residues.append(True)
else:
complete_residues.append(False)
assert len(complete_residues) == len(prot_coords)
prot_coords = prot_coords[complete_residues]
prot_features = prot_features[complete_residues]
prot_node_type = prot_node_type[complete_residues]
if sec_features is not None:
sec_features = sec_features[complete_residues]
assert len(sec_features) == len(prot_coords)
assert len(alpha_c_coords) == len(prot_coords)
assert len(c_coords) == len(prot_coords)
assert len(n_coords) == len(prot_coords)
else:
raise ValueError("error prot_graph_type")
return lig_coords, lig_features, lig_edges, lig_node_type, \
prot_coords, prot_features, prot_edges, prot_node_type, sec_features,\
np.array(alpha_c_coords), np.array(c_coords), np.array(n_coords)
def read_ligands_chembl_smina_multi_pose(name, valid_ligand_index, dataset_path, ligcut, lig_type='openbabel', top_N=2,
docking_type='site_specific'):
valid_lig_multi_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list = [], [], [], [], []
for index, valid in enumerate(valid_ligand_index):
if docking_type == 'site_specific':
lig_paths_mol2 = [os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}.mol2')]
elif docking_type == 'blind':
lig_paths_mol2 = [os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}_blind.mol2')]
elif docking_type == 'all':
lig_paths_mol2 = [os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}.mol2')] +\
[os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}_blind.mol2')]
if valid:
if lig_type == 'openbabel':
lig_multi_coords = []
previou_atom_num = -1
for lig_path_mol2 in lig_paths_mol2:
m_lig_iter = pybel.readfile('mol2', lig_path_mol2)
c_m_lig = 0
while c_m_lig < top_N:
try:
m_lig = next(m_lig_iter)
lig_coords, lig_features = featurizer.get_features(m_lig)
if previou_atom_num != -1:
assert len(lig_coords) == previou_atom_num
else:
previou_atom_num == len(lig_coords)
lig_edges = get_bonded_edges_obmol(m_lig)
lig_node_type = lig_atom_type_obmol(m_lig)
lig_multi_coords.append(lig_coords)
c_m_lig += 1
except:
print(f'{lig_path_mol2} only has {c_m_lig} poses')
break
valid_lig_multi_coords_list.append(lig_multi_coords)
valid_lig_features_list.append(lig_features)
valid_lig_edges_list.append(lig_edges)
valid_lig_node_type_list.append(lig_node_type)
valid_index_list.append(index)
return valid_lig_multi_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list
def read_ligands_chembl_smina(name, valid_ligand_index, dataset_path, ligcut, lig_type='openbabel',docking_type='site_specific'):
valid_lig_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list = [], [], [], [], []
for index, valid in enumerate(valid_ligand_index):
lig_path_mol2 = os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}.mol2')
if docking_type == 'blind':
lig_path_mol2 = os.path.join(dataset_path, name, 'ligand_smina_poses', f'{index}_blind.mol2')
if valid:
if lig_type == 'openbabel':
try:
m_lig = next(pybel.readfile('mol2', lig_path_mol2))
except:
print(lig_path_mol2)
lig_coords, lig_features = featurizer.get_features(m_lig)
lig_edges = get_bonded_edges_obmol(m_lig)
lig_node_type = lig_atom_type_obmol(m_lig)
valid_lig_coords_list.append(lig_coords)
valid_lig_features_list.append(lig_features)
valid_lig_edges_list.append(lig_edges)
valid_lig_node_type_list.append(lig_node_type)
valid_index_list.append(index)
elif lig_type == 'rdkit':
m_lig = read_rdmol(lig_path_mol2)
conf = m_lig.GetConformer()
lig_coords, lig_features = conf.GetPositions(), lig_atom_featurizer_rdmol(m_lig)
lig_edges = get_bonded_edges_rdmol(m_lig)
lig_node_type = lig_atom_type_rdmol(m_lig)
valid_lig_coords_list.append(lig_coords)
valid_lig_features_list.append(lig_features)
valid_lig_edges_list.append(lig_edges)
valid_lig_node_type_list.append(lig_node_type)
valid_index_list.append(index)
return valid_lig_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list
def read_ligands(name, dataset_path, ligcut, lig_type='openbabel'):
#########################Read Ligand########################################################
lig_path_sdf = os.path.join(dataset_path, name, 'visualize_dir', 'total_vs.sdf')
valid_lig_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list = [], [], [], [], []
if lig_type == 'openbabel':
m_ligs = pybel.readfile('sdf', lig_path_sdf)
for index, m_lig in enumerate(m_ligs):
try:
lig_coords, lig_features = featurizer.get_features(m_lig)
lig_edges = get_bonded_edges_obmol(m_lig)
lig_node_type = lig_atom_type_obmol(m_lig)
valid_lig_coords_list.append(lig_coords)
valid_lig_features_list.append(lig_features)
valid_lig_edges_list.append(lig_edges)
valid_lig_node_type_list.append(lig_node_type)
valid_index_list.append(index)
except:
print(f'{index} error')
elif lig_type == 'rdkit':
supplier = Chem.SDMolSupplier(lig_path_sdf, sanitize=True, removeHs=False)
for index, m_lig in enumerate(supplier):
try:
conf = m_lig.GetConformer()
lig_coords, lig_features = conf.GetPositions(), lig_atom_featurizer_rdmol(m_lig)
lig_edges = get_bonded_edges_rdmol(m_lig)
lig_node_type = lig_atom_type_rdmol(m_lig)
valid_lig_coords_list.append(lig_coords)
valid_lig_features_list.append(lig_features)
valid_lig_edges_list.append(lig_edges)
valid_lig_node_type_list.append(lig_node_type)
valid_index_list.append(index)
except:
print(f'{index} error')
return valid_lig_coords_list, valid_lig_features_list, valid_lig_edges_list, valid_lig_node_type_list, valid_index_list
def read_casf_ligands(name, dataset_path, ligcut, lig_type='openbabel'):
lig_files = os.listdir(os.path.join(dataset_path, name))
assert lig_type == 'openbabel'
lig_multi_name_list, lig_multi_coords_list, lig_features_list, lig_edges_list, lig_node_type_list = [], [], [], [], []
for lig_file in lig_files:
lig_name = lig_file.split('_')[-1][:4]
file_type = lig_file.split('.')[-1]
lig_path = os.path.join(dataset_path, name, lig_file)
m_ligs = pybel.readfile(file_type, lig_path)
multi_coords, multi_names = [], []
for index, m_lig in enumerate(m_ligs):
lig_coords, lig_features = featurizer.get_features(m_lig)
if index == 0:
lig_edges = get_bonded_edges_obmol(m_lig)
lig_node_type = lig_atom_type_obmol(m_lig)
multi_coords.append(lig_coords)
multi_names.append(f'{lig_name}_ligand_{index+1}')
lig_multi_name_list.append(multi_names)
lig_multi_coords_list.append(multi_coords)
lig_features_list.append(lig_features)
lig_edges_list.append(lig_edges)
lig_node_type_list.append(lig_node_type)
return lig_multi_name_list, lig_multi_coords_list, lig_features_list, lig_edges_list, lig_node_type_list
def read_proteins(name, dataset_path, prot_graph_type, protcut):
#########################Read Protein########################################################
try:
prot_valid_chains = parsePDB(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb'))
except:
raise ValueError(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb'))
prot_alpha_c = prot_valid_chains.select('calpha')
alpha_c_coords, c_coords, n_coords = [], [], []
# writePDB(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb'), prot_valid_chains)
if prot_graph_type.startswith('atom'):
prot_path = os.path.join(dataset_path, name, f'{name}_{graph_type_filename[prot_graph_type]}')
m_prot = next(pybel.readfile('pdb', prot_path))
sec_features = None
prot_coords_valid, prot_features_valid = featurizer.get_features(m_prot)
prot_edges = get_bonded_edges_obmol(m_prot) if protcut is None else None
prot_node_type = lig_atom_type_obmol(m_prot)
elif prot_graph_type.startswith('residue'):
alpha_c_sec_features = None
m_prot = prot_alpha_c
m_prot_complete = prot_valid_chains
sec_features = alpha_c_sec_features
prot_coords, prot_features = prot_alpha_c_featurizer(m_prot)
prot_node_type = prot_residue_type(m_prot)
prot_edges = None
hv = m_prot_complete.getHierView()
index = 0
valid_index, prot_coords_valid, prot_features_valid = [], [], []
for chain in hv:
for i, residue in enumerate(chain):
alpha_c_coord, c_coord, n_coord = None, None, None
for atom in residue:
if atom.getName() == 'CA':
alpha_c_coord = atom.getCoords()
if atom.getName() == 'C':
c_coord = atom.getCoords()
if atom.getName() == 'N':
n_coord = atom.getCoords()
if alpha_c_coord is not None and c_coord is not None and n_coord is not None:
alpha_c_coords.append(alpha_c_coord)
c_coords.append(c_coord)
n_coords.append(n_coord)
valid_index.append(index)
index += 1
prot_coords_valid = prot_coords[valid_index]
prot_features_valid = prot_features[valid_index]
else:
raise ValueError("error prot_graph_type")
return prot_coords_valid, prot_features_valid, prot_edges, prot_node_type, sec_features,\
np.array(alpha_c_coords), np.array(c_coords), np.array(n_coords),\
def read_molecules(name, dataset_path, prot_graph_type, ligcut, protcut, lig_type='openbabel',init_type='redock_init',
chain_cut=5.0, p2rank_base=None, binding_site_type='ligand_center', LAS_mask=True,
keep_hs_before_rdkit_generate=False, rd_gen_maxIters=200):
#########################Read Ligand########################################################
lig_path_mol2 = os.path.join(dataset_path, name, f'{name}_ligand.mol2')
lig_path_sdf = os.path.join(dataset_path, name, f'{name}_ligand.sdf')
if lig_type == 'openbabel':
m_lig = next(pybel.readfile('mol2', lig_path_mol2))
lig_coords, lig_features = featurizer.get_features(m_lig)
lig_edges = get_bonded_edges_obmol(m_lig)
lig_node_type = lig_atom_type_obmol(m_lig)
elif lig_type == 'rdkit':
m_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True)
if m_lig == None: # read mol2 file if sdf file cannot be sanitized
m_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True)
conf = m_lig.GetConformer()
lig_coords, lig_features = conf.GetPositions(), lig_atom_featurizer_rdmol(m_lig)
lig_edges = get_bonded_edges_rdmol(m_lig)
lig_node_type = lig_atom_type_rdmol(m_lig)
#########################Get Ligand Rdkit Init Coordinates###################################
if init_type == 'rdkit_init':
rd_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=not keep_hs_before_rdkit_generate)
if rd_lig == None: # read mol2 file if sdf file cannot be sanitized
rd_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=not keep_hs_before_rdkit_generate)
try:
lig_init_coords = get_rdkit_coords(rd_lig, sanitize=True, remove_hs=True, maxIters=rd_gen_maxIters)
except Exception as e:
lig_init_coords = lig_coords
with open(f'temp_create_dataset_rdkit_timesplit_no_lig_or_rec_overlap_train_remove_hs_before_generate_{not keep_hs_before_rdkit_generate}.log', 'a') as f:
f.write('Generating RDKit conformer failed for \n')
f.write(name)
f.write('\n')
f.write(str(e))
f.write('\n')
f.flush()
assert len(lig_init_coords) == len(lig_coords)
rdlig_node_type = lig_atom_type_rdmol(rd_lig)
# remove all h
if lig_type == 'openbabel':
lig_init_coords = lig_init_coords[rdlig_node_type != 1]
try:
if len(lig_init_coords)!=len(lig_coords):
raise ValueError('{} {}!={}'.format(name, len(lig_init_coords), len(lig_coords)))
except ValueError as e:
print("error raise:", repr(e))
raise
elif init_type == 'redock_init':
lig_init_coords = lig_coords
elif init_type == 'random_init':
lig_init_coords = np.random.randn(len(lig_coords),3)
else:
lig_init_coords = None
# random location and orientation
if lig_init_coords is not None:
rot_T, rot_b = random_rotation_translation()
mean_to_remove = lig_init_coords.mean(axis=0, keepdims=True)
lig_init_coords = (rot_T @ (lig_init_coords - mean_to_remove).T).T + rot_b
#########################Read Protein########################################################
if os.path.exists(os.path.join(dataset_path, name, f'{name}_protein_processed.pdb')):
prot_complex = parsePDB(os.path.join(dataset_path, name, f'{name}_protein_processed.pdb'))
else:
prot_complex = parsePDB(os.path.join(dataset_path, name, f'{name}_protein.pdb'))
prot_structure_no_water = prot_complex.select('protein')
if chain_cut is not None:
prot_valid_chains = prot_structure_no_water.select(f'same chain as within {chain_cut} of ligand', ligand=lig_coords)
else:
prot_valid_chains = prot_structure_no_water
prot_valid_pocket = prot_structure_no_water.select('same residue as within 12 of ligand', ligand=lig_coords)
try:
prot_alpha_c = prot_valid_chains.select('calpha')
prot_pocket_alpha_c = prot_valid_pocket.select('calpha')
except:
raise ValueError(f'{name} process pdb error')
alpha_c_sec_features = None
prot_pocket_alpha_c_sec_features = None
alpha_c_coords, c_coords, n_coords = [], [], []
writePDB(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb'), prot_valid_chains)
writePDB(os.path.join(dataset_path, name, f'{name}_valid_pocket.pdb'), prot_valid_pocket)
if prot_graph_type.startswith('atom'):
prot_path = os.path.join(dataset_path, name, f'{name}_{graph_type_filename[prot_graph_type]}')
m_prot = next(pybel.readfile('pdb', prot_path))
sec_features = None
prot_coords_valid, prot_features_valid = featurizer.get_features(m_prot)
prot_edges = get_bonded_edges_obmol(m_prot) if protcut is None else None
prot_node_type = lig_atom_type_obmol(m_prot)
elif prot_graph_type.startswith('residue'):
alpha_c_sec_features = None
prot_pocket_alpha_c_sec_features = None
m_prot = prot_pocket_alpha_c if prot_graph_type.endswith('pocket') else prot_alpha_c
m_prot_complete = prot_valid_pocket if prot_graph_type.endswith('pocket') else prot_valid_chains
sec_features = prot_pocket_alpha_c_sec_features if prot_graph_type.endswith('pocket') else alpha_c_sec_features
prot_coords, prot_features = prot_alpha_c_featurizer(m_prot)
prot_node_type = prot_residue_type(m_prot)
prot_edges = None
hv = m_prot_complete.getHierView()
index = 0
valid_index, prot_coords_valid, prot_features_valid = [], [], []
for chain in hv:
for i, residue in enumerate(chain):
alpha_c_coord, c_coord, n_coord = None, None, None
for atom in residue:
if atom.getName() == 'CA':
alpha_c_coord = atom.getCoords()
if atom.getName() == 'C':
c_coord = atom.getCoords()
if atom.getName() == 'N':
n_coord = atom.getCoords()
if alpha_c_coord is not None and c_coord is not None and n_coord is not None:
alpha_c_coords.append(alpha_c_coord)
c_coords.append(c_coord)
n_coords.append(n_coord)
valid_index.append(index)
index += 1
prot_coords_valid = prot_coords[valid_index]
prot_features_valid = prot_features[valid_index]
else:
raise ValueError("error prot_graph_type")
############################### Read Binding Site ##########################################
if binding_site_type == 'p2rank':
p2rank_result_path = os.path.join(p2rank_base, f'{name}_valid_chains.pdb_predictions.csv')
df = pd.read_csv(p2rank_result_path, usecols= [' center_x',' center_y',' center_z'])
possible_binding_sites = df.values
ligand_center = lig_coords.mean(axis=0)
if len(possible_binding_sites) == 0:
binding_site = ligand_center
else:
binding_site_index = ((possible_binding_sites - ligand_center) ** 2).sum(axis=1).argmin()
binding_site = possible_binding_sites[binding_site_index]
elif binding_site_type == 'ligand_center':
binding_site = lig_coords.mean(axis=0)
############################### Get LAS Mask ##########################################
if LAS_mask:
assert lig_type == 'rdkit'
lig_LAS_mask = get_LAS_distance_constraint_mask(m_lig)
else:
lig_LAS_mask = None
return lig_coords, lig_features, lig_edges, lig_node_type, lig_init_coords, \
prot_coords_valid, prot_features_valid, prot_edges, prot_node_type, sec_features,\
np.array(alpha_c_coords), np.array(c_coords), np.array(n_coords),\
binding_site.reshape(1,-1), lig_LAS_mask
def read_molecules_inference(lig_path, protein_path, prot_graph_type, chain_cut=5.0):
#########################Read Ligand########################################################
m_lig = next(pybel.readfile(lig_path.split('.')[-1], lig_path))
lig_coords, lig_features = featurizer.get_features(m_lig)
lig_edges = get_bonded_edges_obmol(m_lig)
lig_node_type = lig_atom_type_obmol(m_lig)
#########################Read Protein########################################################
prot_complex = parsePDB(protein_path)
prot_structure_no_water = prot_complex.select('protein')
if chain_cut is not None:
prot_valid_chains = prot_structure_no_water.select(f'same chain as within {chain_cut} of ligand',
ligand=lig_coords)
else:
prot_valid_chains = prot_structure_no_water
prot_valid_pocket = prot_structure_no_water.select('same residue as within 12 of ligand', ligand=lig_coords)
prot_alpha_c = prot_valid_chains.select('calpha')
prot_pocket_alpha_c = prot_valid_pocket.select('calpha')
alpha_c_coords, c_coords, n_coords = [], [], []
alpha_c_sec_features,prot_pocket_alpha_c_sec_features = None, None
m_prot = prot_pocket_alpha_c if prot_graph_type.endswith('pocket') else prot_alpha_c
m_prot_complete = prot_valid_pocket if prot_graph_type.endswith('pocket') else prot_valid_chains
sec_features = prot_pocket_alpha_c_sec_features if prot_graph_type.endswith('pocket') else alpha_c_sec_features
prot_coords, prot_features = prot_alpha_c_featurizer(m_prot)
prot_node_type = prot_residue_type(m_prot)
prot_edges = None
hv = m_prot_complete.getHierView()
index = 0
valid_index, prot_coords_valid, prot_features_valid, ca_res_number_valid, residue_name_valid, chain_index_valid = [], [], [], [], [], []
for chain in hv:
for i, residue in enumerate(chain):
alpha_c_coord, c_coord, n_coord = None, None, None
ca_res_number = residue.getResnums()[0]
residue_name = residue.getResname()
chain_index = residue.getChid()
# input(ca_res_number)
# input(residue_name)
for atom in residue:
if atom.getName() == 'CA':
alpha_c_coord = atom.getCoords()
if atom.getName() == 'C':
c_coord = atom.getCoords()
if atom.getName() == 'N':
n_coord = atom.getCoords()
if alpha_c_coord is not None and c_coord is not None and n_coord is not None:
alpha_c_coords.append(alpha_c_coord)
c_coords.append(c_coord)
n_coords.append(n_coord)
valid_index.append(index)
ca_res_number_valid.append(ca_res_number)
residue_name_valid.append(residue_name)
chain_index_valid.append(chain_index)
index += 1
prot_coords_valid = alpha_c_coords
ResIndex_valid = [ResDict.get(ResName,UNKOWN_RES) for ResName in residue_name_valid]
prot_node_type = torch.tensor(ResIndex_valid,dtype=torch.int64)
prot_features_valid = torch.tensor(np.eye(UNKOWN_RES + 1)[ResIndex_valid])
return lig_coords, lig_features, lig_edges, lig_node_type, \
prot_coords_valid, prot_features_valid, prot_edges, prot_node_type, sec_features, \
np.array(alpha_c_coords), np.array(c_coords), np.array(n_coords), ca_res_number_valid, chain_index_valid
def get_ligand_smiles(name, dataset_path,):
lig_path_mol2 = os.path.join(dataset_path, name, f'{name}_ligand.mol2')
lig_path_sdf = os.path.join(dataset_path, name, f'{name}_ligand.sdf')
m_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True)
if m_lig == None: # read mol2 file if sdf file cannot be sanitized
m_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True)
sm = Chem.MolToSmiles(m_lig)
m_sm_order = list(m_lig.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
sm2m_order = [0] * len(m_sm_order)
for index, order in enumerate(m_sm_order):
sm2m_order[order] = index
return sm, sm2m_order
def get_protein_fasta(name, dataset_path,):
try:
prot_valid_chains = parsePDB(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb'))
except:
raise ValueError(f'{name} error!')
hv = prot_valid_chains.getHierView()
index = 0
valid_index, prot_coords_valid, prot_features_valid = [], [], []
alpha_c_coords, c_coords, n_coords = [], [], []
fasta_list = []
for chain in hv:
fasta_list.append(chain.getSequence())
for i, residue in enumerate(chain):
alpha_c_coord, c_coord, n_coord = None, None, None
for atom in residue:
if atom.getName() == 'CA':
alpha_c_coord = atom.getCoords()
if atom.getName() == 'C':
c_coord = atom.getCoords()
if atom.getName() == 'N':
n_coord = atom.getCoords()
if alpha_c_coord is not None and c_coord is not None and n_coord is not None:
alpha_c_coords.append(alpha_c_coord)
c_coords.append(c_coord)
n_coords.append(n_coord)
valid_index.append(index)
index += 1
return fasta_list, valid_index
def prot_p2rank_feats(p2rank_result_path, p2rank_feats_tpye='zscore', pocket_cut=10):
df = pd.read_csv(p2rank_result_path)
df.columns = df.columns.str.strip()
residue_zscores, residue_pocket_idxs = df[p2rank_feats_tpye].values, df['pocket'].values
feat_len = pocket_cut + 2
p2rank_feats = torch.zeros(len(residue_zscores), feat_len)
for index, residue_zscore in enumerate(residue_zscores):
pocket_idx = residue_pocket_idxs[index]
if pocket_idx == 0 :
p2rank_feats[index, feat_len - 1] = residue_zscore
elif pocket_idx > pocket_cut :
p2rank_feats[index, feat_len - 2] = residue_zscore
else:
p2rank_feats[index, pocket_idx - 1] = residue_zscore
return p2rank_feats
def get_p2rank_feats(name, dataset_path, p2rank_base=None, p2rank_feats_tpye='zscore', pocket_cut=10):
try:
prot_valid_chains = parsePDB(os.path.join(dataset_path, name, f'{name}_valid_chains.pdb'))
except:
raise ValueError(f'{name} error!')
prot_alpha_c = prot_valid_chains.select('calpha')
alpha_c_coords, c_coords, n_coords = [], [], []
p2rank_result_path = os.path.join(p2rank_base, f'{name}_valid_chains.pdb_residues.csv')
p2rank_features = prot_p2rank_feats(p2rank_result_path, pocket_cut=pocket_cut, p2rank_feats_tpye=p2rank_feats_tpye)
prot_coords, prot_features = prot_alpha_c_featurizer(prot_alpha_c)
try:
assert len(p2rank_features) == len(prot_features)
except:
# print(f'p2rank protein number, {len(p2rank_features)}')
# print(f'prot_features protein number, {len(prot_features)}')
# raise ValueError(f'p2rank length error, {name}')
with open('p2rank_feats_error.txt','a') as f:
f.write(f'{name}\n')
return torch.zeros(len(prot_features), pocket_cut + 2)
hv = prot_valid_chains.getHierView()
index = 0
valid_index, prot_coords_valid, prot_features_valid = [], [], []
for chain in hv:
for i, residue in enumerate(chain):
alpha_c_coord, c_coord, n_coord = None, None, None
for atom in residue:
if atom.getName() == 'CA':
alpha_c_coord = atom.getCoords()
if atom.getName() == 'C':
c_coord = atom.getCoords()
if atom.getName() == 'N':
n_coord = atom.getCoords()
if alpha_c_coord is not None and c_coord is not None and n_coord is not None:
alpha_c_coords.append(alpha_c_coord)
c_coords.append(c_coord)
n_coords.append(n_coord)
valid_index.append(index)
index += 1
p2rank_features_valid = p2rank_features[valid_index]
return p2rank_features_valid
def binarize(x):
return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
#adj - > n_hops connections adj
def n_hops_adj(adj, n_hops):
adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))]
for i in range(2, n_hops+1):
adj_mats.append(binarize(adj_mats[i-1] @ adj_mats[1]))
extend_mat = torch.zeros_like(adj)
for i in range(1, n_hops+1):
extend_mat += (adj_mats[i] - adj_mats[i-1]) * i
return extend_mat
def get_LAS_distance_constraint_mask(mol):
# Get the adj
adj = Chem.GetAdjacencyMatrix(mol)
adj = torch.from_numpy(adj)
extend_adj = n_hops_adj(adj,2)
# add ring
ssr = Chem.GetSymmSSSR(mol)
for ring in ssr:
# print(ring)
for i in ring:
for j in ring:
if i==j:
continue
else:
extend_adj[i][j]+=1
# turn to mask
mol_mask = binarize(extend_adj)
return mol_mask
def get_lig_graph_geodiff(lig_coords, lig_features, lig_node_type, lig_edges):
g_lig = dgl.DGLGraph()
num_atoms_lig = len(lig_coords) # number of ligand atom_level
g_lig.add_nodes(num_atoms_lig)
g_lig.ndata['h'] = torch.from_numpy(lig_features) if isinstance(lig_features, np.ndarray) else lig_features
g_lig.ndata['node_type'] = lig_node_type # schnet\mgcn features
edges = lig_edges
src_ls, dst_ls, bond_type = list(zip(*edges))
src_ls, dst_ls = np.array(src_ls), np.array(dst_ls)
g_lig.add_edges(src_ls, dst_ls)
g_lig.ndata['pos'] = torch.tensor(lig_coords, dtype=torch.float)
g_lig.edata['bond_type'] = torch.tensor(bond_type, dtype=torch.int64)
return g_lig
def get_lig_multi_pose_graph_equibind(lig_multi_coords, lig_features, lig_node_type, max_neighbors=None, cutoff=5.0):
multi_graphs = []
for lig_coords in lig_multi_coords:
multi_graphs.append(get_lig_graph_equibind(lig_coords, lig_features, lig_node_type, max_neighbors, cutoff))
return multi_graphs
def get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type, max_neighbors=None, cutoff=5.0):
num_nodes = lig_coords.shape[0]
assert lig_coords.shape[1] == 3
distance = spatial.distance_matrix(lig_coords, lig_coords)
src_list = []
dst_list = []
dist_list = []
mean_norm_list = []
for i in range(num_nodes):
dst = list(np.where(distance[i, :] < cutoff)[0])
dst.remove(i)
if max_neighbors != None and len(dst) > max_neighbors:
dst = list(np.argsort(distance[i, :]))[1: max_neighbors + 1] # closest would be self loop
if len(dst) == 0:
dst = list(np.argsort(distance[i, :]))[1:2] # closest would be the index i itself > self loop
print(
f'The lig_radius {cutoff} was too small for one lig atom such that it had no neighbors. So we connected {i} to the closest other lig atom {dst}')
assert i not in dst
src = [i] * len(dst)
src_list.extend(src)
dst_list.extend(dst)
valid_dist = list(distance[i, dst])
dist_list.extend(valid_dist)
valid_dist_np = distance[i, dst]
sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1))
weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num)
assert weights[0].sum() > 1 - 1e-2 and weights[0].sum() < 1.01
diff_vecs = lig_coords[src, :] - lig_coords[dst, :] # (neigh_num, 3)
mean_vec = weights.dot(diff_vecs) # (sigma_num, 3)
denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,)
mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,)
mean_norm_list.append(mean_vec_ratio_norm)
assert len(src_list) == len(dst_list)
assert len(dist_list) == len(dst_list)
graph = dgl.graph((torch.tensor(src_list), torch.tensor(dst_list)), num_nodes=num_nodes, idtype=torch.int32)
graph.ndata['h'] = torch.from_numpy(lig_features) if isinstance(lig_features, np.ndarray) else lig_features
graph.ndata['node_type'] = lig_node_type # schnet\mgcn features
graph.edata['e'] = distance_featurizer(dist_list, 0.75) # avg distance = 1.3 So divisor = (4/7)*1.3 = ~0.75
graph.ndata['pos'] = torch.from_numpy(np.array(lig_coords).astype(np.float32))
graph.ndata['mu_r_norm'] = torch.from_numpy(np.array(mean_norm_list).astype(np.float32))
if lig_edges is not None:
edge_src_dst_2_edge_index = {}
for idx, (s, d) in enumerate(zip(src_list, dst_list)):
edge_src_dst_2_edge_index[(s, d)] = idx
bond_src_ls, bond_dst_ls, bond_type = list(zip(*lig_edges))
bond_edge_idx = []
for bs, bd in zip(bond_src_ls, bond_dst_ls):
bond_edge_idx.append(edge_src_dst_2_edge_index[(bs, bd)])
graph.edata['bond_type'] = torch.zeros(len(src_list), len(bond_type[0]))
graph.edata['bond_type'][bond_edge_idx] = torch.tensor(bond_type).to(torch.float32)
return graph
def get_lig_graph(lig_coords,lig_features, lig_edges, lig_node_type, cutoff=None):
g_lig = dgl.DGLGraph()
num_atoms_lig = len(lig_coords) # number of ligand atom_level
g_lig.add_nodes(num_atoms_lig)
g_lig.ndata['h'] = torch.from_numpy(lig_features) if isinstance(lig_features, np.ndarray) else lig_features
g_lig.ndata['node_type'] = lig_node_type # schnet\mgcn features
dis_matrix_lig = spatial.distance_matrix(lig_coords, lig_coords)
if cutoff is None:
edges = lig_edges
src_ls, dst_ls, bond_type = list(zip(*edges))
src_ls, dst_ls = np.array(src_ls), np.array(dst_ls)
else:
node_idx = np.where( (dis_matrix_lig < cutoff) & (dis_matrix_lig!=0) ) # no self-loop
src_ls = node_idx[0]
dst_ls = node_idx[1]
g_lig.add_edges(src_ls, dst_ls)
lig_d = torch.tensor(dis_matrix_lig[src_ls, dst_ls], dtype=torch.float).view(-1, 1)
g_lig.edata['distance'] = lig_d
g_lig.edata['e'] = lig_d * 0.1 # g.edata['e'] ~ (n_bond1+n_bond2) * k
g_lig.ndata['pos'] = torch.tensor(lig_coords, dtype=torch.float)
D3_info = bond_feature(g_lig)
g_lig.edata['e'] = torch.cat([g_lig.edata['e'], D3_info], dim=-1)
g_lig.edata['bond_type'] = torch.tensor(bond_type,dtype=torch.int64)
# g_lig.ndata.pop('pos')
assert not torch.any(torch.isnan(D3_info))
return g_lig
def get_prot_atom_graph(prot_coords, prot_features, prot_edges, prot_node_type, cutoff=None):
g_prot = dgl.DGLGraph()
num_atoms_prot = len(prot_coords)
g_prot.add_nodes(num_atoms_prot)
g_prot.ndata['h'] = torch.from_numpy(prot_features) if isinstance(prot_features, np.ndarray) else prot_features
g_prot.ndata['node_type'] = prot_node_type # schnet\mgcn features
dis_matrix_lig = spatial.distance_matrix(prot_coords, prot_coords)
if cutoff is None:
edges = prot_edges
src_ls, dst_ls, bond_type = list(zip(*edges))
src_ls, dst_ls = np.array(src_ls), np.array(dst_ls)
else:
node_idx = np.where( (dis_matrix_lig < cutoff) & (dis_matrix_lig!=0) ) # no self-loop
src_ls = node_idx[0]
dst_ls = node_idx[1]
g_prot.add_edges(src_ls, dst_ls)
prot_d = torch.tensor(dis_matrix_lig[src_ls, dst_ls], dtype=torch.float).view(-1, 1)
g_prot.edata['distance'] = prot_d
g_prot.edata['e'] = prot_d * 0.1 # g.edata['e'] ~ (n_bond1+n_bond2) * k
g_prot.ndata['pos'] = torch.tensor(prot_coords, dtype=torch.float)
D3_info = bond_feature(g_prot)
g_prot.edata['e'] = torch.cat([g_prot.edata['e'], D3_info], dim=-1)
g_prot.edata['bond_type'] = torch.tensor(bond_type, dtype=torch.int64)
# g_prot.ndata.pop('pos')
assert not torch.any(torch.isnan(D3_info))
return g_prot
def distance_featurizer(dist_list, divisor) -> torch.Tensor:
# you want to use a divisor that is close to 4/7 times the average distance that you want to encode
length_scale_list = [1.5 ** x for x in range(15)]
center_list = [0. for _ in range(15)]
num_edge = len(dist_list)
dist_list = np.array(dist_list)
transformed_dist = [np.exp(- ((dist_list / divisor) ** 2) / float(length_scale))
for length_scale, center in zip(length_scale_list, center_list)]
transformed_dist = np.array(transformed_dist).T
transformed_dist = transformed_dist.reshape((num_edge, -1))
return torch.from_numpy(transformed_dist.astype(np.float32))
def local_coordinate_system_feature(prot_coords, c_alpha_coords, c_coords, n_coords, prot_d, src_ls, dst_ls):
n_i_list, u_i_list, v_i_list = [], [], []
for i in range(len(prot_coords)):
nitrogen = n_coords[i]
c_alpha = c_alpha_coords[i]
carbon = c_coords[i]
u_i = (nitrogen - c_alpha) / np.linalg.norm(nitrogen - c_alpha)
t_i = (carbon - c_alpha) / np.linalg.norm(carbon - c_alpha)
n_i = np.cross(u_i, t_i) / np.linalg.norm(np.cross(u_i, t_i))
v_i = np.cross(n_i, u_i)
assert (math.fabs(
np.linalg.norm(v_i) - 1.) < 1e-5), "protein utils protein_to_graph_dips, v_i norm larger than 1"
n_i_list.append(n_i)
u_i_list.append(u_i)
v_i_list.append(v_i)
n_i_feat, u_i_feat, v_i_feat = np.stack(n_i_list), np.stack(u_i_list), np.stack(v_i_list)
edge_feat_ori_list = []
for i in range(len(prot_d)):
src = src_ls[i]
dst = dst_ls[i]
# place n_i, u_i, v_i as lines in a 3x3 basis matrix
basis_matrix = np.stack((n_i_feat[dst, :], u_i_feat[dst, :], v_i_feat[dst, :]), axis=0)
p_ij = np.matmul(basis_matrix, c_alpha_coords[src, :] - c_alpha_coords[dst, :])
q_ij = np.matmul(basis_matrix, n_i_feat[src, :]) # shape (3,)
k_ij = np.matmul(basis_matrix, u_i_feat[src, :])
t_ij = np.matmul(basis_matrix, v_i_feat[src, :])
s_ij = np.concatenate((p_ij, q_ij, k_ij, t_ij), axis=0) # shape (12,)
edge_feat_ori_list.append(s_ij)
edge_feat_ori_feat = np.stack(edge_feat_ori_list, axis=0) # shape (num_edges, 12)
edge_feat_ori_feat = torch.from_numpy(edge_feat_ori_feat.astype(np.float32))
c_alpha_edge_feat = torch.cat([distance_featurizer(prot_d, divisor=4), edge_feat_ori_feat],axis=1) # shape (num_edges, 17)
return c_alpha_edge_feat
def get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type, sec_features,
alpha_c_coords, c_coords, n_coords,
max_neighbor=None, cutoff=None):
try:
assert len(alpha_c_coords) == len(prot_coords)
assert len(c_coords) == len(prot_coords)
assert len(n_coords) == len(prot_coords)
except:
raise ValueError(f'{len(alpha_c_coords)} == {len(prot_coords)}, {len(c_coords)} == {len(prot_coords)}, {len(n_coords)} == {len(prot_coords)}')
g_prot = dgl.DGLGraph()
num_atoms_prot = len(prot_coords) # number of pocket atom_level
g_prot.add_nodes(num_atoms_prot)
g_prot.ndata['h'] = prot_features
g_prot.ndata['node_type'] = prot_node_type[:num_atoms_prot]
distances = spatial.distance_matrix(prot_coords, prot_coords)
src_list = []
dst_list = []
dist_list = []
mean_norm_list = []
for i in range(num_atoms_prot):
dst = list(np.where(distances[i, :] < cutoff)[0])
dst.remove(i)
if max_neighbor != None and len(dst) > max_neighbor:
dst = list(np.argsort(distances[i, :]))[1: max_neighbor + 1]
if len(dst) == 0:
dst = list(np.argsort(distances[i, :]))[1:2] # choose second because first is i itself
print(
f'The c_alpha_cutoff {cutoff} was too small for one c_alpha such that it had no neighbors. So we connected it to the closest other c_alpha')
assert i not in dst
src = [i] * len(dst)
src_list.extend(src)
dst_list.extend(dst)
valid_dist = list(distances[i, dst])
dist_list.extend(valid_dist)
valid_dist_np = distances[i, dst]
sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1))
weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num)
assert weights[0].sum() > 1 - 1e-2 and weights[0].sum() < 1.01
diff_vecs = alpha_c_coords[src, :] - alpha_c_coords[dst, :] # (neigh_num, 3)
mean_vec = weights.dot(diff_vecs) # (sigma_num, 3)
denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,)
mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,)
mean_norm_list.append(mean_vec_ratio_norm)
assert len(src_list) == len(dst_list)
assert len(dist_list) == len(dst_list)
g_prot.add_edges(src_list, dst_list)
g_prot.edata['e'] = local_coordinate_system_feature(prot_coords, alpha_c_coords, c_coords, n_coords,
dist_list, src_list, dst_list)
residue_representatives_loc_feat = torch.from_numpy(alpha_c_coords.astype(np.float32))
g_prot.ndata['pos'] = residue_representatives_loc_feat
g_prot.ndata['mu_r_norm'] = torch.from_numpy(np.array(mean_norm_list).astype(np.float32))
return g_prot
def get_prot_alpha_c_graph_ign(prot_coords, prot_features, prot_node_type, sec_features, cutoff=None):
g_prot = dgl.DGLGraph()
num_atoms_prot = len(prot_coords) # number of pocket atom_level
g_prot.add_nodes(num_atoms_prot)
g_prot.ndata['h'] = torch.from_numpy(prot_features) if isinstance(prot_features, np.ndarray) else prot_features
g_prot.ndata['node_type'] = prot_node_type[:num_atoms_prot]
dis_matrix = spatial.distance_matrix(prot_coords, prot_coords)
node_idx = np.where((dis_matrix < cutoff) & (dis_matrix != 0)) # no self-loop
src_ls = node_idx[0]
dst_ls = node_idx[1]
g_prot.add_edges(src_ls, dst_ls)
g_prot.ndata['pos'] = torch.tensor(prot_coords, dtype=torch.float)
prot_d = torch.tensor(dis_matrix[src_ls, dst_ls], dtype=torch.float).view(-1, 1)
g_prot.edata['distance'] = prot_d
# g_prot.edata['e'] = prot_d * 0.1
# calculate the 3D info for g
D3_info_th = bond_feature(g_prot)
g_prot.edata['e'] = torch.cat([D3_info_th, prot_d * 0.1], dim=-1)
# g_prot.ndata.pop('pos')
assert not torch.any(torch.isnan(D3_info_th))
return g_prot
def get_interact_graph_fc(lig_coords,prot_coords,cutoff=None):
# get fully connected graph
g_inter = dgl.DGLGraph()
num_atoms_lig = len(lig_coords)
num_atoms_prot = len(prot_coords)
g_inter.add_nodes(num_atoms_lig + num_atoms_prot)
dis_matrix = spatial.distance_matrix(lig_coords, prot_coords)
node_idx = np.where(dis_matrix > 0)
src_ls = np.concatenate([node_idx[0], node_idx[1] + num_atoms_lig])
dst_ls = np.concatenate([node_idx[1] + num_atoms_lig, node_idx[0]])
g_inter.add_edges(src_ls, dst_ls)
# 'd', distance between ligand atom_level and pocket atom_level
inter_dis = np.concatenate([dis_matrix[node_idx[0], node_idx[1]], dis_matrix[node_idx[0], node_idx[1]]])
inter_d = torch.tensor(inter_dis, dtype=torch.float).view(-1, 1)
g_inter.edata['e'] = inter_d # if add_self_loop=ture, need to modify here6+'
g_inter.ndata['pos'] = torch.cat([torch.tensor(lig_coords, dtype=torch.float),torch.tensor(prot_coords, dtype=torch.float)],dim=0)
return g_inter
def get_interact_multi_pose_graph_knn(lig_multi_coords, prot_coords, max_neighbor=None, min_neighbor=None, cutoff=None):
multi_graphs = []
for lig_coords in lig_multi_coords:
multi_graphs.append(get_interact_graph_knn(lig_coords,prot_coords,max_neighbor,min_neighbor,cutoff))
return multi_graphs
def get_interact_graph_knn(lig_coords,prot_coords,max_neighbor=None,min_neighbor=None,cutoff=None):
g_inter = dgl.DGLGraph()
num_atoms_lig = len(lig_coords)
num_atoms_prot = len(prot_coords)
g_inter.add_nodes(num_atoms_lig + num_atoms_prot)
dis_matrix = spatial.distance_matrix(lig_coords, prot_coords)
src_list, dst_list, dis_list = [], [], []
for i in range(num_atoms_lig):
dst = np.where(dis_matrix[i, :] < cutoff)[0]
if max_neighbor != None and len(dst) > max_neighbor:
dst = list(np.argsort(dis_matrix[i, :]))[:max_neighbor]
if min_neighbor != None and len(dst) == 0:
dst = list(np.argsort(dis_matrix[i, :]))[:min_neighbor]
src = [i] * len(dst)
src_list.extend(src)
dst_list.extend([x + num_atoms_lig for x in dst])
dis_list.extend(list(dis_matrix[i,dst]))
for i in range(num_atoms_prot):
dst = list(np.where(dis_matrix[:, i] < cutoff)[0])
if max_neighbor != None and len(dst) > max_neighbor:
dst = list(np.argsort(dis_matrix[:, i]))[:max_neighbor]
if min_neighbor != None and len(dst) == 0:
dst = list(np.argsort(dis_matrix[:, i]))[:min_neighbor] # choose second because first is i itself
src = [i] * len(dst)
src_list.extend([x + num_atoms_lig for x in src])
dst_list.extend(dst)
dis_list.extend(list(dis_matrix[dst, i]))
src_ls = np.array(src_list)
dst_ls = np.array(dst_list)
g_inter.add_edges(src_ls, dst_ls)
# 'd', distance between ligand atom_level and pocket atom_level
inter_dis = np.array(dis_list)
inter_d = torch.tensor(inter_dis, dtype=torch.float).view(-1, 1)
# squared_distance = inter_d ** 2
# all_sigmas_dist = [1.5 ** x for x in range(15)]
# prot_square_distance_scale = 10.0
# x_rel_mag = torch.cat([torch.exp(-(squared_distance / prot_square_distance_scale) / sigma) for sigma in
# all_sigmas_dist], dim=-1)
# g_inter.edata['e'] = x_rel_mag
g_inter.edata['d'] = inter_d
return g_inter
def get_interact_graph_knn_v2(lig_coords,prot_coords,max_neighbor=None,min_neighbor=None,cutoff=None,):
g_inter = dgl.DGLGraph()
num_atoms_lig = len(lig_coords)
num_atoms_prot = len(prot_coords)
g_inter.add_nodes(num_atoms_lig + num_atoms_prot)
dis_matrix = spatial.distance_matrix(lig_coords, prot_coords)
src_list, dst_list, dis_list = [], [], []
for i in range(num_atoms_lig):
dst = np.where(dis_matrix[i, :] < cutoff)[0]
if max_neighbor != None and len(dst) > max_neighbor:
dst = list(np.argsort(dis_matrix[i, :]))[:max_neighbor]
if min_neighbor != None and len(dst) == 0:
dst = list(np.argsort(dis_matrix[i, :]))[:min_neighbor]
src = [i] * len(dst)
src_list.extend(src)
dst_list.extend([x + num_atoms_lig for x in dst])
dis_list.extend(list(dis_matrix[i,dst]))
for i in range(num_atoms_prot):
dst = list(np.where(dis_matrix[:, i] < cutoff)[0])
if max_neighbor != None and len(dst) > max_neighbor:
dst = list(np.argsort(dis_matrix[:, i]))[:max_neighbor]
if min_neighbor != None and len(dst) == 0:
dst = list(np.argsort(dis_matrix[:, i]))[:min_neighbor] # choose second because first is i itself
src = [i] * len(dst)
src_list.extend([x + num_atoms_lig for x in src])
dst_list.extend(dst)
dis_list.extend(list(dis_matrix[dst, i]))
src_ls = np.array(src_list)
dst_ls = np.array(dst_list)
g_inter.add_edges(src_ls, dst_ls)
# 'd', distance between ligand atom_level and pocket atom_level
inter_dis = np.array(dis_list)
inter_d = torch.tensor(inter_dis, dtype=torch.float).view(-1, 1)
squared_distance = inter_d ** 2
all_sigmas_dist = [1.5 ** x for x in range(15)]
prot_square_distance_scale = 10.0
x_rel_mag = torch.cat([torch.exp(-(squared_distance / prot_square_distance_scale) / sigma) for sigma in
all_sigmas_dist], dim=-1)
g_inter.edata['e'] = x_rel_mag
g_inter.edata['d'] = inter_d
return g_inter
def pack_graph_and_labels(lig_graphs, prot_graphs, inter_graphs, labels):
return lig_graphs, prot_graphs, inter_graphs, labels
SSE_color = {'H':'R','G':'R','I':'R','E':'B','T':'G','S':'W','B':'W',' ':'W'}
def ExcuteDSSP(dataset_path,name,dssp = '/data/jiaxianyan/anaconda3/bin/mkdssp'):
InPath = os.path.join(dataset_path, name, f'{name}_protein_chains.pdb')
OutPath = os.path.join(dataset_path, name, f'{name}_protein_chains.pdb.dssp')
SSPath = os.path.join(dataset_path, name, f'{name}_protein_chains.pdb.SS')
cmd = '{} -i {} -o {}'.format(dssp,InPath,OutPath)
os.system(cmd)
def ListToStr(res):
# res is list
Res = ''
for r in res:
Res += r[0]*r[1]
return Res
def StrToList(res):
# res is string
before, length, Res = 'W',0,[]
for s in res:
if s != before:
Res.append((before, length))
length = 1
before = s
else:
length += 1
Res.append((before, length))
return Res
def AllocateIndex(res,BeginIndex=0):
# res is list
str_res = ListToStr(res)
ClusterIndex,Res = BeginIndex,[]
for index in range(len(str_res)):
if index!=0 and str_res[index]!=str_res[index-1]:
ClusterIndex += 1
Res.append(str(ClusterIndex)+','+str_res[index])
return Res,ClusterIndex+1
def SmoothHiearachical(res,threshod=0):
if threshod==0:
return res
# res is list
Res = [0] * len(res)
if len(res)==1:
return res
Res[0] = (res[1][0], res[0][1])
Res[-1] = (res[-2][0], res[-1][1])
for index in range(1,len(res)-1):
SSE,length = res[index]
if length<=threshod:
Res[index] = (Res[index-1][0],length)
else:
Res[index] = res[index]
return Res
def ExtractHiearachical(FilePath):
with open(FilePath,'r') as f:
lines = f.read().strip().split('\n')[28:]
borders, start = [], 0
for index,line in enumerate(lines):
if '!' in line:
borders.append((start,index))
start = index + 1
borders.append((start,len(lines)+1))
SSEIndex,SupplyIndex = 16,18
ResSSE = ''
ResSmoothSSE = ''
NumClusters = 0
ClusterIndexs = []
for border in borders:
SSE,Color = '',''
OneChain = lines[border[0]:border[1]]
for index,line in enumerate(OneChain):
if line[SupplyIndex]=='>' and line[SSEIndex] == ' ':
SSE += OneChain[index + 1][SSEIndex]
Color += SSE_color[OneChain[index + 1][SSEIndex]]
elif line[SupplyIndex]=='<' and line[SSEIndex] ==' ':
SSE += OneChain[index - 1][SSEIndex]
Color += SSE_color[OneChain[index - 1][SSEIndex]]
else:
SSE += line[SSEIndex]
Color += SSE_color[line[SSEIndex]]
ResSSE = ResSSE + SSE
res = StrToList(SSE)
Smoothres = SmoothHiearachical(res)
ClusterIndex,NumCluster = AllocateIndex(Smoothres,NumClusters)
ResSmoothSSE = ResSmoothSSE + ListToStr(Smoothres)
NumClusters = NumCluster
ClusterIndexs.extend(ClusterIndex)
def prot_alpha_c_featurizer(Structure):
Coords = Structure.getCoords()
ResNames = Structure.getResnames()
ResIndex = [ResDict.get(ResName,UNKOWN_RES) for ResName in ResNames]
ProtFeature = torch.tensor(np.eye(UNKOWN_RES+1)[ResIndex])
return Coords, ProtFeature
def prot_residue_type(Structure):
ResNames = Structure.getResnames()
ResIndex = [ResDict.get(ResName,UNKOWN_RES) for ResName in ResNames]
return torch.tensor(ResIndex,dtype=torch.int64)
def break_molecules_by_rotatable_bonds(mol):
patt = Chem.MolFromSmarts('[!$([NH]!@C(=O))&!D1&!$(*#*)]-&!@[!$([NH]!@C(=O))&!D1&!$(*#*)]')
rotatable_bonds = mol.GetSubstructMatches(patt)
bs = [mol.GetBondBetweenAtoms(x, y).GetIdx() for x, y in rotatable_bonds]
mol_broken = Chem.rdmolops.FragmentOnBonds(mol, bs)
frags = Chem.rdmolops.GetMolFrags(mol_broken)
return frags,rotatable_bonds
def get_molecule_mass_center(mol):
from rdkit.Chem import Descriptors
numatoms = mol.GetNumAtoms()
pos = mol.GetConformer().GetPositions()
atoms = [atom for atom in mol.GetAtoms()]
# get center of mass
mass = Descriptors.MolWt(mol)
center_of_mass = np.array(np.sum(atoms[i].GetMass() * pos[i] for i in range(numatoms))) / mass
return center_of_mass
def get_frag_geo_center(mol, frag):
pos = mol.GetConformer().GetPositions()
center_of_geo = np.array(np.sum(pos[i] for i in frag))/len(frag)
return center_of_geo
def get_center_frag(mol, frags):
mass_center = get_molecule_mass_center(mol)
frags_geo_center = np.array([get_frag_geo_center(mol, frag) for frag in frags])
center_distance = (frags_geo_center - mass_center) ** 2
return np.min(center_distance)
def get_frag_neighbors(mol, frags, rotatable_bonds):
numatoms = mol.GetNumAtoms()
frag2neighbour,frag2atom,neighbour2bond,atom2bond,bond2frag = {},{},{},{},{}
for i in range(len(numatoms)):
atom2bond[i] = []
for b_index,rb in enumerate(rotatable_bonds):
x, y = rb
atom2bond[x].append(b_index)
atom2bond[y].append(b_index)
bond2frag[b_index] = []
for f_index,frag in enumerate(frags):
frag2neighbour[f_index] = []
frag2atom[f_index] = []
neighbour2bond[f_index] = []
for atom in frag:
if len(atom2bond[atom]) != 0:
for bond in atom2bond[atom]:
bond2frag[bond].append(frag)
x, y = rotatable_bonds[bond]
if x == atom:
frag2atom[f_index].append(y)
else:
frag2atom[f_index].append(x)
for bond in bond2frag.keys():
x,y = bond2frag[bond]
frag2neighbour[x].append(y)
neighbour2bond[x].append(bond)
frag2neighbour[y].append(x)
neighbour2bond[y].append(bond)
return frag2neighbour,frag2atom,neighbour2bond
def bfs(frag_neighbors,neighbour2bond,center_frag_index):
bfs_rank = [center_frag_index]
bfs_bond_rank = []
visit = [0 for i in range(len(frag_neighbors))]
left,right = 0,1
while left < right:
cur_frag_index = bfs_rank[left]
for index,neighbor_frag_index in enumerate(frag_neighbors[cur_frag_index]):
if not visit[neighbor_frag_index]:
visit[neighbor_frag_index] = 1
bfs_rank.append(neighbor_frag_index)
bfs_bond_rank.append(neighbour2bond[cur_frag_index][index])
right += 1
left += 1
return bfs_rank
def get_bfs_generate_rank(mol, center_frag_index, frags, rotatable_bonds):
frag_neighbors,frag2atom,neighbour2bond = get_frag_neighbors(mol, frags, rotatable_bonds)
bfs_rank = bfs(frag_neighbors,neighbour2bond,center_frag_index)
assert bfs_rank==len(frags)
return bfs_rank
def get_molecule_tree_by_rotatable_bonds(molecule_path):
mol = Chem.MolFromMol2File(molecule_path, sanitize=False, removeHs=False)
frags, rotatable_bonds = break_molecules_by_rotatable_bonds(mol)
center_frag_index = get_center_frag(mol, frags)
moltree = get_bfs_generate_rank(mol, center_frag_index, frags, rotatable_bonds)
return moltree
def get_rdkit_coords(mol, sanitize=True, remove_hs=True, maxIters=200):
ps = AllChem.ETKDGv2()
id = AllChem.EmbedMolecule(mol, ps)
if id == -1:
print('rdkit coords could not be generated without using random coords. using random coords now.')
ps.useRandomCoords = True
AllChem.EmbedMolecule(mol, ps)
AllChem.MMFFOptimizeMolecule(mol, maxIters=maxIters, confId=0)
else:
AllChem.MMFFOptimizeMolecule(mol, maxIters=maxIters, confId=0)
if remove_hs:
mol = Chem.RemoveHs(mol, sanitize=sanitize)
conf = mol.GetConformer()
lig_coords = conf.GetPositions()
# return torch.tensor(lig_coords, dtype=torch.float32)
return lig_coords
def read_rdmol_v2(dataset_path, name):
lig_path_mol2 = os.path.join(dataset_path, name, f'{name}_ligand.mol2')
lig_path_sdf = os.path.join(dataset_path, name, f'{name}_ligand.sdf')
m_lig = read_rdmol(lig_path_sdf, sanitize=True, remove_hs=True)
if m_lig == None: # read mol2 file if sdf file cannot be sanitized
m_lig = read_rdmol(lig_path_mol2, sanitize=True, remove_hs=True)
return m_lig
def read_rdmol(molecule_file, sanitize=False, calc_charges=False, remove_hs=False):
"""Load a molecule from a file of format ``.mol2`` or ``.sdf`` or ``.pdbqt`` or ``.pdb``.
Parameters
----------
molecule_file : str
Path to file for storing a molecule, which can be of format ``.mol2`` or ``.sdf``
or ``.pdbqt`` or ``.pdb``.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
Default to False.
calc_charges : bool
Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
``sanitize`` to be True. Default to False.
remove_hs : bool
Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
slow for large molecules. Default to False.
use_conformation : bool
Whether we need to extract molecular conformation from proteins and ligands.
Default to True.
Returns
-------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for the loaded molecule.
coordinates : np.ndarray of shape (N, 3) or None
The 3D coordinates of atoms in the molecule. N for the number of atoms in
the molecule. None will be returned if ``use_conformation`` is False or
we failed to get conformation information.
"""
if molecule_file.endswith('.mol2'):
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
elif molecule_file.endswith('.sdf'):
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
mol = supplier[0]
elif molecule_file.endswith('.pdbqt'):
with open(molecule_file) as file:
pdbqt_data = file.readlines()
pdb_block = ''
for line in pdbqt_data:
pdb_block += '{}\n'.format(line[:66])
mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)
elif molecule_file.endswith('.pdb'):
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
else:
return ValueError('Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
try:
if sanitize or calc_charges:
Chem.SanitizeMol(mol)
if calc_charges:
# Compute Gasteiger charges on the molecule.
try:
AllChem.ComputeGasteigerCharges(mol)
except:
warnings.warn('Unable to compute charges for the molecule.')
if remove_hs:
mol = Chem.RemoveHs(mol, sanitize=sanitize)
except:
return None
return mol
def random_rotation_translation(translation_distance=5.0):
rotation = Rotation.random(num=1)
rotation_matrix = rotation.as_matrix().squeeze()
t = np.random.randn(1, 3)
t = t / np.sqrt( np.sum(t * t))
length = np.random.uniform(low=0, high=translation_distance)
t = t * length
return torch.from_numpy(rotation_matrix.astype(np.float32)), torch.from_numpy(t.astype(np.float32))