File size: 10,059 Bytes
5caae43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fc0768
5caae43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fc0768
5caae43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import numpy as np
import logging
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
import os
import json
from huggingface_hub import Repository
from huggingface_hub import HfApi

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

WAVELET_TOKENIZER_CONFIG = {
    "model_type": "wavelet",
    "tokenizer_class": "WaveletTokenizer",
    "auto_map": {
        "AutoTokenizer": ["tokenizer.WaveletTokenizer", None]
    }
}

@dataclass
class WaveletTokenizerConfig:
    vocab_size: int = 256
    padding_idx: int = 0
    eeg_channels: int = 74     # Source modality (EEG)
    mu: float = 255.0         # Static μ value for μ-law compression
    verbose: bool = True       # Control logging

class WaveletTokenizer(PreTrainedTokenizer):
    model_input_names = ["input_ids", "attention_mask", "position_ids"]
    
    def __init__(
        self,
        vocab_size: int = 256,
        mu: float = 255.0,
        verbose: bool = True,
        **kwargs
    ):
        self.auto_map = {
            "AutoTokenizer": ["tokenizer.WaveletTokenizer", None]
        }
        
        # Set vocab size first
        self._vocab_size = vocab_size
        self.mu = mu
        self.verbose = verbose
        
        # Store normalization state
        self.channel_mins = None
        self.channel_maxs = None
        
        # Initialize parent class after setting vocab_size
        super().__init__(**kwargs)
        
        if self.verbose:
            logger.info(f"Initialized WaveletTokenizer with μ={self.mu:.2f}")
    
    @property
    def vocab_size(self) -> int:
        """Returns the size of vocabulary (number of possible quantization levels)."""
        return self._vocab_size
    
    @vocab_size.setter
    def vocab_size(self, size: int):
        self._vocab_size = size
    
    def save_pretrained(
        self, 
        save_directory: str,
        legacy_format: bool = True,
        filename_prefix: Optional[str] = None,
        push_to_hub: bool = False,
        **kwargs
    ) -> Tuple[str, ...]:
        """Save tokenizer configuration to a directory."""
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)
            
        # Save tokenizer config
        config = {
            **WAVELET_TOKENIZER_CONFIG,
            "vocab_size": self.vocab_size,
            "mu": self.mu,
            "verbose": self.verbose
        }
        
        config_file = os.path.join(
            save_directory, 
            (filename_prefix + "-" if filename_prefix else "") + "tokenizer_config.json"
        )
        
        with open(config_file, "w") as f:
            json.dump(config, f, indent=2)
            
        # Save vocabulary
        vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
        
        if push_to_hub:
            # Upload files using HTTP
            api = HfApi()
            api.upload_file(
                path_or_fileobj=config_file,
                path_in_repo="tokenizer_config.json",
                repo_id=save_directory,
                commit_message=kwargs.get("commit_message", "Upload tokenizer config")
            )
            
            # Upload vocabulary file
            vocab_file = vocab_files[0]
            api.upload_file(
                path_or_fileobj=vocab_file,
                path_in_repo=os.path.basename(vocab_file),
                repo_id=save_directory,
                commit_message=kwargs.get("commit_message", "Upload tokenizer vocabulary")
            )
            
        return vocab_files + (config_file,)
    
    @classmethod
    def from_pretrained(
        cls, 
        pretrained_model_name_or_path: str, 
        **kwargs
    ) -> "WaveletTokenizer":
        """Load tokenizer from HuggingFace Hub."""
        # Load config first
        config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
        if os.path.exists(config_file):
            with open(config_file, "r") as f:
                config = json.load(f)
            # Update with any passed kwargs
            config.update(kwargs)
        else:
            config = kwargs
            
        return cls(**config)
    
    def get_vocab(self) -> Dict[str, int]:
        """Returns vocab as a dict mapping token strings to ids."""
        # Create a minimal vocabulary with quantization levels
        return {str(i): i for i in range(self.vocab_size)}
    
    def _convert_token_to_id(self, token: str) -> int:
        """Converts a token string to its ID."""
        try:
            return int(token)
        except ValueError:
            return 0  # Return 0 for unknown tokens
    
    def _convert_id_to_token(self, index: int) -> str:
        """Converts an ID back to its token string."""
        return str(index)
    
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Converts a sequence of tokens to a single string."""
        return " ".join(tokens)
    
    def _tokenize(self, text: str) -> List[str]:
        """Basic tokenization for compatibility."""
        if isinstance(text, str):
            return [text]
        return [str(t) for t in text]
    
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]:
        """Save the vocabulary to a directory."""
        vocab_file = os.path.join(
            save_directory, 
            (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
        )
        
        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self.get_vocab(), f, ensure_ascii=False)
            
        return (vocab_file,)
    
    def __call__(
        self,
        eeg_data: np.ndarray,
        **kwargs
    ) -> Dict[str, np.ndarray]:
        """
        Main entry point for tokenization. Handles numpy array input.
        
        Args:
            eeg_data: Raw EEG array of shape (n_channels, time_points)
            
        Returns:
            Dictionary containing:
                - input_ids: Tokenized signal values
                - attention_mask: Binary mask (all ones since we don't pad)
                - position_ids: Sequential position indices
        """
        # Process through tokenization pipeline
        input_ids = self.encode(eeg_data)
        
        # Create attention mask (all ones since we're not padding)
        attention_mask = np.ones_like(input_ids)
        
        # Create position IDs
        n_channels, time_points = eeg_data.shape
        position_ids = np.tile(np.arange(time_points), (n_channels, 1))
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "position_ids": position_ids
        }
    
    def encode(self, eeg_data: np.ndarray) -> np.ndarray:
        """Convert EEG data to token IDs."""
        # 1. Normalize to [0, 1]
        normalized = self.normalize(eeg_data)
        
        # 2. Convert to [-1, 1] for μ-law compression
        centered = 2 * normalized - 1
        
        # 3. Apply μ-law compression
        compressed = self.mu_law_encode(centered)
        
        # 4. Quantize to tokens
        input_values = (compressed + 1) / 2  # to [0, 1]
        token_ids = (input_values * (self.vocab_size - 1)).astype(np.int64)
        
        return token_ids
    
    def normalize(self, x: np.ndarray) -> np.ndarray:
        """
        Apply static normalization per channel and store min/max values.
        Input shape: (n_channels, time_points)
        """
        # Compute min/max per channel and expand dimensions to match input
        self.channel_mins = x.min(axis=1)[:, np.newaxis]  # Shape: (n_channels, 1)
        self.channel_maxs = x.max(axis=1)[:, np.newaxis]  # Shape: (n_channels, 1)
        
        normalized = (x - self.channel_mins) / (self.channel_maxs - self.channel_mins + 1e-8)
        
        if self.verbose:
            logger.info(f"Min-max normalization: input range [{x.min():.3f}, {x.max():.3f}] → [{normalized.min():.3f}, {normalized.max():.3f}]")
        return normalized
    
    def mu_law_encode(self, x: np.ndarray) -> np.ndarray:
        """
        Apply μ-law compression.
        Expects input in [-1, 1] range.
        """
        assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}"
        compressed = np.sign(x) * np.log1p(self.mu * np.abs(x)) / np.log1p(self.mu)
        
        if self.verbose:
            logger.info(f"μ-law compression (μ={self.mu:.2f}): variance before={np.var(x):.3f}, after={np.var(compressed):.3f}")
        return compressed
    
    def mu_law_decode(self, x: np.ndarray) -> np.ndarray:
        """
        Inverse μ-law compression.
        Expects input in [-1, 1] range.
        """
        assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}"
        return np.sign(x) * (1/self.mu) * (np.power(1 + self.mu, np.abs(x)) - 1.0)
    
    def decode(self, token_ids: np.ndarray) -> np.ndarray:
        """
        Decode token IDs back to EEG signal.
        
        Args:
            token_ids: Array of token IDs of shape (n_channels, time_points)
            
        Returns:
            Array of shape (n_channels, time_points)
        """
        # Convert to continuous values in [-1, 1]
        values = token_ids.astype(np.float32) / (self.vocab_size - 1)  # [0, 1]
        values = 2 * values - 1  # [-1, 1]
        
        # Apply inverse μ-law compression
        values = self.mu_law_decode(values)
        
        # Convert back to [0, 1]
        values = (values + 1) / 2
        
        # Denormalize to original scale
        if self.channel_mins is not None and self.channel_maxs is not None:
            values = values * (self.channel_maxs - self.channel_mins) + self.channel_mins
        
        return values