Spaces:
Runtime error
Runtime error
import numpy as np | |
import argparse | |
import re | |
import random | |
import textdistance | |
from rdkit import Chem | |
from rdkit import RDLogger | |
RDLogger.DisableLog('rdApp.*') | |
def smi_tokenizer(smi): | |
pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" | |
regex = re.compile(pattern) | |
tokens = [token for token in regex.findall(smi)] | |
assert smi == ''.join(tokens) | |
return ' '.join(tokens) | |
def clear_map_canonical_smiles(smi, canonical=True, root=-1): | |
mol = Chem.MolFromSmiles(smi) | |
if mol is not None: | |
for atom in mol.GetAtoms(): | |
if atom.HasProp('molAtomMapNumber'): | |
atom.ClearProp('molAtomMapNumber') | |
return Chem.MolToSmiles(mol, isomericSmiles=True, rootedAtAtom=root, canonical=canonical) | |
else: | |
return smi | |
def get_cano_map_number(smi,root=-1): | |
atommap_mol = Chem.MolFromSmiles(smi) | |
canonical_mol = Chem.MolFromSmiles(clear_map_canonical_smiles(smi,root=root)) | |
cano2atommapIdx = atommap_mol.GetSubstructMatch(canonical_mol) | |
correct_mapped = [canonical_mol.GetAtomWithIdx(i).GetSymbol() == atommap_mol.GetAtomWithIdx(index).GetSymbol() for i,index in enumerate(cano2atommapIdx)] | |
atom_number = len(canonical_mol.GetAtoms()) | |
if np.sum(correct_mapped) < atom_number or len(cano2atommapIdx) < atom_number: | |
cano2atommapIdx = [0] * atom_number | |
atommap2canoIdx = canonical_mol.GetSubstructMatch(atommap_mol) | |
if len(atommap2canoIdx) != atom_number: | |
return None | |
for i, index in enumerate(atommap2canoIdx): | |
cano2atommapIdx[index] = i | |
id2atommap = [atom.GetAtomMapNum() for atom in atommap_mol.GetAtoms()] | |
return [id2atommap[cano2atommapIdx[i]] for i in range(atom_number)] | |
def get_root_id(mol,root_map_number): | |
root = -1 | |
for i, atom in enumerate(mol.GetAtoms()): | |
if atom.GetAtomMapNum() == root_map_number: | |
root = i | |
break | |
return root | |
# root = -1 | |
# for i, atom in enumerate(mol.GetAtoms()): | |
# if atom.GetAtomMapNum() == root_map_number: | |
# return i | |
def get_forward_rsmiles(data): | |
pt = re.compile(r':(\d+)]') | |
product = data['product'] | |
reactant = data['reactant'] | |
augmentation = data['augmentation'] | |
separated = data['separated'] | |
pro_mol = Chem.MolFromSmiles(product) | |
rea_mol = Chem.MolFromSmiles(reactant) | |
"""checking data quality""" | |
rids = sorted(re.findall(pt, reactant)) | |
pids = sorted(re.findall(pt, product)) | |
return_status = { | |
"status":0, | |
"src_data":[], | |
"tgt_data":[], | |
"edit_distance":0, | |
} | |
reactant = reactant.split(".") | |
product = product.split(".") | |
rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant] | |
max_times = np.prod([len(map_numbers) for map_numbers in rea_atom_map_numbers]) | |
times = min(augmentation, max_times) | |
reactant_roots = [[-1 for _ in reactant]] | |
j = 0 | |
while j < times: | |
reactant_roots.append([random.sample(rea_atom_map_numbers[k], 1)[0] for k in range(len(reactant))]) | |
if reactant_roots[-1] in reactant_roots[:-1]: | |
reactant_roots.pop() | |
else: | |
j += 1 | |
if j < augmentation: | |
reactant_roots.extend(random.choices(reactant_roots, k=augmentation - times)) | |
times = augmentation | |
reversable = False # no reverse | |
assert times == augmentation | |
if reversable: | |
times = int(times / 2) | |
pro_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", pro))) for pro in product] | |
full_pro_atom_map_numbers = set(map(int, re.findall(r"(?<=:)\d+", ".".join(product)))) | |
for k in range(times): | |
tmp = list(zip(reactant, reactant_roots[k],rea_atom_map_numbers)) | |
random.shuffle(tmp) | |
reactant_k, reactant_roots_k,rea_atom_map_numbers_k = [i[0] for i in tmp], [i[1] for i in tmp], [i[2] for i in tmp] | |
aligned_reactants = [] | |
aligned_products = [] | |
aligned_products_order = [] | |
all_atom_map = [] | |
for i, rea in enumerate(reactant_k): | |
rea_root_atom_map = reactant_roots_k[i] | |
rea_root = get_root_id(Chem.MolFromSmiles(rea), root_map_number=rea_root_atom_map) | |
cano_atom_map = get_cano_map_number(rea, rea_root) | |
if cano_atom_map is None: | |
print(f"Reactant Failed to find Canonical Mol with Atom MapNumber") | |
continue | |
rea_smi = clear_map_canonical_smiles(rea, canonical=True, root=rea_root) | |
aligned_reactants.append(rea_smi) | |
all_atom_map.extend(cano_atom_map) | |
for i, pro_map_number in enumerate(pro_atom_map_numbers): | |
reactant_candidates = [] | |
selected_reactant = [] | |
for j, map_number in enumerate(all_atom_map): | |
if map_number in pro_map_number: | |
for rea_index, rea_atom_map_number in enumerate(rea_atom_map_numbers_k): | |
if map_number in rea_atom_map_number and rea_index not in selected_reactant: | |
selected_reactant.append(rea_index) | |
reactant_candidates.append((map_number, j, len(rea_atom_map_number))) | |
# select maximal reactant | |
reactant_candidates.sort(key=lambda x: x[2], reverse=True) | |
map_number = reactant_candidates[0][0] | |
j = reactant_candidates[0][1] | |
pro_root = get_root_id(Chem.MolFromSmiles(product[i]), root_map_number=map_number) | |
pro_smi = clear_map_canonical_smiles(product[i], canonical=True, root=pro_root) | |
aligned_products.append(pro_smi) | |
aligned_products_order.append(j) | |
sorted_products = sorted(list(zip(aligned_products, aligned_products_order)), key=lambda x: x[1]) | |
aligned_products = [item[0] for item in sorted_products] | |
pro_smi = ".".join(aligned_products) | |
if separated: | |
reactants = [] | |
reagents = [] | |
for i,cano_atom_map in enumerate(rea_atom_map_numbers_k): | |
if len(set(cano_atom_map) & full_pro_atom_map_numbers) > 0: | |
reactants.append(aligned_reactants[i]) | |
else: | |
reagents.append(aligned_reactants[i]) | |
rea_smi = ".".join(reactants) | |
reactant_tokens = smi_tokenizer(rea_smi) | |
if len(reagents) > 0 : | |
reactant_tokens += " <separated> " + smi_tokenizer(".".join(reagents)) | |
else: | |
rea_smi = ".".join(aligned_reactants) | |
reactant_tokens = smi_tokenizer(rea_smi) | |
product_tokens = smi_tokenizer(pro_smi) | |
return_status['src_data'].append(reactant_tokens) | |
return_status['tgt_data'].append(product_tokens) | |
if reversable: | |
aligned_reactants.reverse() | |
aligned_products.reverse() | |
pro_smi = ".".join(aligned_products) | |
rea_smi = ".".join(aligned_reactants) | |
product_tokens = smi_tokenizer(pro_smi) | |
reactant_tokens = smi_tokenizer(rea_smi) | |
return_status['src_data'].append(reactant_tokens) | |
return_status['tgt_data'].append(product_tokens) | |
edit_distances = [] | |
for src,tgt in zip(return_status['src_data'],return_status['tgt_data']): | |
edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split())) | |
return_status['edit_distance'] = np.mean(edit_distances) | |
return return_status | |
def get_retro_rsmiles(data): | |
pt = re.compile(r':(\d+)]') | |
product = data['product'] | |
reactant = data['reactant'] | |
augmentation = data['augmentation'] | |
pro_mol = Chem.MolFromSmiles(product) | |
rea_mol = Chem.MolFromSmiles(reactant) | |
"""checking data quality""" | |
rids = sorted(re.findall(pt, reactant)) | |
pids = sorted(re.findall(pt, product)) | |
return_status = { | |
"status":0, | |
"src_data":[], | |
"tgt_data":[], | |
"edit_distance":0, | |
} | |
pro_atom_map_numbers = list(map(int, re.findall(r"(?<=:)\d+", product))) | |
reactant = reactant.split(".") | |
reversable = False # no shuffle | |
# augmentation = 100 | |
if augmentation == 999: | |
product_roots = pro_atom_map_numbers | |
times = len(product_roots) | |
else: | |
product_roots = [-1] | |
# reversable = len(reactant) > 1 | |
max_times = len(pro_atom_map_numbers) | |
times = min(augmentation, max_times) | |
if times < augmentation: # times = max_times | |
product_roots.extend(pro_atom_map_numbers) | |
product_roots.extend(random.choices(product_roots, k=augmentation - len(product_roots))) | |
else: # times = augmentation | |
while len(product_roots) < times: | |
product_roots.append(random.sample(pro_atom_map_numbers, 1)[0]) | |
# pro_atom_map_numbers.remove(product_roots[-1]) | |
if product_roots[-1] in product_roots[:-1]: | |
product_roots.pop() | |
times = len(product_roots) | |
assert times == augmentation | |
if reversable: | |
times = int(times / 2) | |
# candidates = [] | |
for k in range(times): | |
pro_root_atom_map = product_roots[k] | |
pro_root = get_root_id(pro_mol, root_map_number=pro_root_atom_map) | |
cano_atom_map = get_cano_map_number(product, root=pro_root) | |
if cano_atom_map is None: | |
return_status["status"] = "error_mapping" | |
return return_status | |
pro_smi = clear_map_canonical_smiles(product, canonical=True, root=pro_root) | |
aligned_reactants = [] | |
aligned_reactants_order = [] | |
rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant] | |
used_indices = [] | |
for i, rea_map_number in enumerate(rea_atom_map_numbers): | |
for j, map_number in enumerate(cano_atom_map): | |
# select mapping reactans | |
if map_number in rea_map_number: | |
rea_root = get_root_id(Chem.MolFromSmiles(reactant[i]), root_map_number=map_number) | |
rea_smi = clear_map_canonical_smiles(reactant[i], canonical=True, root=rea_root) | |
aligned_reactants.append(rea_smi) | |
aligned_reactants_order.append(j) | |
used_indices.append(i) | |
break | |
sorted_reactants = sorted(list(zip(aligned_reactants, aligned_reactants_order)), key=lambda x: x[1]) | |
aligned_reactants = [item[0] for item in sorted_reactants] | |
reactant_smi = ".".join(aligned_reactants) | |
product_tokens = smi_tokenizer(pro_smi) | |
reactant_tokens = smi_tokenizer(reactant_smi) | |
return_status['src_data'].append(product_tokens) | |
return_status['tgt_data'].append(reactant_tokens) | |
if reversable: | |
aligned_reactants.reverse() | |
reactant_smi = ".".join(aligned_reactants) | |
product_tokens = smi_tokenizer(pro_smi) | |
reactant_tokens = smi_tokenizer(reactant_smi) | |
return_status['src_data'].append(product_tokens) | |
return_status['tgt_data'].append(reactant_tokens) | |
assert len(return_status['src_data']) == data['augmentation'] | |
edit_distances = [] | |
for src,tgt in zip(return_status['src_data'],return_status['tgt_data']): | |
edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split())) | |
return_status['edit_distance'] = np.mean(edit_distances) | |
return return_status | |
def multi_process(data): | |
pt = re.compile(r':(\d+)]') | |
product = data['product'] | |
reactant = data['reactant'] | |
augmentation = data['augmentation'] | |
pro_mol = Chem.MolFromSmiles(product) | |
rea_mol = Chem.MolFromSmiles(reactant) | |
"""checking data quality""" | |
rids = sorted(re.findall(pt, reactant)) | |
pids = sorted(re.findall(pt, product)) | |
return_status = { | |
"status":0, | |
"src_data":[], | |
"tgt_data":[], | |
"edit_distance":0, | |
} | |
# if ",".join(rids) != ",".join(pids): # mapping is not 1:1 | |
# return_status["status"] = "error_mapping" | |
# if len(set(rids)) != len(rids): # mapping is not 1:1 | |
# return_status["status"] = "error_mapping" | |
# if len(set(pids)) != len(pids): # mapping is not 1:1 | |
# return_status["status"] = "error_mapping" | |
if "" == product: | |
return_status["status"] = "empty_p" | |
if "" == reactant: | |
return_status["status"] = "empty_r" | |
if rea_mol is None: | |
return_status["status"] = "invalid_r" | |
if len(rea_mol.GetAtoms()) < 5: | |
return_status["status"] = "small_r" | |
if pro_mol is None: | |
return_status["status"] = "invalid_p" | |
if len(pro_mol.GetAtoms()) == 1: | |
return_status["status"] = "small_p" | |
if not all([a.HasProp('molAtomMapNumber') for a in pro_mol.GetAtoms()]): | |
return_status["status"] = "error_mapping_p" | |
"""finishing checking data quality""" | |
if return_status['status'] == 0: | |
pro_atom_map_numbers = list(map(int, re.findall(r"(?<=:)\d+", product))) | |
reactant = reactant.split(".") | |
if data['root_aligned']: | |
reversable = False # no shuffle | |
# augmentation = 100 | |
if augmentation == 999: | |
product_roots = pro_atom_map_numbers | |
times = len(product_roots) | |
else: | |
product_roots = [-1] | |
# reversable = len(reactant) > 1 | |
max_times = len(pro_atom_map_numbers) | |
times = min(augmentation, max_times) | |
if times < augmentation: # times = max_times | |
product_roots.extend(pro_atom_map_numbers) | |
product_roots.extend(random.choices(product_roots, k=augmentation - len(product_roots))) | |
else: # times = augmentation | |
while len(product_roots) < times: | |
product_roots.append(random.sample(pro_atom_map_numbers, 1)[0]) | |
# pro_atom_map_numbers.remove(product_roots[-1]) | |
if product_roots[-1] in product_roots[:-1]: | |
product_roots.pop() | |
times = len(product_roots) | |
assert times == augmentation | |
if reversable: | |
times = int(times / 2) | |
# candidates = [] | |
for k in range(times): | |
pro_root_atom_map = product_roots[k] | |
pro_root = get_root_id(pro_mol, root_map_number=pro_root_atom_map) | |
cano_atom_map = get_cano_map_number(product, root=pro_root) | |
if cano_atom_map is None: | |
return_status["status"] = "error_mapping" | |
return return_status | |
pro_smi = clear_map_canonical_smiles(product, canonical=True, root=pro_root) | |
aligned_reactants = [] | |
aligned_reactants_order = [] | |
rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant] | |
used_indices = [] | |
for i, rea_map_number in enumerate(rea_atom_map_numbers): | |
for j, map_number in enumerate(cano_atom_map): | |
# select mapping reactans | |
if map_number in rea_map_number: | |
rea_root = get_root_id(Chem.MolFromSmiles(reactant[i]), root_map_number=map_number) | |
rea_smi = clear_map_canonical_smiles(reactant[i], canonical=True, root=rea_root) | |
aligned_reactants.append(rea_smi) | |
aligned_reactants_order.append(j) | |
used_indices.append(i) | |
break | |
sorted_reactants = sorted(list(zip(aligned_reactants, aligned_reactants_order)), key=lambda x: x[1]) | |
aligned_reactants = [item[0] for item in sorted_reactants] | |
reactant_smi = ".".join(aligned_reactants) | |
product_tokens = smi_tokenizer(pro_smi) | |
reactant_tokens = smi_tokenizer(reactant_smi) | |
return_status['src_data'].append(product_tokens) | |
return_status['tgt_data'].append(reactant_tokens) | |
if reversable: | |
aligned_reactants.reverse() | |
reactant_smi = ".".join(aligned_reactants) | |
product_tokens = smi_tokenizer(pro_smi) | |
reactant_tokens = smi_tokenizer(reactant_smi) | |
return_status['src_data'].append(product_tokens) | |
return_status['tgt_data'].append(reactant_tokens) | |
assert len(return_status['src_data']) == data['augmentation'] | |
else: | |
cano_product = clear_map_canonical_smiles(product) | |
cano_reactanct = ".".join([clear_map_canonical_smiles(rea) for rea in reactant if len(set(map(int, re.findall(r"(?<=:)\d+", rea))) & set(pro_atom_map_numbers)) > 0 ]) | |
return_status['src_data'].append(smi_tokenizer(cano_product)) | |
return_status['tgt_data'].append(smi_tokenizer(cano_reactanct)) | |
pro_mol = Chem.MolFromSmiles(cano_product) | |
rea_mols = [Chem.MolFromSmiles(rea) for rea in cano_reactanct.split(".")] | |
for i in range(int(augmentation-1)): | |
pro_smi = Chem.MolToSmiles(pro_mol,doRandom=True) | |
rea_smi = [Chem.MolToSmiles(rea_mol,doRandom=True) for rea_mol in rea_mols] | |
rea_smi = ".".join(rea_smi) | |
return_status['src_data'].append(smi_tokenizer(pro_smi)) | |
return_status['tgt_data'].append(smi_tokenizer(rea_smi)) | |
edit_distances = [] | |
for src,tgt in zip(return_status['src_data'],return_status['tgt_data']): | |
edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split())) | |
return_status['edit_distance'] = np.mean(edit_distances) | |
return return_status | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-rxn',type=str,required=True) | |
parser.add_argument('-mode',type=str,default="retro",) | |
parser.add_argument('-forward_mode',type=str,default="separated",) | |
parser.add_argument("-augmentation",type=int,default=1) | |
parser.add_argument("-seed",type=int,default=33) | |
args = parser.parse_args() | |
print(args) | |
reactant,reagent,product = args.rxn.split(">") | |
pt = re.compile(r':(\d+)]') | |
rids = sorted(re.findall(pt, reactant)) | |
pids = sorted(re.findall(pt, product)) | |
if len(rids) == 0 or len(pids) == 0: | |
print("No atom mapping found!") | |
exit(1) | |
if args.mode == "retro": | |
args.input = product | |
args.output = reactant | |
else: | |
args.input = reactant | |
args.output = product | |
print("Original input:", args.input) | |
print("Original output:",args.output) | |
src_smi = clear_map_canonical_smiles(args.input) | |
tgt_smi = clear_map_canonical_smiles(args.output) | |
if src_smi == "" or tgt_smi == "": | |
print("Invalid SMILES!") | |
exit(1) | |
print("Canonical input:", src_smi) | |
print("Canonical output:",tgt_smi) | |
mapping_check = True | |
if ",".join(rids) != ",".join(pids): # mapping is not 1:1 | |
mapping_check = False | |
if len(set(rids)) != len(rids): # mapping is not 1:1 | |
mapping_check = False | |
if len(set(pids)) != len(pids): # mapping is not 1:1 | |
mapping_check = False | |
if not mapping_check: | |
print("The quality of the atom mapping may not be good enough, which can affect the effect of root alignment.") | |
data = { | |
'product':product, | |
'reactant':reactant, | |
'augmentation':args.augmentation, | |
'separated':args.forward_mode == "separated" | |
} | |
if args.mode == "retro": | |
res = get_retro_rsmiles(data) | |
else: | |
res = get_forward_rsmiles(data) | |
for index,(src,tgt) in enumerate(zip(res['src_data'], res['tgt_data'])): | |
print(f"ID:{index}") | |
print(f"R-SMILES input:{''.join(src.split())}") | |
print(f"R-SMILES output:{''.join(tgt.split())}") | |
print("Avg. edit distance:", res['edit_distance']) |