Upload tokenizer implementation
Browse files- tokenizer.py +287 -0
tokenizer.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import logging
|
3 |
+
from typing import Dict, List, Tuple, Optional
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from transformers import PreTrainedTokenizer
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
from huggingface_hub import Repository
|
9 |
+
from huggingface_hub import HfApi
|
10 |
+
|
11 |
+
# Setup logging
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
WAVELET_TOKENIZER_CONFIG = {
|
16 |
+
"model_type": "wavelet",
|
17 |
+
"tokenizer_class": "WaveletTokenizer",
|
18 |
+
"auto_map": {
|
19 |
+
"AutoTokenizer": ["tokenizer.WaveletTokenizer", None]
|
20 |
+
}
|
21 |
+
}
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class WaveletTokenizerConfig:
|
25 |
+
vocab_size: int = 256
|
26 |
+
padding_idx: int = 0
|
27 |
+
eeg_channels: int = 74 # Source modality (EEG)
|
28 |
+
mu: float = 255.0 # Static μ value for μ-law compression
|
29 |
+
verbose: bool = True # Control logging
|
30 |
+
|
31 |
+
class WaveletTokenizer(PreTrainedTokenizer):
|
32 |
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
vocab_size: int = 256,
|
37 |
+
mu: float = 255.0,
|
38 |
+
verbose: bool = True,
|
39 |
+
**kwargs
|
40 |
+
):
|
41 |
+
self.auto_map = {
|
42 |
+
"AutoTokenizer": ["tokenizer.WaveletTokenizer", None]
|
43 |
+
}
|
44 |
+
|
45 |
+
# Set vocab size first
|
46 |
+
self._vocab_size = vocab_size
|
47 |
+
self.mu = mu
|
48 |
+
self.verbose = verbose
|
49 |
+
|
50 |
+
# Store normalization state
|
51 |
+
self.channel_mins = None
|
52 |
+
self.channel_maxs = None
|
53 |
+
|
54 |
+
# Initialize parent class after setting vocab_size
|
55 |
+
super().__init__(**kwargs)
|
56 |
+
|
57 |
+
if self.verbose:
|
58 |
+
logger.info(f"Initialized WaveletTokenizer with μ={self.mu:.2f}")
|
59 |
+
|
60 |
+
@property
|
61 |
+
def vocab_size(self) -> int:
|
62 |
+
"""Returns the size of vocabulary (number of possible quantization levels)."""
|
63 |
+
return self._vocab_size
|
64 |
+
|
65 |
+
@vocab_size.setter
|
66 |
+
def vocab_size(self, size: int):
|
67 |
+
self._vocab_size = size
|
68 |
+
|
69 |
+
def save_pretrained(
|
70 |
+
self,
|
71 |
+
save_directory: str,
|
72 |
+
legacy_format: bool = True,
|
73 |
+
filename_prefix: Optional[str] = None,
|
74 |
+
push_to_hub: bool = False,
|
75 |
+
**kwargs
|
76 |
+
) -> Tuple[str, ...]:
|
77 |
+
"""Save tokenizer configuration to a directory."""
|
78 |
+
if not os.path.exists(save_directory):
|
79 |
+
os.makedirs(save_directory)
|
80 |
+
|
81 |
+
# Save tokenizer config
|
82 |
+
config = {
|
83 |
+
**WAVELET_TOKENIZER_CONFIG,
|
84 |
+
"vocab_size": self.vocab_size,
|
85 |
+
"mu": self.mu,
|
86 |
+
"verbose": self.verbose
|
87 |
+
}
|
88 |
+
|
89 |
+
config_file = os.path.join(
|
90 |
+
save_directory,
|
91 |
+
(filename_prefix + "-" if filename_prefix else "") + "tokenizer_config.json"
|
92 |
+
)
|
93 |
+
|
94 |
+
with open(config_file, "w") as f:
|
95 |
+
json.dump(config, f, indent=2)
|
96 |
+
|
97 |
+
# Save vocabulary
|
98 |
+
vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
|
99 |
+
|
100 |
+
if push_to_hub:
|
101 |
+
# Upload files using HTTP
|
102 |
+
api = HfApi()
|
103 |
+
api.upload_file(
|
104 |
+
path_or_fileobj=config_file,
|
105 |
+
path_in_repo="tokenizer_config.json",
|
106 |
+
repo_id=save_directory,
|
107 |
+
commit_message=kwargs.get("commit_message", "Upload tokenizer config")
|
108 |
+
)
|
109 |
+
|
110 |
+
# Upload vocabulary file
|
111 |
+
vocab_file = vocab_files[0]
|
112 |
+
api.upload_file(
|
113 |
+
path_or_fileobj=vocab_file,
|
114 |
+
path_in_repo=os.path.basename(vocab_file),
|
115 |
+
repo_id=save_directory,
|
116 |
+
commit_message=kwargs.get("commit_message", "Upload tokenizer vocabulary")
|
117 |
+
)
|
118 |
+
|
119 |
+
return vocab_files + (config_file,)
|
120 |
+
|
121 |
+
@classmethod
|
122 |
+
def from_pretrained(
|
123 |
+
cls,
|
124 |
+
pretrained_model_name_or_path: str,
|
125 |
+
**kwargs
|
126 |
+
) -> "WaveletTokenizer":
|
127 |
+
"""Load tokenizer from HuggingFace Hub."""
|
128 |
+
# Load config first
|
129 |
+
config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
|
130 |
+
if os.path.exists(config_file):
|
131 |
+
with open(config_file, "r") as f:
|
132 |
+
config = json.load(f)
|
133 |
+
# Update with any passed kwargs
|
134 |
+
config.update(kwargs)
|
135 |
+
else:
|
136 |
+
config = kwargs
|
137 |
+
|
138 |
+
return cls(**config)
|
139 |
+
|
140 |
+
def get_vocab(self) -> Dict[str, int]:
|
141 |
+
"""Returns vocab as a dict mapping token strings to ids."""
|
142 |
+
# Create a minimal vocabulary with quantization levels
|
143 |
+
return {str(i): i for i in range(self.vocab_size)}
|
144 |
+
|
145 |
+
def _convert_token_to_id(self, token: str) -> int:
|
146 |
+
"""Converts a token string to its ID."""
|
147 |
+
try:
|
148 |
+
return int(token)
|
149 |
+
except ValueError:
|
150 |
+
return 0 # Return 0 for unknown tokens
|
151 |
+
|
152 |
+
def _convert_id_to_token(self, index: int) -> str:
|
153 |
+
"""Converts an ID back to its token string."""
|
154 |
+
return str(index)
|
155 |
+
|
156 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
157 |
+
"""Converts a sequence of tokens to a single string."""
|
158 |
+
return " ".join(tokens)
|
159 |
+
|
160 |
+
def _tokenize(self, text: str) -> List[str]:
|
161 |
+
"""Basic tokenization for compatibility."""
|
162 |
+
if isinstance(text, str):
|
163 |
+
return [text]
|
164 |
+
return [str(t) for t in text]
|
165 |
+
|
166 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]:
|
167 |
+
"""Save the vocabulary to a directory."""
|
168 |
+
vocab_file = os.path.join(
|
169 |
+
save_directory,
|
170 |
+
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
|
171 |
+
)
|
172 |
+
|
173 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
174 |
+
json.dump(self.get_vocab(), f, ensure_ascii=False)
|
175 |
+
|
176 |
+
return (vocab_file,)
|
177 |
+
|
178 |
+
def __call__(
|
179 |
+
self,
|
180 |
+
eeg_data: np.ndarray,
|
181 |
+
**kwargs
|
182 |
+
) -> Dict[str, np.ndarray]:
|
183 |
+
"""
|
184 |
+
Main entry point for tokenization. Handles numpy array input.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
eeg_data: Raw EEG array of shape (n_channels, time_points)
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
Dictionary containing:
|
191 |
+
- input_ids: Tokenized signal values
|
192 |
+
- attention_mask: Binary mask (all ones since we don't pad)
|
193 |
+
- position_ids: Sequential position indices
|
194 |
+
"""
|
195 |
+
# Process through tokenization pipeline
|
196 |
+
input_ids = self.encode(eeg_data)
|
197 |
+
|
198 |
+
# Create attention mask (all ones since we're not padding)
|
199 |
+
attention_mask = np.ones_like(input_ids)
|
200 |
+
|
201 |
+
# Create position IDs
|
202 |
+
n_channels, time_points = eeg_data.shape
|
203 |
+
position_ids = np.tile(np.arange(time_points), (n_channels, 1))
|
204 |
+
|
205 |
+
return {
|
206 |
+
"input_ids": input_ids,
|
207 |
+
"attention_mask": attention_mask,
|
208 |
+
"position_ids": position_ids
|
209 |
+
}
|
210 |
+
|
211 |
+
def encode(self, eeg_data: np.ndarray) -> np.ndarray:
|
212 |
+
"""Convert EEG data to token IDs."""
|
213 |
+
# 1. Normalize to [0, 1]
|
214 |
+
normalized = self.normalize(eeg_data)
|
215 |
+
|
216 |
+
# 2. Convert to [-1, 1] for μ-law compression
|
217 |
+
centered = 2 * normalized - 1
|
218 |
+
|
219 |
+
# 3. Apply μ-law compression
|
220 |
+
compressed = self.mu_law_encode(centered)
|
221 |
+
|
222 |
+
# 4. Quantize to tokens
|
223 |
+
input_values = (compressed + 1) / 2 # to [0, 1]
|
224 |
+
token_ids = (input_values * (self.vocab_size - 1)).astype(np.int64)
|
225 |
+
|
226 |
+
return token_ids
|
227 |
+
|
228 |
+
def normalize(self, x: np.ndarray) -> np.ndarray:
|
229 |
+
"""
|
230 |
+
Apply static normalization per channel and store min/max values.
|
231 |
+
Input shape: (n_channels, time_points)
|
232 |
+
"""
|
233 |
+
# Compute min/max per channel and expand dimensions to match input
|
234 |
+
self.channel_mins = x.min(axis=1)[:, np.newaxis] # Shape: (n_channels, 1)
|
235 |
+
self.channel_maxs = x.max(axis=1)[:, np.newaxis] # Shape: (n_channels, 1)
|
236 |
+
|
237 |
+
normalized = (x - self.channel_mins) / (self.channel_maxs - self.channel_mins + 1e-8)
|
238 |
+
|
239 |
+
if self.verbose:
|
240 |
+
logger.info(f"Min-max normalization: input range [{x.min():.3f}, {x.max():.3f}] → [{normalized.min():.3f}, {normalized.max():.3f}]")
|
241 |
+
return normalized
|
242 |
+
|
243 |
+
def mu_law_encode(self, x: np.ndarray) -> np.ndarray:
|
244 |
+
"""
|
245 |
+
Apply μ-law compression.
|
246 |
+
Expects input in [-1, 1] range.
|
247 |
+
"""
|
248 |
+
assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}"
|
249 |
+
compressed = np.sign(x) * np.log1p(self.mu * np.abs(x)) / np.log1p(self.mu)
|
250 |
+
|
251 |
+
if self.verbose:
|
252 |
+
logger.info(f"μ-law compression (μ={self.mu:.2f}): variance before={np.var(x):.3f}, after={np.var(compressed):.3f}")
|
253 |
+
return compressed
|
254 |
+
|
255 |
+
def mu_law_decode(self, x: np.ndarray) -> np.ndarray:
|
256 |
+
"""
|
257 |
+
Inverse μ-law compression.
|
258 |
+
Expects input in [-1, 1] range.
|
259 |
+
"""
|
260 |
+
assert np.all(x >= -1.0) and np.all(x <= 1.0), f"Input must be in [-1, 1] range, got min={x.min():.3f}, max={x.max():.3f}"
|
261 |
+
return np.sign(x) * (1/self.mu) * (np.power(1 + self.mu, np.abs(x)) - 1.0)
|
262 |
+
|
263 |
+
def decode(self, token_ids: np.ndarray) -> np.ndarray:
|
264 |
+
"""
|
265 |
+
Decode token IDs back to EEG signal.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
token_ids: Array of token IDs of shape (n_channels, time_points)
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
Array of shape (n_channels, time_points)
|
272 |
+
"""
|
273 |
+
# Convert to continuous values in [-1, 1]
|
274 |
+
values = token_ids.astype(np.float32) / (self.vocab_size - 1) # [0, 1]
|
275 |
+
values = 2 * values - 1 # [-1, 1]
|
276 |
+
|
277 |
+
# Apply inverse μ-law compression
|
278 |
+
values = self.mu_law_decode(values)
|
279 |
+
|
280 |
+
# Convert back to [0, 1]
|
281 |
+
values = (values + 1) / 2
|
282 |
+
|
283 |
+
# Denormalize to original scale
|
284 |
+
if self.channel_mins is not None and self.channel_maxs is not None:
|
285 |
+
values = values * (self.channel_maxs - self.channel_mins) + self.channel_mins
|
286 |
+
|
287 |
+
return values
|