File size: 2,452 Bytes
d44849f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import sys
from tqdm import tqdm
from typing import List, Tuple


def remove_large_sentences(src_path: str, tgt_path: str) -> Tuple[int, List[str], List[str]]:
    """
    Removes large sentences from a parallel dataset of source and target data.

    Args:
        src_path (str): path to the file containing the source language data.
        tgt_path (str): path to the file containing the target language data.

    Returns:
        Tuple[int, List[str], List[str]]: a tuple of
            - an integer representing the number of sentences removed
            - a list of strings containing the source language data after removing large sentences
            - a list of strings containing the target language data after removing large sentences
    """
    count = 0
    new_src_lines, new_tgt_lines = [], []

    src_num_lines = sum(1 for line in open(src_path, "r", encoding="utf-8"))
    tgt_num_lines = sum(1 for line in open(tgt_path, "r", encoding="utf-8"))
    assert src_num_lines == tgt_num_lines

    with open(src_path, encoding="utf-8") as f1, open(tgt_path, encoding="utf-8") as f2:
        for src_line, tgt_line in tqdm(zip(f1, f2), total=src_num_lines):
            src_tokens = src_line.strip().split(" ")
            tgt_tokens = tgt_line.strip().split(" ")

            if len(src_tokens) > 200 or len(tgt_tokens) > 200:
                count += 1
                continue

            new_src_lines.append(src_line)
            new_tgt_lines.append(tgt_line)

    return count, new_src_lines, new_tgt_lines


def create_txt(out_file: str, lines: List[str]):
    """
    Creates a text file and writes the given list of lines to file.

    Args:
        out_file (str): path to the output file to be created.
        lines (List[str]): a list of strings to be written to the output file.
    """
    add_newline = not "\n" in lines[0]
    outfile = open("{}".format(out_file), "w", encoding="utf-8")
    for line in lines:
        if add_newline:
            outfile.write(line + "\n")
        else:
            outfile.write(line)
    outfile.close()


if __name__ == "__main__":

    src_path = sys.argv[1]
    tgt_path = sys.argv[2]
    new_src_path = sys.argv[3]
    new_tgt_path = sys.argv[4]

    count, new_src_lines, new_tgt_lines = remove_large_sentences(src_path, tgt_path)
    print(f"{count} lines removed due to seq_len > 200")
    create_txt(new_src_path, new_src_lines)
    create_txt(new_tgt_path, new_tgt_lines)