File size: 5,840 Bytes
6e778dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import re
import collections
from typing import Dict, List, Tuple, Set
import json
from pathlib import Path

class TeluguBPE:
    def __init__(self, vocab_size: int = 5000):
        self.vocab_size = vocab_size
        self.merges: Dict[Tuple[str, str], str] = {}
        self.vocab: Set[str] = set()
        
    def preprocess_telugu_text(self, text: str) -> str:
        """

        Preprocess Telugu text with specific rules

        """
        # Remove any ASCII characters except spaces and newlines
        text = re.sub(r'[^\u0C00-\u0C7F\s\n]', '', text)
        
        # Normalize spaces
        text = re.sub(r'\s+', ' ', text)
        
        # Add spaces between Telugu characters and numbers
        text = re.sub(r'(\d+)', r' \1 ', text)
        
        # Add spaces between Telugu punctuation marks
        text = re.sub(r'([।॥,?!])', r' \1 ', text)
        
        # Handle Telugu specific patterns
        # Add space after purna virama (full stop)
        text = re.sub(r'([।॥])', r'\1 ', text)
        
        # Separate combined vowel marks
        text = re.sub(r'([\u0C3E-\u0C4C])', r' \1', text)
        
        return text.strip()

    def get_stats(self, words: List[List[str]]) -> Dict[Tuple[str, str], int]:
        """

        Count frequency of adjacent pairs in current vocabulary

        """
        pairs = collections.defaultdict(int)
        for word in words:
            for i in range(len(word) - 1):
                pairs[tuple(word[i:i + 2])] += 1
        return pairs

    def merge_vocab(self, words: List[List[str]], pair: Tuple[str, str]) -> List[List[str]]:
        """

        Merge all occurrences of the most frequent pair

        """
        first, second = pair
        new_words = []
        
        for word in words:
            i = 0
            new_word = []
            while i < len(word):
                if i < len(word) - 1 and word[i] == first and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_words.append(new_word)
            
        return new_words

    def learn_bpe(self, text: str) -> None:
        """

        Learn BPE merges from text

        """
        # Initial vocabulary: character level
        words = [[char for char in word] for word in text.split()]
        self.vocab = set(char for word in words for char in word)
        
        num_merges = self.vocab_size - len(self.vocab)
        
        for i in range(num_merges):
            pairs = self.get_stats(words)
            if not pairs:
                break
                
            best_pair = max(pairs.items(), key=lambda x: x[1])[0]
            self.merges[best_pair] = best_pair[0] + best_pair[1]
            self.vocab.add(self.merges[best_pair])
            
            words = self.merge_vocab(words, best_pair)
            
            if len(self.vocab) >= self.vocab_size:
                break

    def encode(self, text: str) -> List[str]:
        """

        Encode text using learned BPE merges

        """
        words = [[char for char in word] for word in text.split()]
        for pair, merge in self.merges.items():
            words = self.merge_vocab(words, pair)
        return [token for word in words for token in word]

    def save_model(self, path: str) -> None:
        """

        Save BPE model to file

        """
        model_data = {
            'vocab_size': self.vocab_size,
            'merges': {f'{k[0]} {k[1]}': v for k, v in self.merges.items()},
            'vocab': list(self.vocab)
        }
        with open(path, 'w', encoding='utf-8') as f:
            json.dump(model_data, f, ensure_ascii=False, indent=2)

    def load_model(self, path: str) -> None:
        """

        Load BPE model from file

        """
        with open(path, 'r', encoding='utf-8') as f:
            model_data = json.load(f)
        
        self.vocab_size = model_data['vocab_size']
        self.merges = {tuple(k.split()): v for k, v in model_data['merges'].items()}
        self.vocab = set(model_data['vocab'])

def main():
    # Example usage
    input_file = "telugu_text.txt"
    model_file = "telugu_bpe_model.json"
    
    # Read input text
    with open(input_file, 'r', encoding='utf-8') as f:
        text = f.read()
    
    print(f'Started learning BPE')
    bpe = TeluguBPE(vocab_size=5000)
    
    # Preprocess text
    processed_text = bpe.preprocess_telugu_text(text)
    
    # Calculate original text statistics
    original_chars = len(processed_text)
    original_tokens = len(processed_text.split())
    
    # Learn BPE
    bpe.learn_bpe(processed_text)
    
    # Encode the entire text to calculate compression
    encoded_text = bpe.encode(processed_text)
    encoded_length = len(encoded_text)
    
    # Calculate compression ratio
    compression_ratio = original_chars / encoded_length
    
    # Save model
    bpe.save_model(model_file)
    
    # Print statistics
    print(f"\nCompression Statistics:")
    print(f"Original characters: {original_chars}")
    print(f"Original tokens (words): {original_tokens}")
    print(f"Encoded tokens: {encoded_length}")
    print(f"Compression ratio: {compression_ratio:.2f}x")
    print(f"Vocabulary size: {len(bpe.vocab)}")
    
    # Example encoding
    sample_text = "నమస్కారం"  # "Hello" in Telugu
    encoded = bpe.encode(bpe.preprocess_telugu_text(sample_text))
    print(f"\nExample encoding:")
    print(f"Sample text: {sample_text}")
    print(f"Encoded text: {encoded}")

if __name__ == "__main__":
    main()