File size: 4,310 Bytes
dcc5cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import argparse
import unicodedata
import re
from tqdm import tqdm

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

import re
import unicodedata

PUNCTS = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~«»'
UNICODE_PUNCT = {
    ",": ",",
    "。": ".",
    "、": ",",
    "„": '"',
    "”": '"',
    "“": '"',
    "«": '"',
    "»": '"',
    "1": '"',
    "」": '"',
    "「": '"',
    "《": '"',
    "》": '"',
    "´": "'",
    "∶": ":",
    ":": ":",
    "?": "?",
    "!": "!",
    "(": "(",
    ")": ")",
    ";": ";",
    "–": "-",
    "—": " - ",
    ".": ". ",
    "~": "~",
    "’": "'",
    "…": "...",
    "━": "-",
    "〈": "<",
    "〉": ">",
    "【": "[",
    "】": "]",
    "%": "%",
    "►": "-",
    "■": " ",  # added for Mimir
}

UNICODE_PUNCT_RE = re.compile(f"[{''.join(UNICODE_PUNCT.keys())}]")


def replace_unicode_punct(text: str) -> str:
    return "".join(UNICODE_PUNCT.get(c, c) for c in text)


def remove_unicode_punct(text: str) -> str:
    """More aggressive version of replace_unicode_punct but also faster."""
    return UNICODE_PUNCT_RE.sub("", text)


def strip_accents(line: str) -> str:
    """Strips accents from a piece of text."""
    nfd = unicodedata.normalize("NFD", line)
    output = [c for c in nfd if unicodedata.category(c) != "Mn"]
    if len(output) == line:
        return line
    return "".join(output)


# Build a regex matching all control characters.
NON_PRINTING_CHARS_RE = re.compile(
    f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
)
DIGIT_RE = re.compile(r"\d")
PUNCT_OR_NON_PRINTING_CHARS_RE = re.compile(
    (UNICODE_PUNCT_RE.pattern + NON_PRINTING_CHARS_RE.pattern).replace("][", "")
)


def remove_non_printing_char(text: str) -> str:
    return NON_PRINTING_CHARS_RE.sub("", text)


def normalize(line: str, accent=True, case=True, numbers=True, punct=1) -> str:
    line = line.strip()
    if not line:
        return line
    if case:
        line = line.lower()
    if accent:
        line = strip_accents(line)
    if numbers:
        line = DIGIT_RE.sub("0", line)
    if punct == 1:
        line = replace_unicode_punct(line)
    elif punct == 2:
        line = remove_unicode_punct(line)
    line = remove_non_printing_char(line)
    return line


def slow_normalize_for_dedup(line: str) -> str:
    return normalize(line, accent=False, case=True, numbers=True, punct=2)


def normalize_for_dedup(line: str) -> str:
    line = line.strip()
    if not line:
        return line
    # case
    line = line.lower()
    # numbers
    line = DIGIT_RE.sub("0", line)
    line = PUNCT_OR_NON_PRINTING_CHARS_RE.sub("", line)
    return line

## START OF MIMIR CODE
def normalize_text(line):
    normalized_line = unicodedata.normalize('NFKC', line).lower()

    # Add a trailing dot if the line does not end with a punctuation mark
    normalized_line = normalized_line.rstrip()
    if  normalized_line and normalized_line[-1] not in PUNCTS:
        normalized_line += '.'

    # Replace newline characters with spaces (if any remain)
    # normalized_line = re.sub(r'\r\n|\r|\n', ' ', normalized_line)
    normalized_line = normalize(normalized_line, accent=False, case=True, numbers=True, punct=1)
    return normalized_line


def normalize_file(input_file, output_file, cutoff=None):
    with (open(output_file, 'w', encoding='utf-8') as f,
          open(input_file, 'r', encoding='utf-8') as lines):
        for line_count, line in tqdm(enumerate(lines), desc="Processing"):
            f.write(normalize_text(line) + "\n")
            if cutoff and line_count >= cutoff:
                break


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Normalize text file line by line, ensure trailing punctuation, replace newlines with spaces, and show progress.')
    parser.add_argument('input_file', type=str, help='Input file path')
    parser.add_argument('output_file', type=str, help='Output file path')
    parser.add_argument('--cutoff', required=False, type=int, help='Max number of lines to process')

    args = parser.parse_args()

    normalize_file(args.input_file, args.output_file, args.cutoff)