bye-hindi / byte_pair_encoder.py
aayushraina's picture
Upload 13 files
46759b2 verified
from typing import List, Dict, Optional
from tqdm import tqdm
from collections import Counter
from matplotlib import pyplot as plt
import json
from pathlib import Path
class TrieNode:
"""Node in the prefix tree (trie) for fast token matching"""
def __init__(self):
self.children = {}
self.is_token = False
self.token = None
class BytePairEncoder:
def __init__(self, text: str):
# Initialize vocabulary from characters
self.chars = sorted(list(set(text)))
self.stoi = {ch: i for i, ch in enumerate(self.chars)}
self.itos = {i: ch for i, ch in enumerate(self.chars)}
# Initial encoding of text
self.data = [self.stoi[c] for c in text]
# Statistics tracking
self.stats = {
"vocab_sizes": [len(self.chars)],
"data_sizes": [len(self.data)],
"compression_ratios": [1.0],
"merge_counts": [],
"tokens_created": [],
"max_token_lengths": [1],
}
# Store original length for compression ratio
self.original_length = len(self.data)
self.max_token_length = 1
def get_digram_stats(self) -> Counter:
"""Get digram counts"""
counts = Counter()
for pair in zip(self.data, self.data[1:]):
pair = (int(pair[0]), int(pair[1]))
counts[pair] += 1
return counts
def encode_to_vocab_size(self, target_vocab_size: int, plot_interval: Optional[int] = None,
print_interval: int = 100) -> None:
"""Train until reaching target vocabulary size"""
pbar = tqdm(total=target_vocab_size, desc="Training BPE", initial=len(self.chars))
iteration = 0
while len(self.itos) < target_vocab_size:
result = self._merge_step()
if result is None:
break
iteration += 1
pbar.update(1)
if print_interval and iteration % print_interval == 0:
self._print_progress(iteration)
if plot_interval and iteration % plot_interval == 0:
self.plot_statistics(iteration=iteration)
pbar.close()
def _merge_step(self):
"""Perform one merge operation"""
stats = self.get_digram_stats()
if not stats:
return None
top_pair, count = max(stats.items(), key=lambda x: x[1])
new_token = self._add_token(top_pair)
self.data = self._replace_pairs(top_pair, new_token)
self._update_stats(count)
return new_token, count
def _add_token(self, pair: tuple) -> int:
"""Add new token to vocabulary"""
token_str = self.itos[pair[0]] + self.itos[pair[1]]
token_id = len(self.itos)
self.stoi[token_str] = token_id
self.itos[token_id] = token_str
self.max_token_length = max(self.max_token_length, len(token_str))
return token_id
def _replace_pairs(self, pair: tuple, new_token: int) -> List[int]:
"""Replace all occurrences of pair with new token"""
result = []
i = 0
while i < len(self.data):
if i < len(self.data) - 1 and self.data[i] == pair[0] and self.data[i + 1] == pair[1]:
result.append(new_token)
i += 2
else:
result.append(self.data[i])
i += 1
return result
def _update_stats(self, merge_count: int):
"""Update training statistics"""
self.stats["vocab_sizes"].append(len(self.itos))
self.stats["data_sizes"].append(len(self.data))
compression = self.original_length / len(self.data)
self.stats["compression_ratios"].append(compression)
self.stats["merge_counts"].append(merge_count)
self.stats["tokens_created"].append(self.itos[len(self.itos)-1])
self.stats["max_token_lengths"].append(self.max_token_length)
def plot_statistics(self, iteration: Optional[int] = None):
"""Plot training statistics"""
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
# Plot training metrics
ax1.plot(self.stats["vocab_sizes"], self.stats["data_sizes"])
ax1.set_title("Vocabulary vs Dataset Size")
ax2.plot(self.stats["vocab_sizes"], self.stats["compression_ratios"])
ax2.set_title("Compression Ratio Progress")
if self.stats["merge_counts"]:
ax3.hist(self.stats["merge_counts"], bins=30)
ax3.set_title("Merge Counts Distribution")
if self.stats["tokens_created"]:
lengths = [len(t) for t in self.stats["tokens_created"]]
ax4.plot(range(len(lengths)), lengths)
ax4.set_title("Token Length Evolution")
plt.tight_layout()
plt.show()
def save_to_file(self, filepath: Path):
"""Save encoder state"""
state = {
"chars": self.chars,
"stoi": self.stoi,
"max_token_length": self.max_token_length,
"stats": self.stats
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(state, f, ensure_ascii=False, indent=2)
@classmethod
def load_from_file(cls, filepath: Path):
"""Load encoder state"""
with open(filepath, 'r', encoding='utf-8') as f:
state = json.load(f)
instance = cls("") # Create empty instance
instance.chars = state["chars"]
instance.stoi = state["stoi"]
instance.itos = {int(i): s for s, i in state["stoi"].items()}
instance.max_token_length = state["max_token_length"]
instance.stats = state["stats"]
return instance
def _print_progress(self, iteration: int):
"""Print training progress"""
print(f"\nIteration {iteration}:")
print(f"Vocabulary size: {len(self.itos):,}")
print(f"Data size: {len(self.data):,}")
print(f"Compression ratio: {self.stats['compression_ratios'][-1]:.2f}")
if self.stats["merge_counts"]:
last_merge = self.stats["merge_counts"][-1]
last_token = self.stats["tokens_created"][-1]
print(f"Last merge count: {last_merge:,}")
print(f"Last token created: '{last_token}'")
print(f"Max token length: {self.max_token_length}")
class TokenizerInternal:
"""Tokenizer using trained BPE model"""
def __init__(self, encoder: BytePairEncoder):
self.stoi = encoder.stoi
self.max_token_length = encoder.max_token_length
self._trie = self._build_trie()
def _build_trie(self) -> TrieNode:
"""Build trie for efficient tokenization"""
root = TrieNode()
for token in self.stoi:
node = root
for char in token:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.is_token = True
node.token = token
return root
def tokenize(self, text: str) -> List[str]:
"""Tokenize text using trie-based matching"""
tokens = []
pos = 0
while pos < len(text):
token = self._find_longest_token(text[pos:])
tokens.append(token)
pos += len(token)
return tokens
def _find_longest_token(self, text: str) -> str:
"""Find longest matching token starting at current position"""
node = self._trie
longest = text[0]
current = ""
for char in text[:self.max_token_length]:
if char not in node.children:
break
current += char
node = node.children[char]
if node.is_token:
longest = node.token
return longest