|
import re |
|
import os |
|
import sys |
|
from tqdm import tqdm |
|
|
|
|
|
def remove_bpe(line, bpe_symbol="@@ "): |
|
line = line.replace("\n", '') |
|
line = (line + ' ').replace(bpe_symbol, '').rstrip() |
|
return line + ("\n") |
|
|
|
|
|
def remove_bpe_fn(i=sys.stdin, o=sys.stdout, bpe="@@ "): |
|
lines = tqdm(i) |
|
lines = map(lambda x: remove_bpe(x, bpe), lines) |
|
|
|
for line in lines: |
|
o.write(line) |
|
|
|
|
|
def reprocess(fle): |
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(fle, 'r') as f: |
|
txt = f.read() |
|
|
|
"""reprocess generate.py output""" |
|
p = re.compile(r"[STHP][-]\d+\s*") |
|
hp = re.compile(r"(\s*[-]?\d+[.]?\d+(e[+-]?\d+)?\s*)|(\s*(-inf)\s*)") |
|
source_dict = {} |
|
hypothesis_dict = {} |
|
score_dict = {} |
|
target_dict = {} |
|
pos_score_dict = {} |
|
lines = txt.split("\n") |
|
|
|
for line in lines: |
|
line += "\n" |
|
prefix = re.search(p, line) |
|
if prefix is not None: |
|
assert len(prefix.group()) > 2, "prefix id not found" |
|
_, j = prefix.span() |
|
id_num = prefix.group()[2:] |
|
id_num = int(id_num) |
|
line_type = prefix.group()[0] |
|
if line_type == "H": |
|
h_txt = line[j:] |
|
hypo = re.search(hp, h_txt) |
|
assert hypo is not None, ("regular expression failed to find the hypothesis scoring") |
|
_, i = hypo.span() |
|
score = hypo.group() |
|
hypo_str = h_txt[i:] |
|
|
|
|
|
if id_num in hypothesis_dict: |
|
hypothesis_dict[id_num].append(hypo_str) |
|
score_dict[id_num].append(float(score)) |
|
else: |
|
hypothesis_dict[id_num] = [hypo_str] |
|
score_dict[id_num] = [float(score)] |
|
|
|
elif line_type == "S": |
|
source_dict[id_num] = (line[j:]) |
|
elif line_type == "T": |
|
|
|
continue |
|
elif line_type == "P": |
|
pos_scores = (line[j:]).split() |
|
pos_scores = [float(x) for x in pos_scores] |
|
if id_num in pos_score_dict: |
|
pos_score_dict[id_num].append(pos_scores) |
|
else: |
|
pos_score_dict[id_num] = [pos_scores] |
|
|
|
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict |
|
|
|
|
|
def get_hypo_and_ref(fle, hyp_file, ref_input, ref_file, rank=0): |
|
with open(ref_input, 'r') as f: |
|
refs = f.readlines() |
|
_, hypo_dict, _, _, _ = reprocess(fle) |
|
assert rank < len(hypo_dict[0]) |
|
maxkey = max(hypo_dict, key=int) |
|
f_hyp = open(hyp_file, "w") |
|
f_ref = open(ref_file, "w") |
|
for idx in range(maxkey + 1): |
|
if idx not in hypo_dict: |
|
continue |
|
f_hyp.write(hypo_dict[idx][rank]) |
|
f_ref.write(refs[idx]) |
|
f_hyp.close() |
|
f_ref.close() |
|
|
|
|
|
def recover_bpe(hyp_file): |
|
f_hyp = open(hyp_file, "r") |
|
f_hyp_out = open(hyp_file + ".nobpe", "w") |
|
for _s in ["hyp"]: |
|
f = eval("f_{}".format(_s)) |
|
fout = eval("f_{}_out".format(_s)) |
|
remove_bpe_fn(i=f, o=fout) |
|
f_hyp.close() |
|
f_hyp_out.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
filename = sys.argv[1] |
|
ref_in = sys.argv[2] |
|
hypo_file = os.path.join(os.path.dirname(filename), "hypo.out") |
|
ref_out = os.path.join(os.path.dirname(filename), "ref.out") |
|
get_hypo_and_ref(filename, hypo_file, ref_in, ref_out) |
|
recover_bpe(hypo_file) |
|
|