|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Extracts random constraints from reference files.""" |
|
|
|
import argparse |
|
import random |
|
import sys |
|
|
|
from sacrebleu import extract_ngrams |
|
|
|
|
|
def get_phrase(words, index, length): |
|
assert index < len(words) - length + 1 |
|
phr = " ".join(words[index : index + length]) |
|
for i in range(index, index + length): |
|
words.pop(index) |
|
return phr |
|
|
|
|
|
def main(args): |
|
|
|
if args.seed: |
|
random.seed(args.seed) |
|
|
|
for line in sys.stdin: |
|
constraints = [] |
|
|
|
def add_constraint(constraint): |
|
constraints.append(constraint) |
|
|
|
source = line.rstrip() |
|
if "\t" in line: |
|
source, target = line.split("\t") |
|
if args.add_sos: |
|
target = f"<s> {target}" |
|
if args.add_eos: |
|
target = f"{target} </s>" |
|
|
|
if len(target.split()) >= args.len: |
|
words = [target] |
|
|
|
num = args.number |
|
|
|
choices = {} |
|
for i in range(num): |
|
if len(words) == 0: |
|
break |
|
segmentno = random.choice(range(len(words))) |
|
segment = words.pop(segmentno) |
|
tokens = segment.split() |
|
phrase_index = random.choice(range(len(tokens))) |
|
choice = " ".join( |
|
tokens[phrase_index : min(len(tokens), phrase_index + args.len)] |
|
) |
|
for j in range( |
|
phrase_index, min(len(tokens), phrase_index + args.len) |
|
): |
|
tokens.pop(phrase_index) |
|
if phrase_index > 0: |
|
words.append(" ".join(tokens[0:phrase_index])) |
|
if phrase_index + 1 < len(tokens): |
|
words.append(" ".join(tokens[phrase_index:])) |
|
choices[target.find(choice)] = choice |
|
|
|
|
|
target = target.replace(choice, " " * len(choice), 1) |
|
|
|
for key in sorted(choices.keys()): |
|
add_constraint(choices[key]) |
|
|
|
print(source, *constraints, sep="\t") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases") |
|
parser.add_argument("--len", "-l", type=int, default=1, help="phrase length") |
|
parser.add_argument( |
|
"--add-sos", default=False, action="store_true", help="add <s> token" |
|
) |
|
parser.add_argument( |
|
"--add-eos", default=False, action="store_true", help="add </s> token" |
|
) |
|
parser.add_argument("--seed", "-s", default=0, type=int) |
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|