Spaces:
Sleeping
Sleeping
from typing import List, Dict, Optional | |
from tqdm import tqdm | |
from collections import Counter | |
from matplotlib import pyplot as plt | |
import json | |
from pathlib import Path | |
class TrieNode: | |
"""Node in the prefix tree (trie) for fast token matching""" | |
def __init__(self): | |
self.children = {} | |
self.is_token = False | |
self.token = None | |
class BytePairEncoder: | |
def __init__(self, text: str): | |
# Initialize vocabulary from characters | |
self.chars = sorted(list(set(text))) | |
self.stoi = {ch: i for i, ch in enumerate(self.chars)} | |
self.itos = {i: ch for i, ch in enumerate(self.chars)} | |
# Initial encoding of text | |
self.data = [self.stoi[c] for c in text] | |
# Statistics tracking | |
self.stats = { | |
"vocab_sizes": [len(self.chars)], | |
"data_sizes": [len(self.data)], | |
"compression_ratios": [1.0], | |
"merge_counts": [], | |
"tokens_created": [], | |
"max_token_lengths": [1], | |
} | |
# Store original length for compression ratio | |
self.original_length = len(self.data) | |
self.max_token_length = 1 | |
def get_digram_stats(self) -> Counter: | |
"""Get digram counts""" | |
counts = Counter() | |
for pair in zip(self.data, self.data[1:]): | |
pair = (int(pair[0]), int(pair[1])) | |
counts[pair] += 1 | |
return counts | |
def encode_to_vocab_size(self, target_vocab_size: int, plot_interval: Optional[int] = None, | |
print_interval: int = 100) -> None: | |
"""Train until reaching target vocabulary size""" | |
pbar = tqdm(total=target_vocab_size, desc="Training BPE", initial=len(self.chars)) | |
iteration = 0 | |
while len(self.itos) < target_vocab_size: | |
result = self._merge_step() | |
if result is None: | |
break | |
iteration += 1 | |
pbar.update(1) | |
if print_interval and iteration % print_interval == 0: | |
self._print_progress(iteration) | |
if plot_interval and iteration % plot_interval == 0: | |
self.plot_statistics(iteration=iteration) | |
pbar.close() | |
def _merge_step(self): | |
"""Perform one merge operation""" | |
stats = self.get_digram_stats() | |
if not stats: | |
return None | |
top_pair, count = max(stats.items(), key=lambda x: x[1]) | |
new_token = self._add_token(top_pair) | |
self.data = self._replace_pairs(top_pair, new_token) | |
self._update_stats(count) | |
return new_token, count | |
def _add_token(self, pair: tuple) -> int: | |
"""Add new token to vocabulary""" | |
token_str = self.itos[pair[0]] + self.itos[pair[1]] | |
token_id = len(self.itos) | |
self.stoi[token_str] = token_id | |
self.itos[token_id] = token_str | |
self.max_token_length = max(self.max_token_length, len(token_str)) | |
return token_id | |
def _replace_pairs(self, pair: tuple, new_token: int) -> List[int]: | |
"""Replace all occurrences of pair with new token""" | |
result = [] | |
i = 0 | |
while i < len(self.data): | |
if i < len(self.data) - 1 and self.data[i] == pair[0] and self.data[i + 1] == pair[1]: | |
result.append(new_token) | |
i += 2 | |
else: | |
result.append(self.data[i]) | |
i += 1 | |
return result | |
def _update_stats(self, merge_count: int): | |
"""Update training statistics""" | |
self.stats["vocab_sizes"].append(len(self.itos)) | |
self.stats["data_sizes"].append(len(self.data)) | |
compression = self.original_length / len(self.data) | |
self.stats["compression_ratios"].append(compression) | |
self.stats["merge_counts"].append(merge_count) | |
self.stats["tokens_created"].append(self.itos[len(self.itos)-1]) | |
self.stats["max_token_lengths"].append(self.max_token_length) | |
def plot_statistics(self, iteration: Optional[int] = None): | |
"""Plot training statistics""" | |
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) | |
# Plot training metrics | |
ax1.plot(self.stats["vocab_sizes"], self.stats["data_sizes"]) | |
ax1.set_title("Vocabulary vs Dataset Size") | |
ax2.plot(self.stats["vocab_sizes"], self.stats["compression_ratios"]) | |
ax2.set_title("Compression Ratio Progress") | |
if self.stats["merge_counts"]: | |
ax3.hist(self.stats["merge_counts"], bins=30) | |
ax3.set_title("Merge Counts Distribution") | |
if self.stats["tokens_created"]: | |
lengths = [len(t) for t in self.stats["tokens_created"]] | |
ax4.plot(range(len(lengths)), lengths) | |
ax4.set_title("Token Length Evolution") | |
plt.tight_layout() | |
plt.show() | |
def save_to_file(self, filepath: Path): | |
"""Save encoder state""" | |
state = { | |
"chars": self.chars, | |
"stoi": self.stoi, | |
"max_token_length": self.max_token_length, | |
"stats": self.stats | |
} | |
with open(filepath, 'w', encoding='utf-8') as f: | |
json.dump(state, f, ensure_ascii=False, indent=2) | |
def load_from_file(cls, filepath: Path): | |
"""Load encoder state""" | |
with open(filepath, 'r', encoding='utf-8') as f: | |
state = json.load(f) | |
instance = cls("") # Create empty instance | |
instance.chars = state["chars"] | |
instance.stoi = state["stoi"] | |
instance.itos = {int(i): s for s, i in state["stoi"].items()} | |
instance.max_token_length = state["max_token_length"] | |
instance.stats = state["stats"] | |
return instance | |
def _print_progress(self, iteration: int): | |
"""Print training progress""" | |
print(f"\nIteration {iteration}:") | |
print(f"Vocabulary size: {len(self.itos):,}") | |
print(f"Data size: {len(self.data):,}") | |
print(f"Compression ratio: {self.stats['compression_ratios'][-1]:.2f}") | |
if self.stats["merge_counts"]: | |
last_merge = self.stats["merge_counts"][-1] | |
last_token = self.stats["tokens_created"][-1] | |
print(f"Last merge count: {last_merge:,}") | |
print(f"Last token created: '{last_token}'") | |
print(f"Max token length: {self.max_token_length}") | |
class TokenizerInternal: | |
"""Tokenizer using trained BPE model""" | |
def __init__(self, encoder: BytePairEncoder): | |
self.stoi = encoder.stoi | |
self.max_token_length = encoder.max_token_length | |
self._trie = self._build_trie() | |
def _build_trie(self) -> TrieNode: | |
"""Build trie for efficient tokenization""" | |
root = TrieNode() | |
for token in self.stoi: | |
node = root | |
for char in token: | |
if char not in node.children: | |
node.children[char] = TrieNode() | |
node = node.children[char] | |
node.is_token = True | |
node.token = token | |
return root | |
def tokenize(self, text: str) -> List[str]: | |
"""Tokenize text using trie-based matching""" | |
tokens = [] | |
pos = 0 | |
while pos < len(text): | |
token = self._find_longest_token(text[pos:]) | |
tokens.append(token) | |
pos += len(token) | |
return tokens | |
def _find_longest_token(self, text: str) -> str: | |
"""Find longest matching token starting at current position""" | |
node = self._trie | |
longest = text[0] | |
current = "" | |
for char in text[:self.max_token_length]: | |
if char not in node.children: | |
break | |
current += char | |
node = node.children[char] | |
if node.is_token: | |
longest = node.token | |
return longest |