Spaces:
Running
Running
File size: 3,049 Bytes
7dd9869 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
'''
Code from https://github.com/blender-nlp/MolT5
```bibtex
@article{edwards2022translation,
title={Translation between Molecules and Natural Language},
author={Edwards, Carl and Lai, Tuan and Ros, Kevin and Honke, Garrett and Ji, Heng},
journal={arXiv preprint arXiv:2204.11817},
year={2022}
}
```
'''
import pickle
import argparse
import csv
import os.path as osp
import numpy as np
#load metric stuff
from nltk.translate.bleu_score import corpus_bleu
#from nltk.translate.meteor_score import meteor_score
from Levenshtein import distance as lev
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
def evaluate(input_fp, verbose=False):
outputs = []
with open(osp.join(input_fp)) as f:
reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
for n, line in enumerate(reader):
gt_smi = line['ground truth']
ot_smi = line['output']
outputs.append((line['description'], gt_smi, ot_smi))
bleu_scores = []
#meteor_scores = []
references = []
hypotheses = []
for i, (smi, gt, out) in enumerate(outputs):
if i % 100 == 0:
if verbose:
print(i, 'processed.')
gt_tokens = [c for c in gt]
out_tokens = [c for c in out]
references.append([gt_tokens])
hypotheses.append(out_tokens)
# mscore = meteor_score([gt], out)
# meteor_scores.append(mscore)
# BLEU score
bleu_score = corpus_bleu(references, hypotheses)
if verbose: print('BLEU score:', bleu_score)
# Meteor score
# _meteor_score = np.mean(meteor_scores)
# print('Average Meteor score:', _meteor_score)
rouge_scores = []
references = []
hypotheses = []
levs = []
num_exact = 0
bad_mols = 0
for i, (smi, gt, out) in enumerate(outputs):
hypotheses.append(out)
references.append(gt)
try:
m_out = Chem.MolFromSmiles(out)
m_gt = Chem.MolFromSmiles(gt)
if Chem.MolToInchi(m_out) == Chem.MolToInchi(m_gt): num_exact += 1
#if gt == out: num_exact += 1 #old version that didn't standardize strings
except:
bad_mols += 1
levs.append(lev(out, gt))
# Exact matching score
exact_match_score = num_exact/(i+1)
if verbose:
print('Exact Match:')
print(exact_match_score)
# Levenshtein score
levenshtein_score = np.mean(levs)
if verbose:
print('Levenshtein:')
print(levenshtein_score)
validity_score = 1 - bad_mols/len(outputs)
if verbose:
print('validity:', validity_score)
return bleu_score, exact_match_score, levenshtein_score, validity_score
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_file', type=str, default='caption2smiles_example.txt', help='path where test generations are saved')
args = parser.parse_args()
evaluate(args.input_file, verbose=True)
|