|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import math |
|
import os |
|
import subprocess |
|
import sys |
|
import tempfile |
|
from collections import defaultdict |
|
from itertools import combinations |
|
|
|
|
|
def read_translations(path, n_repeats): |
|
segment_counter = 0 |
|
segment_translations = [] |
|
translations = defaultdict(list) |
|
for line in open(path): |
|
segment_translations.append(" ".join(line.split())) |
|
if len(segment_translations) == n_repeats: |
|
translations[segment_counter] = segment_translations |
|
segment_translations = [] |
|
segment_counter += 1 |
|
return translations |
|
|
|
|
|
def generate_input(translations, n_repeats): |
|
_, ref_path = tempfile.mkstemp() |
|
_, mt_path = tempfile.mkstemp() |
|
ref_fh = open(ref_path, "w") |
|
mt_fh = open(mt_path, "w") |
|
for segid in sorted(translations.keys()): |
|
assert len(translations[segid]) == n_repeats |
|
indexes = combinations(range(n_repeats), 2) |
|
for idx1, idx2 in indexes: |
|
mt_fh.write(translations[segid][idx1].strip() + "\n") |
|
ref_fh.write(translations[segid][idx2].strip() + "\n") |
|
sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path)) |
|
return ref_path, mt_path |
|
|
|
|
|
def run_meteor(ref_path, mt_path, metric_path, lang="en"): |
|
_, out_path = tempfile.mkstemp() |
|
subprocess.call( |
|
[ |
|
"java", |
|
"-Xmx2G", |
|
"-jar", |
|
metric_path, |
|
mt_path, |
|
ref_path, |
|
"-p", |
|
"0.5 0.2 0.6 0.75", |
|
"-norm", |
|
"-l", |
|
lang, |
|
], |
|
stdout=open(out_path, "w"), |
|
) |
|
os.remove(ref_path) |
|
os.remove(mt_path) |
|
sys.stderr.write("\nSaved Meteor output to %s" % out_path) |
|
return out_path |
|
|
|
|
|
def read_output(meteor_output_path, n_repeats): |
|
n_combinations = math.factorial(n_repeats) / ( |
|
math.factorial(2) * math.factorial(n_repeats - 2) |
|
) |
|
raw_scores = [] |
|
average_scores = [] |
|
for line in open(meteor_output_path): |
|
if not line.startswith("Segment "): |
|
continue |
|
score = float(line.strip().split("\t")[1]) |
|
raw_scores.append(score) |
|
if len(raw_scores) == n_combinations: |
|
average_scores.append(sum(raw_scores) / n_combinations) |
|
raw_scores = [] |
|
os.remove(meteor_output_path) |
|
return average_scores |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("-i", "--infile") |
|
parser.add_argument("-n", "--repeat_times", type=int) |
|
parser.add_argument("-m", "--meteor") |
|
parser.add_argument("-o", "--output") |
|
args = parser.parse_args() |
|
|
|
translations = read_translations(args.infile, args.repeat_times) |
|
sys.stderr.write("\nGenerating input for Meteor...") |
|
ref_path, mt_path = generate_input(translations, args.repeat_times) |
|
sys.stderr.write("\nRunning Meteor...") |
|
out_path = run_meteor(ref_path, mt_path, args.meteor) |
|
sys.stderr.write("\nReading output...") |
|
scores = read_output(out_path, args.repeat_times) |
|
sys.stderr.write("\nWriting results...") |
|
with open(args.output, "w") as o: |
|
for scr in scores: |
|
o.write("{}\n".format(scr)) |
|
o.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|