gabrycina commited on
Commit
5caae43
·
verified ·
1 Parent(s): 05db8dc

Upload tokenizer implementation

Browse files
Files changed (1) hide show
  1. tokenizer.py +287 -0
tokenizer.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import logging
3
+ from typing import Dict, List, Tuple, Optional
4
+ from dataclasses import dataclass
5
+ from transformers import PreTrainedTokenizer
6
+ import os
7
+ import json
8
+ from huggingface_hub import Repository
9
+ from huggingface_hub import HfApi
10
+
11
+ # Setup logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ WAVELET_TOKENIZER_CONFIG = {
16
+ "model_type": "wavelet",
17
+ "tokenizer_class": "WaveletTokenizer",
18
+ "auto_map": {
19
+ "AutoTokenizer": ["tokenizer.WaveletTokenizer", None]
20
+ }
21
+ }
22
+
23
+ @dataclass
24
+ class WaveletTokenizerConfig:
25
+ vocab_size: int = 256
26
+ padding_idx: int = 0
27
+ eeg_channels: int = 74 # Source modality (EEG)
28
+ mu: float = 255.0 # Static μ value for μ-law compression
29
+ verbose: bool = True # Control logging
30
+
31
+ class WaveletTokenizer(PreTrainedTokenizer):
32
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
33
+
34
+ def __init__(
35
+ self,
36
+ vocab_size: int = 256,
37
+ mu: float = 255.0,
38
+ verbose: bool = True,
39
+ **kwargs
40
+ ):
41
+ self.auto_map = {
42
+ "AutoTokenizer": ["tokenizer.WaveletTokenizer", None]
43
+ }
44
+
45
+ # Set vocab size first
46
+ self._vocab_size = vocab_size
47
+ self.mu = mu
48
+ self.verbose = verbose
49
+
50
+ # Store normalization state
51
+ self.channel_mins = None
52
+ self.channel_maxs = None
53
+
54
+ # Initialize parent class after setting vocab_size
55
+ super().__init__(**kwargs)
56
+
57
+ if self.verbose:
58
+ logger.info(f"Initialized WaveletTokenizer with μ={self.mu:.2f}")
59
+
60
+ @property
61
+ def vocab_size(self) -> int:
62
+ """Returns the size of vocabulary (number of possible quantization levels)."""
63
+ return self._vocab_size
64
+
65
+ @vocab_size.setter
66
+ def vocab_size(self, size: int):
67
+ self._vocab_size = size
68
+
69
+ def save_pretrained(
70
+ self,
71
+ save_directory: str,
72
+ legacy_format: bool = True,
73
+ filename_prefix: Optional[str] = None,
74
+ push_to_hub: bool = False,
75
+ **kwargs
76
+ ) -> Tuple[str, ...]:
77
+ """Save tokenizer configuration to a directory."""
78
+ if not os.path.exists(save_directory):
79
+ os.makedirs(save_directory)
80
+
81
+ # Save tokenizer config
82
+ config = {
83
+ **WAVELET_TOKENIZER_CONFIG,
84
+ "vocab_size": self.vocab_size,
85
+ "mu": self.mu,
86
+ "verbose": self.verbose
87
+ }
88
+
89
+ config_file = os.path.join(
90
+ save_directory,
91
+ (filename_prefix + "-" if filename_prefix else "") + "tokenizer_config.json"
92
+ )
93
+
94
+ with open(config_file, "w") as f:
95
+ json.dump(config, f, indent=2)
96
+
97
+ # Save vocabulary
98
+ vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
99
+
100
+ if push_to_hub:
101
+ # Upload files using HTTP
102
+ api = HfApi()
103
+ api.upload_file(
104
+ path_or_fileobj=config_file,
105
+ path_in_repo="tokenizer_config.json",
106
+ repo_id=save_directory,
107
+ commit_message=kwargs.get("commit_message", "Upload tokenizer config")
108
+ )
109
+
110
+ # Upload vocabulary file
111
+ vocab_file = vocab_files[0]
112
+ api.upload_file(
113
+ path_or_fileobj=vocab_file,
114
+ path_in_repo=os.path.basename(vocab_file),
115
+ repo_id=save_directory,
116
+ commit_message=kwargs.get("commit_message", "Upload tokenizer vocabulary")
117
+ )
118
+
119
+ return vocab_files + (config_file,)
120
+
121
+ @classmethod
122
+ def from_pretrained(
123
+ cls,
124
+ pretrained_model_name_or_path: str,
125
+ **kwargs
126
+ ) -> "WaveletTokenizer":
127
+ """Load tokenizer from HuggingFace Hub."""
128
+ # Load config first
129
+ config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
130
+ if os.path.exists(config_file):
131
+ with open(config_file, "r") as f:
132
+ config = json.load(f)
133
+ # Update with any passed kwargs
134
+ config.update(kwargs)
135
+ else:
136
+ config = kwargs
137
+
138
+ return cls(**config)
139
+
140
+ def get_vocab(self) -> Dict[str, int]:
141
+ """Returns vocab as a dict mapping token strings to ids."""
142
+ # Create a minimal vocabulary with quantization levels
143
+ return {str(i): i for i in range(self.vocab_size)}
144
+
145
+ def _convert_token_to_id(self, token: str) -> int:
146
+ """Converts a token string to its ID."""
147
+ try:
148
+ return int(token)
149
+ except ValueError:
150
+ return 0 # Return 0 for unknown tokens
151
+
152
+ def _convert_id_to_token(self, index: int) -> str:
153
+ """Converts an ID back to its token string."""
154
+ return str(index)
155
+
156
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
157
+ """Converts a sequence of tokens to a single string."""
158
+ return " ".join(tokens)
159
+
160
+ def _tokenize(self, text: str) -> List[str]:
161
+ """Basic tokenization for compatibility."""
162
+ if isinstance(text, str):
163
+ return [text]
164
+ return [str(t) for t in text]
165
+
166
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]:
167
+ """Save the vocabulary to a directory."""
168
+ vocab_file = os.path.join(
169
+ save_directory,
170
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
171
+ )
172
+
173
+ with open(vocab_file, "w", encoding="utf-8") as f:
174
+ json.dump(self.get_vocab(), f, ensure_ascii=False)
175
+
176
+ return (vocab_file,)
177
+
178
+ def __call__(
179
+ self,
180
+ eeg_data: np.ndarray,
181
+ **kwargs
182
+ ) -> Dict[str, np.ndarray]:
183
+ """
184
+ Main entry point for tokenization. Handles numpy array input.
185
+
186
+ Args:
187
+ eeg_data: Raw EEG array of shape (n_channels, time_points)
188
+
189
+ Returns:
190
+ Dictionary containing:
191
+ - input_ids: Tokenized signal values
192
+ - attention_mask: Binary mask (all ones since we don't pad)
193
+ - position_ids: Sequential position indices
194
+ """
195
+ # Process through tokenization pipeline
196
+ input_ids = self.encode(eeg_data)
197
+
198
+ # Create attention mask (all ones since we're not padding)
199
+ attention_mask = np.ones_like(input_ids)
200
+
201
+ # Create position IDs
202
+ n_channels, time_points = eeg_data.shape
203
+ position_ids = np.tile(np.arange(time_points), (n_channels, 1))
204
+
205
+ return {
206
+ "input_ids": input_ids,
207
+ "attention_mask": attention_mask,
208
+ "position_ids": position_ids
209
+ }
210
+
211
+ def encode(self, eeg_data: np.ndarray) -> np.ndarray:
212
+ """Convert EEG data to token IDs."""
213
+ # 1. Normalize to [0, 1]
214
+ normalized = self.normalize(eeg_data)
215
+
216
+ # 2. Convert to [-1, 1] for μ-law compression
217
+ centered = 2 * normalized - 1
218
+
219
+ # 3. Apply μ-law compression
220
+ compressed = self.mu_law_encode(centered)
221
+
222
+ # 4. Quantize to tokens
223
+ input_values = (compressed + 1) / 2 # to [0, 1]
224
+ token_ids = (input_values * (self.vocab_size - 1)).astype(np.int64)
225
+
226
+ return token_ids
227
+
228
+ def normalize(self, x: np.ndarray) -> np.ndarray:
229
+ """
230
+ Apply static normalization per channel and store min/max values.
231
+ Input shape: (n_channels, time_points)
232
+ """
233
+ # Compute min/max per channel and expand dimensions to match input
234
+ self.channel_mins = x.min(axis=1)[:, np.newaxis] # Shape: (n_channels, 1)
235
+ self.channel_maxs = x.max(axis=1)[:, np.newaxis] # Shape: (n_channels, 1)
236
+
237
+ normalized = (x - self.channel_mins) / (self.channel_maxs - self.channel_mins + 1e-8)
238
+
239
+ if self.verbose:
240
+ logger.info(f"Min-max normalization: input range [{x.min():.3f}, {x.max():.3f}] → [{normalized.min():.3f}, {normalized.max():.3f}]")
241
+ return normalized
242
+
243
+ def mu_law_encode(self, x: np.ndarray) -> np.ndarray:
244
+ """
245
+ Apply μ-law compression.
246
+ Expects input in [-1, 1] range.
247
+ """
248
+ assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}"
249
+ compressed = np.sign(x) * np.log1p(self.mu * np.abs(x)) / np.log1p(self.mu)
250
+
251
+ if self.verbose:
252
+ logger.info(f"μ-law compression (μ={self.mu:.2f}): variance before={np.var(x):.3f}, after={np.var(compressed):.3f}")
253
+ return compressed
254
+
255
+ def mu_law_decode(self, x: np.ndarray) -> np.ndarray:
256
+ """
257
+ Inverse μ-law compression.
258
+ Expects input in [-1, 1] range.
259
+ """
260
+ assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}"
261
+ return np.sign(x) * (1/self.mu) * (np.power(1 + self.mu, np.abs(x)) - 1.0)
262
+
263
+ def decode(self, token_ids: np.ndarray) -> np.ndarray:
264
+ """
265
+ Decode token IDs back to EEG signal.
266
+
267
+ Args:
268
+ token_ids: Array of token IDs of shape (n_channels, time_points)
269
+
270
+ Returns:
271
+ Array of shape (n_channels, time_points)
272
+ """
273
+ # Convert to continuous values in [-1, 1]
274
+ values = token_ids.astype(np.float32) / (self.vocab_size - 1) # [0, 1]
275
+ values = 2 * values - 1 # [-1, 1]
276
+
277
+ # Apply inverse μ-law compression
278
+ values = self.mu_law_decode(values)
279
+
280
+ # Convert back to [0, 1]
281
+ values = (values + 1) / 2
282
+
283
+ # Denormalize to original scale
284
+ if self.channel_mins is not None and self.channel_maxs is not None:
285
+ values = values * (self.channel_maxs - self.channel_mins) + self.channel_mins
286
+
287
+ return values