Spaces:
Runtime error
Runtime error
import re | |
import os | |
import sys | |
from tqdm import tqdm | |
def remove_bpe(line, bpe_symbol="@@ "): | |
line = line.replace("\n", '') | |
line = (line + ' ').replace(bpe_symbol, '').rstrip() | |
return line + ("\n") | |
def remove_bpe_fn(i=sys.stdin, o=sys.stdout, bpe="@@ "): | |
lines = tqdm(i) | |
lines = map(lambda x: remove_bpe(x, bpe), lines) | |
# _write_lines(lines, f=o) | |
for line in lines: | |
o.write(line) | |
def reprocess(fle): | |
# takes in a file of generate.py translation generate_output | |
# returns a source dict and hypothesis dict, where keys are the ID num (as a string) | |
# and values and the corresponding source and translation. There may be several translations | |
# per source, so the values for hypothesis_dict are lists. | |
# parses output of generate.py | |
with open(fle, 'r') as f: | |
txt = f.read() | |
"""reprocess generate.py output""" | |
p = re.compile(r"[STHP][-]\d+\s*") | |
hp = re.compile(r"(\s*[-]?\d+[.]?\d+(e[+-]?\d+)?\s*)|(\s*(-inf)\s*)") | |
source_dict = {} | |
hypothesis_dict = {} | |
score_dict = {} | |
target_dict = {} | |
pos_score_dict = {} | |
lines = txt.split("\n") | |
for line in lines: | |
line += "\n" | |
prefix = re.search(p, line) | |
if prefix is not None: | |
assert len(prefix.group()) > 2, "prefix id not found" | |
_, j = prefix.span() | |
id_num = prefix.group()[2:] | |
id_num = int(id_num) | |
line_type = prefix.group()[0] | |
if line_type == "H": | |
h_txt = line[j:] | |
hypo = re.search(hp, h_txt) | |
assert hypo is not None, ("regular expression failed to find the hypothesis scoring") | |
_, i = hypo.span() | |
score = hypo.group() | |
hypo_str = h_txt[i:] | |
# if r2l: # todo: reverse score as well | |
# hypo_str = " ".join(reversed(hypo_str.strip().split(" "))) + "\n" | |
if id_num in hypothesis_dict: | |
hypothesis_dict[id_num].append(hypo_str) | |
score_dict[id_num].append(float(score)) | |
else: | |
hypothesis_dict[id_num] = [hypo_str] | |
score_dict[id_num] = [float(score)] | |
elif line_type == "S": | |
source_dict[id_num] = (line[j:]) | |
elif line_type == "T": | |
# target_dict[id_num] = (line[j:]) | |
continue | |
elif line_type == "P": | |
pos_scores = (line[j:]).split() | |
pos_scores = [float(x) for x in pos_scores] | |
if id_num in pos_score_dict: | |
pos_score_dict[id_num].append(pos_scores) | |
else: | |
pos_score_dict[id_num] = [pos_scores] | |
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict | |
def get_hypo_and_ref(fle, hyp_file, ref_input, ref_file, rank=0): | |
with open(ref_input, 'r') as f: | |
refs = f.readlines() | |
_, hypo_dict, _, _, _ = reprocess(fle) | |
assert rank < len(hypo_dict[0]) | |
maxkey = max(hypo_dict, key=int) | |
f_hyp = open(hyp_file, "w") | |
f_ref = open(ref_file, "w") | |
for idx in range(maxkey + 1): | |
if idx not in hypo_dict: | |
continue | |
f_hyp.write(hypo_dict[idx][rank]) | |
f_ref.write(refs[idx]) | |
f_hyp.close() | |
f_ref.close() | |
def recover_bpe(hyp_file): | |
f_hyp = open(hyp_file, "r") | |
f_hyp_out = open(hyp_file + ".nobpe", "w") | |
for _s in ["hyp"]: | |
f = eval("f_{}".format(_s)) | |
fout = eval("f_{}_out".format(_s)) | |
remove_bpe_fn(i=f, o=fout) | |
f_hyp.close() | |
f_hyp_out.close() | |
if __name__ == "__main__": | |
filename = sys.argv[1] | |
ref_in = sys.argv[2] | |
hypo_file = os.path.join(os.path.dirname(filename), "hypo.out") | |
ref_out = os.path.join(os.path.dirname(filename), "ref.out") | |
get_hypo_and_ref(filename, hypo_file, ref_in, ref_out) | |
recover_bpe(hypo_file) | |