File size: 3,225 Bytes
d44849f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import regex as re
import os
import sys
from collections import defaultdict
from tqdm import tqdm

def remove_overlaps(in_data_dir: str, out_data_dir: str, benchmark_dir: str):
    """
    Removes overlapping sentences between train dataset and dev/test dataset from the 
    input directory and writes de-duplicated train data to the specified output directory.
    
    Args:
        in_data_dir (str): path of the directory containing train data for each language pair.
        out_data_dir (str): path of the directory where the de-duplicated train data will be written for each language pair.
        benchmark_dir (str): path of the directory containing the language-wise monolingual side of dev/test set.
    """
    # load dev/test dataset for each language
    devtest_normalized = defaultdict(set)
    for lang in os.listdir(benchmark_dir):
        fname = os.path.join(benchmark_dir, lang)
        
        with open(fname, "r") as f:
            sents = [sent for sent in f.read().split("\n") if sent.strip()]
            sents = [re.sub(" +", " ", sent).replace("\n", "").strip() for sent in sents]
            sents = [re.sub(" +", " ", re.sub(r"[^\w\s]", "", x)).lower() for x in sents]
        devtest_normalized[lang] = set(sents)
    
    # process each language pair train dataset to remove overlapping sentences
    pairs = sorted(os.listdir(in_data_dir))
    for pair in pairs:
        print(pair)
        src_lang, tgt_lang = pair.split("-")
        
        src_infname = os.path.join(in_data_dir, pair, f"train.{src_lang}")
        tgt_infname = os.path.join(in_data_dir, pair, f"train.{tgt_lang}")
        
        src_outfname = os.path.join(out_data_dir, pair, f"train.{src_lang}")
        tgt_outfname = os.path.join(out_data_dir, pair, f"train.{tgt_lang}")
        
        os.makedirs(os.path.join(out_data_dir, pair), exist_ok=True)
        
        # remove overlapping sentences and write de-duplicated train data to output directory
        with open(src_infname, 'r', encoding='utf-8') as src_infile, \
            open(tgt_infname, 'r', encoding='utf-8') as tgt_infile, \
            open(src_outfname, 'w', encoding='utf-8') as src_outfile, \
            open(tgt_outfname, 'w', encoding='utf-8') as tgt_outfile:
            
            for src_line, tgt_line in tqdm(zip(src_infile, tgt_infile)):
                src_line = re.sub(" +", " ", src_line).replace("\n", "").strip()
                tgt_line = re.sub(" +", " ", tgt_line).replace("\n", "").strip()
                
                src_line_normalized = re.sub(" +", " ", re.sub(r"[^\w\s]", "", src_line)).lower()
                tgt_line_normalized = re.sub(" +", " ", re.sub(r"[^\w\s]", "", tgt_line)).lower()
                if src_line_normalized in devtest_normalized[src_lang] or tgt_line_normalized in devtest_normalized[tgt_lang]:
                    continue
                
                src_outfile.write(src_line + "\n")
                tgt_outfile.write(tgt_line + "\n")


if __name__ == "__main__":
    in_data_dir = sys.argv[1]
    out_data_dir = sys.argv[2]
    benchmark_dir = sys.argv[3]
    
    os.makedirs(out_data_dir, exist_ok=True)
    
    remove_overlaps(in_data_dir, out_data_dir, benchmark_dir)