import os import sys import glob import regex as re import pandas as pd import requests import unicodedata import json from collections import defaultdict, Counter from typing import List, Dict, Tuple, Set from tqdm import tqdm class GujaratiBPETokenizer: def __init__(self, vocab_size: int = 5000): self.vocab_size = vocab_size self.vocab = {} self.inverse_vocab = {} self.compression_ratio = 0. self.merges = {} self.special_tokens = { '': 0, '': 1, '': 2, '': 3 } # applies on the entire corpus self.global_pattern = re.compile(r""" [\p{L}\p{M}\p{N}]+|[\p{L}\p{M}\p{N}]+|[^\r\n\p{L}\p{M}\p{N}]+""") # applies on each words to separate morphpligical transformation ending with "ન" or "મ" self.local_pattern = re.compile(r"""([\s\p{L}\p{M}]+|[\s\p{L}\p{M}\p{N}]+)([નમ](?:\p{M}))$""") self.eng2guj = self.get_eng_to_guj_digits_mapping() self.guj_unicode_df = self.get_guj_unicodes() # Initialize basic Odia character vocabulary self.base_vocab = set() # Add basic Odia characters (vowels, consonants, marks) self._initialize_base_vocab() def get_guj_unicodes(self): res = requests.get("https://www.unicode.org/Public/UNIDATA/UnicodeData.txt") lines = res.text.splitlines() lines = [",".join(line.split(";")[:2]) for line in lines if "GUJARATI" in line] data = { "code": [l.split(",")[0] for l in lines], "name": [l.split(",")[-1] for l in lines], "char": [unicodedata.lookup(l.split(",")[1]) for l in lines], } df = pd.DataFrame(data) return df def _initialize_base_vocab(self): """Initialize vocabulary with basic Odia characters""" # Vowels self.base_vocab.update(self.guj_unicode_df["char"].to_list()) # Whitespace characters with period. self.base_vocab.update([' ', '\n', '\t', "."]) def _get_stats(self, words: List[List[str]]) -> Dict[Tuple[str, str], int]: """Count frequency of adjacent pairs in the vocabulary""" pairs = defaultdict(int) for word in words: for i in range(len(word) - 1): pairs[tuple(word[i:i + 2])] += 1 return pairs def _merge_vocab(self, words: List[List[str]], pair: Tuple[str, str]) -> List[List[str]]: """Merge all occurrences of the most frequent pair""" first, second = pair new_words = [] for word in words: i = 0 new_word = [] while i < len(word): if i < len(word) - 1 and word[i] == first and word[i + 1] == second: new_word.append(first + second) i += 2 else: new_word.append(word[i]) i += 1 new_words.append(new_word) return new_words def get_eng_to_guj_digits_mapping(self): e2g = dict() # Add digits 0 to 9 for i in range(10): e2g[str(i)] = unicodedata.lookup(f"GUJARATI DIGIT {unicodedata.name(chr(48+i)).split()[-1]}") return e2g def remove_eng_words(self, text): pat = re.compile(r"[a-zA-Z]+", re.IGNORECASE) text = " ".join(re.sub(pat, "", text).split()) # text = re.sub(pat, "", text)) return text def eng_to_guj_digits(self, text, e2g): new_text = "" for ch in text: if ch.isdigit() and ch not in e2g.values(): new_text += e2g[ch] else: new_text += ch return new_text def process_text_with_regex(self, text): split_text = re.findall(self.global_pattern, text) new_text =[] for t in split_text: split_words = re.findall(self.local_pattern, t) # print(f"word: {t} --> word split: {split_words}") if split_words: for item in split_words: if isinstance(item, tuple): w = [i for i in item if i != ""] # print(f"item: {item} --> {w}") new_text.extend(w) else: new_text.append(t) return new_text def tokenize_text(self, texts: List[str]): """ Takes a list of text and provides list of processed words required for the encoding. Args: texts (List[str]): text lines Returns: list: list of extraced words from the text lines """ processed_text = [] for t in tqdm(texts, desc="preprocessing", colour="green", bar_format="{l_bar}{bar:30}{r_bar}"): processed_text.append(self.eng_to_guj_digits(self.remove_eng_words(t), self.eng2guj)) processed_text = " ".join(processed_text) words = self.process_text_with_regex(processed_text) return words def train(self, texts: List[str], min_freq: int = 2) -> None: """Train BPE model on texts""" tokens = self.tokenize_text(texts) words = tokens vocab = self.base_vocab.copy() num_merges = self.vocab_size - len(self.special_tokens) - len(vocab) # print("num_merges : ", num_merges) # Perform BPE merges train_bar = tqdm(range(num_merges), desc="Merging pairs", total=num_merges, colour="blue", file=sys.stdout, bar_format="{l_bar}{bar:30}{r_bar}" ) for i in train_bar: pairs = self._get_stats(words) if not pairs: break # Find most frequent pair best_pair = max(pairs.items(), key=lambda x: x[1]) if best_pair[1] < min_freq: break pair = best_pair[0] new_token = ''.join(pair) vocab.add(new_token) #print("merging ..", pair) # print(len(vocab)) # Record the merge operation self.merges[pair] = new_token # Merge the pair in all words words = self._merge_vocab(words, pair) # Build final vocabulary self.vocab = {**self.special_tokens} idx = len(self.special_tokens) for token in sorted(vocab): self.vocab[token] = idx idx += 1 self.inverse_vocab = {v: k for k, v in self.vocab.items()} self.compression_ratio = len(tokens) / len(words) print("tokens length:", len(tokens)) print("tokens length after merge operation:", len(words)) print(f"compression ratio: {len(tokens) / len(words):.2f}X") def encode(self, text: str) -> List[int]: """Encode text using learned BPE merges""" # odia_word_pattern = re.compile(r""" ?[\u0B00-\u0B7F]+| ?[^\s]+|\s+(?!\S)|\s+""") # extracted_words = odia_word_pattern.findall(text) # words = [list(word) for word in extracted_words] #words = [list(text)] tokenized_words = self.tokenize_text([text]) words = [list(word) for word in tokenized_words] # print("Before merges: ", words) # Apply merges in order for pair, merged in self.merges.items(): words = self._merge_vocab(words, pair) # print("After mergers: ", words) # Convert to token IDs result = [] for word in words: for token in word: if token in self.vocab.keys(): result.append(self.vocab[token]) else: result.append(self.special_tokens['']) return result def decode(self, ids: List[int]) -> str: """Decode token IDs back to text""" return ''.join(self.inverse_vocab.get(id, '') for id in ids) def calculate_compression_ratio(self, text: str) -> float: """Calculate compression ratio""" encoded = self.encode(text) return len(text) / len(encoded) def save(self, path: str) -> None: """Save tokenizer state""" # Convert tuple keys to strings for JSON serialization serializable_merges = {f"{first}|{second}": merged for (first, second), merged in self.merges.items()} data = { 'vocab': self.vocab, 'merges': serializable_merges, 'vocab_size': self.vocab_size, 'special_tokens': self.special_tokens, 'compression_ratio': self.compression_ratio } with open(path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) @classmethod def load(cls, path: str) -> 'GujaratiBPETokenizer': """Load tokenizer from file""" with open(path, 'r', encoding='utf-8') as f: data = json.load(f) tokenizer = cls(vocab_size=data['vocab_size']) tokenizer.vocab = data['vocab'] # Convert string keys back to tuples tokenizer.merges = {tuple(k.split('|')): v for k, v in data['merges'].items()} tokenizer.special_tokens = data['special_tokens'] tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()} tokenizer.compression_ratio = data['compression_ratio'] print(f"Tokenizer loaded!") return tokenizer if __name__ == "__main__": # train data_path = os.path.join("data") news_articles = glob.glob(os.path.join(data_path, "news dataset", "*.txt")) cc100_dataset = glob.glob(os.path.join(data_path, "cc100-Gujarati", "*.txt")) indic_dataset = glob.glob(os.path.join(data_path, "IndicCorp", "*.txt")) final_dataset = news_articles + cc100_dataset + indic_dataset texts = [] c = 0 for article in final_dataset: with open(os.path.join(article), "r", encoding='utf-8') as f: texts.append(f.readline().strip()) tokenizer = GujaratiBPETokenizer() tokenizer.train(texts) tokenizer.save(os.path.join("Gujarati_tokenizer.json")) # # test # tokenizer = GujaratiBPETokenizer().load("Gujarati_tokenizer.json") # text1 = "ચામરાજનગર ભારત દેશના દક્ષિણ ભાગમાં આવેલા કર્ણાટક રાજ્યના ચામરાજનગર જિલ્લામાં આવેલું એક નગર છે. ચામરાજનગરમાં ચામરાજનગર જિલ્લાનું મુખ્યાલય છે." # enc_text1 = tokenizer.encode(text1) # print(enc_text1, len(enc_text1)) # text2 = tokenizer.decode(enc_text1) # print(text2) # assert text1 == text2, "Problem with BPE!!"