File size: 11,029 Bytes
a911970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import os
import sys
import glob
import regex as re
import pandas as pd
import requests
import unicodedata
import json
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Set
from tqdm import tqdm


class GujaratiBPETokenizer:
    def __init__(self, vocab_size: int = 5000):
        self.vocab_size = vocab_size
        self.vocab = {}
        self.inverse_vocab = {}
        self.compression_ratio = 0.
        self.merges = {}
        self.special_tokens = {
            '<PAD>': 0,
            '<UNK>': 1,
            '<BOS>': 2,
            '<EOS>': 3
        }
        # applies on the entire corpus
        self.global_pattern = re.compile(r""" [\p{L}\p{M}\p{N}]+|[\p{L}\p{M}\p{N}]+|[^\r\n\p{L}\p{M}\p{N}]+""")
        # applies on each words to separate morphpligical transformation ending with "ન" or "મ"
        self.local_pattern = re.compile(r"""([\s\p{L}\p{M}]+|[\s\p{L}\p{M}\p{N}]+)([નમ](?:\p{M}))$""")
        self.eng2guj = self.get_eng_to_guj_digits_mapping()
        self.guj_unicode_df = self.get_guj_unicodes()
        # Initialize basic Odia character vocabulary
        self.base_vocab = set()
        # Add basic Odia characters (vowels, consonants, marks)
        self._initialize_base_vocab()


    def get_guj_unicodes(self):
        res = requests.get("https://www.unicode.org/Public/UNIDATA/UnicodeData.txt")
        lines = res.text.splitlines()
        lines = [",".join(line.split(";")[:2]) for line in lines if "GUJARATI" in line]
        data = {
            "code": [l.split(",")[0] for l in lines],
            "name": [l.split(",")[-1] for l in lines],
            "char": [unicodedata.lookup(l.split(",")[1]) for l in lines],
        }
        df = pd.DataFrame(data)
        return df
    

    def _initialize_base_vocab(self):
        """Initialize vocabulary with basic Odia characters"""
        # Vowels
        self.base_vocab.update(self.guj_unicode_df["char"].to_list())
        # Whitespace characters with period.
        self.base_vocab.update([' ', '\n', '\t', "."])


    def _get_stats(self, words: List[List[str]]) -> Dict[Tuple[str, str], int]:
        """Count frequency of adjacent pairs in the vocabulary"""
        pairs = defaultdict(int)
        for word in words:
            for i in range(len(word) - 1):
                pairs[tuple(word[i:i + 2])] += 1
        return pairs


    def _merge_vocab(self, words: List[List[str]], pair: Tuple[str, str]) -> List[List[str]]:
        """Merge all occurrences of the most frequent pair"""
        first, second = pair
        new_words = []
        
        for word in words:
            i = 0
            new_word = []
            while i < len(word):
                if i < len(word) - 1 and word[i] == first and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_words.append(new_word)
        
        return new_words

    def get_eng_to_guj_digits_mapping(self):
        e2g = dict()
        # Add digits 0 to 9
        for i in range(10):
            e2g[str(i)] = unicodedata.lookup(f"GUJARATI DIGIT {unicodedata.name(chr(48+i)).split()[-1]}")
        
        return e2g


    def remove_eng_words(self, text):
        pat = re.compile(r"[a-zA-Z]+", re.IGNORECASE)
        text = " ".join(re.sub(pat, "", text).split())
        # text = re.sub(pat, "", text))
        return text


    def eng_to_guj_digits(self, text, e2g):
        new_text = ""
        for ch in text:
            if ch.isdigit() and ch not in e2g.values():
                new_text += e2g[ch]
            else:
                new_text += ch

        return new_text
    
    
    def process_text_with_regex(self, text):
        split_text = re.findall(self.global_pattern, text)
        new_text =[]
        for t in split_text:
            split_words = re.findall(self.local_pattern, t)
            # print(f"word: {t} --> word split: {split_words}")
            if split_words:
                for item in split_words:
                    if isinstance(item, tuple):
                        w = [i for i in item if i != ""]
                        # print(f"item: {item} --> {w}")
                        new_text.extend(w)
            else:
                new_text.append(t)

        return new_text
    
    def tokenize_text(self, texts: List[str]):
        """
        Takes a list of text and provides list of processed words required for the encoding.

        Args:
            texts (List[str]): text lines

        Returns:
            list: list of extraced words from the text lines
        """
        processed_text = []
        for t in tqdm(texts, desc="preprocessing", colour="green", bar_format="{l_bar}{bar:30}{r_bar}"):
            processed_text.append(self.eng_to_guj_digits(self.remove_eng_words(t), self.eng2guj))

        processed_text = " ".join(processed_text)
        words = self.process_text_with_regex(processed_text)

        return words
    

    def train(self, texts: List[str], min_freq: int = 2) -> None:
        """Train BPE model on texts"""
        
        tokens = self.tokenize_text(texts)
        words = tokens
                    
        vocab = self.base_vocab.copy()
        num_merges = self.vocab_size - len(self.special_tokens) - len(vocab)
        # print("num_merges : ", num_merges)
        # Perform BPE merges
        train_bar = tqdm(range(num_merges),
                         desc="Merging pairs",
                         total=num_merges, 
                         colour="blue", 
                         file=sys.stdout, 
                         bar_format="{l_bar}{bar:30}{r_bar}"
        )
        for i in train_bar:
            pairs = self._get_stats(words)
            if not pairs:
                break

            # Find most frequent pair
            best_pair = max(pairs.items(), key=lambda x: x[1])
            if best_pair[1] < min_freq:
                break

            pair = best_pair[0]
            new_token = ''.join(pair)
            vocab.add(new_token)
            #print("merging ..", pair)
            # print(len(vocab))
            # Record the merge operation
            self.merges[pair] = new_token
            
            # Merge the pair in all words
            words = self._merge_vocab(words, pair)

        # Build final vocabulary
        self.vocab = {**self.special_tokens}
        idx = len(self.special_tokens)
        for token in sorted(vocab):
            self.vocab[token] = idx
            idx += 1

        self.inverse_vocab = {v: k for k, v in self.vocab.items()}
        self.compression_ratio = len(tokens) / len(words)
        print("tokens length:", len(tokens))
        print("tokens length after merge operation:", len(words))
        print(f"compression ratio: {len(tokens) / len(words):.2f}X")


    def encode(self, text: str) -> List[int]:
        """Encode text using learned BPE merges"""

        # odia_word_pattern = re.compile(r""" ?[\u0B00-\u0B7F]+| ?[^\s]+|\s+(?!\S)|\s+""")
        # extracted_words = odia_word_pattern.findall(text)

        # words = [list(word) for word in extracted_words]
        #words = [list(text)]

        tokenized_words = self.tokenize_text([text])
        words = [list(word) for word in tokenized_words]
        # print("Before merges: ", words)
        
        # Apply merges in order
        for pair, merged in self.merges.items():
            words = self._merge_vocab(words, pair)
        # print("After mergers: ", words)

        # Convert to token IDs
        result = []
        for word in words:
            for token in word:
                if token in self.vocab.keys():
                    result.append(self.vocab[token])
                else:
                    result.append(self.special_tokens['<UNK>'])
        
        return result


    def decode(self, ids: List[int]) -> str:
        """Decode token IDs back to text"""
        return ''.join(self.inverse_vocab.get(id, '<UNK>') for id in ids)


    def calculate_compression_ratio(self, text: str) -> float:
        """Calculate compression ratio"""
        encoded = self.encode(text)
        return len(text) / len(encoded)


    def save(self, path: str) -> None:
        """Save tokenizer state"""
        # Convert tuple keys to strings for JSON serialization
        serializable_merges = {f"{first}|{second}": merged 
                              for (first, second), merged in self.merges.items()}
        
        data = {
            'vocab': self.vocab,
            'merges': serializable_merges,
            'vocab_size': self.vocab_size,
            'special_tokens': self.special_tokens,
            'compression_ratio': self.compression_ratio
        }
        with open(path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)


    @classmethod
    def load(cls, path: str) -> 'GujaratiBPETokenizer':
        """Load tokenizer from file"""
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        tokenizer = cls(vocab_size=data['vocab_size'])
        tokenizer.vocab = data['vocab']
        
        # Convert string keys back to tuples
        tokenizer.merges = {tuple(k.split('|')): v 
                           for k, v in data['merges'].items()}
        
        tokenizer.special_tokens = data['special_tokens']
        tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
        tokenizer.compression_ratio = data['compression_ratio']
        print(f"Tokenizer loaded!")
        return tokenizer


if __name__ == "__main__":
    # train
    data_path = os.path.join("data")
    news_articles = glob.glob(os.path.join(data_path, "news dataset", "*.txt"))
    cc100_dataset = glob.glob(os.path.join(data_path, "cc100-Gujarati", "*.txt"))
    indic_dataset = glob.glob(os.path.join(data_path, "IndicCorp", "*.txt"))
    final_dataset = news_articles + cc100_dataset + indic_dataset

    texts = []
    c = 0
    for article in final_dataset:
        with open(os.path.join(article), "r", encoding='utf-8') as f:
            texts.append(f.readline().strip())

    tokenizer = GujaratiBPETokenizer()
    tokenizer.train(texts)
    tokenizer.save(os.path.join("Gujarati_tokenizer.json"))

    # # test
    # tokenizer = GujaratiBPETokenizer().load("Gujarati_tokenizer.json")
    # text1 = "ચામરાજનગર ભારત દેશના દક્ષિણ ભાગમાં આવેલા કર્ણાટક રાજ્યના ચામરાજનગર જિલ્લામાં આવેલું એક નગર છે. ચામરાજનગરમાં ચામરાજનગર જિલ્લાનું મુખ્યાલય છે."
    # enc_text1 = tokenizer.encode(text1)
    # print(enc_text1, len(enc_text1))
    # text2 = tokenizer.decode(enc_text1)
    # print(text2)

    # assert text1 == text2, "Problem with BPE!!"