import numpy as np import logging from typing import Dict, List, Tuple, Optional from dataclasses import dataclass from transformers import PreTrainedTokenizer import os import json from huggingface_hub import Repository from huggingface_hub import HfApi # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) WAVELET_TOKENIZER_CONFIG = { "model_type": "wavelet", "tokenizer_class": "WaveletTokenizer", "auto_map": { "AutoTokenizer": ["tokenizer.WaveletTokenizer", None] } } @dataclass class WaveletTokenizerConfig: vocab_size: int = 256 padding_idx: int = 0 eeg_channels: int = 74 # Source modality (EEG) mu: float = 255.0 # Static μ value for μ-law compression verbose: bool = True # Control logging class WaveletTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask", "position_ids"] def __init__( self, vocab_size: int = 256, mu: float = 255.0, verbose: bool = True, **kwargs ): self.auto_map = { "AutoTokenizer": ["tokenizer.WaveletTokenizer", None] } # Set vocab size first self._vocab_size = vocab_size self.mu = mu self.verbose = verbose # Store normalization state self.channel_mins = None self.channel_maxs = None # Initialize parent class after setting vocab_size super().__init__(**kwargs) if self.verbose: logger.info(f"Initialized WaveletTokenizer with μ={self.mu:.2f}") @property def vocab_size(self) -> int: """Returns the size of vocabulary (number of possible quantization levels).""" return self._vocab_size @vocab_size.setter def vocab_size(self, size: int): self._vocab_size = size def save_pretrained( self, save_directory: str, legacy_format: bool = True, filename_prefix: Optional[str] = None, push_to_hub: bool = False, **kwargs ) -> Tuple[str, ...]: """Save tokenizer configuration to a directory.""" if not os.path.exists(save_directory): os.makedirs(save_directory) # Save tokenizer config config = { **WAVELET_TOKENIZER_CONFIG, "vocab_size": self.vocab_size, "mu": self.mu, "verbose": self.verbose } config_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "tokenizer_config.json" ) with open(config_file, "w") as f: json.dump(config, f, indent=2) # Save vocabulary vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix) if push_to_hub: # Upload files using HTTP api = HfApi() api.upload_file( path_or_fileobj=config_file, path_in_repo="tokenizer_config.json", repo_id=save_directory, commit_message=kwargs.get("commit_message", "Upload tokenizer config") ) # Upload vocabulary file vocab_file = vocab_files[0] api.upload_file( path_or_fileobj=vocab_file, path_in_repo=os.path.basename(vocab_file), repo_id=save_directory, commit_message=kwargs.get("commit_message", "Upload tokenizer vocabulary") ) return vocab_files + (config_file,) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, **kwargs ) -> "WaveletTokenizer": """Load tokenizer from HuggingFace Hub.""" # Load config first config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json") if os.path.exists(config_file): with open(config_file, "r") as f: config = json.load(f) # Update with any passed kwargs config.update(kwargs) else: config = kwargs return cls(**config) def get_vocab(self) -> Dict[str, int]: """Returns vocab as a dict mapping token strings to ids.""" # Create a minimal vocabulary with quantization levels return {str(i): i for i in range(self.vocab_size)} def _convert_token_to_id(self, token: str) -> int: """Converts a token string to its ID.""" try: return int(token) except ValueError: return 0 # Return 0 for unknown tokens def _convert_id_to_token(self, index: int) -> str: """Converts an ID back to its token string.""" return str(index) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Converts a sequence of tokens to a single string.""" return " ".join(tokens) def _tokenize(self, text: str) -> List[str]: """Basic tokenization for compatibility.""" if isinstance(text, str): return [text] return [str(t) for t in text] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]: """Save the vocabulary to a directory.""" vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json" ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self.get_vocab(), f, ensure_ascii=False) return (vocab_file,) def __call__( self, eeg_data: np.ndarray, **kwargs ) -> Dict[str, np.ndarray]: """ Main entry point for tokenization. Handles numpy array input. Args: eeg_data: Raw EEG array of shape (n_channels, time_points) Returns: Dictionary containing: - input_ids: Tokenized signal values - attention_mask: Binary mask (all ones since we don't pad) - position_ids: Sequential position indices """ # Process through tokenization pipeline input_ids = self.encode(eeg_data) # Create attention mask (all ones since we're not padding) attention_mask = np.ones_like(input_ids) # Create position IDs n_channels, time_points = eeg_data.shape position_ids = np.tile(np.arange(time_points), (n_channels, 1)) return { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids } def encode(self, eeg_data: np.ndarray) -> np.ndarray: """Convert EEG data to token IDs.""" # 1. Normalize to [0, 1] normalized = self.normalize(eeg_data) # 2. Convert to [-1, 1] for μ-law compression centered = 2 * normalized - 1 # 3. Apply μ-law compression compressed = self.mu_law_encode(centered) # 4. Quantize to tokens input_values = (compressed + 1) / 2 # to [0, 1] token_ids = (input_values * (self.vocab_size - 1)).astype(np.int64) return token_ids def normalize(self, x: np.ndarray) -> np.ndarray: """ Apply static normalization per channel and store min/max values. Input shape: (n_channels, time_points) """ # Compute min/max per channel and expand dimensions to match input self.channel_mins = x.min(axis=1)[:, np.newaxis] # Shape: (n_channels, 1) self.channel_maxs = x.max(axis=1)[:, np.newaxis] # Shape: (n_channels, 1) normalized = (x - self.channel_mins) / (self.channel_maxs - self.channel_mins + 1e-8) if self.verbose: logger.info(f"Min-max normalization: input range [{x.min():.3f}, {x.max():.3f}] → [{normalized.min():.3f}, {normalized.max():.3f}]") return normalized def mu_law_encode(self, x: np.ndarray) -> np.ndarray: """ Apply μ-law compression. Expects input in [-1, 1] range. """ assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}" compressed = np.sign(x) * np.log1p(self.mu * np.abs(x)) / np.log1p(self.mu) if self.verbose: logger.info(f"μ-law compression (μ={self.mu:.2f}): variance before={np.var(x):.3f}, after={np.var(compressed):.3f}") return compressed def mu_law_decode(self, x: np.ndarray) -> np.ndarray: """ Inverse μ-law compression. Expects input in [-1, 1] range. """ assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}" return np.sign(x) * (1/self.mu) * (np.power(1 + self.mu, np.abs(x)) - 1.0) def decode(self, token_ids: np.ndarray) -> np.ndarray: """ Decode token IDs back to EEG signal. Args: token_ids: Array of token IDs of shape (n_channels, time_points) Returns: Array of shape (n_channels, time_points) """ # Convert to continuous values in [-1, 1] values = token_ids.astype(np.float32) / (self.vocab_size - 1) # [0, 1] values = 2 * values - 1 # [-1, 1] # Apply inverse μ-law compression values = self.mu_law_decode(values) # Convert back to [0, 1] values = (values + 1) / 2 # Denormalize to original scale if self.channel_mins is not None and self.channel_maxs is not None: values = values * (self.channel_maxs - self.channel_mins) + self.channel_mins return values