eeg2meg-tokenizer / tokenizer.py
gabrycina's picture
Upload tokenizer implementation
2fc0768 verified
import numpy as np
import logging
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
import os
import json
from huggingface_hub import Repository
from huggingface_hub import HfApi
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
WAVELET_TOKENIZER_CONFIG = {
"model_type": "wavelet",
"tokenizer_class": "WaveletTokenizer",
"auto_map": {
"AutoTokenizer": ["tokenizer.WaveletTokenizer", None]
}
}
@dataclass
class WaveletTokenizerConfig:
vocab_size: int = 256
padding_idx: int = 0
eeg_channels: int = 74 # Source modality (EEG)
mu: float = 255.0 # Static μ value for μ-law compression
verbose: bool = True # Control logging
class WaveletTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(
self,
vocab_size: int = 256,
mu: float = 255.0,
verbose: bool = True,
**kwargs
):
self.auto_map = {
"AutoTokenizer": ["tokenizer.WaveletTokenizer", None]
}
# Set vocab size first
self._vocab_size = vocab_size
self.mu = mu
self.verbose = verbose
# Store normalization state
self.channel_mins = None
self.channel_maxs = None
# Initialize parent class after setting vocab_size
super().__init__(**kwargs)
if self.verbose:
logger.info(f"Initialized WaveletTokenizer with μ={self.mu:.2f}")
@property
def vocab_size(self) -> int:
"""Returns the size of vocabulary (number of possible quantization levels)."""
return self._vocab_size
@vocab_size.setter
def vocab_size(self, size: int):
self._vocab_size = size
def save_pretrained(
self,
save_directory: str,
legacy_format: bool = True,
filename_prefix: Optional[str] = None,
push_to_hub: bool = False,
**kwargs
) -> Tuple[str, ...]:
"""Save tokenizer configuration to a directory."""
if not os.path.exists(save_directory):
os.makedirs(save_directory)
# Save tokenizer config
config = {
**WAVELET_TOKENIZER_CONFIG,
"vocab_size": self.vocab_size,
"mu": self.mu,
"verbose": self.verbose
}
config_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "tokenizer_config.json"
)
with open(config_file, "w") as f:
json.dump(config, f, indent=2)
# Save vocabulary
vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
if push_to_hub:
# Upload files using HTTP
api = HfApi()
api.upload_file(
path_or_fileobj=config_file,
path_in_repo="tokenizer_config.json",
repo_id=save_directory,
commit_message=kwargs.get("commit_message", "Upload tokenizer config")
)
# Upload vocabulary file
vocab_file = vocab_files[0]
api.upload_file(
path_or_fileobj=vocab_file,
path_in_repo=os.path.basename(vocab_file),
repo_id=save_directory,
commit_message=kwargs.get("commit_message", "Upload tokenizer vocabulary")
)
return vocab_files + (config_file,)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
**kwargs
) -> "WaveletTokenizer":
"""Load tokenizer from HuggingFace Hub."""
# Load config first
config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
if os.path.exists(config_file):
with open(config_file, "r") as f:
config = json.load(f)
# Update with any passed kwargs
config.update(kwargs)
else:
config = kwargs
return cls(**config)
def get_vocab(self) -> Dict[str, int]:
"""Returns vocab as a dict mapping token strings to ids."""
# Create a minimal vocabulary with quantization levels
return {str(i): i for i in range(self.vocab_size)}
def _convert_token_to_id(self, token: str) -> int:
"""Converts a token string to its ID."""
try:
return int(token)
except ValueError:
return 0 # Return 0 for unknown tokens
def _convert_id_to_token(self, index: int) -> str:
"""Converts an ID back to its token string."""
return str(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens to a single string."""
return " ".join(tokens)
def _tokenize(self, text: str) -> List[str]:
"""Basic tokenization for compatibility."""
if isinstance(text, str):
return [text]
return [str(t) for t in text]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]:
"""Save the vocabulary to a directory."""
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self.get_vocab(), f, ensure_ascii=False)
return (vocab_file,)
def __call__(
self,
eeg_data: np.ndarray,
**kwargs
) -> Dict[str, np.ndarray]:
"""
Main entry point for tokenization. Handles numpy array input.
Args:
eeg_data: Raw EEG array of shape (n_channels, time_points)
Returns:
Dictionary containing:
- input_ids: Tokenized signal values
- attention_mask: Binary mask (all ones since we don't pad)
- position_ids: Sequential position indices
"""
# Process through tokenization pipeline
input_ids = self.encode(eeg_data)
# Create attention mask (all ones since we're not padding)
attention_mask = np.ones_like(input_ids)
# Create position IDs
n_channels, time_points = eeg_data.shape
position_ids = np.tile(np.arange(time_points), (n_channels, 1))
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids
}
def encode(self, eeg_data: np.ndarray) -> np.ndarray:
"""Convert EEG data to token IDs."""
# 1. Normalize to [0, 1]
normalized = self.normalize(eeg_data)
# 2. Convert to [-1, 1] for μ-law compression
centered = 2 * normalized - 1
# 3. Apply μ-law compression
compressed = self.mu_law_encode(centered)
# 4. Quantize to tokens
input_values = (compressed + 1) / 2 # to [0, 1]
token_ids = (input_values * (self.vocab_size - 1)).astype(np.int64)
return token_ids
def normalize(self, x: np.ndarray) -> np.ndarray:
"""
Apply static normalization per channel and store min/max values.
Input shape: (n_channels, time_points)
"""
# Compute min/max per channel and expand dimensions to match input
self.channel_mins = x.min(axis=1)[:, np.newaxis] # Shape: (n_channels, 1)
self.channel_maxs = x.max(axis=1)[:, np.newaxis] # Shape: (n_channels, 1)
normalized = (x - self.channel_mins) / (self.channel_maxs - self.channel_mins + 1e-8)
if self.verbose:
logger.info(f"Min-max normalization: input range [{x.min():.3f}, {x.max():.3f}] → [{normalized.min():.3f}, {normalized.max():.3f}]")
return normalized
def mu_law_encode(self, x: np.ndarray) -> np.ndarray:
"""
Apply μ-law compression.
Expects input in [-1, 1] range.
"""
assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}"
compressed = np.sign(x) * np.log1p(self.mu * np.abs(x)) / np.log1p(self.mu)
if self.verbose:
logger.info(f"μ-law compression (μ={self.mu:.2f}): variance before={np.var(x):.3f}, after={np.var(compressed):.3f}")
return compressed
def mu_law_decode(self, x: np.ndarray) -> np.ndarray:
"""
Inverse μ-law compression.
Expects input in [-1, 1] range.
"""
assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}"
return np.sign(x) * (1/self.mu) * (np.power(1 + self.mu, np.abs(x)) - 1.0)
def decode(self, token_ids: np.ndarray) -> np.ndarray:
"""
Decode token IDs back to EEG signal.
Args:
token_ids: Array of token IDs of shape (n_channels, time_points)
Returns:
Array of shape (n_channels, time_points)
"""
# Convert to continuous values in [-1, 1]
values = token_ids.astype(np.float32) / (self.vocab_size - 1) # [0, 1]
values = 2 * values - 1 # [-1, 1]
# Apply inverse μ-law compression
values = self.mu_law_decode(values)
# Convert back to [0, 1]
values = (values + 1) / 2
# Denormalize to original scale
if self.channel_mins is not None and self.channel_maxs is not None:
values = values * (self.channel_maxs - self.channel_mins) + self.channel_mins
return values