nragrawal's picture
Update to read_files and tokenizer.json
73ba2f8
raw
history blame
8.31 kB
import os
import json
import regex as re
from natsort import natsorted
from tqdm import tqdm
import gc # For garbage collection
# Add the Marathi regex pattern at the top level
MARATHI_PATTERN = re.compile(r"""
# Contractions and common affixes
'चा|'ची|'चे|'ला|'ले|'नी|
# Words with optional vowel signs and modifiers
[\p{L}\p{M}]+|
# Numbers
\p{N}+|
# Punctuation and special characters
[^\s\p{L}\p{N}\p{M}]+|
# Whitespace
\s+
""", re.VERBOSE)
def text_to_bytes(text):
"""Convert text to byte tokens after applying Marathi regex"""
words = MARATHI_PATTERN.findall(text)
all_bytes = []
for word in words:
bytes_tokens = [b for c in word for b in c.encode('utf-8')]
all_bytes.extend(bytes_tokens)
return all_bytes
def read_text_files(folder_path='train', limit=50000, batch_size=1000):
"""
Read text files in batches to manage memory
"""
if not os.path.exists(folder_path):
print(f"Error: The folder '{folder_path}' does not exist.")
return
# Get list of all files
files = os.listdir(folder_path)
text_files = natsorted([f for f in files if f.endswith(('.txt', '.text'))])
if not text_files:
print(f"No text files found in '{folder_path}' folder.")
return
# Take only the first 'limit' files
text_files = text_files[:limit]
total_files = len(text_files)
# Process files in batches
all_tokens = []
for i in tqdm(range(0, total_files, batch_size), desc="Processing files"):
batch_files = text_files[i:i + batch_size]
batch_tokens = []
for file_name in batch_files:
file_path = os.path.join(folder_path, file_name)
try:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
tokens = text_to_bytes(content)
batch_tokens.extend(tokens)
except Exception as e:
print(f"Error reading {file_name}: {str(e)}")
# Process batch
all_tokens.extend(batch_tokens)
# Print batch statistics
if (i + batch_size) % 5000 == 0:
print(f"\nProcessed {i + len(batch_files)}/{total_files} files")
print(f"Current tokens: {len(all_tokens)}")
# Garbage collection after each batch
gc.collect()
print("\n=== Final Statistics ===")
print(f"Total files processed: {total_files}")
print(f"Total tokens: {len(all_tokens)}")
return all_tokens
def get_stats(ids):
counts = {}
for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
def encode(text, merges):
"""
Encode text into tokens using the learned merges
"""
# First convert text to bytes using Marathi-aware tokenization
ids = text_to_bytes(text)
# Apply the merges in order of their token indices
# Sort by the token index to ensure consistent ordering
sorted_merges = sorted(merges.items(), key=lambda x: x[1])
for (p1, p2), idx in sorted_merges:
ids = merge(ids, (p1, p2), idx)
return ids
def decode(ids, merges):
"""
Decode tokens back to text using the learned merges
"""
# Create reverse mapping from token to pair
reverse_merges = {idx: pair for pair, idx in merges.items()}
# Expand all tokens recursively
def expand_token(token):
if token < 256: # Base case: token is a byte
return bytes([token])
# Recursive case: expand the token into its constituent pair
pair = reverse_merges[token]
return expand_token(pair[0]) + expand_token(pair[1])
# Expand all tokens and concatenate
bytes_list = [expand_token(id) for id in ids]
bytes_data = b''.join(bytes_list)
# Convert bytes back to text
try:
return bytes_data.decode('utf-8')
except UnicodeDecodeError:
return "[DECODE_ERROR]"
class Tokenizer:
def __init__(self, merges=None):
self.merges = merges or {}
def encode(self, text):
return encode(text, self.merges)
def decode(self, ids):
return decode(ids, self.merges)
def save(self, path):
"""Save the tokenizer to a JSON file"""
# Convert tuple keys to strings for JSON serialization
serializable_merges = {f"{p1},{p2}": idx for (p1, p2), idx in self.merges.items()}
with open(path, 'w') as f:
json.dump(serializable_merges, f)
@classmethod
def load(cls, path):
"""Load a tokenizer from a JSON file"""
with open(path, 'r') as f:
serialized_merges = json.load(f)
# Convert string keys back to tuples
merges = {tuple(map(int, k.split(','))): v for k, v in serialized_merges.items()}
return cls(merges)
def train_tokenizer(vocab_size=5000, input_folder='train', output_file='model/tokenizer.json', file_limit=50000):
"""
Train tokenizer on a large dataset
"""
print("Reading files...")
all_tokens = read_text_files(folder_path=input_folder, limit=file_limit)
initial_len = len(all_tokens)
initial_bytes = sum(len(str(t).encode('utf-8')) for t in all_tokens)
print("\nTraining tokenizer...")
num_merges = vocab_size - 256
ids = list(all_tokens)
merges = {}
pbar = tqdm(range(num_merges), desc="Learning merges")
for i in pbar:
# Get statistics in chunks to save memory
stats = get_stats(ids)
pair = max(stats.items(), key=lambda x: x[1])[0]
idx = 256 + i
# Apply merge
ids = merge(ids, pair, idx)
merges[pair] = idx
# Show progress
if (i + 1) % 100 == 0:
current_ratio = initial_len / len(ids)
pbar.write(f"Iteration {i+1}: compression ratio: {current_ratio:.2f}X")
# Garbage collection periodically
if (i + 1) % 1000 == 0:
gc.collect()
# Save intermediate merges
if (i + 1) % 5000 == 0:
temp_tokenizer = Tokenizer(merges)
temp_tokenizer.save(f"{output_file}.checkpoint")
# Create and save final tokenizer
final_tokenizer = Tokenizer(merges)
final_tokenizer.save(output_file)
# Calculate final statistics
final_len = len(ids)
final_bytes = sum(len(str(t).encode('utf-8')) for t in ids)
token_ratio = initial_len / final_len
byte_ratio = initial_bytes / final_bytes
print("\n=== Final Statistics ===")
print(f"Vocabulary size: {vocab_size}")
print(f"Initial tokens: {initial_len:,}")
print(f"Final tokens: {final_len:,}")
print(f"Initial bytes: {initial_bytes:,}")
print(f"Final bytes: {final_bytes:,}")
print(f"Token compression ratio: {token_ratio:.2f}X")
print(f"Byte compression ratio: {byte_ratio:.2f}X")
print(f"Saved tokenizer to: {output_file}")
return final_tokenizer
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input', default='train', help='Input folder containing text files')
parser.add_argument('--output', default='model/tokenizer.json', help='Output tokenizer file')
parser.add_argument('--vocab-size', type=int, default=5000, help='Desired vocabulary size')
parser.add_argument('--file-limit', type=int, default=50000, help='Number of files to process')
parser.add_argument('--batch-size', type=int, default=1000, help='Batch size for processing files')
args = parser.parse_args()
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(args.output), exist_ok=True)
# Train tokenizer
tokenizer = train_tokenizer(
vocab_size=args.vocab_size,
input_folder=args.input,
output_file=args.output,
file_limit=args.file_limit
)