|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Generate n-best translations using a trained model. |
|
""" |
|
|
|
import os |
|
import subprocess |
|
from contextlib import redirect_stdout |
|
|
|
from fairseq import options |
|
from fairseq_cli import generate, preprocess |
|
|
|
from examples.noisychannel import rerank_options, rerank_utils |
|
|
|
|
|
def gen_and_reprocess_nbest(args): |
|
if args.score_dict_dir is None: |
|
args.score_dict_dir = args.data |
|
if args.prefix_len is not None: |
|
assert ( |
|
args.right_to_left1 is False |
|
), "prefix length not compatible with right to left models" |
|
assert ( |
|
args.right_to_left2 is False |
|
), "prefix length not compatible with right to left models" |
|
|
|
if args.nbest_list is not None: |
|
assert args.score_model2 is None |
|
|
|
if args.backwards1: |
|
scorer1_src = args.target_lang |
|
scorer1_tgt = args.source_lang |
|
else: |
|
scorer1_src = args.source_lang |
|
scorer1_tgt = args.target_lang |
|
|
|
store_data = ( |
|
os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name |
|
) |
|
if not os.path.exists(store_data): |
|
os.makedirs(store_data) |
|
|
|
( |
|
pre_gen, |
|
left_to_right_preprocessed_dir, |
|
right_to_left_preprocessed_dir, |
|
backwards_preprocessed_dir, |
|
lm_preprocessed_dir, |
|
) = rerank_utils.get_directories( |
|
args.data_dir_name, |
|
args.num_rescore, |
|
args.gen_subset, |
|
args.gen_model_name, |
|
args.shard_id, |
|
args.num_shards, |
|
args.sampling, |
|
args.prefix_len, |
|
args.target_prefix_frac, |
|
args.source_prefix_frac, |
|
) |
|
assert not ( |
|
args.right_to_left1 and args.backwards1 |
|
), "backwards right to left not supported" |
|
assert not ( |
|
args.right_to_left2 and args.backwards2 |
|
), "backwards right to left not supported" |
|
assert not ( |
|
args.prefix_len is not None and args.target_prefix_frac is not None |
|
), "target prefix frac and target prefix len incompatible" |
|
|
|
|
|
if not os.path.exists(pre_gen): |
|
os.makedirs(pre_gen) |
|
|
|
rerank1_is_gen = ( |
|
args.gen_model == args.score_model1 and args.source_prefix_frac is None |
|
) |
|
rerank2_is_gen = ( |
|
args.gen_model == args.score_model2 and args.source_prefix_frac is None |
|
) |
|
|
|
if args.nbest_list is not None: |
|
rerank2_is_gen = True |
|
|
|
|
|
if not os.path.exists(left_to_right_preprocessed_dir): |
|
os.makedirs(left_to_right_preprocessed_dir) |
|
if not os.path.exists(right_to_left_preprocessed_dir): |
|
os.makedirs(right_to_left_preprocessed_dir) |
|
if not os.path.exists(lm_preprocessed_dir): |
|
os.makedirs(lm_preprocessed_dir) |
|
if not os.path.exists(backwards_preprocessed_dir): |
|
os.makedirs(backwards_preprocessed_dir) |
|
|
|
score1_file = rerank_utils.rescore_file_name( |
|
pre_gen, |
|
args.prefix_len, |
|
args.model1_name, |
|
target_prefix_frac=args.target_prefix_frac, |
|
source_prefix_frac=args.source_prefix_frac, |
|
backwards=args.backwards1, |
|
) |
|
if args.score_model2 is not None: |
|
score2_file = rerank_utils.rescore_file_name( |
|
pre_gen, |
|
args.prefix_len, |
|
args.model2_name, |
|
target_prefix_frac=args.target_prefix_frac, |
|
source_prefix_frac=args.source_prefix_frac, |
|
backwards=args.backwards2, |
|
) |
|
|
|
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" |
|
|
|
using_nbest = args.nbest_list is not None |
|
|
|
if using_nbest: |
|
print("Using predefined n-best list from interactive.py") |
|
predictions_bpe_file = args.nbest_list |
|
|
|
else: |
|
if not os.path.isfile(predictions_bpe_file): |
|
print("STEP 1: generate predictions using the p(T|S) model with bpe") |
|
print(args.data) |
|
param1 = [ |
|
args.data, |
|
"--path", |
|
args.gen_model, |
|
"--shard-id", |
|
str(args.shard_id), |
|
"--num-shards", |
|
str(args.num_shards), |
|
"--nbest", |
|
str(args.num_rescore), |
|
"--batch-size", |
|
str(args.batch_size), |
|
"--beam", |
|
str(args.num_rescore), |
|
"--batch-size", |
|
str(args.num_rescore), |
|
"--gen-subset", |
|
args.gen_subset, |
|
"--source-lang", |
|
args.source_lang, |
|
"--target-lang", |
|
args.target_lang, |
|
] |
|
if args.sampling: |
|
param1 += ["--sampling"] |
|
|
|
gen_parser = options.get_generation_parser() |
|
input_args = options.parse_args_and_arch(gen_parser, param1) |
|
|
|
print(input_args) |
|
with open(predictions_bpe_file, "w") as f: |
|
with redirect_stdout(f): |
|
generate.main(input_args) |
|
|
|
gen_output = rerank_utils.BitextOutputFromGen( |
|
predictions_bpe_file, |
|
bpe_symbol=args.post_process, |
|
nbest=using_nbest, |
|
prefix_len=args.prefix_len, |
|
target_prefix_frac=args.target_prefix_frac, |
|
) |
|
|
|
if args.diff_bpe: |
|
rerank_utils.write_reprocessed( |
|
gen_output.no_bpe_source, |
|
gen_output.no_bpe_hypo, |
|
gen_output.no_bpe_target, |
|
pre_gen + "/source_gen_bpe." + args.source_lang, |
|
pre_gen + "/target_gen_bpe." + args.target_lang, |
|
pre_gen + "/reference_gen_bpe." + args.target_lang, |
|
) |
|
bitext_bpe = args.rescore_bpe_code |
|
bpe_src_param = [ |
|
"-c", |
|
bitext_bpe, |
|
"--input", |
|
pre_gen + "/source_gen_bpe." + args.source_lang, |
|
"--output", |
|
pre_gen + "/rescore_data." + args.source_lang, |
|
] |
|
bpe_tgt_param = [ |
|
"-c", |
|
bitext_bpe, |
|
"--input", |
|
pre_gen + "/target_gen_bpe." + args.target_lang, |
|
"--output", |
|
pre_gen + "/rescore_data." + args.target_lang, |
|
] |
|
|
|
subprocess.call( |
|
[ |
|
"python", |
|
os.path.join( |
|
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" |
|
), |
|
] |
|
+ bpe_src_param, |
|
shell=False, |
|
) |
|
|
|
subprocess.call( |
|
[ |
|
"python", |
|
os.path.join( |
|
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" |
|
), |
|
] |
|
+ bpe_tgt_param, |
|
shell=False, |
|
) |
|
|
|
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or ( |
|
args.score_model2 is not None |
|
and not os.path.isfile(score2_file) |
|
and not rerank2_is_gen |
|
): |
|
print( |
|
"STEP 2: process the output of generate.py so we have clean text files with the translations" |
|
) |
|
|
|
rescore_file = "/rescore_data" |
|
if args.prefix_len is not None: |
|
prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len) |
|
if args.target_prefix_frac is not None: |
|
target_prefix_frac_rescore_file = ( |
|
rescore_file + "target_prefix_frac" + str(args.target_prefix_frac) |
|
) |
|
if args.source_prefix_frac is not None: |
|
source_prefix_frac_rescore_file = ( |
|
rescore_file + "source_prefix_frac" + str(args.source_prefix_frac) |
|
) |
|
|
|
if not args.right_to_left1 or not args.right_to_left2: |
|
if not args.diff_bpe: |
|
rerank_utils.write_reprocessed( |
|
gen_output.source, |
|
gen_output.hypo, |
|
gen_output.target, |
|
pre_gen + rescore_file + "." + args.source_lang, |
|
pre_gen + rescore_file + "." + args.target_lang, |
|
pre_gen + "/reference_file", |
|
bpe_symbol=args.post_process, |
|
) |
|
if args.prefix_len is not None: |
|
bw_rescore_file = prefix_len_rescore_file |
|
rerank_utils.write_reprocessed( |
|
gen_output.source, |
|
gen_output.hypo, |
|
gen_output.target, |
|
pre_gen + prefix_len_rescore_file + "." + args.source_lang, |
|
pre_gen + prefix_len_rescore_file + "." + args.target_lang, |
|
pre_gen + "/reference_file", |
|
prefix_len=args.prefix_len, |
|
bpe_symbol=args.post_process, |
|
) |
|
elif args.target_prefix_frac is not None: |
|
bw_rescore_file = target_prefix_frac_rescore_file |
|
rerank_utils.write_reprocessed( |
|
gen_output.source, |
|
gen_output.hypo, |
|
gen_output.target, |
|
pre_gen |
|
+ target_prefix_frac_rescore_file |
|
+ "." |
|
+ args.source_lang, |
|
pre_gen |
|
+ target_prefix_frac_rescore_file |
|
+ "." |
|
+ args.target_lang, |
|
pre_gen + "/reference_file", |
|
bpe_symbol=args.post_process, |
|
target_prefix_frac=args.target_prefix_frac, |
|
) |
|
else: |
|
bw_rescore_file = rescore_file |
|
|
|
if args.source_prefix_frac is not None: |
|
fw_rescore_file = source_prefix_frac_rescore_file |
|
rerank_utils.write_reprocessed( |
|
gen_output.source, |
|
gen_output.hypo, |
|
gen_output.target, |
|
pre_gen |
|
+ source_prefix_frac_rescore_file |
|
+ "." |
|
+ args.source_lang, |
|
pre_gen |
|
+ source_prefix_frac_rescore_file |
|
+ "." |
|
+ args.target_lang, |
|
pre_gen + "/reference_file", |
|
bpe_symbol=args.post_process, |
|
source_prefix_frac=args.source_prefix_frac, |
|
) |
|
else: |
|
fw_rescore_file = rescore_file |
|
|
|
if args.right_to_left1 or args.right_to_left2: |
|
rerank_utils.write_reprocessed( |
|
gen_output.source, |
|
gen_output.hypo, |
|
gen_output.target, |
|
pre_gen + "/right_to_left_rescore_data." + args.source_lang, |
|
pre_gen + "/right_to_left_rescore_data." + args.target_lang, |
|
pre_gen + "/right_to_left_reference_file", |
|
right_to_left=True, |
|
bpe_symbol=args.post_process, |
|
) |
|
|
|
print("STEP 3: binarize the translations") |
|
if ( |
|
not args.right_to_left1 |
|
or args.score_model2 is not None |
|
and not args.right_to_left2 |
|
or not rerank1_is_gen |
|
): |
|
|
|
if args.backwards1 or args.backwards2: |
|
if args.backwards_score_dict_dir is not None: |
|
bw_dict = args.backwards_score_dict_dir |
|
else: |
|
bw_dict = args.score_dict_dir |
|
bw_preprocess_param = [ |
|
"--source-lang", |
|
scorer1_src, |
|
"--target-lang", |
|
scorer1_tgt, |
|
"--trainpref", |
|
pre_gen + bw_rescore_file, |
|
"--srcdict", |
|
bw_dict + "/dict." + scorer1_src + ".txt", |
|
"--tgtdict", |
|
bw_dict + "/dict." + scorer1_tgt + ".txt", |
|
"--destdir", |
|
backwards_preprocessed_dir, |
|
] |
|
preprocess_parser = options.get_preprocessing_parser() |
|
input_args = preprocess_parser.parse_args(bw_preprocess_param) |
|
preprocess.main(input_args) |
|
|
|
preprocess_param = [ |
|
"--source-lang", |
|
scorer1_src, |
|
"--target-lang", |
|
scorer1_tgt, |
|
"--trainpref", |
|
pre_gen + fw_rescore_file, |
|
"--srcdict", |
|
args.score_dict_dir + "/dict." + scorer1_src + ".txt", |
|
"--tgtdict", |
|
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt", |
|
"--destdir", |
|
left_to_right_preprocessed_dir, |
|
] |
|
preprocess_parser = options.get_preprocessing_parser() |
|
input_args = preprocess_parser.parse_args(preprocess_param) |
|
preprocess.main(input_args) |
|
|
|
if args.right_to_left1 or args.right_to_left2: |
|
preprocess_param = [ |
|
"--source-lang", |
|
scorer1_src, |
|
"--target-lang", |
|
scorer1_tgt, |
|
"--trainpref", |
|
pre_gen + "/right_to_left_rescore_data", |
|
"--srcdict", |
|
args.score_dict_dir + "/dict." + scorer1_src + ".txt", |
|
"--tgtdict", |
|
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt", |
|
"--destdir", |
|
right_to_left_preprocessed_dir, |
|
] |
|
preprocess_parser = options.get_preprocessing_parser() |
|
input_args = preprocess_parser.parse_args(preprocess_param) |
|
preprocess.main(input_args) |
|
|
|
return gen_output |
|
|
|
|
|
def cli_main(): |
|
parser = rerank_options.get_reranking_parser() |
|
args = options.parse_args_and_arch(parser) |
|
gen_and_reprocess_nbest(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
cli_main() |
|
|