|
|
|
import argparse |
|
import os |
|
import time |
|
import pyrouge |
|
import shutil |
|
import sys |
|
import codecs |
|
|
|
from onmt.utils.logging import init_logger, logger |
|
|
|
|
|
def test_rouge(cand, ref): |
|
"""Calculate ROUGE scores of sequences passed as an iterator |
|
e.g. a list of str, an open file, StringIO or even sys.stdin |
|
""" |
|
current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) |
|
tmp_dir = ".rouge-tmp-{}".format(current_time) |
|
try: |
|
if not os.path.isdir(tmp_dir): |
|
os.mkdir(tmp_dir) |
|
os.mkdir(tmp_dir + "/candidate") |
|
os.mkdir(tmp_dir + "/reference") |
|
candidates = [line.strip() for line in cand] |
|
references = [line.strip() for line in ref] |
|
assert len(candidates) == len(references) |
|
cnt = len(candidates) |
|
for i in range(cnt): |
|
if len(references[i]) < 1: |
|
continue |
|
with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", |
|
encoding="utf-8") as f: |
|
f.write(candidates[i]) |
|
with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", |
|
encoding="utf-8") as f: |
|
f.write(references[i]) |
|
r = pyrouge.Rouge155() |
|
r.model_dir = tmp_dir + "/reference/" |
|
r.system_dir = tmp_dir + "/candidate/" |
|
r.model_filename_pattern = 'ref.#ID#.txt' |
|
r.system_filename_pattern = r'cand.(\d+).txt' |
|
rouge_results = r.convert_and_evaluate() |
|
results_dict = r.output_to_dict(rouge_results) |
|
return results_dict |
|
finally: |
|
pass |
|
if os.path.isdir(tmp_dir): |
|
shutil.rmtree(tmp_dir) |
|
|
|
|
|
def rouge_results_to_str(results_dict): |
|
return ">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}".format( |
|
results_dict["rouge_1_f_score"] * 100, |
|
results_dict["rouge_2_f_score"] * 100, |
|
results_dict["rouge_3_f_score"] * 100, |
|
results_dict["rouge_l_f_score"] * 100, |
|
results_dict["rouge_su*_f_score"] * 100) |
|
|
|
|
|
if __name__ == "__main__": |
|
init_logger('test_rouge.log') |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-c', type=str, default="candidate.txt", |
|
help='candidate file') |
|
parser.add_argument('-r', type=str, default="reference.txt", |
|
help='reference file') |
|
args = parser.parse_args() |
|
if args.c.upper() == "STDIN": |
|
candidates = sys.stdin |
|
else: |
|
candidates = codecs.open(args.c, encoding="utf-8") |
|
references = codecs.open(args.r, encoding="utf-8") |
|
|
|
results_dict = test_rouge(candidates, references) |
|
logger.info(rouge_results_to_str(results_dict)) |
|
|