jaygala24 commited on
Commit
a067af2
·
verified ·
1 Parent(s): 6516479

Create tokenization_indictrans.py

Browse files
Files changed (1) hide show
  1. tokenization_indictrans.py +239 -0
tokenization_indictrans.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from typing import Dict, List, Optional, Union, Tuple
5
+
6
+ from transformers.utils import logging
7
+ from sentencepiece import SentencePieceProcessor
8
+ from transformers.tokenization_utils import PreTrainedTokenizer
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ SPIECE_UNDERLINE = "▁"
14
+ SUPPORTED_LANGUAGES = [
15
+ "asm_Beng",
16
+ "awa_Deva",
17
+ "ben_Beng",
18
+ "bho_Deva",
19
+ "brx_Deva",
20
+ "doi_Deva",
21
+ "eng_Latn",
22
+ "gom_Deva",
23
+ "gon_Deva",
24
+ "guj_Gujr",
25
+ "hin_Deva",
26
+ "hne_Deva",
27
+ "kan_Knda",
28
+ "kas_Arab",
29
+ "kas_Deva",
30
+ "kha_Latn",
31
+ "lus_Latn",
32
+ "mag_Deva",
33
+ "mai_Deva",
34
+ "mal_Mlym",
35
+ "mar_Deva",
36
+ "mni_Beng",
37
+ "mni_Mtei",
38
+ "npi_Deva",
39
+ "ory_Orya",
40
+ "pan_Guru",
41
+ "san_Deva",
42
+ "sat_Olck",
43
+ "snd_Arab",
44
+ "snd_Deva",
45
+ "tam_Taml",
46
+ "tel_Telu",
47
+ "urd_Arab",
48
+ "unr_Deva",
49
+ ]
50
+
51
+ VOCAB_FILES_NAMES = {
52
+ "src_vocab_fp": "dict.SRC.json",
53
+ "tgt_vocab_fp": "dict.TGT.json",
54
+ "src_spm_fp": "model.SRC",
55
+ "tgt_spm_fp": "model.TGT",
56
+ }
57
+
58
+
59
+ class IndicTransTokenizer(PreTrainedTokenizer):
60
+ _added_tokens_encoder = {}
61
+ _added_tokens_decoder = {}
62
+
63
+ vocab_files_names = VOCAB_FILES_NAMES
64
+ model_input_names = ["input_ids", "attention_mask"]
65
+
66
+ def __init__(
67
+ self,
68
+ src_vocab_fp=None,
69
+ tgt_vocab_fp=None,
70
+ src_spm_fp=None,
71
+ tgt_spm_fp=None,
72
+ unk_token="<unk>",
73
+ bos_token="<s>",
74
+ eos_token="</s>",
75
+ pad_token="<pad>",
76
+ do_lower_case=False,
77
+ **kwargs
78
+ ):
79
+
80
+ self.src = True
81
+
82
+ self.src_vocab_fp = src_vocab_fp
83
+ self.tgt_vocab_fp = tgt_vocab_fp
84
+ self.src_spm_fp = src_spm_fp
85
+ self.tgt_spm_fp = tgt_spm_fp
86
+
87
+ self.unk_token = unk_token
88
+ self.pad_token = pad_token
89
+ self.eos_token = eos_token
90
+ self.bos_token = bos_token
91
+
92
+ self.encoder = self._load_json(self.src_vocab_fp)
93
+ if self.unk_token not in self.encoder:
94
+ raise KeyError("<unk> token must be in vocab")
95
+ assert self.pad_token in self.encoder
96
+ self.encoder_rev = {v: k for k, v in self.encoder.items()}
97
+
98
+ self.decoder = self._load_json(self.tgt_vocab_fp)
99
+ if self.unk_token not in self.encoder:
100
+ raise KeyError("<unk> token must be in vocab")
101
+ assert self.pad_token in self.encoder
102
+ self.decoder_rev = {v: k for k, v in self.decoder.items()}
103
+
104
+ # load SentencePiece model for pre-processing
105
+ self.src_spm = self._load_spm(self.src_spm_fp)
106
+ self.tgt_spm = self._load_spm(self.tgt_spm_fp)
107
+
108
+ self.current_spm = self.src_spm
109
+ self.current_encoder = self.encoder
110
+ self.current_encoder_rev = self.encoder_rev
111
+
112
+ self.unk_token_id = self.encoder[self.unk_token]
113
+ self.pad_token_id = self.encoder[self.pad_token]
114
+ self.eos_token_id = self.encoder[self.eos_token]
115
+ self.bos_token_id = self.encoder[self.bos_token]
116
+
117
+ super().__init__(
118
+ src_vocab_file=self.src_vocab_fp,
119
+ tgt_vocab_file=self.src_vocab_fp,
120
+ do_lower_case=do_lower_case,
121
+ unk_token=unk_token,
122
+ bos_token=bos_token,
123
+ eos_token=eos_token,
124
+ pad_token=pad_token,
125
+ **kwargs,
126
+ )
127
+
128
+ def _switch_to_input_mode(self):
129
+ self.src = True
130
+ self.padding_side = "left"
131
+ self.current_spm = self.src_spm
132
+ self.current_encoder = self.encoder
133
+ self.current_encoder_rev = self.encoder_rev
134
+
135
+ def _switch_to_target_mode(self):
136
+ self.src = False
137
+ self.padding_side = "right"
138
+ self.current_spm = self.tgt_spm
139
+ self.current_encoder = self.decoder
140
+ self.current_encoder_rev = self.decoder_rev
141
+
142
+ def _load_spm(self, path: str) -> SentencePieceProcessor:
143
+ return SentencePieceProcessor(model_file=path)
144
+
145
+ def _save_json(self, data, path: str) -> None:
146
+ with open(path, "w", encoding="utf-8") as f:
147
+ json.dump(data, f, indent=2)
148
+
149
+ def _load_json(self, path: str) -> Union[Dict, List]:
150
+ with open(path, "r", encoding="utf-8") as f:
151
+ return json.load(f)
152
+
153
+ @property
154
+ def src_vocab_size(self) -> int:
155
+ return len(self.encoder)
156
+
157
+ @property
158
+ def tgt_vocab_size(self) -> int:
159
+ return len(self.decoder)
160
+
161
+ def get_src_vocab(self) -> Dict[str, int]:
162
+ return dict(self.encoder, **self.added_tokens_encoder)
163
+
164
+ def get_tgt_vocab(self) -> Dict[str, int]:
165
+ return dict(self.decoder, **self.added_tokens_decoder)
166
+
167
+ # hack override
168
+ def get_vocab(self) -> Dict[str, int]:
169
+ return self.get_src_vocab()
170
+
171
+ # hack override
172
+ @property
173
+ def vocab_size(self) -> int:
174
+ return self.src_vocab_size
175
+
176
+ def _convert_token_to_id(self, token: str) -> int:
177
+ """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
178
+ return self.current_encoder.get(token, self.current_encoder[self.unk_token])
179
+
180
+ def _convert_id_to_token(self, index: int) -> str:
181
+ """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
182
+ return self.current_encoder_rev.get(index, self.unk_token)
183
+
184
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
185
+ """Uses sentencepiece model for detokenization"""
186
+ pad_tokens = [token for token in tokens if token == self.pad_token]
187
+ tokens = [token for token in tokens if token != self.pad_token]
188
+ if self.src:
189
+ return (
190
+ " ".join(pad_tokens)
191
+ + " "
192
+ + " ".join(tokens[:2])
193
+ + " "
194
+ + "".join(tokens[2:]).replace(SPIECE_UNDERLINE, " ").strip()
195
+ )
196
+ return (
197
+ "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
198
+ + " "
199
+ + " ".join(pad_tokens)
200
+ )
201
+
202
+ def _tokenize(self, text) -> List[str]:
203
+ if self.src:
204
+ tokens = text.split(" ")
205
+ tags = tokens[:2]
206
+ text = " ".join(tokens[2:])
207
+ tokens = self.current_spm.EncodeAsPieces(text)
208
+ return tags + tokens
209
+ else:
210
+ return self.current_spm.EncodeAsPieces(text)
211
+
212
+ def build_inputs_with_special_tokens(
213
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
214
+ ) -> List[int]:
215
+ if token_ids_1 is None:
216
+ return token_ids_0 + [self.eos_token_id]
217
+ # We don't expect to process pairs, but leave the pair logic for API consistency
218
+ return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
219
+
220
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
221
+ if not os.path.isdir(save_directory):
222
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
223
+ return
224
+
225
+ src_spm_fp = os.path.join(save_directory, "model.SRC")
226
+ tgt_spm_fp = os.path.join(save_directory, "model.TGT")
227
+ src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
228
+ tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
229
+
230
+ self._save_json(self.encoder, src_vocab_fp)
231
+ self._save_json(self.decoder, tgt_vocab_fp)
232
+
233
+ with open(src_spm_fp, 'wb') as f:
234
+ f.write(self.src_spm.serialized_model_proto())
235
+
236
+ with open(tgt_spm_fp, 'wb') as f:
237
+ f.write(self.tgt_spm.serialized_model_proto())
238
+
239
+ return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp