Spaces:
Sleeping
Sleeping
File size: 4,421 Bytes
48b5e1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
#!/usr/bin/env python
import os
import re
import string
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
from simpletransformers.ner import NERModel
class BERTmodel:
def __init__(self, normalization="full", wrds_per_pred=256):
self.normalization = normalization
self.wrds_per_pred = wrds_per_pred
self.overlap_wrds = 32
self.valid_labels = ["O", "F", "C", "Q"]
self.label_to_punct = {"F": "۔", "C": "،", "Q": "؟", "O": ""}
self.model = NERModel(
"bert",
"/code/models/urdu",
use_cuda=False,
labels=self.valid_labels,
args={"silent": True, "max_seq_length": 512},
)
self.patterns = {
"partial": r"[ً-٠ٰ۟-ۤۧ-۪ۨ-ۭ،۔؟]+",
"full": string.punctuation + "،؛؟۔٪ء‘’",
}
def punctuation_removal(self, text: str) -> str:
if self.normalization == "partial":
return re.sub(self.patterns[self.normalization], "", text).strip()
else:
return "".join(ch for ch in text if ch not in self.patterns[self.normalization])
def punctuate(self, text: str):
text = self.punctuation_removal(text)
splits = self.split_on_tokens(text)
full_preds_lst = [self.predict(i["text"]) for i in splits]
preds_lst = [i[0][0] for i in full_preds_lst]
combined_preds = self.combine_results(text, preds_lst)
punct_text = self.punctuate_texts(combined_preds)
return punct_text
def predict(self, input_slice):
return self.model.predict([input_slice])
def split_on_tokens(self, text):
wrds = text.replace("\n", " ").split()
response = []
lst_chunk_idx = 0
i = 0
while True:
wrds_len = wrds[i * self.wrds_per_pred : (i + 1) * self.wrds_per_pred]
wrds_ovlp = wrds[
(i + 1) * self.wrds_per_pred : (i + 1) * self.wrds_per_pred + self.overlap_wrds
]
wrds_split = wrds_len + wrds_ovlp
if not wrds_split:
break
response_obj = {
"text": " ".join(wrds_split),
"start_idx": lst_chunk_idx,
"end_idx": lst_chunk_idx + len(" ".join(wrds_len)),
}
response.append(response_obj)
lst_chunk_idx += response_obj["end_idx"] + 1
i += 1
return response
def combine_results(self, full_text: str, text_slices):
split_full_text = full_text.replace("\n", " ").split(" ")
split_full_text = [i for i in split_full_text if i]
split_full_text_len = len(split_full_text)
output_text = []
index = 0
if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
text_slices = text_slices[:-1]
for slice in text_slices:
slice_wrds = len(slice)
for ix, wrd in enumerate(slice):
if index == split_full_text_len:
break
if (
split_full_text[index] == str(list(wrd.keys())[0])
and ix <= slice_wrds - 3
and text_slices[-1] != slice
):
index += 1
pred_item_tuple = list(wrd.items())[0]
output_text.append(pred_item_tuple)
elif (
split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == slice
):
index += 1
pred_item_tuple = list(wrd.items())[0]
output_text.append(pred_item_tuple)
assert [i[0] for i in output_text] == split_full_text
return output_text
def punctuate_texts(self, full_pred: list):
punct_resp = []
for punct_wrd, label in full_pred:
punct_wrd += self.label_to_punct[label]
if punct_wrd.endswith("‘‘"):
punct_wrd = punct_wrd[:-2] + self.label_to_punct[label] + "‘‘"
punct_resp.append(punct_wrd)
punct_resp = " ".join(punct_resp)
if punct_resp[-1].isalnum():
punct_resp += "۔"
return punct_resp
class Urdu:
def __init__(self):
self.model = BERTmodel()
def punctuate(self, data: str):
return self.model.punctuate(data)
|