""" |
Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of |
candidate hypotheses. |
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" |
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_. |
""" |
import argparse |
import random |
import sys |
from itertools import chain |
import numpy as np |
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu |
def main(): |
parser = argparse.ArgumentParser(sys.argv[0]) |
parser.add_argument( |
"--sys", nargs="*", default="", metavar="FILE", help="path to system output" |
) |
parser.add_argument("--ref", default="", metavar="FILE", help="path to references") |
parser.add_argument( |
"--output", |
default="", |
metavar="FILE", |
help="print outputs into a pretty format", |
) |
args = parser.parse_args() |
if args.sys: |
src, tgt, hypos, log_probs = load_sys(args.sys) |
print("pairwise BLEU: %.2f" % pairwise(hypos)) |
if args.output: |
merge(src, tgt, hypos, log_probs, args.output) |
if args.ref: |
_, _, refs = load_ref(args.ref) |
if args.sys: |
multi_ref(refs, hypos) |
else: |
intra_ref(refs) |
def dictolist(d): |
a = sorted(d.items(), key=lambda i: i[0]) |
return [i[1] for i in a] |
def load_sys(paths): |
src, tgt, hypos, log_probs = {}, {}, {}, {} |
for path in paths: |
with open(path) as f: |
for line in f: |
line = line.rstrip() |
if line.startswith(("S-", "T-", "D-")): |
i = int(line[line.find("-") + 1 : line.find("\t")]) |
if line.startswith("S-"): |
src[i] = line.split("\t")[1] |
if line.startswith("T-"): |
tgt[i] = line.split("\t")[1] |
if line.startswith("D-"): |
if i not in hypos: |
hypos[i] = [] |
log_probs[i] = [] |
hypos[i].append(line.split("\t")[2]) |
log_probs[i].append(float(line.split("\t")[1])) |
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs) |
def load_ref(path): |
with open(path) as f: |
lines = f.readlines() |
src, tgt, refs = [], [], [] |
i = 0 |
while i < len(lines): |
if lines[i].startswith("S-"): |
src.append(lines[i].split("\t")[1].rstrip()) |
i += 1 |
elif lines[i].startswith("T-"): |
tgt.append(lines[i].split("\t")[1].rstrip()) |
i += 1 |
else: |
a = [] |
while i < len(lines) and lines[i].startswith("R"): |
a.append(lines[i].split("\t")[1].rstrip()) |
i += 1 |
refs.append(a) |
return src, tgt, refs |
def merge(src, tgt, hypos, log_probs, path): |
with open(path, "w") as f: |
for s, t, hs, lps in zip(src, tgt, hypos, log_probs): |
f.write(s + "\n") |
f.write(t + "\n") |
f.write("\n") |
for h, lp in zip(hs, lps): |
f.write("\t%f\t%s\n" % (lp, h.strip())) |
f.write("------------------------------------------------------\n") |
def corpus_bleu(sys_stream, ref_streams): |
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none") |
return bleu.score |
def sentence_bleu(hypothesis, reference): |
bleu = _corpus_bleu(hypothesis, reference) |
for i in range(1, 4): |
bleu.counts[i] += 1 |
bleu.totals[i] += 1 |
bleu = compute_bleu( |
bleu.counts, |
bleu.totals, |
bleu.sys_len, |
bleu.ref_len, |
smooth_method="exp", |
) |
return bleu.score |
def pairwise(sents): |
_ref, _hypo = [], [] |
for s in sents: |
for i in range(len(s)): |
for j in range(len(s)): |
if i != j: |
_ref.append(s[i]) |
_hypo.append(s[j]) |
return corpus_bleu(_hypo, [_ref]) |
def multi_ref(refs, hypos): |
_ref, _hypo = [], [] |
ref_cnt = 0 |
assert len(refs) == len(hypos) |
for rs, hs in zip(refs, hypos): |
a = set() |
for h in hs: |
s = [sentence_bleu(h, r) for r in rs] |
j = np.argmax(s) |
_ref.append(rs[j]) |
_hypo.append(h) |
best = [k for k in range(len(rs)) if s[k] == s[j]] |
a.add(random.choice(best)) |
ref_cnt += len(a) |
print("#refs covered: %.2f" % (ref_cnt / len(refs))) |
refs = list(zip(*refs)) |
hypos = list(zip(*hypos)) |
k = len(hypos) |
m = len(refs) |
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)] |
duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs] |
loo_bleus = [] |
for held_out_ref in range(m): |
remaining_refs = ( |
duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :] |
) |
assert len(remaining_refs) == m - 1 |
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs)) |
print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus)) |
def intra_ref(refs): |
print("ref pairwise BLEU: %.2f" % pairwise(refs)) |
refs = list(zip(*refs)) |
m = len(refs) |
concat_h = [] |
concat_rest = [[] for j in range(m - 1)] |
for i, h in enumerate(refs): |
rest = refs[:i] + refs[i + 1 :] |
concat_h.append(h) |
for j in range(m - 1): |
concat_rest[j].extend(rest[j]) |
concat_h = list(chain.from_iterable(concat_h)) |
bleu = corpus_bleu(concat_h, concat_rest) |
print("multi-reference BLEU (leave-one-out): %.2f" % bleu) |
if __name__ == "__main__": |
main() |