import regex as re import os import sys from collections import defaultdict from tqdm import tqdm def remove_overlaps(in_data_dir: str, out_data_dir: str, benchmark_dir: str): """ Removes overlapping sentences between train dataset and dev/test dataset from the input directory and writes de-duplicated train data to the specified output directory. Args: in_data_dir (str): path of the directory containing train data for each language pair. out_data_dir (str): path of the directory where the de-duplicated train data will be written for each language pair. benchmark_dir (str): path of the directory containing the language-wise monolingual side of dev/test set. """ # load dev/test dataset for each language devtest_normalized = defaultdict(set) for lang in os.listdir(benchmark_dir): fname = os.path.join(benchmark_dir, lang) with open(fname, "r") as f: sents = [sent for sent in f.read().split("\n") if sent.strip()] sents = [re.sub(" +", " ", sent).replace("\n", "").strip() for sent in sents] sents = [re.sub(" +", " ", re.sub(r"[^\w\s]", "", x)).lower() for x in sents] devtest_normalized[lang] = set(sents) # process each language pair train dataset to remove overlapping sentences pairs = sorted(os.listdir(in_data_dir)) for pair in pairs: print(pair) src_lang, tgt_lang = pair.split("-") src_infname = os.path.join(in_data_dir, pair, f"train.{src_lang}") tgt_infname = os.path.join(in_data_dir, pair, f"train.{tgt_lang}") src_outfname = os.path.join(out_data_dir, pair, f"train.{src_lang}") tgt_outfname = os.path.join(out_data_dir, pair, f"train.{tgt_lang}") os.makedirs(os.path.join(out_data_dir, pair), exist_ok=True) # remove overlapping sentences and write de-duplicated train data to output directory with open(src_infname, 'r', encoding='utf-8') as src_infile, \ open(tgt_infname, 'r', encoding='utf-8') as tgt_infile, \ open(src_outfname, 'w', encoding='utf-8') as src_outfile, \ open(tgt_outfname, 'w', encoding='utf-8') as tgt_outfile: for src_line, tgt_line in tqdm(zip(src_infile, tgt_infile)): src_line = re.sub(" +", " ", src_line).replace("\n", "").strip() tgt_line = re.sub(" +", " ", tgt_line).replace("\n", "").strip() src_line_normalized = re.sub(" +", " ", re.sub(r"[^\w\s]", "", src_line)).lower() tgt_line_normalized = re.sub(" +", " ", re.sub(r"[^\w\s]", "", tgt_line)).lower() if src_line_normalized in devtest_normalized[src_lang] or tgt_line_normalized in devtest_normalized[tgt_lang]: continue src_outfile.write(src_line + "\n") tgt_outfile.write(tgt_line + "\n") if __name__ == "__main__": in_data_dir = sys.argv[1] out_data_dir = sys.argv[2] benchmark_dir = sys.argv[3] os.makedirs(out_data_dir, exist_ok=True) remove_overlaps(in_data_dir, out_data_dir, benchmark_dir)