Spaces:
Sleeping
Sleeping
File size: 5,840 Bytes
6e778dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import re
import collections
from typing import Dict, List, Tuple, Set
import json
from pathlib import Path
class TeluguBPE:
def __init__(self, vocab_size: int = 5000):
self.vocab_size = vocab_size
self.merges: Dict[Tuple[str, str], str] = {}
self.vocab: Set[str] = set()
def preprocess_telugu_text(self, text: str) -> str:
"""
Preprocess Telugu text with specific rules
"""
# Remove any ASCII characters except spaces and newlines
text = re.sub(r'[^\u0C00-\u0C7F\s\n]', '', text)
# Normalize spaces
text = re.sub(r'\s+', ' ', text)
# Add spaces between Telugu characters and numbers
text = re.sub(r'(\d+)', r' \1 ', text)
# Add spaces between Telugu punctuation marks
text = re.sub(r'([।॥,?!])', r' \1 ', text)
# Handle Telugu specific patterns
# Add space after purna virama (full stop)
text = re.sub(r'([।॥])', r'\1 ', text)
# Separate combined vowel marks
text = re.sub(r'([\u0C3E-\u0C4C])', r' \1', text)
return text.strip()
def get_stats(self, words: List[List[str]]) -> Dict[Tuple[str, str], int]:
"""
Count frequency of adjacent pairs in current vocabulary
"""
pairs = collections.defaultdict(int)
for word in words:
for i in range(len(word) - 1):
pairs[tuple(word[i:i + 2])] += 1
return pairs
def merge_vocab(self, words: List[List[str]], pair: Tuple[str, str]) -> List[List[str]]:
"""
Merge all occurrences of the most frequent pair
"""
first, second = pair
new_words = []
for word in words:
i = 0
new_word = []
while i < len(word):
if i < len(word) - 1 and word[i] == first and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_words.append(new_word)
return new_words
def learn_bpe(self, text: str) -> None:
"""
Learn BPE merges from text
"""
# Initial vocabulary: character level
words = [[char for char in word] for word in text.split()]
self.vocab = set(char for word in words for char in word)
num_merges = self.vocab_size - len(self.vocab)
for i in range(num_merges):
pairs = self.get_stats(words)
if not pairs:
break
best_pair = max(pairs.items(), key=lambda x: x[1])[0]
self.merges[best_pair] = best_pair[0] + best_pair[1]
self.vocab.add(self.merges[best_pair])
words = self.merge_vocab(words, best_pair)
if len(self.vocab) >= self.vocab_size:
break
def encode(self, text: str) -> List[str]:
"""
Encode text using learned BPE merges
"""
words = [[char for char in word] for word in text.split()]
for pair, merge in self.merges.items():
words = self.merge_vocab(words, pair)
return [token for word in words for token in word]
def save_model(self, path: str) -> None:
"""
Save BPE model to file
"""
model_data = {
'vocab_size': self.vocab_size,
'merges': {f'{k[0]} {k[1]}': v for k, v in self.merges.items()},
'vocab': list(self.vocab)
}
with open(path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, ensure_ascii=False, indent=2)
def load_model(self, path: str) -> None:
"""
Load BPE model from file
"""
with open(path, 'r', encoding='utf-8') as f:
model_data = json.load(f)
self.vocab_size = model_data['vocab_size']
self.merges = {tuple(k.split()): v for k, v in model_data['merges'].items()}
self.vocab = set(model_data['vocab'])
def main():
# Example usage
input_file = "telugu_text.txt"
model_file = "telugu_bpe_model.json"
# Read input text
with open(input_file, 'r', encoding='utf-8') as f:
text = f.read()
print(f'Started learning BPE')
bpe = TeluguBPE(vocab_size=5000)
# Preprocess text
processed_text = bpe.preprocess_telugu_text(text)
# Calculate original text statistics
original_chars = len(processed_text)
original_tokens = len(processed_text.split())
# Learn BPE
bpe.learn_bpe(processed_text)
# Encode the entire text to calculate compression
encoded_text = bpe.encode(processed_text)
encoded_length = len(encoded_text)
# Calculate compression ratio
compression_ratio = original_chars / encoded_length
# Save model
bpe.save_model(model_file)
# Print statistics
print(f"\nCompression Statistics:")
print(f"Original characters: {original_chars}")
print(f"Original tokens (words): {original_tokens}")
print(f"Encoded tokens: {encoded_length}")
print(f"Compression ratio: {compression_ratio:.2f}x")
print(f"Vocabulary size: {len(bpe.vocab)}")
# Example encoding
sample_text = "నమస్కారం" # "Hello" in Telugu
encoded = bpe.encode(bpe.preprocess_telugu_text(sample_text))
print(f"\nExample encoding:")
print(f"Sample text: {sample_text}")
print(f"Encoded text: {encoded}")
if __name__ == "__main__":
main() |