Spaces:
Runtime error
Runtime error
import os | |
import re | |
import unicodedata | |
import urllib.request | |
from typing import Dict | |
import kenlm | |
class KenlmModel: | |
digit_re: re.Pattern = re.compile(r"\d") | |
unicode_punct: Dict[str, str] = { | |
",": ",", | |
"。": ".", | |
"、": ",", | |
"„": '"', | |
"”": '"', | |
"“": '"', | |
"«": '"', | |
"»": '"', | |
"1": '"', | |
"」": '"', | |
"「": '"', | |
"《": '"', | |
"》": '"', | |
"´": "'", | |
"∶": ":", | |
":": ":", | |
"?": "?", | |
"!": "!", | |
"(": "(", | |
")": ")", | |
";": ";", | |
"–": "-", | |
"—": " - ", | |
".": ". ", | |
"~": "~", | |
"’": "'", | |
"…": "...", | |
"━": "-", | |
"〈": "<", | |
"〉": ">", | |
"【": "[", | |
"】": "]", | |
"%": "%", | |
"►": "-", | |
} | |
unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]") | |
non_printing_chars_re = re.compile(f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]") | |
def __init__(self, language): | |
download_kenlm_model(language) | |
try: | |
self.model = kenlm.Model(f"{language}.arpa.bin") | |
except OSError: | |
os.remove(f"{language}.arpa.bin") | |
if os.path.exists(f"{language}.sp.model"): | |
os.remove(f"{language}.sp.model") | |
raise OSError("File was corrupt and should have been removed. Please, retry.") | |
def from_pretrained(cls, language: str): | |
return cls(language) | |
def get_perplexity(self, doc: str, normalize_cc_net: bool = True): | |
if normalize_cc_net: | |
doc = self.normalize(doc) | |
doc_log_score, doc_length = 0, 0 | |
for line in doc.split("\n"): | |
log_score = self.model.score(line) | |
length = len(line.split()) + 1 | |
doc_log_score += log_score | |
doc_length += length | |
return 10.0 ** (-doc_log_score / doc_length) | |
def normalize( | |
self, | |
line: str, | |
accent: bool = True, | |
case: bool = True, | |
numbers: bool = True, | |
punct: int = 1, | |
) -> str: | |
line = line.strip() | |
if not line: | |
return line | |
if case: | |
line = line.lower() | |
if accent: | |
line = self.strip_accents(line) | |
if numbers: | |
line = self.digit_re.sub("0", line) | |
if punct == 1: | |
line = self.replace_unicode_punct(line) | |
elif punct == 2: | |
line = self.remove_unicode_punct(line) | |
line = self.remove_non_printing_char(line) | |
return line | |
def strip_accents(self, line: str) -> str: | |
"""Strips accents from a piece of text.""" | |
nfd = unicodedata.normalize("NFD", line) | |
output = [c for c in nfd if unicodedata.category(c) != "Mn"] | |
if len(output) == line: | |
return line | |
return "".join(output) | |
def replace_unicode_punct(self, text: str) -> str: | |
return "".join((self.unicode_punct.get(c, c) for c in text)) | |
def remove_unicode_punct(self, text: str) -> str: | |
"""More aggressive version of replace_unicode_punct but also faster.""" | |
return self.unicode_punct_re.sub("", text) | |
def remove_non_printing_char(self, text: str) -> str: | |
return self.non_printing_chars_re.sub("", text) | |
def download_kenlm_model(language: str): | |
root_url = "http://dl.fbaipublicfiles.com/cc_net/lm" | |
bin_name = f"{language}.arpa.bin" | |
model_name = f"{language}.sp.model" | |
bin_url = f"{root_url}/{bin_name}" | |
model_url = f"{root_url}/{model_name}" | |
if not os.path.isfile(bin_name): | |
urllib.request.urlretrieve(bin_url, bin_name) | |
if not os.path.isfile(model_name): | |
urllib.request.urlretrieve(model_url, model_name) | |