|
|
|
|
|
import argparse |
|
import fileinput |
|
import logging |
|
import os |
|
import sys |
|
|
|
from fairseq.models.transformer import TransformerModel |
|
|
|
|
|
logging.getLogger().setLevel(logging.INFO) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="") |
|
parser.add_argument("--en2fr", required=True, help="path to en2fr model") |
|
parser.add_argument( |
|
"--fr2en", required=True, help="path to fr2en mixture of experts model" |
|
) |
|
parser.add_argument( |
|
"--user-dir", help="path to fairseq examples/translation_moe/src directory" |
|
) |
|
parser.add_argument( |
|
"--num-experts", |
|
type=int, |
|
default=10, |
|
help="(keep at 10 unless using a different model)", |
|
) |
|
parser.add_argument( |
|
"files", |
|
nargs="*", |
|
default=["-"], |
|
help='input files to paraphrase; "-" for stdin', |
|
) |
|
args = parser.parse_args() |
|
|
|
if args.user_dir is None: |
|
args.user_dir = os.path.join( |
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), |
|
"translation_moe", |
|
"src", |
|
) |
|
if os.path.exists(args.user_dir): |
|
logging.info("found user_dir:" + args.user_dir) |
|
else: |
|
raise RuntimeError( |
|
"cannot find fairseq examples/translation_moe/src " |
|
"(tried looking here: {})".format(args.user_dir) |
|
) |
|
|
|
logging.info("loading en2fr model from:" + args.en2fr) |
|
en2fr = TransformerModel.from_pretrained( |
|
model_name_or_path=args.en2fr, |
|
tokenizer="moses", |
|
bpe="sentencepiece", |
|
).eval() |
|
|
|
logging.info("loading fr2en model from:" + args.fr2en) |
|
fr2en = TransformerModel.from_pretrained( |
|
model_name_or_path=args.fr2en, |
|
tokenizer="moses", |
|
bpe="sentencepiece", |
|
user_dir=args.user_dir, |
|
task="translation_moe", |
|
).eval() |
|
|
|
def gen_paraphrases(en): |
|
fr = en2fr.translate(en) |
|
return [ |
|
fr2en.translate(fr, inference_step_args={"expert": i}) |
|
for i in range(args.num_experts) |
|
] |
|
|
|
logging.info("Type the input sentence and press return:") |
|
for line in fileinput.input(args.files): |
|
line = line.strip() |
|
if len(line) == 0: |
|
continue |
|
for paraphrase in gen_paraphrases(line): |
|
print(paraphrase) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|