File size: 4,421 Bytes
48b5e1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python

import os
import re
import string

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

from simpletransformers.ner import NERModel


class BERTmodel:
    def __init__(self, normalization="full", wrds_per_pred=256):
        self.normalization = normalization
        self.wrds_per_pred = wrds_per_pred
        self.overlap_wrds = 32
        self.valid_labels = ["O", "F", "C", "Q"]
        self.label_to_punct = {"F": "۔", "C": "،", "Q": "؟", "O": ""}
        self.model = NERModel(
            "bert",
            "/code/models/urdu",
            use_cuda=False,
            labels=self.valid_labels,
            args={"silent": True, "max_seq_length": 512},
        )
        self.patterns = {
            "partial": r"[ً-٠ٰ۟-ۤۧ-۪ۨ-ۭ،۔؟]+",
            "full": string.punctuation + "،؛؟۔٪ء‘’",
        }

    def punctuation_removal(self, text: str) -> str:
        if self.normalization == "partial":
            return re.sub(self.patterns[self.normalization], "", text).strip()
        else:
            return "".join(ch for ch in text if ch not in self.patterns[self.normalization])

    def punctuate(self, text: str):
        text = self.punctuation_removal(text)
        splits = self.split_on_tokens(text)
        full_preds_lst = [self.predict(i["text"]) for i in splits]
        preds_lst = [i[0][0] for i in full_preds_lst]
        combined_preds = self.combine_results(text, preds_lst)
        punct_text = self.punctuate_texts(combined_preds)
        return punct_text

    def predict(self, input_slice):
        return self.model.predict([input_slice])

    def split_on_tokens(self, text):
        wrds = text.replace("\n", " ").split()
        response = []
        lst_chunk_idx = 0
        i = 0

        while True:
            wrds_len = wrds[i * self.wrds_per_pred : (i + 1) * self.wrds_per_pred]
            wrds_ovlp = wrds[
                (i + 1) * self.wrds_per_pred : (i + 1) * self.wrds_per_pred + self.overlap_wrds
            ]
            wrds_split = wrds_len + wrds_ovlp

            if not wrds_split:
                break

            response_obj = {
                "text": " ".join(wrds_split),
                "start_idx": lst_chunk_idx,
                "end_idx": lst_chunk_idx + len(" ".join(wrds_len)),
            }

            response.append(response_obj)
            lst_chunk_idx += response_obj["end_idx"] + 1
            i += 1

        return response

    def combine_results(self, full_text: str, text_slices):
        split_full_text = full_text.replace("\n", " ").split(" ")
        split_full_text = [i for i in split_full_text if i]
        split_full_text_len = len(split_full_text)
        output_text = []
        index = 0

        if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
            text_slices = text_slices[:-1]

        for slice in text_slices:
            slice_wrds = len(slice)
            for ix, wrd in enumerate(slice):
                if index == split_full_text_len:
                    break

                if (
                    split_full_text[index] == str(list(wrd.keys())[0])
                    and ix <= slice_wrds - 3
                    and text_slices[-1] != slice
                ):
                    index += 1
                    pred_item_tuple = list(wrd.items())[0]
                    output_text.append(pred_item_tuple)
                elif (
                    split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == slice
                ):
                    index += 1
                    pred_item_tuple = list(wrd.items())[0]
                    output_text.append(pred_item_tuple)

        assert [i[0] for i in output_text] == split_full_text
        return output_text

    def punctuate_texts(self, full_pred: list):
        punct_resp = []
        for punct_wrd, label in full_pred:
            punct_wrd += self.label_to_punct[label]
            if punct_wrd.endswith("‘‘"):
                punct_wrd = punct_wrd[:-2] + self.label_to_punct[label] + "‘‘"
            punct_resp.append(punct_wrd)

        punct_resp = " ".join(punct_resp)
        if punct_resp[-1].isalnum():
            punct_resp += "۔"

        return punct_resp


class Urdu:
    def __init__(self):
        self.model = BERTmodel()

    def punctuate(self, data: str):
        return self.model.punctuate(data)