BytePairEncode_V0 / telugu_bpe.py
anjikum's picture
Upload 4 files
6e778dd verified
import re
import collections
from typing import Dict, List, Tuple, Set
import json
from pathlib import Path
class TeluguBPE:
def __init__(self, vocab_size: int = 5000):
self.vocab_size = vocab_size
self.merges: Dict[Tuple[str, str], str] = {}
self.vocab: Set[str] = set()
def preprocess_telugu_text(self, text: str) -> str:
"""
Preprocess Telugu text with specific rules
"""
# Remove any ASCII characters except spaces and newlines
text = re.sub(r'[^\u0C00-\u0C7F\s\n]', '', text)
# Normalize spaces
text = re.sub(r'\s+', ' ', text)
# Add spaces between Telugu characters and numbers
text = re.sub(r'(\d+)', r' \1 ', text)
# Add spaces between Telugu punctuation marks
text = re.sub(r'([।॥,?!])', r' \1 ', text)
# Handle Telugu specific patterns
# Add space after purna virama (full stop)
text = re.sub(r'([।॥])', r'\1 ', text)
# Separate combined vowel marks
text = re.sub(r'([\u0C3E-\u0C4C])', r' \1', text)
return text.strip()
def get_stats(self, words: List[List[str]]) -> Dict[Tuple[str, str], int]:
"""
Count frequency of adjacent pairs in current vocabulary
"""
pairs = collections.defaultdict(int)
for word in words:
for i in range(len(word) - 1):
pairs[tuple(word[i:i + 2])] += 1
return pairs
def merge_vocab(self, words: List[List[str]], pair: Tuple[str, str]) -> List[List[str]]:
"""
Merge all occurrences of the most frequent pair
"""
first, second = pair
new_words = []
for word in words:
i = 0
new_word = []
while i < len(word):
if i < len(word) - 1 and word[i] == first and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_words.append(new_word)
return new_words
def learn_bpe(self, text: str) -> None:
"""
Learn BPE merges from text
"""
# Initial vocabulary: character level
words = [[char for char in word] for word in text.split()]
self.vocab = set(char for word in words for char in word)
num_merges = self.vocab_size - len(self.vocab)
for i in range(num_merges):
pairs = self.get_stats(words)
if not pairs:
break
best_pair = max(pairs.items(), key=lambda x: x[1])[0]
self.merges[best_pair] = best_pair[0] + best_pair[1]
self.vocab.add(self.merges[best_pair])
words = self.merge_vocab(words, best_pair)
if len(self.vocab) >= self.vocab_size:
break
def encode(self, text: str) -> List[str]:
"""
Encode text using learned BPE merges
"""
words = [[char for char in word] for word in text.split()]
for pair, merge in self.merges.items():
words = self.merge_vocab(words, pair)
return [token for word in words for token in word]
def save_model(self, path: str) -> None:
"""
Save BPE model to file
"""
model_data = {
'vocab_size': self.vocab_size,
'merges': {f'{k[0]} {k[1]}': v for k, v in self.merges.items()},
'vocab': list(self.vocab)
}
with open(path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, ensure_ascii=False, indent=2)
def load_model(self, path: str) -> None:
"""
Load BPE model from file
"""
with open(path, 'r', encoding='utf-8') as f:
model_data = json.load(f)
self.vocab_size = model_data['vocab_size']
self.merges = {tuple(k.split()): v for k, v in model_data['merges'].items()}
self.vocab = set(model_data['vocab'])
def main():
# Example usage
input_file = "telugu_text.txt"
model_file = "telugu_bpe_model.json"
# Read input text
with open(input_file, 'r', encoding='utf-8') as f:
text = f.read()
print(f'Started learning BPE')
bpe = TeluguBPE(vocab_size=5000)
# Preprocess text
processed_text = bpe.preprocess_telugu_text(text)
# Calculate original text statistics
original_chars = len(processed_text)
original_tokens = len(processed_text.split())
# Learn BPE
bpe.learn_bpe(processed_text)
# Encode the entire text to calculate compression
encoded_text = bpe.encode(processed_text)
encoded_length = len(encoded_text)
# Calculate compression ratio
compression_ratio = original_chars / encoded_length
# Save model
bpe.save_model(model_file)
# Print statistics
print(f"\nCompression Statistics:")
print(f"Original characters: {original_chars}")
print(f"Original tokens (words): {original_tokens}")
print(f"Encoded tokens: {encoded_length}")
print(f"Compression ratio: {compression_ratio:.2f}x")
print(f"Vocabulary size: {len(bpe.vocab)}")
# Example encoding
sample_text = "నమస్కారం" # "Hello" in Telugu
encoded = bpe.encode(bpe.preprocess_telugu_text(sample_text))
print(f"\nExample encoding:")
print(f"Sample text: {sample_text}")
print(f"Encoded text: {encoded}")
if __name__ == "__main__":
main()