jiang
init commit
650c5f6
raw
history blame
7.99 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import os.path as op
from collections import namedtuple
from multiprocessing import cpu_count
from typing import List, Optional
import sentencepiece as sp
from fairseq.data.encoders.byte_bpe import ByteBPE
from fairseq.data.encoders.byte_utils import byte_encode
from fairseq.data.encoders.bytes import Bytes
from fairseq.data.encoders.characters import Characters
from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
SPLITS = ["train", "valid", "test"]
def _convert_xml(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
ss = s.strip()
if not ss.startswith("<seg"):
continue
ss = ss.replace("</seg>", "").split('">')
assert len(ss) == 2
f_o.write(ss[1].strip() + "\n")
def _convert_train(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
ss = s.strip()
if ss.startswith("<"):
continue
f_o.write(ss.strip() + "\n")
def _get_bytes(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
f_o.write(Bytes.encode(s.strip()) + "\n")
def _get_chars(in_path: str, out_path: str):
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
f_o.write(Characters.encode(s.strip()) + "\n")
def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
Args = namedtuple(
"Args",
[
"moses_source_lang",
"moses_target_lang",
"moses_no_dash_splits",
"moses_no_escape",
],
)
args = Args(
moses_source_lang=src,
moses_target_lang=tgt,
moses_no_dash_splits=False,
moses_no_escape=False,
)
pretokenizer = MosesTokenizer(args)
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
f_o.write(pretokenizer.encode(s.strip()) + "\n")
def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
with open(out_path, "w") as f_o:
for lang in [src, tgt]:
with open(f"{in_path_prefix}.{lang}") as f:
for s in f:
f_o.write(byte_encode(s.strip()) + "\n")
def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
arguments = [
f"--input={in_path}",
f"--model_prefix={model_prefix}",
f"--model_type=bpe",
f"--vocab_size={vocab_size}",
"--character_coverage=1.0",
"--normalization_rule_name=identity",
f"--num_threads={cpu_count()}",
]
sp.SentencePieceTrainer.Train(" ".join(arguments))
def _apply_bbpe(model_path: str, in_path: str, out_path: str):
Args = namedtuple("Args", ["sentencepiece_model_path"])
args = Args(sentencepiece_model_path=model_path)
tokenizer = ByteBPE(args)
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
f_o.write(tokenizer.encode(s.strip()) + "\n")
def _apply_bpe(model_path: str, in_path: str, out_path: str):
Args = namedtuple("Args", ["sentencepiece_model"])
args = Args(sentencepiece_model=model_path)
tokenizer = SentencepieceBPE(args)
with open(in_path) as f, open(out_path, "w") as f_o:
for s in f:
f_o.write(tokenizer.encode(s.strip()) + "\n")
def _concat_files(in_paths: List[str], out_path: str):
with open(out_path, "w") as f_o:
for p in in_paths:
with open(p) as f:
for r in f:
f_o.write(r)
def preprocess_iwslt17(
root: str,
src: str,
tgt: str,
bpe_size: Optional[int],
need_chars: bool,
bbpe_size: Optional[int],
need_bytes: bool,
):
# extract bitext
in_root = op.join(root, f"{src}-{tgt}")
for lang in [src, tgt]:
_convert_train(
op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
op.join(root, f"train.{lang}"),
)
_convert_xml(
op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
op.join(root, f"valid.{lang}"),
)
_convert_xml(
op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
op.join(root, f"test.{lang}"),
)
# pre-tokenize
for lang in [src, tgt]:
for split in SPLITS:
pretokenize(
op.join(root, f"{split}.{lang}"),
op.join(root, f"{split}.moses.{lang}"),
src,
tgt,
)
# tokenize with BPE vocabulary
if bpe_size is not None:
# learn vocabulary
concated_train_path = op.join(root, "train.all")
_concat_files(
[op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
concated_train_path,
)
bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
_get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
os.remove(concated_train_path)
# apply
for lang in [src, tgt]:
for split in SPLITS:
_apply_bpe(
bpe_model_prefix + ".model",
op.join(root, f"{split}.moses.{lang}"),
op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
)
# tokenize with bytes vocabulary
if need_bytes:
for lang in [src, tgt]:
for split in SPLITS:
_get_bytes(
op.join(root, f"{split}.moses.{lang}"),
op.join(root, f"{split}.moses.bytes.{lang}"),
)
# tokenize with characters vocabulary
if need_chars:
for lang in [src, tgt]:
for split in SPLITS:
_get_chars(
op.join(root, f"{split}.moses.{lang}"),
op.join(root, f"{split}.moses.chars.{lang}"),
)
# tokenize with byte-level BPE vocabulary
if bbpe_size is not None:
# learn vocabulary
bchar_path = op.join(root, "train.bchar")
_convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
_get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
os.remove(bchar_path)
# apply
for lang in [src, tgt]:
for split in SPLITS:
_apply_bbpe(
bbpe_model_prefix + ".model",
op.join(root, f"{split}.moses.{lang}"),
op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="data")
parser.add_argument(
"--bpe-vocab",
default=None,
type=int,
help="Generate tokenized bitext with BPE of size K."
"Default to None (disabled).",
)
parser.add_argument(
"--bbpe-vocab",
default=None,
type=int,
help="Generate tokenized bitext with BBPE of size K."
"Default to None (disabled).",
)
parser.add_argument(
"--byte-vocab",
action="store_true",
help="Generate tokenized bitext with bytes vocabulary",
)
parser.add_argument(
"--char-vocab",
action="store_true",
help="Generate tokenized bitext with chars vocabulary",
)
args = parser.parse_args()
preprocess_iwslt17(
args.root,
"fr",
"en",
args.bpe_vocab,
args.char_vocab,
args.bbpe_vocab,
args.byte_vocab,
)
if __name__ == "__main__":
main()