iioSnail commited on
Commit
7436a15
1 Parent(s): 68fbe9c

Upload 8 files

Browse files
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ReaLiseForCSC"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "auto_map": {
7
+ "AutoModel": "csc_model.ReaLiseForCSC"
8
+ },
9
+ "classifier_dropout": null,
10
+ "directionality": "bidi",
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 768,
14
+ "image_model_type": 0,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "layer_norm_eps": 1e-12,
18
+ "max_position_embeddings": 512,
19
+ "model_type": "bert",
20
+ "num_attention_heads": 12,
21
+ "num_fonts": 3,
22
+ "num_hidden_layers": 12,
23
+ "output_past": true,
24
+ "pad_token_id": 0,
25
+ "pooler_fc_size": 768,
26
+ "pooler_num_attention_heads": 12,
27
+ "pooler_num_fc_layers": 3,
28
+ "pooler_size_per_head": 128,
29
+ "pooler_type": "first_token_transform",
30
+ "position_embedding_type": "absolute",
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.33.2",
33
+ "type_vocab_size": 2,
34
+ "use_cache": true,
35
+ "vocab_size": 21128,
36
+ "vocab_size_or_config_json_file": 21128
37
+ }
csc_model.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from copy import deepcopy
3
+
4
+ import numpy as np
5
+ import opencc
6
+ import pypinyin
7
+ import torch
8
+ from PIL import ImageFont
9
+ from torch import nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from transformers.modeling_outputs import MaskedLMOutput
12
+
13
+ from transformers import BertPreTrainedModel, BertModel
14
+
15
+
16
+ def _is_chinese_char(cp):
17
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
18
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
19
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
20
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
21
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
22
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
23
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
24
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
25
+ return True
26
+ return False
27
+
28
+
29
+ class Pinyin2(object):
30
+ def __init__(self):
31
+ super(Pinyin2, self).__init__()
32
+ pho_vocab = ['P']
33
+ pho_vocab += [chr(x) for x in range(ord('1'), ord('5') + 1)]
34
+ pho_vocab += [chr(x) for x in range(ord('a'), ord('z') + 1)]
35
+ pho_vocab += ['U']
36
+ assert len(pho_vocab) == 33
37
+ self.pho_vocab_size = len(pho_vocab)
38
+ self.pho_vocab = {c: idx for idx, c in enumerate(pho_vocab)}
39
+
40
+ def get_pho_size(self):
41
+ return self.pho_vocab_size
42
+
43
+ @staticmethod
44
+ def get_pinyin(c):
45
+ if len(c) > 1:
46
+ return 'U'
47
+ s = pypinyin.pinyin(
48
+ c,
49
+ style=pypinyin.Style.TONE3,
50
+ neutral_tone_with_five=True,
51
+ errors=lambda x: ['U' for _ in x],
52
+ )[0][0]
53
+ if s == 'U':
54
+ return s
55
+ assert isinstance(s, str)
56
+ assert s[-1] in '12345'
57
+ s = s[-1] + s[:-1]
58
+ return s
59
+
60
+ def convert(self, chars):
61
+ pinyins = list(map(self.get_pinyin, chars))
62
+ pinyin_ids = [list(map(self.pho_vocab.get, pinyin)) for pinyin in pinyins]
63
+ pinyin_lens = [len(pinyin) for pinyin in pinyins]
64
+ pinyin_ids = torch.nn.utils.rnn.pad_sequence(
65
+ [torch.tensor(x) for x in pinyin_ids],
66
+ batch_first=True,
67
+ padding_value=0,
68
+ )
69
+ return pinyin_ids, pinyin_lens
70
+
71
+
72
+ pho2_convertor = Pinyin2()
73
+
74
+
75
+ class CharResNet(torch.nn.Module):
76
+
77
+ def __init__(self, in_channels=1):
78
+ super().__init__()
79
+ # input_image: bxcx32x32, output_image: bx768x1x1
80
+ self.res_block1 = BasicBlock(in_channels, 64, stride=2) # channels: 64, size: 16x16
81
+ self.res_block2 = BasicBlock(64, 128, stride=2) # channels: 128, size: 8x8
82
+ self.res_block3 = BasicBlock(128, 256, stride=2) # channels: 256, size: 4x4
83
+ self.res_block4 = BasicBlock(256, 512, stride=2) # channels: 512, size: 2x2
84
+ self.res_block5 = BasicBlock(512, 768, stride=2) # channels: 768, size: 1x1
85
+
86
+ def forward(self, x):
87
+ # input_shape: bxcx32x32, output_image: bx768
88
+ # x = x.unsqueeze(1)
89
+ h = self.res_block1(x)
90
+ h = self.res_block2(h)
91
+ h = self.res_block3(h)
92
+ h = self.res_block4(h)
93
+ h = self.res_block5(h)
94
+ h = h.squeeze(-1).squeeze(-1)
95
+ return h
96
+
97
+
98
+ class BasicBlock(nn.Module):
99
+ expansion = 1
100
+
101
+ def __init__(self, in_channels, out_channels, stride=1):
102
+ super().__init__()
103
+
104
+ self.residual_function = nn.Sequential(
105
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
106
+ nn.BatchNorm2d(out_channels),
107
+ nn.ReLU(inplace=True),
108
+ nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
109
+ nn.BatchNorm2d(out_channels * BasicBlock.expansion)
110
+ )
111
+
112
+ self.shortcut = nn.Sequential()
113
+
114
+ if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
115
+ self.shortcut = nn.Sequential(
116
+ nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
117
+ nn.BatchNorm2d(out_channels * BasicBlock.expansion)
118
+ )
119
+
120
+ def forward(self, x):
121
+ return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
122
+
123
+
124
+ class CharResNet1(torch.nn.Module):
125
+
126
+ def __init__(self, in_channels=1):
127
+ super().__init__()
128
+ self.res_block1 = BasicBlock(in_channels, 64, stride=2) # channels: 64, size: 16x16
129
+ self.res_block2 = BasicBlock(64, 128, stride=2) # channels: 128, size: 8x8
130
+ self.res_block3 = BasicBlock(128, 192, stride=2) # channels: 256, size: 4x4
131
+ self.res_block4 = BasicBlock(192, 192, stride=2)
132
+
133
+ def forward(self, x):
134
+ # input_shape: bxcx32x32, output_shape: bx128x8x8
135
+ h = x
136
+ h = self.res_block1(h)
137
+ h = self.res_block2(h)
138
+ h = self.res_block3(h)
139
+ h = self.res_block4(h)
140
+ h = h.view(h.shape[0], -1)
141
+ return h
142
+
143
+
144
+ class ReaLiseForCSC(BertPreTrainedModel):
145
+
146
+ def __init__(self, config):
147
+ super(ReaLiseForCSC, self).__init__(config)
148
+ self.config = config
149
+
150
+ self.vocab_size = config.vocab_size
151
+ self.bert = BertModel(config)
152
+
153
+ self.pho_embeddings = nn.Embedding(pho2_convertor.get_pho_size(), config.hidden_size, padding_idx=0)
154
+ self.pho_gru = nn.GRU(
155
+ input_size=config.hidden_size,
156
+ hidden_size=config.hidden_size,
157
+ num_layers=1,
158
+ batch_first=True,
159
+ dropout=0,
160
+ bidirectional=False,
161
+ )
162
+ pho_config = deepcopy(config)
163
+ pho_config.num_hidden_layers = 4
164
+ self.pho_model = BertModel(pho_config)
165
+
166
+ self.char_images_multifonts = torch.nn.Parameter(torch.rand(21128, 3, 32, 32))
167
+ self.char_images_multifonts.requires_grad = False
168
+
169
+ self.resnet = CharResNet(in_channels=3)
170
+ self.resnet_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
171
+
172
+ self.gate_net = nn.Linear(4 * config.hidden_size, 3)
173
+
174
+ out_config = deepcopy(config)
175
+ out_config.num_hidden_layers = 3
176
+ self.output_block = BertModel(out_config)
177
+
178
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
179
+ self.classifier = nn.Linear(config.hidden_size, config.vocab_size)
180
+
181
+ self.init_weights()
182
+
183
+ self.loss_fnt = CrossEntropyLoss(ignore_index=0)
184
+
185
+ self.tokenizer = None
186
+
187
+ def tie_cls_weight(self):
188
+ self.classifier.weight = self.bert.embeddings.word_embeddings.weight
189
+
190
+ def build_glyce_embed(self, vocab_dir, font_path, font_size=32):
191
+ vocab_path = os.path.join(vocab_dir, 'vocab.txt')
192
+ with open(vocab_path, 'r', encoding='utf-8') as f:
193
+ vocab = [s.strip() for s in f]
194
+
195
+ font = ImageFont.truetype(font_path, size=font_size)
196
+
197
+ char_images = []
198
+ for char in vocab:
199
+ if len(char) != 1 or (not _is_chinese_char(ord(char))):
200
+ char_images.append(np.zeros((font_size, font_size)).astype(np.float32))
201
+ continue
202
+ image = font.getmask(char)
203
+ image = np.asarray(image).astype(np.float32).reshape(image.size[::-1]) # Must be [::-1]
204
+
205
+ # Crop
206
+ image = image[:font_size, :font_size]
207
+
208
+ # Pad
209
+ if image.size != (font_size, font_size):
210
+ back_image = np.zeros((font_size, font_size)).astype(np.float32)
211
+ offset0 = (font_size - image.shape[0]) // 2
212
+ offset1 = (font_size - image.shape[1]) // 2
213
+ back_image[offset0:offset0 + image.shape[0], offset1:offset1 + image.shape[1]] = image
214
+ image = back_image
215
+
216
+ char_images.append(image)
217
+ char_images = np.array(char_images)
218
+ char_images = (char_images - np.mean(char_images)) / np.std(char_images)
219
+ char_images = torch.from_numpy(char_images).reshape(char_images.shape[0], -1)
220
+ assert char_images.shape == (21128, 1024)
221
+ self.char_images.weight.data.copy_(char_images)
222
+
223
+ # Add by hengdaxu
224
+ def build_glyce_embed_multifonts(self, vocab_dir, num_fonts, use_traditional_font, font_size=32):
225
+ font_paths = [
226
+ ('simhei.ttf', False),
227
+ ('xiaozhuan.ttf', False),
228
+ ('simhei.ttf', True),
229
+ ]
230
+ font_paths = font_paths[:num_fonts]
231
+ if use_traditional_font:
232
+ font_paths = font_paths[:-1]
233
+ font_paths.append(('simhei.ttf', True))
234
+ self.converter = opencc.OpenCC('s2t.json')
235
+
236
+ images_list = []
237
+ for font_path, use_traditional in font_paths:
238
+ images = self.build_glyce_embed_onefont(
239
+ vocab_dir=vocab_dir,
240
+ font_path=font_path,
241
+ font_size=font_size,
242
+ use_traditional=use_traditional,
243
+ )
244
+ images_list.append(images)
245
+
246
+ char_images = torch.stack(images_list, dim=1).contiguous()
247
+ self.char_images_multifonts.data.copy_(char_images)
248
+
249
+ # Add by hengdaxu
250
+ def build_glyce_embed_onefont(self, vocab_dir, font_path, font_size, use_traditional):
251
+ vocab_path = os.path.join(vocab_dir, 'vocab.txt')
252
+ with open(vocab_path, encoding='utf-8') as f:
253
+ vocab = [s.strip() for s in f.readlines()]
254
+ if use_traditional:
255
+ vocab = [self.converter.convert(c) if len(c) == 1 else c for c in vocab]
256
+
257
+ font = ImageFont.truetype(font_path, size=font_size)
258
+
259
+ char_images = []
260
+ for char in vocab:
261
+ if len(char) > 1:
262
+ char_images.append(np.zeros((font_size, font_size)).astype(np.float32))
263
+ continue
264
+ image = font.getmask(char)
265
+ image = np.asarray(image).astype(np.float32).reshape(image.size[::-1]) # Must be [::-1]
266
+
267
+ # Crop
268
+ image = image[:font_size, :font_size]
269
+
270
+ # Pad
271
+ if image.size != (font_size, font_size):
272
+ back_image = np.zeros((font_size, font_size)).astype(np.float32)
273
+ offset0 = (font_size - image.shape[0]) // 2
274
+ offset1 = (font_size - image.shape[1]) // 2
275
+ back_image[offset0:offset0 + image.shape[0], offset1:offset1 + image.shape[1]] = image
276
+ image = back_image
277
+
278
+ char_images.append(image)
279
+ char_images = np.array(char_images)
280
+ char_images = (char_images - np.mean(char_images)) / np.std(char_images)
281
+ char_images = torch.from_numpy(char_images).contiguous()
282
+ return char_images
283
+
284
+ @staticmethod
285
+ def build_batch(batch, tokenizer):
286
+ src_idx = batch['src_idx'].flatten().tolist()
287
+ chars = tokenizer.convert_ids_to_tokens(src_idx)
288
+ pho_idx, pho_lens = pho2_convertor.convert(chars)
289
+ batch['pho_idx'] = pho_idx
290
+ batch['pho_lens'] = pho_lens
291
+ return batch
292
+
293
+ def forward(self,
294
+ input_ids=None,
295
+ pho_idx=None,
296
+ pho_lens=None,
297
+ attention_mask=None,
298
+ labels=None,
299
+ **kwargs):
300
+ input_shape = input_ids.size()
301
+
302
+ bert_hiddens = self.bert(input_ids, attention_mask=attention_mask)[0]
303
+
304
+ pho_embeddings = self.pho_embeddings(pho_idx)
305
+ pho_embeddings = torch.nn.utils.rnn.pack_padded_sequence(
306
+ input=pho_embeddings,
307
+ lengths=pho_lens,
308
+ batch_first=True,
309
+ enforce_sorted=False,
310
+ )
311
+ _, pho_hiddens = self.pho_gru(pho_embeddings)
312
+ pho_hiddens = pho_hiddens.squeeze(0).reshape(input_shape[0], input_shape[1], -1).contiguous()
313
+ pho_hiddens = self.pho_model(inputs_embeds=pho_hiddens, attention_mask=attention_mask)[0]
314
+
315
+ src_idxs = input_ids.view(-1)
316
+
317
+ if self.config.num_fonts == 1:
318
+ images = self.char_images(src_idxs).reshape(src_idxs.shape[0], 1, 32, 32).contiguous()
319
+ else:
320
+ images = self.char_images_multifonts.index_select(dim=0, index=src_idxs)
321
+
322
+ res_hiddens = self.resnet(images)
323
+ res_hiddens = res_hiddens.reshape(input_shape[0], input_shape[1], -1).contiguous()
324
+ res_hiddens = self.resnet_layernorm(res_hiddens)
325
+
326
+ bert_hiddens_mean = (bert_hiddens * attention_mask.to(torch.float).unsqueeze(2)).sum(dim=1) / attention_mask.to(
327
+ torch.float).sum(dim=1, keepdim=True)
328
+ bert_hiddens_mean = bert_hiddens_mean.unsqueeze(1).expand(-1, bert_hiddens.size(1), -1)
329
+
330
+ concated_outputs = torch.cat((bert_hiddens, pho_hiddens, res_hiddens, bert_hiddens_mean), dim=-1)
331
+ gated_values = self.gate_net(concated_outputs)
332
+ # B * S * 3
333
+ g0 = torch.sigmoid(gated_values[:, :, 0].unsqueeze(-1))
334
+ g1 = torch.sigmoid(gated_values[:, :, 1].unsqueeze(-1))
335
+ g2 = torch.sigmoid(gated_values[:, :, 2].unsqueeze(-1))
336
+
337
+ hiddens = g0 * bert_hiddens + g1 * pho_hiddens + g2 * res_hiddens
338
+
339
+ outputs = self.output_block(inputs_embeds=hiddens,
340
+ position_ids=torch.zeros(input_ids.size(), dtype=torch.long,
341
+ device=input_ids.device),
342
+ attention_mask=attention_mask)
343
+
344
+ sequence_output = outputs[0]
345
+
346
+ sequence_output = self.dropout(sequence_output)
347
+ logits = self.classifier(sequence_output)
348
+
349
+ outputs = MaskedLMOutput(
350
+ logits=logits,
351
+ hidden_states=outputs.last_hidden_state,
352
+ )
353
+
354
+ if labels is not None:
355
+ # Only keep active parts of the loss
356
+ labels[labels == 101] = 0
357
+ labels[labels == 102] = 0
358
+ loss = self.loss_fnt(logits.view(-1, logits.size(-1)), labels.view(-1))
359
+ outputs.loss = loss
360
+
361
+ return outputs
362
+
363
+ def set_tokenizer(self, tokenizer):
364
+ self.tokenizer = tokenizer
365
+
366
+ def predict(self, sentences):
367
+ if self.tokenizer is None:
368
+ raise RuntimeError("Please init tokenizer by `set_tokenizer(tokenizer)` before predict.")
369
+
370
+ str_flag = False
371
+ if type(sentences) == str:
372
+ sentences = [sentences]
373
+ str_flag = True
374
+
375
+ inputs = self.tokenizer(sentences, padding=True, return_tensors="pt")
376
+ outputs = self.forward(**inputs).logits
377
+
378
+ ids_list = outputs.argmax(-1)
379
+
380
+ preds = []
381
+ for i, ids in enumerate(ids_list):
382
+ ids = ids[inputs['attention_mask'][i].bool()]
383
+ pred_tokens = self.tokenizer.convert_ids_to_tokens(ids)
384
+ pred_tokens = [t if not t.startswith('##') else t[2:] for t in pred_tokens]
385
+ pred_tokens = [t if t != self.tokenizer.unk_token else '×' for t in pred_tokens]
386
+
387
+ offsets = inputs[i].offsets
388
+ src_tokens = list(sentences[i])
389
+ for (start, end), pred_token in zip(offsets, pred_tokens):
390
+ if end - start <= 0:
391
+ continue
392
+
393
+ if (end - start) != len(pred_token):
394
+ continue
395
+
396
+ if pred_token == '×':
397
+ continue
398
+
399
+ if (end - start) == 1 and not _is_chinese_char(ord(src_tokens[start])):
400
+ continue
401
+
402
+ src_tokens[start:end] = pred_token
403
+
404
+ pred = ''.join(src_tokens)
405
+ preds.append(pred)
406
+
407
+ if str_flag:
408
+ return preds[0]
409
+
410
+ return preds
csc_tokenizer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Optional
2
+
3
+ import pypinyin
4
+ import torch
5
+ from torch import NoneType
6
+
7
+ from transformers import BertTokenizerFast
8
+
9
+
10
+ class Pinyin2(object):
11
+ def __init__(self):
12
+ super(Pinyin2, self).__init__()
13
+ pho_vocab = ['P']
14
+ pho_vocab += [chr(x) for x in range(ord('1'), ord('5') + 1)]
15
+ pho_vocab += [chr(x) for x in range(ord('a'), ord('z') + 1)]
16
+ pho_vocab += ['U']
17
+ assert len(pho_vocab) == 33
18
+ self.pho_vocab_size = len(pho_vocab)
19
+ self.pho_vocab = {c: idx for idx, c in enumerate(pho_vocab)}
20
+
21
+ def get_pho_size(self):
22
+ return self.pho_vocab_size
23
+
24
+ @staticmethod
25
+ def get_pinyin(c):
26
+ if len(c) > 1:
27
+ return 'U'
28
+ s = pypinyin.pinyin(
29
+ c,
30
+ style=pypinyin.Style.TONE3,
31
+ neutral_tone_with_five=True,
32
+ errors=lambda x: ['U' for _ in x],
33
+ )[0][0]
34
+ if s == 'U':
35
+ return s
36
+ assert isinstance(s, str)
37
+ assert s[-1] in '12345'
38
+ s = s[-1] + s[:-1]
39
+ return s
40
+
41
+ def convert(self, chars):
42
+ pinyins = list(map(self.get_pinyin, chars))
43
+ pinyin_ids = [list(map(self.pho_vocab.get, pinyin)) for pinyin in pinyins]
44
+ pinyin_lens = [len(pinyin) for pinyin in pinyins]
45
+ pinyin_ids = torch.nn.utils.rnn.pad_sequence(
46
+ [torch.tensor(x) for x in pinyin_ids],
47
+ batch_first=True,
48
+ padding_value=0,
49
+ )
50
+ return pinyin_ids, pinyin_lens
51
+
52
+
53
+ class ReaLiSeTokenizer(BertTokenizerFast):
54
+
55
+ def __init__(self, **kwargs):
56
+ super(ReaLiSeTokenizer, self).__init__(**kwargs)
57
+
58
+ self.pho2_convertor = Pinyin2()
59
+
60
+ def __call__(self,
61
+ text: Union[str, List[str], List[List[str]]] = None,
62
+ text_pair: Union[str, List[str], List[List[str]], NoneType] = None,
63
+ text_target: Union[str, List[str], List[List[str]]] = None,
64
+ text_pair_target: Union[str, List[str], List[List[str]], NoneType] = None,
65
+ add_special_tokens: bool = True,
66
+ padding=False,
67
+ truncation=None,
68
+ max_length: Optional[int] = None,
69
+ stride: int = 0,
70
+ is_split_into_words: bool = False,
71
+ pad_to_multiple_of: Optional[int] = None,
72
+ return_tensors=None,
73
+ return_token_type_ids: Optional[bool] = None,
74
+ return_attention_mask: Optional[bool] = None,
75
+ return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False,
76
+ return_offsets_mapping: bool = False,
77
+ return_length: bool = False,
78
+ verbose: bool = True, **kwargs):
79
+ encoding = super(ReaLiSeTokenizer, self).__call__(
80
+ text=text,
81
+ text_pair=text_pair,
82
+ text_target=text_target,
83
+ text_pair_target=text_pair_target,
84
+ add_special_tokens=add_special_tokens,
85
+ padding=padding,
86
+ truncation=truncation,
87
+ max_length=max_length,
88
+ stride=stride,
89
+ is_split_into_words=is_split_into_words,
90
+ pad_to_multiple_of=pad_to_multiple_of,
91
+ return_tensors=return_tensors,
92
+ return_token_type_ids=return_token_type_ids,
93
+ return_attention_mask=return_attention_mask,
94
+ return_overflowing_tokens=return_overflowing_tokens,
95
+ return_offsets_mapping=return_offsets_mapping,
96
+ return_length=return_length,
97
+ verbose=verbose,
98
+ )
99
+
100
+ input_ids = encoding['input_ids']
101
+ if type(text) == str and return_tensors is None:
102
+ input_ids = [input_ids]
103
+
104
+ pho_idx_list = []
105
+ pho_lens_list = []
106
+ for ids in input_ids:
107
+ chars = self.convert_ids_to_tokens(ids)
108
+ pho_idx, pho_lens = self.pho2_convertor.convert(chars)
109
+ if return_tensors is None:
110
+ pho_idx = pho_idx.tolist()
111
+ pho_idx_list.append(pho_idx)
112
+ pho_lens_list += pho_lens
113
+
114
+ pho_idx = pho_idx_list
115
+ pho_lens = pho_lens_list
116
+ if return_tensors == 'pt':
117
+ pho_idx = torch.vstack(pho_idx)
118
+ pho_lens = torch.LongTensor(pho_lens)
119
+
120
+ if type(text) == str and return_tensors is None:
121
+ pho_idx = pho_idx[0]
122
+
123
+ encoding['pho_idx'] = pho_idx
124
+ encoding['pho_lens'] = pho_lens
125
+
126
+ return encoding
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2c1d619e580891dffbe99d430cc116a82389433ab2a7ed15f8b43088135cf8d
3
+ size 1140770411
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "csc_tokenizer.ReaLiSeTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "clean_up_tokenization_spaces": true,
9
+ "cls_token": "[CLS]",
10
+ "do_basic_tokenize": true,
11
+ "do_lower_case": false,
12
+ "mask_token": "[MASK]",
13
+ "model_max_length": 1000000000000000019884624838656,
14
+ "never_split": null,
15
+ "pad_token": "[PAD]",
16
+ "sep_token": "[SEP]",
17
+ "strip_accents": null,
18
+ "tokenize_chinese_chars": true,
19
+ "tokenizer_class": "ReaLiSeTokenizer",
20
+ "unk_token": "[UNK]"
21
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff