File size: 4,591 Bytes
1f30dbc
0def03f
 
1f30dbc
0def03f
1f30dbc
 
7b62017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f30dbc
 
 
0def03f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b62017
 
 
0def03f
1f30dbc
 
9ec7b19
 
7b62017
9ec7b19
 
ab7449f
 
7b62017
 
 
1f30dbc
 
 
 
 
7b62017
 
 
0def03f
 
 
7b62017
 
1f30dbc
 
 
 
 
 
7b62017
1f30dbc
0def03f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b62017
0def03f
 
 
 
 
 
 
 
1f30dbc
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import re
import unicodedata
import urllib.request
from typing import Dict

import kenlm
import sentencepiece


class SentencePiece:
    def __init__(
        self,
        model: str,
    ):
        super().__init__()
        self.sp = sentencepiece.SentencePieceProcessor()
        self.sp.load(str(model))

    def do(self, text: dict) -> dict:
        tokenized = self.sp.encode_as_pieces(text)
        return " ".join(tokenized)


class KenlmModel:
    digit_re: re.Pattern = re.compile(r"\d")
    unicode_punct: Dict[str, str] = {
        ",": ",",
        "。": ".",
        "、": ",",
        "„": '"',
        "”": '"',
        "“": '"',
        "«": '"',
        "»": '"',
        "1": '"',
        "」": '"',
        "「": '"',
        "《": '"',
        "》": '"',
        "´": "'",
        "∶": ":",
        ":": ":",
        "?": "?",
        "!": "!",
        "(": "(",
        ")": ")",
        ";": ";",
        "–": "-",
        "—": " - ",
        ".": ". ",
        "~": "~",
        "’": "'",
        "…": "...",
        "━": "-",
        "〈": "<",
        "〉": ">",
        "【": "[",
        "】": "]",
        "%": "%",
        "►": "-",
    }
    unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
    non_printing_chars_re = re.compile(
        f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
    )

    def __init__(self, language):
        download_kenlm_model(language)
        try:
            self.model = kenlm.Model(f"{language}.arpa.bin")
            self.tokenizer = SentencePiece(f"{language}.sp.model")
        except OSError:
            os.remove(f"{language}.arpa.bin")
            if os.path.exists(f"{language}.sp.model"):
                os.remove(f"{language}.sp.model")
            raise OSError(
                "File was corrupt and should have been removed. Please, retry."
            )

    @classmethod
    def from_pretrained(cls, language: str):
        return cls(language)

    def pp(self, log_score, length):
        return 10.0 ** (-log_score / length)

    def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
        if normalize_cc_net:
            doc = self.normalize(doc)
        # Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
        doc = self.tokenizer.do(doc)
        doc_log_score, doc_length = 0, 0
        for line in doc.split("\n"):
            log_score = self.model.score(line)
            length = len(line.split()) + 1
            doc_log_score += log_score
            doc_length += length
        return round(self.pp(doc_log_score, doc_length), 1)

    def normalize(
        self,
        line: str,
        accent: bool = True,
        case: bool = True,
        numbers: bool = True,
        punct: int = 1,
    ) -> str:
        line = line.strip()
        if not line:
            return line
        if case:
            line = line.lower()
        if accent:
            line = self.strip_accents(line)
        if numbers:
            line = self.digit_re.sub("0", line)
        if punct == 1:
            line = self.replace_unicode_punct(line)
        elif punct == 2:
            line = self.remove_unicode_punct(line)
        line = self.remove_non_printing_char(line)
        return line

    def strip_accents(self, line: str) -> str:
        """Strips accents from a piece of text."""
        nfd = unicodedata.normalize("NFD", line)
        output = [c for c in nfd if unicodedata.category(c) != "Mn"]
        if len(output) == line:
            return line
        return "".join(output)

    def replace_unicode_punct(self, text: str) -> str:
        return "".join(self.unicode_punct.get(c, c) for c in text)

    def remove_unicode_punct(self, text: str) -> str:
        """More aggressive version of replace_unicode_punct but also faster."""
        return self.unicode_punct_re.sub("", text)

    def remove_non_printing_char(self, text: str) -> str:
        return self.non_printing_chars_re.sub("", text)


def download_kenlm_model(language: str):
    root_url = "http://dl.fbaipublicfiles.com/cc_net/lm"
    bin_name = f"{language}.arpa.bin"
    model_name = f"{language}.sp.model"
    bin_url = f"{root_url}/{bin_name}"
    model_url = f"{root_url}/{model_name}"

    if not os.path.isfile(bin_name):
        urllib.request.urlretrieve(bin_url, bin_name)

    if not os.path.isfile(model_name):
        urllib.request.urlretrieve(model_url, model_name)