|
import numpy as np |
|
import logging |
|
from typing import Dict, List, Tuple, Optional |
|
from dataclasses import dataclass |
|
from transformers import PreTrainedTokenizer |
|
import os |
|
import json |
|
from huggingface_hub import Repository |
|
from huggingface_hub import HfApi |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
WAVELET_TOKENIZER_CONFIG = { |
|
"model_type": "wavelet", |
|
"tokenizer_class": "WaveletTokenizer", |
|
"auto_map": { |
|
"AutoTokenizer": ["tokenizer.WaveletTokenizer", None] |
|
} |
|
} |
|
|
|
@dataclass |
|
class WaveletTokenizerConfig: |
|
vocab_size: int = 256 |
|
padding_idx: int = 0 |
|
eeg_channels: int = 74 |
|
mu: float = 255.0 |
|
verbose: bool = True |
|
|
|
class WaveletTokenizer(PreTrainedTokenizer): |
|
model_input_names = ["input_ids", "attention_mask", "position_ids"] |
|
|
|
def __init__( |
|
self, |
|
vocab_size: int = 256, |
|
mu: float = 255.0, |
|
verbose: bool = True, |
|
**kwargs |
|
): |
|
self.auto_map = { |
|
"AutoTokenizer": ["tokenizer.WaveletTokenizer", None] |
|
} |
|
|
|
|
|
self._vocab_size = vocab_size |
|
self.mu = mu |
|
self.verbose = verbose |
|
|
|
|
|
self.channel_mins = None |
|
self.channel_maxs = None |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
if self.verbose: |
|
logger.info(f"Initialized WaveletTokenizer with μ={self.mu:.2f}") |
|
|
|
@property |
|
def vocab_size(self) -> int: |
|
"""Returns the size of vocabulary (number of possible quantization levels).""" |
|
return self._vocab_size |
|
|
|
@vocab_size.setter |
|
def vocab_size(self, size: int): |
|
self._vocab_size = size |
|
|
|
def save_pretrained( |
|
self, |
|
save_directory: str, |
|
legacy_format: bool = True, |
|
filename_prefix: Optional[str] = None, |
|
push_to_hub: bool = False, |
|
**kwargs |
|
) -> Tuple[str, ...]: |
|
"""Save tokenizer configuration to a directory.""" |
|
if not os.path.exists(save_directory): |
|
os.makedirs(save_directory) |
|
|
|
|
|
config = { |
|
**WAVELET_TOKENIZER_CONFIG, |
|
"vocab_size": self.vocab_size, |
|
"mu": self.mu, |
|
"verbose": self.verbose |
|
} |
|
|
|
config_file = os.path.join( |
|
save_directory, |
|
(filename_prefix + "-" if filename_prefix else "") + "tokenizer_config.json" |
|
) |
|
|
|
with open(config_file, "w") as f: |
|
json.dump(config, f, indent=2) |
|
|
|
|
|
vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix) |
|
|
|
if push_to_hub: |
|
|
|
api = HfApi() |
|
api.upload_file( |
|
path_or_fileobj=config_file, |
|
path_in_repo="tokenizer_config.json", |
|
repo_id=save_directory, |
|
commit_message=kwargs.get("commit_message", "Upload tokenizer config") |
|
) |
|
|
|
|
|
vocab_file = vocab_files[0] |
|
api.upload_file( |
|
path_or_fileobj=vocab_file, |
|
path_in_repo=os.path.basename(vocab_file), |
|
repo_id=save_directory, |
|
commit_message=kwargs.get("commit_message", "Upload tokenizer vocabulary") |
|
) |
|
|
|
return vocab_files + (config_file,) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: str, |
|
**kwargs |
|
) -> "WaveletTokenizer": |
|
"""Load tokenizer from HuggingFace Hub.""" |
|
|
|
config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json") |
|
if os.path.exists(config_file): |
|
with open(config_file, "r") as f: |
|
config = json.load(f) |
|
|
|
config.update(kwargs) |
|
else: |
|
config = kwargs |
|
|
|
return cls(**config) |
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
"""Returns vocab as a dict mapping token strings to ids.""" |
|
|
|
return {str(i): i for i in range(self.vocab_size)} |
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
"""Converts a token string to its ID.""" |
|
try: |
|
return int(token) |
|
except ValueError: |
|
return 0 |
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
"""Converts an ID back to its token string.""" |
|
return str(index) |
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
"""Converts a sequence of tokens to a single string.""" |
|
return " ".join(tokens) |
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
"""Basic tokenization for compatibility.""" |
|
if isinstance(text, str): |
|
return [text] |
|
return [str(t) for t in text] |
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]: |
|
"""Save the vocabulary to a directory.""" |
|
vocab_file = os.path.join( |
|
save_directory, |
|
(filename_prefix + "-" if filename_prefix else "") + "vocab.json" |
|
) |
|
|
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
json.dump(self.get_vocab(), f, ensure_ascii=False) |
|
|
|
return (vocab_file,) |
|
|
|
def __call__( |
|
self, |
|
eeg_data: np.ndarray, |
|
**kwargs |
|
) -> Dict[str, np.ndarray]: |
|
""" |
|
Main entry point for tokenization. Handles numpy array input. |
|
|
|
Args: |
|
eeg_data: Raw EEG array of shape (n_channels, time_points) |
|
|
|
Returns: |
|
Dictionary containing: |
|
- input_ids: Tokenized signal values |
|
- attention_mask: Binary mask (all ones since we don't pad) |
|
- position_ids: Sequential position indices |
|
""" |
|
|
|
input_ids = self.encode(eeg_data) |
|
|
|
|
|
attention_mask = np.ones_like(input_ids) |
|
|
|
|
|
n_channels, time_points = eeg_data.shape |
|
position_ids = np.tile(np.arange(time_points), (n_channels, 1)) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"position_ids": position_ids |
|
} |
|
|
|
def encode(self, eeg_data: np.ndarray) -> np.ndarray: |
|
"""Convert EEG data to token IDs.""" |
|
|
|
normalized = self.normalize(eeg_data) |
|
|
|
|
|
centered = 2 * normalized - 1 |
|
|
|
|
|
compressed = self.mu_law_encode(centered) |
|
|
|
|
|
input_values = (compressed + 1) / 2 |
|
token_ids = (input_values * (self.vocab_size - 1)).astype(np.int64) |
|
|
|
return token_ids |
|
|
|
def normalize(self, x: np.ndarray) -> np.ndarray: |
|
""" |
|
Apply static normalization per channel and store min/max values. |
|
Input shape: (n_channels, time_points) |
|
""" |
|
|
|
self.channel_mins = x.min(axis=1)[:, np.newaxis] |
|
self.channel_maxs = x.max(axis=1)[:, np.newaxis] |
|
|
|
normalized = (x - self.channel_mins) / (self.channel_maxs - self.channel_mins + 1e-8) |
|
|
|
if self.verbose: |
|
logger.info(f"Min-max normalization: input range [{x.min():.3f}, {x.max():.3f}] → [{normalized.min():.3f}, {normalized.max():.3f}]") |
|
return normalized |
|
|
|
def mu_law_encode(self, x: np.ndarray) -> np.ndarray: |
|
""" |
|
Apply μ-law compression. |
|
Expects input in [-1, 1] range. |
|
""" |
|
assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}" |
|
compressed = np.sign(x) * np.log1p(self.mu * np.abs(x)) / np.log1p(self.mu) |
|
|
|
if self.verbose: |
|
logger.info(f"μ-law compression (μ={self.mu:.2f}): variance before={np.var(x):.3f}, after={np.var(compressed):.3f}") |
|
return compressed |
|
|
|
def mu_law_decode(self, x: np.ndarray) -> np.ndarray: |
|
""" |
|
Inverse μ-law compression. |
|
Expects input in [-1, 1] range. |
|
""" |
|
assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}" |
|
return np.sign(x) * (1/self.mu) * (np.power(1 + self.mu, np.abs(x)) - 1.0) |
|
|
|
def decode(self, token_ids: np.ndarray) -> np.ndarray: |
|
""" |
|
Decode token IDs back to EEG signal. |
|
|
|
Args: |
|
token_ids: Array of token IDs of shape (n_channels, time_points) |
|
|
|
Returns: |
|
Array of shape (n_channels, time_points) |
|
""" |
|
|
|
values = token_ids.astype(np.float32) / (self.vocab_size - 1) |
|
values = 2 * values - 1 |
|
|
|
|
|
values = self.mu_law_decode(values) |
|
|
|
|
|
values = (values + 1) / 2 |
|
|
|
|
|
if self.channel_mins is not None and self.channel_maxs is not None: |
|
values = values * (self.channel_maxs - self.channel_mins) + self.channel_mins |
|
|
|
return values |