import os
import json
import regex as re
from natsort import natsorted
from tqdm import tqdm
import gc  # For garbage collection

# Add the Marathi regex pattern at the top level
MARATHI_PATTERN = re.compile(r"""
    # Contractions and common affixes
    'चा|'ची|'चे|'ला|'ले|'नी|
    # Words with optional vowel signs and modifiers
    [\p{L}\p{M}]+|
    # Numbers
    \p{N}+|
    # Punctuation and special characters
    [^\s\p{L}\p{N}\p{M}]+|
    # Whitespace
    \s+
""", re.VERBOSE)

def text_to_bytes(text):
    """Convert text to byte tokens after applying Marathi regex"""
    words = MARATHI_PATTERN.findall(text)
    all_bytes = []
    for word in words:
        bytes_tokens = [b for c in word for b in c.encode('utf-8')]
        all_bytes.extend(bytes_tokens)
    return all_bytes

def read_text_files(folder_path='train', limit=50000, batch_size=1000):
    """
    Read text files in batches to manage memory
    """
    if not os.path.exists(folder_path):
        print(f"Error: The folder '{folder_path}' does not exist.")
        return
    
    # Get list of all files
    files = os.listdir(folder_path)
    text_files = natsorted([f for f in files if f.endswith(('.txt', '.text'))])
    
    if not text_files:
        print(f"No text files found in '{folder_path}' folder.")
        return
    
    # Take only the first 'limit' files
    text_files = text_files[:limit]
    total_files = len(text_files)
    
    # Process files in batches
    all_tokens = []
    
    for i in tqdm(range(0, total_files, batch_size), desc="Processing files"):
        batch_files = text_files[i:i + batch_size]
        batch_tokens = []
        
        for file_name in batch_files:
            file_path = os.path.join(folder_path, file_name)
            try:
                with open(file_path, 'r', encoding='utf-8') as file:
                    content = file.read()
                    tokens = text_to_bytes(content)
                    batch_tokens.extend(tokens)
            except Exception as e:
                print(f"Error reading {file_name}: {str(e)}")
        
        # Process batch
        all_tokens.extend(batch_tokens)
        
        # Print batch statistics
        if (i + batch_size) % 5000 == 0:
            print(f"\nProcessed {i + len(batch_files)}/{total_files} files")
            print(f"Current tokens: {len(all_tokens)}")
            
        # Garbage collection after each batch
        gc.collect()
    
    print("\n=== Final Statistics ===")
    print(f"Total files processed: {total_files}")
    print(f"Total tokens: {len(all_tokens)}")
    return all_tokens

def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

def encode(text, merges):
    """
    Encode text into tokens using the learned merges
    """
    # First convert text to bytes using Marathi-aware tokenization
    ids = text_to_bytes(text)
    
    # Apply the merges in order of their token indices
    # Sort by the token index to ensure consistent ordering
    sorted_merges = sorted(merges.items(), key=lambda x: x[1])
    for (p1, p2), idx in sorted_merges:
        ids = merge(ids, (p1, p2), idx)
    
    return ids

def decode(ids, merges):
    """
    Decode tokens back to text using the learned merges
    """
    # Create reverse mapping from token to pair
    reverse_merges = {idx: pair for pair, idx in merges.items()}
    
    # Expand all tokens recursively
    def expand_token(token):
        if token < 256:  # Base case: token is a byte
            return bytes([token])
        
        # Recursive case: expand the token into its constituent pair
        pair = reverse_merges[token]
        return expand_token(pair[0]) + expand_token(pair[1])
    
    # Expand all tokens and concatenate
    bytes_list = [expand_token(id) for id in ids]
    bytes_data = b''.join(bytes_list)
    
    # Convert bytes back to text
    try:
        return bytes_data.decode('utf-8')
    except UnicodeDecodeError:
        return "[DECODE_ERROR]"

class Tokenizer:
    def __init__(self, merges=None):
        self.merges = merges or {}
    
    def encode(self, text):
        return encode(text, self.merges)
    
    def decode(self, ids):
        return decode(ids, self.merges)
    
    def save(self, path):
        """Save the tokenizer to a JSON file"""
        # Convert tuple keys to strings for JSON serialization
        serializable_merges = {f"{p1},{p2}": idx for (p1, p2), idx in self.merges.items()}
        with open(path, 'w') as f:
            json.dump(serializable_merges, f)
    
    @classmethod
    def load(cls, path):
        """Load a tokenizer from a JSON file"""
        with open(path, 'r') as f:
            serialized_merges = json.load(f)
        # Convert string keys back to tuples
        merges = {tuple(map(int, k.split(','))): v for k, v in serialized_merges.items()}
        return cls(merges)

def train_tokenizer(vocab_size=5000, input_folder='train', output_file='model/tokenizer.json', file_limit=50000):
    """
    Train tokenizer on a large dataset
    """
    print("Reading files...")
    all_tokens = read_text_files(folder_path=input_folder, limit=file_limit)
    initial_len = len(all_tokens)
    initial_bytes = sum(len(str(t).encode('utf-8')) for t in all_tokens)
    
    print("\nTraining tokenizer...")
    num_merges = vocab_size - 256
    ids = list(all_tokens)
    merges = {}
    
    pbar = tqdm(range(num_merges), desc="Learning merges")
    for i in pbar:
        # Get statistics in chunks to save memory
        stats = get_stats(ids)
        pair = max(stats.items(), key=lambda x: x[1])[0]
        idx = 256 + i
        
        # Apply merge
        ids = merge(ids, pair, idx)
        merges[pair] = idx
        
        # Show progress
        if (i + 1) % 100 == 0:
            current_ratio = initial_len / len(ids)
            pbar.write(f"Iteration {i+1}: compression ratio: {current_ratio:.2f}X")
        
        # Garbage collection periodically
        if (i + 1) % 1000 == 0:
            gc.collect()
        
        # Save intermediate merges
        if (i + 1) % 5000 == 0:
            temp_tokenizer = Tokenizer(merges)
            temp_tokenizer.save(f"{output_file}.checkpoint")
    
    # Create and save final tokenizer
    final_tokenizer = Tokenizer(merges)
    final_tokenizer.save(output_file)
    
    # Calculate final statistics
    final_len = len(ids)
    final_bytes = sum(len(str(t).encode('utf-8')) for t in ids)
    token_ratio = initial_len / final_len
    byte_ratio = initial_bytes / final_bytes
    
    print("\n=== Final Statistics ===")
    print(f"Vocabulary size: {vocab_size}")
    print(f"Initial tokens: {initial_len:,}")
    print(f"Final tokens: {final_len:,}")
    print(f"Initial bytes: {initial_bytes:,}")
    print(f"Final bytes: {final_bytes:,}")
    print(f"Token compression ratio: {token_ratio:.2f}X")
    print(f"Byte compression ratio: {byte_ratio:.2f}X")
    print(f"Saved tokenizer to: {output_file}")
    
    return final_tokenizer

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', default='train', help='Input folder containing text files')
    parser.add_argument('--output', default='model/tokenizer.json', help='Output tokenizer file')
    parser.add_argument('--vocab-size', type=int, default=5000, help='Desired vocabulary size')
    parser.add_argument('--file-limit', type=int, default=50000, help='Number of files to process')
    parser.add_argument('--batch-size', type=int, default=1000, help='Batch size for processing files')
    args = parser.parse_args()
    
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    
    # Train tokenizer
    tokenizer = train_tokenizer(
        vocab_size=args.vocab_size,
        input_folder=args.input,
        output_file=args.output,
        file_limit=args.file_limit
    )