SyrWin
init
95f97c5
from rdkit import Chem
import os
import argparse
from tqdm import tqdm
import multiprocessing
import pandas as pd
from rdkit import RDLogger
import re
from utils import *
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
def extract_smiles(s):
start_token = "[START_SMILES]"
end_token = "[END_SMILES]"
start_index = s.find(start_token) + len(start_token)
end_index = s.find(end_token)
if start_index > -1 and end_index > -1:
return s[start_index:end_index].strip()
return s
def canonicalize_smiles_clear_map(smiles,return_max_frag=True):
mol = Chem.MolFromSmiles(smiles,sanitize=not opt.synthon)
if mol is not None:
[atom.ClearProp('molAtomMapNumber') for atom in mol.GetAtoms() if atom.HasProp('molAtomMapNumber')]
try:
smi = Chem.MolToSmiles(mol, isomericSmiles=False)
except:
if return_max_frag:
return '',''
else:
return ''
if return_max_frag:
sub_smi = smi.split(".")
sub_mol = [Chem.MolFromSmiles(smiles,sanitize=not opt.synthon) for smiles in sub_smi]
sub_mol_size = [(sub_smi[i], len(m.GetAtoms())) for i, m in enumerate(sub_mol) if m is not None]
if len(sub_mol_size) > 0:
return smi, canonicalize_smiles_clear_map(sorted(sub_mol_size,key=lambda x:x[1],reverse=True)[0][0],return_max_frag=False)
else:
return smi, ''
else:
return smi
else:
if return_max_frag:
return '',''
else:
return ''
def compute_rank(input_smiles, prediction,raw=False,alpha=1.0):
valid_score = [[k for k in range(len(prediction[j]))] for j in range(len(prediction))]
invalid_rates = [0 for k in range(len(prediction[0]))]
rank = {}
max_frag_rank = {}
highest = {}
if raw:
# no test augmentation
assert len(prediction) == 1
for j in range(len(prediction)):
for k in range(len(prediction[j])):
if prediction[j][k][0] == "":
invalid_rates[k] += 1
# error detection
de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0][0] != ""]
prediction[j] = list(set(de_error))
prediction[j].sort(key=de_error.index)
for k, data in enumerate(prediction[j]):
rank[data] = 1 / (alpha * k + 1)
else:
for j in range(len(prediction)): # aug_num, beam_size, 2
for k in range(len(prediction[j])):
# predictions[i][j][k] = canonicalize_smiles_clear_map(predictions[i][j][k])
if prediction[j][k][0] == "":
valid_score[j][k] = opt.beam_size + 1
invalid_rates[k] += 1
# error detection and deduplication
de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0][0] != ""]
prediction[j] = list(set(de_error))
prediction[j].sort(key=de_error.index)
for k, data in enumerate(prediction[j]):
if data in rank:
rank[data] += 1 / (alpha * k + 1)
else:
rank[data] = 1 / (alpha * k + 1)
if data in highest:
highest[data] = min(k,highest[data])
else:
highest[data] = k
for key in rank.keys():
rank[key] += highest[key] * -1
rank[key] += abs(len(key[0])-len(input_smiles)) * -0.2
rank[key] += len(key[0]) * -0.2
return rank,invalid_rates
def read_dataset(opt):
print(f'Reading {opt.path}...')
with open(opt.path, 'r', encoding='utf-8') as f:
test_tgt = [json.loads(line) for line in f.readlines()]
if opt.raw:
test_tgt = test_tgt[::opt.augmentation]
filtered_tgt = {}
idx_key = 'ds_idx' if 'ds_idx' in test_tgt[0] else 'index'
for dic in test_tgt:
if dic[idx_key] not in filtered_tgt:
filtered_tgt[dic[idx_key]] = dic
test_tgt = list(filtered_tgt.values())
test_tgt.sort(key=lambda x: x[idx_key])
print(f'{len(test_tgt)} samples read.')
input_list = [extract_smiles(i['input']) for i in test_tgt]
gt_list = [i['targets'].replace('[START_SMILES]', '').replace('[END_SMILES]', '').replace('SPL1T-TH1S-Pl3A5E','').strip().replace(' ','.') for i in test_tgt]
pred_list = [[smi.strip().replace(' ','.') for smi in i['predictions']] for i in test_tgt]
return input_list, gt_list, pred_list
def main(opt):
input_list, gt_list, pred_list = read_dataset(opt)
if opt.raw:
opt.augmentation=1
print('Reading predictions from file ...')
# inputs
print("Input Length", len(gt_list))
ras_src_smiles = input_list[::opt.augmentation]
with multiprocessing.Pool(processes=opt.process_number) as pool:
ras_src_smiles = pool.map(func=canonicalize_smiles_clear_map,iterable=ras_src_smiles)
ras_src_smiles = [i[0] for i in ras_src_smiles]
# predictions
print("Prediction Length", len(pred_list))
pred_lines = [i.split('>')[0] for d in pred_list for i in d]
data_size = len(pred_lines) // (opt.augmentation * opt.beam_size) if opt.length == -1 else opt.length
pred_lines = pred_lines[:data_size * (opt.augmentation * opt.beam_size)]
print("Canonicalizing predictions using Process Number ",opt.process_number)
with multiprocessing.Pool(processes=opt.process_number) as pool:
raw_predictions = pool.map(func=canonicalize_smiles_clear_map,iterable=pred_lines)
predictions = [[[] for j in range(opt.augmentation)] for i in range(data_size)] # data_len x augmentation x beam_size
for i, line in enumerate(raw_predictions):
predictions[i // (opt.beam_size * opt.augmentation)][i % (opt.beam_size * opt.augmentation) // opt.beam_size].append(line)
# ground truth
print("Origin Length", len(gt_list))
targets = [''.join(gt_list[i].strip().split(' ')) for i in tqdm(range(0,data_size * opt.augmentation,opt.augmentation))]
with multiprocessing.Pool(processes=opt.process_number) as pool:
targets = pool.map(func=canonicalize_smiles_clear_map, iterable=targets)
print("predictions Length", len(predictions), len(predictions[0]), len(predictions[0][0]))
print("Target Length", len(targets))
ground_truth = targets
print("Origin Target Lentgh, ", len(ground_truth))
print("Cutted Length, ",data_size)
print('\n')
accuracy = [0 for j in range(opt.n_best)]
topn_accuracy_chirality = [0 for _ in range(opt.n_best)]
topn_accuracy_wochirality = [0 for _ in range(opt.n_best)]
topn_accuracy_ringopening = [0 for _ in range(opt.n_best)]
topn_accuracy_ringformation = [0 for _ in range(opt.n_best)]
topn_accuracy_woring = [0 for _ in range(opt.n_best)]
total_chirality = 0
total_ringopening = 0
total_ringformation = 0
atomsize_topk = []
accurate_indices = [[] for j in range(opt.n_best)]
max_frag_accuracy = [0 for j in range(opt.n_best)]
invalid_rates = [0 for j in range(opt.beam_size)]
sorted_invalid_rates = [0 for j in range(opt.beam_size)]
unique_rates = 0
ranked_results = []
for i in tqdm(range(len(predictions))):
accurate_flag = False
if opt.detailed:
chirality_flag = False
ringopening_flag = False
ringformation_flag = False
pro_mol = Chem.MolFromSmiles(ras_src_smiles[i])
rea_mol = Chem.MolFromSmiles(ground_truth[i][0])
try:
pro_ringinfo = pro_mol.GetRingInfo()
rea_ringinfo = rea_mol.GetRingInfo()
pro_ringnum = pro_ringinfo.NumRings()
rea_ringnum = rea_ringinfo.NumRings()
size = len(rea_mol.GetAtoms()) - len(pro_mol.GetAtoms())
# if (int(ras_src_smiles[i].count("@") > 0) + int(ground_truth[i][0].count("@") > 0)) == 1:
if ras_src_smiles[i].count("@") > 0 or ground_truth[i][0].count("@") > 0:
total_chirality += 1
chirality_flag = True
if pro_ringnum < rea_ringnum:
total_ringopening += 1
ringopening_flag = True
if pro_ringnum > rea_ringnum:
total_ringformation += 1
ringformation_flag = True
except:
pass
# continue
inputs = input_list[i*opt.augmentation:(i+1)*opt.augmentation]
gt = gt_list[i*opt.augmentation:(i+1)*opt.augmentation]
rank, invalid_rate = compute_rank(ras_src_smiles[i], predictions[i], raw=opt.raw,alpha=opt.score_alpha)
rank_ = {k[0]: v for k, v in sorted(rank.items(), key=lambda item: item[1], reverse=True)}
if opt.detailed:
print('Index', i)
print('inputs', json.dumps(inputs, indent=4))
print('targets', json.dumps(gt, indent=4))
print('input', ras_src_smiles[i])
print('target', targets[i][0])
print('rank', json.dumps(rank_,indent=4))
print('invalid_rate', json.dumps(invalid_rate,indent=4))
print('\n')
for j in range(opt.beam_size):
invalid_rates[j] += invalid_rate[j]
rank = list(zip(rank.keys(),rank.values()))
rank.sort(key=lambda x:x[1],reverse=True)
rank = rank[:opt.n_best]
ranked_results.append([item[0][0] for item in rank])
for j, item in enumerate(rank):
if item[0][0] == ground_truth[i][0]:
if not accurate_flag:
accurate_flag = True
accurate_indices[j].append(i)
for k in range(j, opt.n_best):
accuracy[k] += 1
if opt.detailed:
atomsize_topk.append((size,j))
if chirality_flag:
for k in range(j,opt.n_best):
topn_accuracy_chirality[k] += 1
else:
for k in range(j,opt.n_best):
topn_accuracy_wochirality[k] += 1
if ringopening_flag:
for k in range(j,opt.n_best):
topn_accuracy_ringopening[k] += 1
if ringformation_flag:
for k in range(j,opt.n_best):
topn_accuracy_ringformation[k] += 1
if not ringopening_flag and not ringformation_flag:
for k in range(j,opt.n_best):
topn_accuracy_woring[k] += 1
if opt.detailed and not accurate_flag:
atomsize_topk.append((size,opt.n_best))
for j, item in enumerate(rank):
if item[0][1] == ground_truth[i][1]:
for k in range(j,opt.n_best):
max_frag_accuracy[k] += 1
break
for j in range(len(rank),opt.beam_size):
sorted_invalid_rates[j] += 1
unique_rates += len(rank)
for i in range(opt.n_best):
if i in [0,1,2,3,4,5,6,7,8,9,19,49]:
# if i in range(10):
print("Top-{} Acc:{:.3f}%, MaxFrag {:.3f}%,".format(i+1,accuracy[i] / data_size * 100,max_frag_accuracy[i] / data_size * 100),
" Invalid SMILES:{:.3f}% Sorted Invalid SMILES:{:.3f}%".format(invalid_rates[i] / data_size / opt.augmentation * 100,sorted_invalid_rates[i] / data_size / opt.augmentation * 100))
print(' '.join([f'{accuracy[i] / data_size * 100:.3f}' for i in [0,2,4,9]]))
print("Unique Rates:{:.3f}%".format(unique_rates / data_size / opt.beam_size * 100))
if opt.detailed:
print_topk = [1,3,5,10]
save_dict = {}
atomsize_topk.sort(key=lambda x:x[0])
differ_now = atomsize_topk[0][0]
topn_accuracy_bydiffer = [0 for _ in range(opt.n_best)]
total_bydiffer = 0
for i,item in enumerate(atomsize_topk):
if differ_now < 11 and differ_now != item[0]:
for j in range(opt.n_best):
if (j+1) in print_topk:
save_dict[f'top-{j+1}_size_{differ_now}'] = topn_accuracy_bydiffer[j] / total_bydiffer * 100
print("Top-{} Atom differ size {} Acc:{:.3f}%, Number:{:.3f}%".format(j+1,
differ_now,
topn_accuracy_bydiffer[j] / total_bydiffer * 100,
total_bydiffer/data_size * 100))
total_bydiffer = 0
topn_accuracy_bydiffer = [0 for _ in range(opt.n_best)]
differ_now = item[0]
for k in range(item[1],opt.n_best):
topn_accuracy_bydiffer[k] += 1
total_bydiffer += 1
for j in range(opt.n_best):
if (j + 1) in print_topk:
print("Top-{} Atom differ size {} Acc:{:.3f}%, Number:{:.3f}%".format(j + 1,
differ_now,
topn_accuracy_bydiffer[j] / total_bydiffer * 100,
total_bydiffer / data_size * 100))
save_dict[f'top-{j+1}_size_{differ_now}'] = topn_accuracy_bydiffer[j] / total_bydiffer * 100
for i in range(opt.n_best):
if (i+1) in print_topk:
if total_chirality > 0:
print("Top-{} Accuracy with chirality:{:.3f}%".format(i + 1, topn_accuracy_chirality[i] / total_chirality * 100))
save_dict[f'top-{i+1}_chilarity'] = topn_accuracy_chirality[i] / total_chirality * 100
print("Top-{} Accuracy without chirality:{:.3f}%".format(i + 1, topn_accuracy_wochirality[i] / (data_size - total_chirality) * 100))
save_dict[f'top-{i+1}_wochilarity'] = topn_accuracy_wochirality[i] / (data_size - total_chirality) * 100
if total_ringopening > 0:
print("Top-{} Accuracy ring Opening:{:.3f}%".format(i + 1, topn_accuracy_ringopening[i] / total_ringopening * 100))
save_dict[f'top-{i+1}_ringopening'] = topn_accuracy_ringopening[i] / total_ringopening * 100
if total_ringformation > 0:
print("Top-{} Accuracy ring Formation:{:.3f}%".format(i + 1, topn_accuracy_ringformation[i] / total_ringformation * 100))
save_dict[f'top-{i+1}_ringformation'] = topn_accuracy_ringformation[i] / total_ringformation * 100
print("Top-{} Accuracy without ring:{:.3f}%".format(i + 1, topn_accuracy_woring[i] / (data_size - total_ringopening - total_ringformation) * 100))
save_dict[f'top-{i+1}_wocring'] = topn_accuracy_woring[i] / (data_size - total_ringopening - total_ringformation)* 100
print(total_chirality)
print(total_ringformation)
print(total_ringopening)
# df = pd.DataFrame(list(save_dict.items()))
df = pd.DataFrame(save_dict,index=[0])
df.to_csv("detailed_results.csv")
if opt.save_accurate_indices != "":
with open(opt.save_accurate_indices, "w") as f:
total_accurate_indices = []
for indices in accurate_indices:
total_accurate_indices.extend(indices)
total_accurate_indices.sort()
# for index in total_accurate_indices:
for index in accurate_indices[0]:
f.write(str(index))
f.write("\n")
if opt.save_file != "":
with open(opt.save_file,"w") as f:
for res in ranked_results:
for smi in res:
f.write(smi)
f.write("\n")
for i in range(len(res),opt.n_best):
f.write("")
f.write("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='score.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--beam_size', type=int, default=10,help='Beam size')
parser.add_argument('--n_best', type=int, default=10,help='n best')
parser.add_argument('--path', type=str, required=True, help="Path to file containing the predictions and ground truth.")
parser.add_argument('--augmentation', type=int, default=20)
parser.add_argument('--score_alpha', type=float, default=1.0)
parser.add_argument('--length', type=int, default=-1)
parser.add_argument('--process_number', type=int, default=multiprocessing.cpu_count())
parser.add_argument('--synthon', action="store_true", default=False)
parser.add_argument('--detailed', action="store_true", default=False)
parser.add_argument('--raw', action="store_true", default=False)
parser.add_argument('--save_file', type=str,default="")
parser.add_argument('--save_accurate_indices', type=str,default="")
opt = parser.parse_args()
print(opt)
main(opt)