|
import shutil |
|
import os, sys |
|
from subprocess import check_call, check_output |
|
import glob |
|
import argparse |
|
import shutil |
|
import pathlib |
|
import itertools |
|
|
|
def call_output(cmd): |
|
print(f"Executing: {cmd}") |
|
ret = check_output(cmd, shell=True) |
|
print(ret) |
|
return ret |
|
|
|
def call(cmd): |
|
print(cmd) |
|
check_call(cmd, shell=True) |
|
|
|
|
|
WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) |
|
|
|
if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): |
|
print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') |
|
sys.exit(-1) |
|
|
|
SPM_PATH = os.environ.get('SPM_PATH', None) |
|
|
|
if SPM_PATH is None or not SPM_PATH.strip(): |
|
print("Please install sentence piecence from https://github.com/google/sentencepiece and set SPM_PATH pointing to the installed spm_encode.py. Exitting...") |
|
sys.exit(-1) |
|
|
|
|
|
SPM_MODEL = f'{WORKDIR_ROOT}/sentence.bpe.model' |
|
SPM_VOCAB = f'{WORKDIR_ROOT}/dict_250k.txt' |
|
|
|
SPM_ENCODE = f'{SPM_PATH}' |
|
|
|
if not os.path.exists(SPM_MODEL): |
|
call(f"wget https://dl.fbaipublicfiles.com/fairseq/models/mbart50/sentence.bpe.model -O {SPM_MODEL}") |
|
|
|
|
|
if not os.path.exists(SPM_VOCAB): |
|
call(f"wget https://dl.fbaipublicfiles.com/fairseq/models/mbart50/dict_250k.txt -O {SPM_VOCAB}") |
|
|
|
|
|
|
|
def get_data_size(raw): |
|
cmd = f'wc -l {raw}' |
|
ret = call_output(cmd) |
|
return int(ret.split()[0]) |
|
|
|
def encode_spm(model, direction, prefix='', splits=['train', 'test', 'valid'], pairs_per_shard=None): |
|
src, tgt = direction.split('-') |
|
|
|
for split in splits: |
|
src_raw, tgt_raw = f'{RAW_DIR}/{split}{prefix}.{direction}.{src}', f'{RAW_DIR}/{split}{prefix}.{direction}.{tgt}' |
|
if os.path.exists(src_raw) and os.path.exists(tgt_raw): |
|
cmd = f"""python {SPM_ENCODE} \ |
|
--model {model}\ |
|
--output_format=piece \ |
|
--inputs {src_raw} {tgt_raw} \ |
|
--outputs {BPE_DIR}/{direction}{prefix}/{split}.bpe.{src} {BPE_DIR}/{direction}{prefix}/{split}.bpe.{tgt} """ |
|
print(cmd) |
|
call(cmd) |
|
|
|
|
|
def binarize_( |
|
bpe_dir, |
|
databin_dir, |
|
direction, spm_vocab=SPM_VOCAB, |
|
splits=['train', 'test', 'valid'], |
|
): |
|
src, tgt = direction.split('-') |
|
|
|
try: |
|
shutil.rmtree(f'{databin_dir}', ignore_errors=True) |
|
os.mkdir(f'{databin_dir}') |
|
except OSError as error: |
|
print(error) |
|
cmds = [ |
|
"fairseq-preprocess", |
|
f"--source-lang {src} --target-lang {tgt}", |
|
f"--destdir {databin_dir}/", |
|
f"--workers 8", |
|
] |
|
if isinstance(spm_vocab, tuple): |
|
src_vocab, tgt_vocab = spm_vocab |
|
cmds.extend( |
|
[ |
|
f"--srcdict {src_vocab}", |
|
f"--tgtdict {tgt_vocab}", |
|
] |
|
) |
|
else: |
|
cmds.extend( |
|
[ |
|
f"--joined-dictionary", |
|
f"--srcdict {spm_vocab}", |
|
] |
|
) |
|
input_options = [] |
|
if 'train' in splits and glob.glob(f"{bpe_dir}/train.bpe*"): |
|
input_options.append( |
|
f"--trainpref {bpe_dir}/train.bpe", |
|
) |
|
if 'valid' in splits and glob.glob(f"{bpe_dir}/valid.bpe*"): |
|
input_options.append(f"--validpref {bpe_dir}/valid.bpe") |
|
if 'test' in splits and glob.glob(f"{bpe_dir}/test.bpe*"): |
|
input_options.append(f"--testpref {bpe_dir}/test.bpe") |
|
if len(input_options) > 0: |
|
cmd = " ".join(cmds + input_options) |
|
print(cmd) |
|
call(cmd) |
|
|
|
|
|
def binarize( |
|
databin_dir, |
|
direction, spm_vocab=SPM_VOCAB, prefix='', |
|
splits=['train', 'test', 'valid'], |
|
pairs_per_shard=None, |
|
): |
|
def move_databin_files(from_folder, to_folder): |
|
for bin_file in glob.glob(f"{from_folder}/*.bin") \ |
|
+ glob.glob(f"{from_folder}/*.idx") \ |
|
+ glob.glob(f"{from_folder}/dict*"): |
|
try: |
|
shutil.move(bin_file, to_folder) |
|
except OSError as error: |
|
print(error) |
|
bpe_databin_dir = f"{BPE_DIR}/{direction}{prefix}_databin" |
|
bpe_dir = f"{BPE_DIR}/{direction}{prefix}" |
|
if pairs_per_shard is None: |
|
binarize_(bpe_dir, bpe_databin_dir, direction, spm_vocab=spm_vocab, splits=splits) |
|
move_databin_files(bpe_databin_dir, databin_dir) |
|
else: |
|
|
|
binarize_( |
|
bpe_dir, bpe_databin_dir, direction, |
|
spm_vocab=spm_vocab, splits=[s for s in splits if s != "train"]) |
|
for shard_bpe_dir in glob.glob(f"{bpe_dir}/shard*"): |
|
path_strs = os.path.split(shard_bpe_dir) |
|
shard_str = path_strs[-1] |
|
shard_folder = f"{bpe_databin_dir}/{shard_str}" |
|
databin_shard_folder = f"{databin_dir}/{shard_str}" |
|
print(f'working from {shard_folder} to {databin_shard_folder}') |
|
os.makedirs(databin_shard_folder, exist_ok=True) |
|
binarize_( |
|
shard_bpe_dir, shard_folder, direction, |
|
spm_vocab=spm_vocab, splits=["train"]) |
|
|
|
for test_data in glob.glob(f"{bpe_databin_dir}/valid.*") + glob.glob(f"{bpe_databin_dir}/test.*"): |
|
filename = os.path.split(test_data)[-1] |
|
try: |
|
os.symlink(test_data, f"{databin_shard_folder}/{filename}") |
|
except OSError as error: |
|
print(error) |
|
move_databin_files(shard_folder, databin_shard_folder) |
|
|
|
|
|
def load_langs(path): |
|
with open(path) as fr: |
|
langs = [l.strip() for l in fr] |
|
return langs |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--data_root", default=f"{WORKDIR_ROOT}/ML50") |
|
parser.add_argument("--raw-folder", default='raw') |
|
parser.add_argument("--bpe-folder", default='bpe') |
|
parser.add_argument("--databin-folder", default='databin') |
|
|
|
args = parser.parse_args() |
|
|
|
DATA_PATH = args.data_root |
|
RAW_DIR = f'{DATA_PATH}/{args.raw_folder}' |
|
BPE_DIR = f'{DATA_PATH}/{args.bpe_folder}' |
|
DATABIN_DIR = f'{DATA_PATH}/{args.databin_folder}' |
|
os.makedirs(BPE_DIR, exist_ok=True) |
|
|
|
raw_files = itertools.chain( |
|
glob.glob(f'{RAW_DIR}/train*'), |
|
glob.glob(f'{RAW_DIR}/valid*'), |
|
glob.glob(f'{RAW_DIR}/test*'), |
|
) |
|
|
|
directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files] |
|
|
|
for direction in directions: |
|
prefix = "" |
|
splits = ['train', 'valid', 'test'] |
|
try: |
|
shutil.rmtree(f'{BPE_DIR}/{direction}{prefix}', ignore_errors=True) |
|
os.mkdir(f'{BPE_DIR}/{direction}{prefix}') |
|
os.makedirs(DATABIN_DIR, exist_ok=True) |
|
except OSError as error: |
|
print(error) |
|
spm_model, spm_vocab = SPM_MODEL, SPM_VOCAB |
|
encode_spm(spm_model, direction=direction, splits=splits) |
|
binarize(DATABIN_DIR, direction, spm_vocab=spm_vocab, splits=splits) |
|
|