File size: 2,408 Bytes
26fd00c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
#!/usr/bin/env python3 -u
import argparse
import fileinput
import logging
import os
import sys
from fairseq.models.transformer import TransformerModel
logging.getLogger().setLevel(logging.INFO)
def main():
parser = argparse.ArgumentParser(description="")
parser.add_argument("--en2fr", required=True, help="path to en2fr model")
parser.add_argument(
"--fr2en", required=True, help="path to fr2en mixture of experts model"
)
parser.add_argument(
"--user-dir", help="path to fairseq examples/translation_moe/src directory"
)
parser.add_argument(
"--num-experts",
type=int,
default=10,
help="(keep at 10 unless using a different model)",
)
parser.add_argument(
"files",
nargs="*",
default=["-"],
help='input files to paraphrase; "-" for stdin',
)
args = parser.parse_args()
if args.user_dir is None:
args.user_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
"translation_moe",
"src",
)
if os.path.exists(args.user_dir):
logging.info("found user_dir:" + args.user_dir)
else:
raise RuntimeError(
"cannot find fairseq examples/translation_moe/src "
"(tried looking here: {})".format(args.user_dir)
)
logging.info("loading en2fr model from:" + args.en2fr)
en2fr = TransformerModel.from_pretrained(
model_name_or_path=args.en2fr,
tokenizer="moses",
bpe="sentencepiece",
).eval()
logging.info("loading fr2en model from:" + args.fr2en)
fr2en = TransformerModel.from_pretrained(
model_name_or_path=args.fr2en,
tokenizer="moses",
bpe="sentencepiece",
user_dir=args.user_dir,
task="translation_moe",
).eval()
def gen_paraphrases(en):
fr = en2fr.translate(en)
return [
fr2en.translate(fr, inference_step_args={"expert": i})
for i in range(args.num_experts)
]
logging.info("Type the input sentence and press return:")
for line in fileinput.input(args.files):
line = line.strip()
if len(line) == 0:
continue
for paraphrase in gen_paraphrases(line):
print(paraphrase)
if __name__ == "__main__":
main()
|