FoldMark / protenix /data /substructure_perms.py
Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from collections import defaultdict
import numpy as np
from rdkit import Chem
def neutralize_atoms(mol: Chem.Mol):
pattern = Chem.MolFromSmarts(
"[+1!h0!$([*]~[-1,-2,-3,-4]),-1!#4!#5!$([*]~[+1,+2,+3,+4])]"
)
at_matches = mol.GetSubstructMatches(pattern)
at_matches_list = [y[0] for y in at_matches]
if len(at_matches_list) > 0:
for at_idx in at_matches_list:
atom = mol.GetAtomWithIdx(at_idx)
chg = atom.GetFormalCharge()
hcount = atom.GetTotalNumHs()
atom.SetFormalCharge(0)
atom.SetNumExplicitHs(hcount - chg)
atom.UpdatePropertyCache()
return mol
def recursive_permutation(atom_inds, permutation_list, res):
def _permute_atom_ind(atom_inds, permutation):
# atom_inds: list of atom (positional) indices
# permutation: values to be permutated in the given order
permute_inds = [i for i, a in enumerate(atom_inds) if a in permutation]
for i, perm_ind in enumerate(permute_inds):
atom_inds[perm_ind] = permutation[i]
return atom_inds
if len(permutation_list) == 0:
res.append(atom_inds)
else:
current_permutation_list = permutation_list.copy()
for permutation in current_permutation_list.pop(0):
atom_inds_permed = _permute_atom_ind(atom_inds.copy(), permutation)
recursive_permutation(atom_inds_permed, current_permutation_list, res)
def augment_atom_maps_with_conjugate_terminal_groups(
original_maps, atomic_number_mapping, terminal_group_tuples, MaxMatches=1e6
):
"""
Augment atom maps from GetSubstructMatches with extra symmetry from confjugated terminal groups.
Parameters
--------------
original_maps: Tuple(Tuples), all possible atom index mappings, note we require that the mappings should range from 0 to n_heavy_atom-1 (a.k.a. no gap in indexing)
atomic_number_mapping: dict, mapping from atom (positional) indices to its atomic numbers, for splitting/removing different types of atoms in each terminal group
terminal_group_tuples: Tuple(Tuples), a group of pair of atoms whose bonds match the SMARTS string. Ex: ((0, 1), (2, 1), (10, 9), (11, 9), (12, 9), (14, 13), (15, 13))
MaxMatches: int, cutoff for total number of matches (n_original_perm * n_conjugate perm)
Returns
--------------
augmented_maps: Tuple(Tuples) , original_maps augmented by muliplying the permutations induced by terminal_group_tuples.
"""
def _terminal_atom_cluster_from_pairs(edges):
graph = defaultdict(set)
for u, v in edges:
graph[u].add(v)
graph[v].add(u)
return graph
def _split_sets_by_mapped_values(list_of_sets, mapping):
result = []
for s in list_of_sets:
mapped_sets = {}
for elem in s:
mapped_value = mapping.get(elem)
if mapped_value not in mapped_sets:
mapped_sets[mapped_value] = set()
mapped_sets[mapped_value].add(elem)
result.extend(mapped_sets.values())
return result
# group terminal group tuples with common atom_indices: [{0, 2}, {10, 11, 12}, {14, 15}]
terminal_atom_clusters = _terminal_atom_cluster_from_pairs(terminal_group_tuples)
MaxTerminalGroups = max(
1, int(np.ceil(np.emath.logn(3, MaxMatches / len(original_maps))))
)
# if MaxTerminalGroups is less than the total number terminal groups, sample the first {MaxTerminalGroups} groups (to remove randomness)
perm_groups = sorted(
[
atom_inds
for common_id, atom_inds in terminal_atom_clusters.items()
if len(atom_inds) > 1
]
)[: min(MaxTerminalGroups, len(terminal_atom_clusters))]
# within each terminal group, if there are different atom types, split by atom type (if only one left, discard)
perm_groups = _split_sets_by_mapped_values(perm_groups, atomic_number_mapping)
perm_groups = [p for p in perm_groups if len(p) > 1]
# all permutations according to symmetric conjugate terminal atoms: [[(0, 2), (2, 0)], [(10, 11, 12), (10, 12, 11), (11, 10, 12), (11, 12, 10), (12, 10, 11), (12, 11, 10)], [(14, 15), (15, 14)]]
perm_groups = [sorted(list(itertools.permutations(g))) for g in perm_groups]
# recursively permute the original mappings
augmented_maps = []
for initial_mapping in original_maps:
recursive_permutation(list(initial_mapping), perm_groups, augmented_maps)
# Convert to the same data type as in original_maps
augmented_maps = tuple(tuple(a) for a in augmented_maps)
# Remove duplicates: original_maps might have already permutated some of the conjugate_terminal group indices
return tuple(set(augmented_maps))
def _get_substructure_perms(
mol: Chem.Mol,
Neutralize: bool = False,
CheckStereochem: bool = True,
SymmetrizeConjugatedTerminal: bool = True,
MaxMatches: int = 512,
) -> np.ndarray:
"""
Args:
CheckStereochem: whether to assure stereochem does not change after permutation
Neutralize: if true, neutralize the mol before computing the permutations
SymmetrizeConjugatedTerminal: if true, consider symmetrization of conjugated terminal groups
MaxMatches: int, cutoff for total number of matches
return shape=[num_perms, num_atoms]
"""
ori_idx_w_h = []
for atom in mol.GetAtoms():
atom.SetProp("ori_idx_w_h", str(atom.GetIdx()))
ori_idx_w_h.append(atom.GetIdx())
# Attention !!!
# Remove Hs; Otherwise, there will be too many matches.
mol = Chem.RemoveHs(mol)
if Neutralize:
mol = neutralize_atoms(mol)
# Get substructure matches
base_perms = np.array(
mol.GetSubstructMatches(mol, uniquify=False, maxMatches=MaxMatches)
)
assert len(base_perms) > 0, "no matches found, error"
# Check stereochem
if CheckStereochem:
chem_order = np.array(
list(Chem.rdmolfiles.CanonicalRankAtoms(mol, breakTies=False))
)
perms_mask = (chem_order[base_perms] == chem_order[None]).sum(
-1
) == mol.GetNumAtoms()
base_perms = base_perms[perms_mask]
# Add terminal conjugate groups
sma = "[O,N;D1;$([O,N;D1]-[*]=[O,N;D1]),$([O,N;D1]=[*]-[O,N;D1])]~[*]"
patt = Chem.MolFromSmarts(sma)
terminal_group_tuples = mol.GetSubstructMatches(patt)
if (
len(terminal_group_tuples) > 0 and SymmetrizeConjugatedTerminal
): # Only augment if there exist conjugate pairs or if user sets to
atomic_number_mapping = {
i: atom.GetAtomicNum() for i, atom in enumerate(mol.GetAtoms())
}
base_perms = augment_atom_maps_with_conjugate_terminal_groups(
tuple(tuple(a) for a in base_perms),
atomic_number_mapping,
terminal_group_tuples,
MaxMatches,
)
base_perms = np.array(base_perms)
if len(base_perms) > MaxMatches:
base_perms = base_perms[:MaxMatches]
new_to_ori_idx_map = {}
ori_to_new_idx_map = {}
for atom in mol.GetAtoms():
ori_idx = int(atom.GetProp("ori_idx_w_h"))
new_idx = atom.GetIdx()
new_to_ori_idx_map[new_idx] = ori_idx
ori_to_new_idx_map[ori_idx] = new_idx
base_perms = np.vectorize(new_to_ori_idx_map.get)(base_perms)
perms = np.zeros(shape=(base_perms.shape[0], len(ori_idx_w_h)))
for i in range(len(ori_idx_w_h)):
if i in ori_to_new_idx_map:
perms[:, i] = base_perms[:, ori_to_new_idx_map[i]]
else:
# The position of the H atom will not be exchanged.
perms[:, i] = i
return perms
def get_substructure_perms(
mol: Chem.Mol,
CheckStereochem: bool = True,
SymmetrizeConjugatedTerminal: bool = True,
MaxMatches: int = 512,
KeepProtonation: bool = False,
) -> np.ndarray:
kwargs = {
"CheckStereochem": CheckStereochem,
"SymmetrizeConjugatedTerminal": SymmetrizeConjugatedTerminal,
"MaxMatches": MaxMatches,
}
if KeepProtonation:
perms = _get_substructure_perms(mol, Neutralize=False, **kwargs)
else:
# Have to deuplicate permutations across the two protonation states
perms = np.unique(
np.row_stack(
(
_get_substructure_perms(mol, Neutralize=False, **kwargs),
_get_substructure_perms(mol, Neutralize=True, **kwargs),
)
),
axis=0,
)
nperm = len(perms)
if nperm > MaxMatches:
perms = perms[np.random.choice(range(nperm), MaxMatches, replace=False)]
return perms
def test():
testcases = [
"C1=CC=CC=C1",
"CC(=O)OC1=CC=CC=C1C(=O)O",
"C[C@H](CCC(=O)O)[C@H]1CC[C@@H]2[C@@]1(CC[C@H]3[C@H]2CC=C4[C@@]3(CC[C@@H](C4)O)C)C",
"CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
]
for smiles in testcases:
print(smiles)
molecule = Chem.MolFromSmiles(smiles)
perms = get_substructure_perms(molecule)
print(perms.shape)
print(perms.T)
if __name__ == "__main__":
test()