ReactXT / data_provider /r_smiles.py
SyrWin
init
95f97c5
raw
history blame
20.2 kB
import numpy as np
import argparse
import re
import random
import textdistance
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
def smi_tokenizer(smi):
pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
regex = re.compile(pattern)
tokens = [token for token in regex.findall(smi)]
assert smi == ''.join(tokens)
return ' '.join(tokens)
def clear_map_canonical_smiles(smi, canonical=True, root=-1):
mol = Chem.MolFromSmiles(smi)
if mol is not None:
for atom in mol.GetAtoms():
if atom.HasProp('molAtomMapNumber'):
atom.ClearProp('molAtomMapNumber')
return Chem.MolToSmiles(mol, isomericSmiles=True, rootedAtAtom=root, canonical=canonical)
else:
return smi
def get_cano_map_number(smi,root=-1):
atommap_mol = Chem.MolFromSmiles(smi)
canonical_mol = Chem.MolFromSmiles(clear_map_canonical_smiles(smi,root=root))
cano2atommapIdx = atommap_mol.GetSubstructMatch(canonical_mol)
correct_mapped = [canonical_mol.GetAtomWithIdx(i).GetSymbol() == atommap_mol.GetAtomWithIdx(index).GetSymbol() for i,index in enumerate(cano2atommapIdx)]
atom_number = len(canonical_mol.GetAtoms())
if np.sum(correct_mapped) < atom_number or len(cano2atommapIdx) < atom_number:
cano2atommapIdx = [0] * atom_number
atommap2canoIdx = canonical_mol.GetSubstructMatch(atommap_mol)
if len(atommap2canoIdx) != atom_number:
return None
for i, index in enumerate(atommap2canoIdx):
cano2atommapIdx[index] = i
id2atommap = [atom.GetAtomMapNum() for atom in atommap_mol.GetAtoms()]
return [id2atommap[cano2atommapIdx[i]] for i in range(atom_number)]
def get_root_id(mol,root_map_number):
root = -1
for i, atom in enumerate(mol.GetAtoms()):
if atom.GetAtomMapNum() == root_map_number:
root = i
break
return root
# root = -1
# for i, atom in enumerate(mol.GetAtoms()):
# if atom.GetAtomMapNum() == root_map_number:
# return i
def get_forward_rsmiles(data):
pt = re.compile(r':(\d+)]')
product = data['product']
reactant = data['reactant']
augmentation = data['augmentation']
separated = data['separated']
pro_mol = Chem.MolFromSmiles(product)
rea_mol = Chem.MolFromSmiles(reactant)
"""checking data quality"""
rids = sorted(re.findall(pt, reactant))
pids = sorted(re.findall(pt, product))
return_status = {
"status":0,
"src_data":[],
"tgt_data":[],
"edit_distance":0,
}
reactant = reactant.split(".")
product = product.split(".")
rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant]
max_times = np.prod([len(map_numbers) for map_numbers in rea_atom_map_numbers])
times = min(augmentation, max_times)
reactant_roots = [[-1 for _ in reactant]]
j = 0
while j < times:
reactant_roots.append([random.sample(rea_atom_map_numbers[k], 1)[0] for k in range(len(reactant))])
if reactant_roots[-1] in reactant_roots[:-1]:
reactant_roots.pop()
else:
j += 1
if j < augmentation:
reactant_roots.extend(random.choices(reactant_roots, k=augmentation - times))
times = augmentation
reversable = False # no reverse
assert times == augmentation
if reversable:
times = int(times / 2)
pro_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", pro))) for pro in product]
full_pro_atom_map_numbers = set(map(int, re.findall(r"(?<=:)\d+", ".".join(product))))
for k in range(times):
tmp = list(zip(reactant, reactant_roots[k],rea_atom_map_numbers))
random.shuffle(tmp)
reactant_k, reactant_roots_k,rea_atom_map_numbers_k = [i[0] for i in tmp], [i[1] for i in tmp], [i[2] for i in tmp]
aligned_reactants = []
aligned_products = []
aligned_products_order = []
all_atom_map = []
for i, rea in enumerate(reactant_k):
rea_root_atom_map = reactant_roots_k[i]
rea_root = get_root_id(Chem.MolFromSmiles(rea), root_map_number=rea_root_atom_map)
cano_atom_map = get_cano_map_number(rea, rea_root)
if cano_atom_map is None:
print(f"Reactant Failed to find Canonical Mol with Atom MapNumber")
continue
rea_smi = clear_map_canonical_smiles(rea, canonical=True, root=rea_root)
aligned_reactants.append(rea_smi)
all_atom_map.extend(cano_atom_map)
for i, pro_map_number in enumerate(pro_atom_map_numbers):
reactant_candidates = []
selected_reactant = []
for j, map_number in enumerate(all_atom_map):
if map_number in pro_map_number:
for rea_index, rea_atom_map_number in enumerate(rea_atom_map_numbers_k):
if map_number in rea_atom_map_number and rea_index not in selected_reactant:
selected_reactant.append(rea_index)
reactant_candidates.append((map_number, j, len(rea_atom_map_number)))
# select maximal reactant
reactant_candidates.sort(key=lambda x: x[2], reverse=True)
map_number = reactant_candidates[0][0]
j = reactant_candidates[0][1]
pro_root = get_root_id(Chem.MolFromSmiles(product[i]), root_map_number=map_number)
pro_smi = clear_map_canonical_smiles(product[i], canonical=True, root=pro_root)
aligned_products.append(pro_smi)
aligned_products_order.append(j)
sorted_products = sorted(list(zip(aligned_products, aligned_products_order)), key=lambda x: x[1])
aligned_products = [item[0] for item in sorted_products]
pro_smi = ".".join(aligned_products)
if separated:
reactants = []
reagents = []
for i,cano_atom_map in enumerate(rea_atom_map_numbers_k):
if len(set(cano_atom_map) & full_pro_atom_map_numbers) > 0:
reactants.append(aligned_reactants[i])
else:
reagents.append(aligned_reactants[i])
rea_smi = ".".join(reactants)
reactant_tokens = smi_tokenizer(rea_smi)
if len(reagents) > 0 :
reactant_tokens += " <separated> " + smi_tokenizer(".".join(reagents))
else:
rea_smi = ".".join(aligned_reactants)
reactant_tokens = smi_tokenizer(rea_smi)
product_tokens = smi_tokenizer(pro_smi)
return_status['src_data'].append(reactant_tokens)
return_status['tgt_data'].append(product_tokens)
if reversable:
aligned_reactants.reverse()
aligned_products.reverse()
pro_smi = ".".join(aligned_products)
rea_smi = ".".join(aligned_reactants)
product_tokens = smi_tokenizer(pro_smi)
reactant_tokens = smi_tokenizer(rea_smi)
return_status['src_data'].append(reactant_tokens)
return_status['tgt_data'].append(product_tokens)
edit_distances = []
for src,tgt in zip(return_status['src_data'],return_status['tgt_data']):
edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split()))
return_status['edit_distance'] = np.mean(edit_distances)
return return_status
def get_retro_rsmiles(data):
pt = re.compile(r':(\d+)]')
product = data['product']
reactant = data['reactant']
augmentation = data['augmentation']
pro_mol = Chem.MolFromSmiles(product)
rea_mol = Chem.MolFromSmiles(reactant)
"""checking data quality"""
rids = sorted(re.findall(pt, reactant))
pids = sorted(re.findall(pt, product))
return_status = {
"status":0,
"src_data":[],
"tgt_data":[],
"edit_distance":0,
}
pro_atom_map_numbers = list(map(int, re.findall(r"(?<=:)\d+", product)))
reactant = reactant.split(".")
reversable = False # no shuffle
# augmentation = 100
if augmentation == 999:
product_roots = pro_atom_map_numbers
times = len(product_roots)
else:
product_roots = [-1]
# reversable = len(reactant) > 1
max_times = len(pro_atom_map_numbers)
times = min(augmentation, max_times)
if times < augmentation: # times = max_times
product_roots.extend(pro_atom_map_numbers)
product_roots.extend(random.choices(product_roots, k=augmentation - len(product_roots)))
else: # times = augmentation
while len(product_roots) < times:
product_roots.append(random.sample(pro_atom_map_numbers, 1)[0])
# pro_atom_map_numbers.remove(product_roots[-1])
if product_roots[-1] in product_roots[:-1]:
product_roots.pop()
times = len(product_roots)
assert times == augmentation
if reversable:
times = int(times / 2)
# candidates = []
for k in range(times):
pro_root_atom_map = product_roots[k]
pro_root = get_root_id(pro_mol, root_map_number=pro_root_atom_map)
cano_atom_map = get_cano_map_number(product, root=pro_root)
if cano_atom_map is None:
return_status["status"] = "error_mapping"
return return_status
pro_smi = clear_map_canonical_smiles(product, canonical=True, root=pro_root)
aligned_reactants = []
aligned_reactants_order = []
rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant]
used_indices = []
for i, rea_map_number in enumerate(rea_atom_map_numbers):
for j, map_number in enumerate(cano_atom_map):
# select mapping reactans
if map_number in rea_map_number:
rea_root = get_root_id(Chem.MolFromSmiles(reactant[i]), root_map_number=map_number)
rea_smi = clear_map_canonical_smiles(reactant[i], canonical=True, root=rea_root)
aligned_reactants.append(rea_smi)
aligned_reactants_order.append(j)
used_indices.append(i)
break
sorted_reactants = sorted(list(zip(aligned_reactants, aligned_reactants_order)), key=lambda x: x[1])
aligned_reactants = [item[0] for item in sorted_reactants]
reactant_smi = ".".join(aligned_reactants)
product_tokens = smi_tokenizer(pro_smi)
reactant_tokens = smi_tokenizer(reactant_smi)
return_status['src_data'].append(product_tokens)
return_status['tgt_data'].append(reactant_tokens)
if reversable:
aligned_reactants.reverse()
reactant_smi = ".".join(aligned_reactants)
product_tokens = smi_tokenizer(pro_smi)
reactant_tokens = smi_tokenizer(reactant_smi)
return_status['src_data'].append(product_tokens)
return_status['tgt_data'].append(reactant_tokens)
assert len(return_status['src_data']) == data['augmentation']
edit_distances = []
for src,tgt in zip(return_status['src_data'],return_status['tgt_data']):
edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split()))
return_status['edit_distance'] = np.mean(edit_distances)
return return_status
def multi_process(data):
pt = re.compile(r':(\d+)]')
product = data['product']
reactant = data['reactant']
augmentation = data['augmentation']
pro_mol = Chem.MolFromSmiles(product)
rea_mol = Chem.MolFromSmiles(reactant)
"""checking data quality"""
rids = sorted(re.findall(pt, reactant))
pids = sorted(re.findall(pt, product))
return_status = {
"status":0,
"src_data":[],
"tgt_data":[],
"edit_distance":0,
}
# if ",".join(rids) != ",".join(pids): # mapping is not 1:1
# return_status["status"] = "error_mapping"
# if len(set(rids)) != len(rids): # mapping is not 1:1
# return_status["status"] = "error_mapping"
# if len(set(pids)) != len(pids): # mapping is not 1:1
# return_status["status"] = "error_mapping"
if "" == product:
return_status["status"] = "empty_p"
if "" == reactant:
return_status["status"] = "empty_r"
if rea_mol is None:
return_status["status"] = "invalid_r"
if len(rea_mol.GetAtoms()) < 5:
return_status["status"] = "small_r"
if pro_mol is None:
return_status["status"] = "invalid_p"
if len(pro_mol.GetAtoms()) == 1:
return_status["status"] = "small_p"
if not all([a.HasProp('molAtomMapNumber') for a in pro_mol.GetAtoms()]):
return_status["status"] = "error_mapping_p"
"""finishing checking data quality"""
if return_status['status'] == 0:
pro_atom_map_numbers = list(map(int, re.findall(r"(?<=:)\d+", product)))
reactant = reactant.split(".")
if data['root_aligned']:
reversable = False # no shuffle
# augmentation = 100
if augmentation == 999:
product_roots = pro_atom_map_numbers
times = len(product_roots)
else:
product_roots = [-1]
# reversable = len(reactant) > 1
max_times = len(pro_atom_map_numbers)
times = min(augmentation, max_times)
if times < augmentation: # times = max_times
product_roots.extend(pro_atom_map_numbers)
product_roots.extend(random.choices(product_roots, k=augmentation - len(product_roots)))
else: # times = augmentation
while len(product_roots) < times:
product_roots.append(random.sample(pro_atom_map_numbers, 1)[0])
# pro_atom_map_numbers.remove(product_roots[-1])
if product_roots[-1] in product_roots[:-1]:
product_roots.pop()
times = len(product_roots)
assert times == augmentation
if reversable:
times = int(times / 2)
# candidates = []
for k in range(times):
pro_root_atom_map = product_roots[k]
pro_root = get_root_id(pro_mol, root_map_number=pro_root_atom_map)
cano_atom_map = get_cano_map_number(product, root=pro_root)
if cano_atom_map is None:
return_status["status"] = "error_mapping"
return return_status
pro_smi = clear_map_canonical_smiles(product, canonical=True, root=pro_root)
aligned_reactants = []
aligned_reactants_order = []
rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant]
used_indices = []
for i, rea_map_number in enumerate(rea_atom_map_numbers):
for j, map_number in enumerate(cano_atom_map):
# select mapping reactans
if map_number in rea_map_number:
rea_root = get_root_id(Chem.MolFromSmiles(reactant[i]), root_map_number=map_number)
rea_smi = clear_map_canonical_smiles(reactant[i], canonical=True, root=rea_root)
aligned_reactants.append(rea_smi)
aligned_reactants_order.append(j)
used_indices.append(i)
break
sorted_reactants = sorted(list(zip(aligned_reactants, aligned_reactants_order)), key=lambda x: x[1])
aligned_reactants = [item[0] for item in sorted_reactants]
reactant_smi = ".".join(aligned_reactants)
product_tokens = smi_tokenizer(pro_smi)
reactant_tokens = smi_tokenizer(reactant_smi)
return_status['src_data'].append(product_tokens)
return_status['tgt_data'].append(reactant_tokens)
if reversable:
aligned_reactants.reverse()
reactant_smi = ".".join(aligned_reactants)
product_tokens = smi_tokenizer(pro_smi)
reactant_tokens = smi_tokenizer(reactant_smi)
return_status['src_data'].append(product_tokens)
return_status['tgt_data'].append(reactant_tokens)
assert len(return_status['src_data']) == data['augmentation']
else:
cano_product = clear_map_canonical_smiles(product)
cano_reactanct = ".".join([clear_map_canonical_smiles(rea) for rea in reactant if len(set(map(int, re.findall(r"(?<=:)\d+", rea))) & set(pro_atom_map_numbers)) > 0 ])
return_status['src_data'].append(smi_tokenizer(cano_product))
return_status['tgt_data'].append(smi_tokenizer(cano_reactanct))
pro_mol = Chem.MolFromSmiles(cano_product)
rea_mols = [Chem.MolFromSmiles(rea) for rea in cano_reactanct.split(".")]
for i in range(int(augmentation-1)):
pro_smi = Chem.MolToSmiles(pro_mol,doRandom=True)
rea_smi = [Chem.MolToSmiles(rea_mol,doRandom=True) for rea_mol in rea_mols]
rea_smi = ".".join(rea_smi)
return_status['src_data'].append(smi_tokenizer(pro_smi))
return_status['tgt_data'].append(smi_tokenizer(rea_smi))
edit_distances = []
for src,tgt in zip(return_status['src_data'],return_status['tgt_data']):
edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split()))
return_status['edit_distance'] = np.mean(edit_distances)
return return_status
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-rxn',type=str,required=True)
parser.add_argument('-mode',type=str,default="retro",)
parser.add_argument('-forward_mode',type=str,default="separated",)
parser.add_argument("-augmentation",type=int,default=1)
parser.add_argument("-seed",type=int,default=33)
args = parser.parse_args()
print(args)
reactant,reagent,product = args.rxn.split(">")
pt = re.compile(r':(\d+)]')
rids = sorted(re.findall(pt, reactant))
pids = sorted(re.findall(pt, product))
if len(rids) == 0 or len(pids) == 0:
print("No atom mapping found!")
exit(1)
if args.mode == "retro":
args.input = product
args.output = reactant
else:
args.input = reactant
args.output = product
print("Original input:", args.input)
print("Original output:",args.output)
src_smi = clear_map_canonical_smiles(args.input)
tgt_smi = clear_map_canonical_smiles(args.output)
if src_smi == "" or tgt_smi == "":
print("Invalid SMILES!")
exit(1)
print("Canonical input:", src_smi)
print("Canonical output:",tgt_smi)
mapping_check = True
if ",".join(rids) != ",".join(pids): # mapping is not 1:1
mapping_check = False
if len(set(rids)) != len(rids): # mapping is not 1:1
mapping_check = False
if len(set(pids)) != len(pids): # mapping is not 1:1
mapping_check = False
if not mapping_check:
print("The quality of the atom mapping may not be good enough, which can affect the effect of root alignment.")
data = {
'product':product,
'reactant':reactant,
'augmentation':args.augmentation,
'separated':args.forward_mode == "separated"
}
if args.mode == "retro":
res = get_retro_rsmiles(data)
else:
res = get_forward_rsmiles(data)
for index,(src,tgt) in enumerate(zip(res['src_data'], res['tgt_data'])):
print(f"ID:{index}")
print(f"R-SMILES input:{''.join(src.split())}")
print(f"R-SMILES output:{''.join(tgt.split())}")
print("Avg. edit distance:", res['edit_distance'])