Spaces:
Running
Running
''' | |
Code from https://github.com/blender-nlp/MolT5 | |
```bibtex | |
@article{edwards2022translation, | |
title={Translation between Molecules and Natural Language}, | |
author={Edwards, Carl and Lai, Tuan and Ros, Kevin and Honke, Garrett and Ji, Heng}, | |
journal={arXiv preprint arXiv:2204.11817}, | |
year={2022} | |
} | |
``` | |
''' | |
import pickle | |
import argparse | |
import csv | |
import os.path as osp | |
import numpy as np | |
#load metric stuff | |
from nltk.translate.bleu_score import corpus_bleu | |
#from nltk.translate.meteor_score import meteor_score | |
from Levenshtein import distance as lev | |
from rdkit import Chem | |
from rdkit import RDLogger | |
RDLogger.DisableLog('rdApp.*') | |
def evaluate(input_fp, verbose=False): | |
outputs = [] | |
with open(osp.join(input_fp)) as f: | |
reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE) | |
for n, line in enumerate(reader): | |
gt_smi = line['ground truth'] | |
ot_smi = line['output'] | |
outputs.append((line['description'], gt_smi, ot_smi)) | |
bleu_scores = [] | |
#meteor_scores = [] | |
references = [] | |
hypotheses = [] | |
for i, (smi, gt, out) in enumerate(outputs): | |
if i % 100 == 0: | |
if verbose: | |
print(i, 'processed.') | |
gt_tokens = [c for c in gt] | |
out_tokens = [c for c in out] | |
references.append([gt_tokens]) | |
hypotheses.append(out_tokens) | |
# mscore = meteor_score([gt], out) | |
# meteor_scores.append(mscore) | |
# BLEU score | |
bleu_score = corpus_bleu(references, hypotheses) | |
if verbose: print('BLEU score:', bleu_score) | |
# Meteor score | |
# _meteor_score = np.mean(meteor_scores) | |
# print('Average Meteor score:', _meteor_score) | |
rouge_scores = [] | |
references = [] | |
hypotheses = [] | |
levs = [] | |
num_exact = 0 | |
bad_mols = 0 | |
for i, (smi, gt, out) in enumerate(outputs): | |
hypotheses.append(out) | |
references.append(gt) | |
try: | |
m_out = Chem.MolFromSmiles(out) | |
m_gt = Chem.MolFromSmiles(gt) | |
if Chem.MolToInchi(m_out) == Chem.MolToInchi(m_gt): num_exact += 1 | |
#if gt == out: num_exact += 1 #old version that didn't standardize strings | |
except: | |
bad_mols += 1 | |
levs.append(lev(out, gt)) | |
# Exact matching score | |
exact_match_score = num_exact/(i+1) | |
if verbose: | |
print('Exact Match:') | |
print(exact_match_score) | |
# Levenshtein score | |
levenshtein_score = np.mean(levs) | |
if verbose: | |
print('Levenshtein:') | |
print(levenshtein_score) | |
validity_score = 1 - bad_mols/len(outputs) | |
if verbose: | |
print('validity:', validity_score) | |
return bleu_score, exact_match_score, levenshtein_score, validity_score | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input_file', type=str, default='caption2smiles_example.txt', help='path where test generations are saved') | |
args = parser.parse_args() | |
evaluate(args.input_file, verbose=True) | |