Spaces:
Sleeping
Sleeping
import re | |
import collections | |
from typing import Dict, List, Tuple, Set | |
import json | |
from pathlib import Path | |
class TeluguBPE: | |
def __init__(self, vocab_size: int = 5000): | |
self.vocab_size = vocab_size | |
self.merges: Dict[Tuple[str, str], str] = {} | |
self.vocab: Set[str] = set() | |
def preprocess_telugu_text(self, text: str) -> str: | |
""" | |
Preprocess Telugu text with specific rules | |
""" | |
# Remove any ASCII characters except spaces and newlines | |
text = re.sub(r'[^\u0C00-\u0C7F\s\n]', '', text) | |
# Normalize spaces | |
text = re.sub(r'\s+', ' ', text) | |
# Add spaces between Telugu characters and numbers | |
text = re.sub(r'(\d+)', r' \1 ', text) | |
# Add spaces between Telugu punctuation marks | |
text = re.sub(r'([।॥,?!])', r' \1 ', text) | |
# Handle Telugu specific patterns | |
# Add space after purna virama (full stop) | |
text = re.sub(r'([।॥])', r'\1 ', text) | |
# Separate combined vowel marks | |
text = re.sub(r'([\u0C3E-\u0C4C])', r' \1', text) | |
return text.strip() | |
def get_stats(self, words: List[List[str]]) -> Dict[Tuple[str, str], int]: | |
""" | |
Count frequency of adjacent pairs in current vocabulary | |
""" | |
pairs = collections.defaultdict(int) | |
for word in words: | |
for i in range(len(word) - 1): | |
pairs[tuple(word[i:i + 2])] += 1 | |
return pairs | |
def merge_vocab(self, words: List[List[str]], pair: Tuple[str, str]) -> List[List[str]]: | |
""" | |
Merge all occurrences of the most frequent pair | |
""" | |
first, second = pair | |
new_words = [] | |
for word in words: | |
i = 0 | |
new_word = [] | |
while i < len(word): | |
if i < len(word) - 1 and word[i] == first and word[i + 1] == second: | |
new_word.append(first + second) | |
i += 2 | |
else: | |
new_word.append(word[i]) | |
i += 1 | |
new_words.append(new_word) | |
return new_words | |
def learn_bpe(self, text: str) -> None: | |
""" | |
Learn BPE merges from text | |
""" | |
# Initial vocabulary: character level | |
words = [[char for char in word] for word in text.split()] | |
self.vocab = set(char for word in words for char in word) | |
num_merges = self.vocab_size - len(self.vocab) | |
for i in range(num_merges): | |
pairs = self.get_stats(words) | |
if not pairs: | |
break | |
best_pair = max(pairs.items(), key=lambda x: x[1])[0] | |
self.merges[best_pair] = best_pair[0] + best_pair[1] | |
self.vocab.add(self.merges[best_pair]) | |
words = self.merge_vocab(words, best_pair) | |
if len(self.vocab) >= self.vocab_size: | |
break | |
def encode(self, text: str) -> List[str]: | |
""" | |
Encode text using learned BPE merges | |
""" | |
words = [[char for char in word] for word in text.split()] | |
for pair, merge in self.merges.items(): | |
words = self.merge_vocab(words, pair) | |
return [token for word in words for token in word] | |
def save_model(self, path: str) -> None: | |
""" | |
Save BPE model to file | |
""" | |
model_data = { | |
'vocab_size': self.vocab_size, | |
'merges': {f'{k[0]} {k[1]}': v for k, v in self.merges.items()}, | |
'vocab': list(self.vocab) | |
} | |
with open(path, 'w', encoding='utf-8') as f: | |
json.dump(model_data, f, ensure_ascii=False, indent=2) | |
def load_model(self, path: str) -> None: | |
""" | |
Load BPE model from file | |
""" | |
with open(path, 'r', encoding='utf-8') as f: | |
model_data = json.load(f) | |
self.vocab_size = model_data['vocab_size'] | |
self.merges = {tuple(k.split()): v for k, v in model_data['merges'].items()} | |
self.vocab = set(model_data['vocab']) | |
def main(): | |
# Example usage | |
input_file = "telugu_text.txt" | |
model_file = "telugu_bpe_model.json" | |
# Read input text | |
with open(input_file, 'r', encoding='utf-8') as f: | |
text = f.read() | |
print(f'Started learning BPE') | |
bpe = TeluguBPE(vocab_size=5000) | |
# Preprocess text | |
processed_text = bpe.preprocess_telugu_text(text) | |
# Calculate original text statistics | |
original_chars = len(processed_text) | |
original_tokens = len(processed_text.split()) | |
# Learn BPE | |
bpe.learn_bpe(processed_text) | |
# Encode the entire text to calculate compression | |
encoded_text = bpe.encode(processed_text) | |
encoded_length = len(encoded_text) | |
# Calculate compression ratio | |
compression_ratio = original_chars / encoded_length | |
# Save model | |
bpe.save_model(model_file) | |
# Print statistics | |
print(f"\nCompression Statistics:") | |
print(f"Original characters: {original_chars}") | |
print(f"Original tokens (words): {original_tokens}") | |
print(f"Encoded tokens: {encoded_length}") | |
print(f"Compression ratio: {compression_ratio:.2f}x") | |
print(f"Vocabulary size: {len(bpe.vocab)}") | |
# Example encoding | |
sample_text = "నమస్కారం" # "Hello" in Telugu | |
encoded = bpe.encode(bpe.preprocess_telugu_text(sample_text)) | |
print(f"\nExample encoding:") | |
print(f"Sample text: {sample_text}") | |
print(f"Encoded text: {encoded}") | |
if __name__ == "__main__": | |
main() |