Spaces:
Running
Running
''' | |
Code from https://github.com/blender-nlp/MolT5 | |
```bibtex | |
@article{edwards2022translation, | |
title={Translation between Molecules and Natural Language}, | |
author={Edwards, Carl and Lai, Tuan and Ros, Kevin and Honke, Garrett and Ji, Heng}, | |
journal={arXiv preprint arXiv:2204.11817}, | |
year={2022} | |
} | |
``` | |
''' | |
import argparse | |
import csv | |
import os.path as osp | |
import numpy as np | |
from rdkit import Chem | |
from rdkit.Chem import MACCSkeys | |
from rdkit import DataStructs | |
from rdkit.Chem import AllChem | |
from rdkit import RDLogger | |
RDLogger.DisableLog('rdApp.*') | |
def evaluate(input_file, morgan_r, verbose=False): | |
outputs = [] | |
bad_mols = 0 | |
with open(osp.join(input_file)) as f: | |
reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE) | |
for n, line in enumerate(reader): | |
try: | |
gt_smi = line['ground truth'] | |
ot_smi = line['output'] | |
gt_m = Chem.MolFromSmiles(gt_smi) | |
ot_m = Chem.MolFromSmiles(ot_smi) | |
if ot_m == None: raise ValueError('Bad SMILES') | |
outputs.append((line['description'], gt_m, ot_m)) | |
except: | |
bad_mols += 1 | |
validity_score = len(outputs)/(len(outputs)+bad_mols) | |
if verbose: | |
print('validity:', validity_score) | |
MACCS_sims = [] | |
morgan_sims = [] | |
RDK_sims = [] | |
enum_list = outputs | |
for i, (desc, gt_m, ot_m) in enumerate(enum_list): | |
if i % 100 == 0: | |
if verbose: print(i, 'processed.') | |
MACCS_sims.append(DataStructs.FingerprintSimilarity(MACCSkeys.GenMACCSKeys(gt_m), MACCSkeys.GenMACCSKeys(ot_m), metric=DataStructs.TanimotoSimilarity)) | |
RDK_sims.append(DataStructs.FingerprintSimilarity(Chem.RDKFingerprint(gt_m), Chem.RDKFingerprint(ot_m), metric=DataStructs.TanimotoSimilarity)) | |
morgan_sims.append(DataStructs.TanimotoSimilarity(AllChem.GetMorganFingerprint(gt_m,morgan_r), AllChem.GetMorganFingerprint(ot_m, morgan_r))) | |
maccs_sims_score = np.mean(MACCS_sims) | |
rdk_sims_score = np.mean(RDK_sims) | |
morgan_sims_score = np.mean(morgan_sims) | |
if verbose: | |
print('Average MACCS Similarity:', maccs_sims_score) | |
print('Average RDK Similarity:', rdk_sims_score) | |
print('Average Morgan Similarity:', morgan_sims_score) | |
return validity_score, maccs_sims_score, rdk_sims_score, morgan_sims_score | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input_file', type=str, default='caption2smiles_example.txt', help='path where test generations are saved') | |
parser.add_argument('--morgan_r', type=int, default=2, help='morgan fingerprint radius') | |
args = parser.parse_args() | |
evaluate(args.input_file, args.morgan_r, True) | |