File size: 3,839 Bytes
1f30dbc
0def03f
 
1f30dbc
0def03f
1f30dbc
 
 
 
 
0def03f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f30dbc
 
9ec7b19
 
 
 
ab7449f
 
9ec7b19
1f30dbc
 
 
 
 
0def03f
 
 
1f30dbc
 
 
 
 
 
 
 
0def03f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f30dbc
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import re
import unicodedata
import urllib.request
from typing import Dict

import kenlm


class KenlmModel:
    digit_re: re.Pattern = re.compile(r"\d")
    unicode_punct: Dict[str, str] = {
        ",": ",",
        "。": ".",
        "、": ",",
        "„": '"',
        "”": '"',
        "“": '"',
        "«": '"',
        "»": '"',
        "1": '"',
        "」": '"',
        "「": '"',
        "《": '"',
        "》": '"',
        "´": "'",
        "∶": ":",
        ":": ":",
        "?": "?",
        "!": "!",
        "(": "(",
        ")": ")",
        ";": ";",
        "–": "-",
        "—": " - ",
        ".": ". ",
        "~": "~",
        "’": "'",
        "…": "...",
        "━": "-",
        "〈": "<",
        "〉": ">",
        "【": "[",
        "】": "]",
        "%": "%",
        "►": "-",
    }
    unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
    non_printing_chars_re = re.compile(f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]")

    def __init__(self, language):
        download_kenlm_model(language)
        try:
            self.model = kenlm.Model(f"{language}.arpa.bin")
        except OSError:
            os.remove(f"{language}.arpa.bin")
            if os.path.exists(f"{language}.sp.model"):
                os.remove(f"{language}.sp.model")
            raise OSError("File was corrupt and should have been removed. Please, retry.")

    @classmethod
    def from_pretrained(cls, language: str):
        return cls(language)

    def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
        if normalize_cc_net:
            doc = self.normalize(doc)
        doc_log_score, doc_length = 0, 0
        for line in doc.split("\n"):
            log_score = self.model.score(line)
            length = len(line.split()) + 1
            doc_log_score += log_score
            doc_length += length
        return 10.0 ** (-doc_log_score / doc_length)

    def normalize(
        self,
        line: str,
        accent: bool = True,
        case: bool = True,
        numbers: bool = True,
        punct: int = 1,
    ) -> str:
        line = line.strip()
        if not line:
            return line
        if case:
            line = line.lower()
        if accent:
            line = self.strip_accents(line)
        if numbers:
            line = self.digit_re.sub("0", line)
        if punct == 1:
            line = self.replace_unicode_punct(line)
        elif punct == 2:
            line = self.remove_unicode_punct(line)
        line = self.remove_non_printing_char(line)
        return line

    def strip_accents(self, line: str) -> str:
        """Strips accents from a piece of text."""
        nfd = unicodedata.normalize("NFD", line)
        output = [c for c in nfd if unicodedata.category(c) != "Mn"]
        if len(output) == line:
            return line
        return "".join(output)

    def replace_unicode_punct(self, text: str) -> str:
        return "".join((self.unicode_punct.get(c, c) for c in text))

    def remove_unicode_punct(self, text: str) -> str:
        """More aggressive version of replace_unicode_punct but also faster."""
        return self.unicode_punct_re.sub("", text)

    def remove_non_printing_char(self, text: str) -> str:
        return self.non_printing_chars_re.sub("", text)


def download_kenlm_model(language: str):
    root_url = "http://dl.fbaipublicfiles.com/cc_net/lm"
    bin_name = f"{language}.arpa.bin"
    model_name = f"{language}.sp.model"
    bin_url = f"{root_url}/{bin_name}"
    model_url = f"{root_url}/{model_name}"

    if not os.path.isfile(bin_name):
        urllib.request.urlretrieve(bin_url, bin_name)

    if not os.path.isfile(model_name):
        urllib.request.urlretrieve(model_url, model_name)