mgelard commited on
Commit
e308d04
·
verified ·
1 Parent(s): 5c47e2d

Upload tokenizer

Browse files
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
tokenizer.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from transformers import PreTrainedTokenizer
8
+
9
+
10
+ class BinnedOmicTokenizer(PreTrainedTokenizer):
11
+ def __init__(
12
+ self,
13
+ n_expressions_bins: int = 64,
14
+ min_omic_value: float = 0.0,
15
+ max_omic_value: float = 1.0,
16
+ use_max_normalization: bool = True,
17
+ normalization_factor: float = 1.0,
18
+ prepend_cls_token: bool = False,
19
+ fixed_sequence_length: Optional[int] = None,
20
+ unpadded_length: Optional[int] = None,
21
+ **kwargs,
22
+ ):
23
+ bin_tokens = [str(i) for i in range(n_expressions_bins)]
24
+ special_tokens = ["<pad>", "<mask>", "<cls>"]
25
+
26
+ vocab = {tok: i for i, tok in enumerate(bin_tokens)}
27
+ offset = len(vocab)
28
+ for i, tok in enumerate(special_tokens):
29
+ vocab[tok] = offset + i
30
+
31
+ ids_to_tokens = {i: tok for tok, i in vocab.items()}
32
+
33
+ self.vocab = vocab
34
+ self.ids_to_tokens = ids_to_tokens
35
+
36
+ self.n_expressions_bins = n_expressions_bins
37
+ self.min_omic_value = min_omic_value
38
+ self.max_omic_value = max_omic_value
39
+ self.use_max_normalization = use_max_normalization
40
+ self.normalization_factor = normalization_factor
41
+ self.prepend_cls_token = prepend_cls_token
42
+ self.fixed_sequence_length = fixed_sequence_length
43
+ self.unpadded_length = unpadded_length
44
+
45
+ self.bin_edges = np.linspace(min_omic_value, max_omic_value, n_expressions_bins)
46
+
47
+ self.pad_token = "<pad>"
48
+ self.mask_token = "<mask>"
49
+ self.cls_token = "<cls>"
50
+
51
+ super().__init__(**kwargs)
52
+
53
+ self.add_special_tokens(
54
+ {
55
+ "pad_token": "<pad>",
56
+ "mask_token": "<mask>",
57
+ "cls_token": "<cls>",
58
+ "unk_token": "<pad>",
59
+ }
60
+ )
61
+
62
+ def _convert_token_to_id(self, token: str) -> int:
63
+ return self.vocab.get(token, self.vocab[self.unk_token])
64
+
65
+ def _convert_id_to_token(self, index: int) -> str:
66
+ return self.ids_to_tokens.get(index, self.unk_token)
67
+
68
+ def get_vocab(self) -> dict:
69
+ return self.vocab
70
+
71
+ def _tokenize(self, text, **kwargs):
72
+ raise NotImplementedError("Use `encode` or `batch_encode_plus` methods.")
73
+
74
+ def decode(self, token_ids, **kwargs):
75
+ return [self._convert_id_to_token(i) for i in token_ids]
76
+
77
+ def encode(
78
+ self,
79
+ gene_expr: Union[np.ndarray, List[float]],
80
+ pad_to_fixed_length: bool = False,
81
+ max_length: Optional[int] = None,
82
+ return_tensors: Optional[str] = None,
83
+ **kwargs,
84
+ ) -> Union[List[int], torch.Tensor]:
85
+ gene_expr = np.array(gene_expr)
86
+
87
+ if self.use_max_normalization:
88
+ gene_expr = gene_expr / self.normalization_factor
89
+
90
+ token_ids = np.digitize(gene_expr, self.bin_edges).astype(int)
91
+ token_ids[gene_expr == 0.0] = 0
92
+
93
+ if self.prepend_cls_token:
94
+ token_ids = np.concatenate([[self.cls_token_id], token_ids])
95
+
96
+ if pad_to_fixed_length:
97
+ current_max_length = self.fixed_sequence_length or max_length
98
+ if current_max_length is None:
99
+ raise ValueError("fixed_sequence_length or max_length must be set.")
100
+ pad_len = current_max_length - len(token_ids)
101
+ if pad_len > 0:
102
+ token_ids = np.concatenate([token_ids, [self.pad_token_id] * pad_len])
103
+ else:
104
+ token_ids = token_ids[:current_max_length]
105
+
106
+ if return_tensors == "pt":
107
+ return torch.tensor(token_ids).unsqueeze(0)
108
+ return token_ids.tolist() # type: ignore
109
+
110
+ def batch_encode_plus(
111
+ self,
112
+ batch_gene_expr: Union[np.ndarray, List[np.ndarray]],
113
+ pad_to_fixed_length: bool = False,
114
+ max_length: Optional[int] = None,
115
+ return_tensors: Optional[str] = None,
116
+ **kwargs,
117
+ ):
118
+ if isinstance(batch_gene_expr, list):
119
+ batch_gene_expr = np.array(batch_gene_expr)
120
+
121
+ encoded = [
122
+ self.encode(
123
+ gene_expr,
124
+ pad_to_fixed_length=pad_to_fixed_length,
125
+ max_length=max_length,
126
+ return_tensors=None,
127
+ **kwargs,
128
+ )
129
+ for gene_expr in batch_gene_expr
130
+ ]
131
+
132
+ encoded = np.array(encoded, dtype=np.int64)
133
+
134
+ if return_tensors == "pt":
135
+ return {"input_ids": torch.tensor(encoded)}
136
+ return {"input_ids": encoded}
137
+
138
+ @property
139
+ def vocab_size(self) -> int:
140
+ return len(self.vocab)
141
+
142
+ def save_vocabulary(
143
+ self, save_directory: str, filename_prefix: Optional[str] = None
144
+ ):
145
+ vocab_file = os.path.join(
146
+ save_directory,
147
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
148
+ )
149
+ with open(vocab_file, "w") as f:
150
+ json.dump(self.vocab, f)
151
+ return (vocab_file,)
152
+
153
+
154
+ class MOJOTokenizer(PreTrainedTokenizer):
155
+ def __init__(
156
+ self,
157
+ n_expressions_bins: dict[str, int],
158
+ min_omic_value: dict[str, float],
159
+ max_omic_value: dict[str, float],
160
+ use_max_normalization: dict[str, bool],
161
+ normalization_factor: dict[str, float],
162
+ prepend_cls_token: bool,
163
+ fixed_sequence_length: int,
164
+ unpadded_length: int,
165
+ **kwargs,
166
+ ):
167
+ self.omics = n_expressions_bins.keys()
168
+ self.omic_tokenizers = {
169
+ omic: BinnedOmicTokenizer(
170
+ n_expressions_bins=n_expressions_bins[omic],
171
+ min_omic_value=min_omic_value[omic],
172
+ max_omic_value=max_omic_value[omic],
173
+ use_max_normalization=use_max_normalization[omic],
174
+ normalization_factor=normalization_factor[omic],
175
+ prepend_cls_token=prepend_cls_token,
176
+ fixed_sequence_length=fixed_sequence_length,
177
+ unpadded_length=unpadded_length,
178
+ **kwargs,
179
+ )
180
+ for omic in n_expressions_bins.keys()
181
+ }
182
+
183
+ self.vocab = {omic: self.omic_tokenizers[omic].vocab for omic in self.omics}
184
+ self.ids_to_tokens = {
185
+ omic: self.omic_tokenizers[omic].ids_to_tokens for omic in self.omics
186
+ }
187
+
188
+ super().__init__(**kwargs)
189
+
190
+ def _convert_token_to_id(self, token: dict[str, str]) -> dict[str, int]:
191
+ return {
192
+ omic: self.vocab[omic].get(token[omic], self.vocab[omic][self.unk_token])
193
+ for omic in token
194
+ }
195
+
196
+ def _convert_id_to_token(self, index: dict[str, int]) -> dict[str, str]:
197
+ return {
198
+ omic: self.omic_tokenizers[omic]._convert_id_to_token(index[omic])
199
+ for omic in index
200
+ }
201
+
202
+ def get_vocab(self) -> dict:
203
+ return self.vocab
204
+
205
+ def _tokenize(self, text, **kwargs):
206
+ raise NotImplementedError("Use `encode` or `batch_encode_plus` methods.")
207
+
208
+ def decode(self, token_ids: dict[str, list[int]], **kwargs):
209
+ return {
210
+ omic: self.omic_tokenizers[omic].decode(token_ids[omic])
211
+ for omic in token_ids
212
+ }
213
+
214
+ def encode(
215
+ self,
216
+ omic_array: Union[dict[str, np.ndarray], dict[str, List[float]]],
217
+ pad_to_fixed_length: bool = False,
218
+ max_length: Optional[int] = None,
219
+ return_tensors: Optional[str] = None,
220
+ **kwargs,
221
+ ) -> Union[dict[str, List[int]], dict[str, torch.Tensor]]:
222
+ return {
223
+ omic: self.omic_tokenizers[omic].encode(
224
+ omic_array[omic],
225
+ pad_to_fixed_length=pad_to_fixed_length,
226
+ max_length=max_length,
227
+ return_tensors=return_tensors,
228
+ )
229
+ for omic in omic_array
230
+ }
231
+
232
+ def batch_encode_plus(
233
+ self,
234
+ batch_omic_array: Union[dict[str, np.ndarray], dict[str, List[np.ndarray]]],
235
+ pad_to_fixed_length: bool = False,
236
+ max_length: Optional[int] = None,
237
+ return_tensors: Optional[str] = None,
238
+ **kwargs,
239
+ ):
240
+ return {
241
+ omic: self.omic_tokenizers[omic].batch_encode_plus(
242
+ batch_omic_array[omic],
243
+ pad_to_fixed_length=pad_to_fixed_length,
244
+ max_length=max_length,
245
+ return_tensors=return_tensors,
246
+ )
247
+ for omic in batch_omic_array
248
+ }
249
+
250
+ @property
251
+ def vocab_size(self) -> int:
252
+ return sum(len(self.vocab[omic]) for omic in self.vocab)
253
+
254
+ def save_vocabulary(
255
+ self, save_directory: str, filename_prefix: Optional[str] = None
256
+ ):
257
+ vocab_files = []
258
+ for omic in self.omics:
259
+ vocab_file = os.path.join(
260
+ save_directory,
261
+ (filename_prefix + "-" if filename_prefix else "")
262
+ + f"vocab_{omic}.json",
263
+ )
264
+ with open(vocab_file, "w") as f:
265
+ json.dump(self.vocab[omic], f)
266
+ vocab_files.append(vocab_file)
267
+ return tuple(vocab_files)
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenizer.MOJOTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "clean_up_tokenization_spaces": true,
10
+ "model_max_length": 1000000000000000019884624838656,
11
+ "tokenizer_class": "MOJOTokenizer"
12
+ }
vocab_methylation.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "10": 10, "11": 11, "12": 12, "13": 13, "14": 14, "15": 15, "16": 16, "17": 17, "18": 18, "19": 19, "20": 20, "21": 21, "22": 22, "23": 23, "24": 24, "25": 25, "26": 26, "27": 27, "28": 28, "29": 29, "30": 30, "31": 31, "32": 32, "33": 33, "34": 34, "35": 35, "36": 36, "37": 37, "38": 38, "39": 39, "40": 40, "41": 41, "42": 42, "43": 43, "44": 44, "45": 45, "46": 46, "47": 47, "48": 48, "49": 49, "50": 50, "51": 51, "52": 52, "53": 53, "54": 54, "55": 55, "56": 56, "57": 57, "58": 58, "59": 59, "60": 60, "61": 61, "62": 62, "63": 63, "<pad>": 64, "<mask>": 65, "<cls>": 66}
vocab_rnaseq.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "10": 10, "11": 11, "12": 12, "13": 13, "14": 14, "15": 15, "16": 16, "17": 17, "18": 18, "19": 19, "20": 20, "21": 21, "22": 22, "23": 23, "24": 24, "25": 25, "26": 26, "27": 27, "28": 28, "29": 29, "30": 30, "31": 31, "32": 32, "33": 33, "34": 34, "35": 35, "36": 36, "37": 37, "38": 38, "39": 39, "40": 40, "41": 41, "42": 42, "43": 43, "44": 44, "45": 45, "46": 46, "47": 47, "48": 48, "49": 49, "50": 50, "51": 51, "52": 52, "53": 53, "54": 54, "55": 55, "56": 56, "57": 57, "58": 58, "59": 59, "60": 60, "61": 61, "62": 62, "63": 63, "<pad>": 64, "<mask>": 65, "<cls>": 66}