Upload 8 files
Browse files- config.json +37 -0
- csc_model.py +410 -0
- csc_tokenizer.py +126 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +7 -0
- tokenizer.json +0 -0
- tokenizer_config.json +21 -0
- vocab.txt +0 -0
config.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"ReaLiseForCSC"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"auto_map": {
|
7 |
+
"AutoModel": "csc_model.ReaLiseForCSC"
|
8 |
+
},
|
9 |
+
"classifier_dropout": null,
|
10 |
+
"directionality": "bidi",
|
11 |
+
"hidden_act": "gelu",
|
12 |
+
"hidden_dropout_prob": 0.1,
|
13 |
+
"hidden_size": 768,
|
14 |
+
"image_model_type": 0,
|
15 |
+
"initializer_range": 0.02,
|
16 |
+
"intermediate_size": 3072,
|
17 |
+
"layer_norm_eps": 1e-12,
|
18 |
+
"max_position_embeddings": 512,
|
19 |
+
"model_type": "bert",
|
20 |
+
"num_attention_heads": 12,
|
21 |
+
"num_fonts": 3,
|
22 |
+
"num_hidden_layers": 12,
|
23 |
+
"output_past": true,
|
24 |
+
"pad_token_id": 0,
|
25 |
+
"pooler_fc_size": 768,
|
26 |
+
"pooler_num_attention_heads": 12,
|
27 |
+
"pooler_num_fc_layers": 3,
|
28 |
+
"pooler_size_per_head": 128,
|
29 |
+
"pooler_type": "first_token_transform",
|
30 |
+
"position_embedding_type": "absolute",
|
31 |
+
"torch_dtype": "float32",
|
32 |
+
"transformers_version": "4.33.2",
|
33 |
+
"type_vocab_size": 2,
|
34 |
+
"use_cache": true,
|
35 |
+
"vocab_size": 21128,
|
36 |
+
"vocab_size_or_config_json_file": 21128
|
37 |
+
}
|
csc_model.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from copy import deepcopy
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import opencc
|
6 |
+
import pypinyin
|
7 |
+
import torch
|
8 |
+
from PIL import ImageFont
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import CrossEntropyLoss
|
11 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
12 |
+
|
13 |
+
from transformers import BertPreTrainedModel, BertModel
|
14 |
+
|
15 |
+
|
16 |
+
def _is_chinese_char(cp):
|
17 |
+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
18 |
+
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
19 |
+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
20 |
+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
21 |
+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
22 |
+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
23 |
+
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
24 |
+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
25 |
+
return True
|
26 |
+
return False
|
27 |
+
|
28 |
+
|
29 |
+
class Pinyin2(object):
|
30 |
+
def __init__(self):
|
31 |
+
super(Pinyin2, self).__init__()
|
32 |
+
pho_vocab = ['P']
|
33 |
+
pho_vocab += [chr(x) for x in range(ord('1'), ord('5') + 1)]
|
34 |
+
pho_vocab += [chr(x) for x in range(ord('a'), ord('z') + 1)]
|
35 |
+
pho_vocab += ['U']
|
36 |
+
assert len(pho_vocab) == 33
|
37 |
+
self.pho_vocab_size = len(pho_vocab)
|
38 |
+
self.pho_vocab = {c: idx for idx, c in enumerate(pho_vocab)}
|
39 |
+
|
40 |
+
def get_pho_size(self):
|
41 |
+
return self.pho_vocab_size
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def get_pinyin(c):
|
45 |
+
if len(c) > 1:
|
46 |
+
return 'U'
|
47 |
+
s = pypinyin.pinyin(
|
48 |
+
c,
|
49 |
+
style=pypinyin.Style.TONE3,
|
50 |
+
neutral_tone_with_five=True,
|
51 |
+
errors=lambda x: ['U' for _ in x],
|
52 |
+
)[0][0]
|
53 |
+
if s == 'U':
|
54 |
+
return s
|
55 |
+
assert isinstance(s, str)
|
56 |
+
assert s[-1] in '12345'
|
57 |
+
s = s[-1] + s[:-1]
|
58 |
+
return s
|
59 |
+
|
60 |
+
def convert(self, chars):
|
61 |
+
pinyins = list(map(self.get_pinyin, chars))
|
62 |
+
pinyin_ids = [list(map(self.pho_vocab.get, pinyin)) for pinyin in pinyins]
|
63 |
+
pinyin_lens = [len(pinyin) for pinyin in pinyins]
|
64 |
+
pinyin_ids = torch.nn.utils.rnn.pad_sequence(
|
65 |
+
[torch.tensor(x) for x in pinyin_ids],
|
66 |
+
batch_first=True,
|
67 |
+
padding_value=0,
|
68 |
+
)
|
69 |
+
return pinyin_ids, pinyin_lens
|
70 |
+
|
71 |
+
|
72 |
+
pho2_convertor = Pinyin2()
|
73 |
+
|
74 |
+
|
75 |
+
class CharResNet(torch.nn.Module):
|
76 |
+
|
77 |
+
def __init__(self, in_channels=1):
|
78 |
+
super().__init__()
|
79 |
+
# input_image: bxcx32x32, output_image: bx768x1x1
|
80 |
+
self.res_block1 = BasicBlock(in_channels, 64, stride=2) # channels: 64, size: 16x16
|
81 |
+
self.res_block2 = BasicBlock(64, 128, stride=2) # channels: 128, size: 8x8
|
82 |
+
self.res_block3 = BasicBlock(128, 256, stride=2) # channels: 256, size: 4x4
|
83 |
+
self.res_block4 = BasicBlock(256, 512, stride=2) # channels: 512, size: 2x2
|
84 |
+
self.res_block5 = BasicBlock(512, 768, stride=2) # channels: 768, size: 1x1
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
# input_shape: bxcx32x32, output_image: bx768
|
88 |
+
# x = x.unsqueeze(1)
|
89 |
+
h = self.res_block1(x)
|
90 |
+
h = self.res_block2(h)
|
91 |
+
h = self.res_block3(h)
|
92 |
+
h = self.res_block4(h)
|
93 |
+
h = self.res_block5(h)
|
94 |
+
h = h.squeeze(-1).squeeze(-1)
|
95 |
+
return h
|
96 |
+
|
97 |
+
|
98 |
+
class BasicBlock(nn.Module):
|
99 |
+
expansion = 1
|
100 |
+
|
101 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
self.residual_function = nn.Sequential(
|
105 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
|
106 |
+
nn.BatchNorm2d(out_channels),
|
107 |
+
nn.ReLU(inplace=True),
|
108 |
+
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
|
109 |
+
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
110 |
+
)
|
111 |
+
|
112 |
+
self.shortcut = nn.Sequential()
|
113 |
+
|
114 |
+
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
|
115 |
+
self.shortcut = nn.Sequential(
|
116 |
+
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
|
117 |
+
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
118 |
+
)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
122 |
+
|
123 |
+
|
124 |
+
class CharResNet1(torch.nn.Module):
|
125 |
+
|
126 |
+
def __init__(self, in_channels=1):
|
127 |
+
super().__init__()
|
128 |
+
self.res_block1 = BasicBlock(in_channels, 64, stride=2) # channels: 64, size: 16x16
|
129 |
+
self.res_block2 = BasicBlock(64, 128, stride=2) # channels: 128, size: 8x8
|
130 |
+
self.res_block3 = BasicBlock(128, 192, stride=2) # channels: 256, size: 4x4
|
131 |
+
self.res_block4 = BasicBlock(192, 192, stride=2)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
# input_shape: bxcx32x32, output_shape: bx128x8x8
|
135 |
+
h = x
|
136 |
+
h = self.res_block1(h)
|
137 |
+
h = self.res_block2(h)
|
138 |
+
h = self.res_block3(h)
|
139 |
+
h = self.res_block4(h)
|
140 |
+
h = h.view(h.shape[0], -1)
|
141 |
+
return h
|
142 |
+
|
143 |
+
|
144 |
+
class ReaLiseForCSC(BertPreTrainedModel):
|
145 |
+
|
146 |
+
def __init__(self, config):
|
147 |
+
super(ReaLiseForCSC, self).__init__(config)
|
148 |
+
self.config = config
|
149 |
+
|
150 |
+
self.vocab_size = config.vocab_size
|
151 |
+
self.bert = BertModel(config)
|
152 |
+
|
153 |
+
self.pho_embeddings = nn.Embedding(pho2_convertor.get_pho_size(), config.hidden_size, padding_idx=0)
|
154 |
+
self.pho_gru = nn.GRU(
|
155 |
+
input_size=config.hidden_size,
|
156 |
+
hidden_size=config.hidden_size,
|
157 |
+
num_layers=1,
|
158 |
+
batch_first=True,
|
159 |
+
dropout=0,
|
160 |
+
bidirectional=False,
|
161 |
+
)
|
162 |
+
pho_config = deepcopy(config)
|
163 |
+
pho_config.num_hidden_layers = 4
|
164 |
+
self.pho_model = BertModel(pho_config)
|
165 |
+
|
166 |
+
self.char_images_multifonts = torch.nn.Parameter(torch.rand(21128, 3, 32, 32))
|
167 |
+
self.char_images_multifonts.requires_grad = False
|
168 |
+
|
169 |
+
self.resnet = CharResNet(in_channels=3)
|
170 |
+
self.resnet_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
171 |
+
|
172 |
+
self.gate_net = nn.Linear(4 * config.hidden_size, 3)
|
173 |
+
|
174 |
+
out_config = deepcopy(config)
|
175 |
+
out_config.num_hidden_layers = 3
|
176 |
+
self.output_block = BertModel(out_config)
|
177 |
+
|
178 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
179 |
+
self.classifier = nn.Linear(config.hidden_size, config.vocab_size)
|
180 |
+
|
181 |
+
self.init_weights()
|
182 |
+
|
183 |
+
self.loss_fnt = CrossEntropyLoss(ignore_index=0)
|
184 |
+
|
185 |
+
self.tokenizer = None
|
186 |
+
|
187 |
+
def tie_cls_weight(self):
|
188 |
+
self.classifier.weight = self.bert.embeddings.word_embeddings.weight
|
189 |
+
|
190 |
+
def build_glyce_embed(self, vocab_dir, font_path, font_size=32):
|
191 |
+
vocab_path = os.path.join(vocab_dir, 'vocab.txt')
|
192 |
+
with open(vocab_path, 'r', encoding='utf-8') as f:
|
193 |
+
vocab = [s.strip() for s in f]
|
194 |
+
|
195 |
+
font = ImageFont.truetype(font_path, size=font_size)
|
196 |
+
|
197 |
+
char_images = []
|
198 |
+
for char in vocab:
|
199 |
+
if len(char) != 1 or (not _is_chinese_char(ord(char))):
|
200 |
+
char_images.append(np.zeros((font_size, font_size)).astype(np.float32))
|
201 |
+
continue
|
202 |
+
image = font.getmask(char)
|
203 |
+
image = np.asarray(image).astype(np.float32).reshape(image.size[::-1]) # Must be [::-1]
|
204 |
+
|
205 |
+
# Crop
|
206 |
+
image = image[:font_size, :font_size]
|
207 |
+
|
208 |
+
# Pad
|
209 |
+
if image.size != (font_size, font_size):
|
210 |
+
back_image = np.zeros((font_size, font_size)).astype(np.float32)
|
211 |
+
offset0 = (font_size - image.shape[0]) // 2
|
212 |
+
offset1 = (font_size - image.shape[1]) // 2
|
213 |
+
back_image[offset0:offset0 + image.shape[0], offset1:offset1 + image.shape[1]] = image
|
214 |
+
image = back_image
|
215 |
+
|
216 |
+
char_images.append(image)
|
217 |
+
char_images = np.array(char_images)
|
218 |
+
char_images = (char_images - np.mean(char_images)) / np.std(char_images)
|
219 |
+
char_images = torch.from_numpy(char_images).reshape(char_images.shape[0], -1)
|
220 |
+
assert char_images.shape == (21128, 1024)
|
221 |
+
self.char_images.weight.data.copy_(char_images)
|
222 |
+
|
223 |
+
# Add by hengdaxu
|
224 |
+
def build_glyce_embed_multifonts(self, vocab_dir, num_fonts, use_traditional_font, font_size=32):
|
225 |
+
font_paths = [
|
226 |
+
('simhei.ttf', False),
|
227 |
+
('xiaozhuan.ttf', False),
|
228 |
+
('simhei.ttf', True),
|
229 |
+
]
|
230 |
+
font_paths = font_paths[:num_fonts]
|
231 |
+
if use_traditional_font:
|
232 |
+
font_paths = font_paths[:-1]
|
233 |
+
font_paths.append(('simhei.ttf', True))
|
234 |
+
self.converter = opencc.OpenCC('s2t.json')
|
235 |
+
|
236 |
+
images_list = []
|
237 |
+
for font_path, use_traditional in font_paths:
|
238 |
+
images = self.build_glyce_embed_onefont(
|
239 |
+
vocab_dir=vocab_dir,
|
240 |
+
font_path=font_path,
|
241 |
+
font_size=font_size,
|
242 |
+
use_traditional=use_traditional,
|
243 |
+
)
|
244 |
+
images_list.append(images)
|
245 |
+
|
246 |
+
char_images = torch.stack(images_list, dim=1).contiguous()
|
247 |
+
self.char_images_multifonts.data.copy_(char_images)
|
248 |
+
|
249 |
+
# Add by hengdaxu
|
250 |
+
def build_glyce_embed_onefont(self, vocab_dir, font_path, font_size, use_traditional):
|
251 |
+
vocab_path = os.path.join(vocab_dir, 'vocab.txt')
|
252 |
+
with open(vocab_path, encoding='utf-8') as f:
|
253 |
+
vocab = [s.strip() for s in f.readlines()]
|
254 |
+
if use_traditional:
|
255 |
+
vocab = [self.converter.convert(c) if len(c) == 1 else c for c in vocab]
|
256 |
+
|
257 |
+
font = ImageFont.truetype(font_path, size=font_size)
|
258 |
+
|
259 |
+
char_images = []
|
260 |
+
for char in vocab:
|
261 |
+
if len(char) > 1:
|
262 |
+
char_images.append(np.zeros((font_size, font_size)).astype(np.float32))
|
263 |
+
continue
|
264 |
+
image = font.getmask(char)
|
265 |
+
image = np.asarray(image).astype(np.float32).reshape(image.size[::-1]) # Must be [::-1]
|
266 |
+
|
267 |
+
# Crop
|
268 |
+
image = image[:font_size, :font_size]
|
269 |
+
|
270 |
+
# Pad
|
271 |
+
if image.size != (font_size, font_size):
|
272 |
+
back_image = np.zeros((font_size, font_size)).astype(np.float32)
|
273 |
+
offset0 = (font_size - image.shape[0]) // 2
|
274 |
+
offset1 = (font_size - image.shape[1]) // 2
|
275 |
+
back_image[offset0:offset0 + image.shape[0], offset1:offset1 + image.shape[1]] = image
|
276 |
+
image = back_image
|
277 |
+
|
278 |
+
char_images.append(image)
|
279 |
+
char_images = np.array(char_images)
|
280 |
+
char_images = (char_images - np.mean(char_images)) / np.std(char_images)
|
281 |
+
char_images = torch.from_numpy(char_images).contiguous()
|
282 |
+
return char_images
|
283 |
+
|
284 |
+
@staticmethod
|
285 |
+
def build_batch(batch, tokenizer):
|
286 |
+
src_idx = batch['src_idx'].flatten().tolist()
|
287 |
+
chars = tokenizer.convert_ids_to_tokens(src_idx)
|
288 |
+
pho_idx, pho_lens = pho2_convertor.convert(chars)
|
289 |
+
batch['pho_idx'] = pho_idx
|
290 |
+
batch['pho_lens'] = pho_lens
|
291 |
+
return batch
|
292 |
+
|
293 |
+
def forward(self,
|
294 |
+
input_ids=None,
|
295 |
+
pho_idx=None,
|
296 |
+
pho_lens=None,
|
297 |
+
attention_mask=None,
|
298 |
+
labels=None,
|
299 |
+
**kwargs):
|
300 |
+
input_shape = input_ids.size()
|
301 |
+
|
302 |
+
bert_hiddens = self.bert(input_ids, attention_mask=attention_mask)[0]
|
303 |
+
|
304 |
+
pho_embeddings = self.pho_embeddings(pho_idx)
|
305 |
+
pho_embeddings = torch.nn.utils.rnn.pack_padded_sequence(
|
306 |
+
input=pho_embeddings,
|
307 |
+
lengths=pho_lens,
|
308 |
+
batch_first=True,
|
309 |
+
enforce_sorted=False,
|
310 |
+
)
|
311 |
+
_, pho_hiddens = self.pho_gru(pho_embeddings)
|
312 |
+
pho_hiddens = pho_hiddens.squeeze(0).reshape(input_shape[0], input_shape[1], -1).contiguous()
|
313 |
+
pho_hiddens = self.pho_model(inputs_embeds=pho_hiddens, attention_mask=attention_mask)[0]
|
314 |
+
|
315 |
+
src_idxs = input_ids.view(-1)
|
316 |
+
|
317 |
+
if self.config.num_fonts == 1:
|
318 |
+
images = self.char_images(src_idxs).reshape(src_idxs.shape[0], 1, 32, 32).contiguous()
|
319 |
+
else:
|
320 |
+
images = self.char_images_multifonts.index_select(dim=0, index=src_idxs)
|
321 |
+
|
322 |
+
res_hiddens = self.resnet(images)
|
323 |
+
res_hiddens = res_hiddens.reshape(input_shape[0], input_shape[1], -1).contiguous()
|
324 |
+
res_hiddens = self.resnet_layernorm(res_hiddens)
|
325 |
+
|
326 |
+
bert_hiddens_mean = (bert_hiddens * attention_mask.to(torch.float).unsqueeze(2)).sum(dim=1) / attention_mask.to(
|
327 |
+
torch.float).sum(dim=1, keepdim=True)
|
328 |
+
bert_hiddens_mean = bert_hiddens_mean.unsqueeze(1).expand(-1, bert_hiddens.size(1), -1)
|
329 |
+
|
330 |
+
concated_outputs = torch.cat((bert_hiddens, pho_hiddens, res_hiddens, bert_hiddens_mean), dim=-1)
|
331 |
+
gated_values = self.gate_net(concated_outputs)
|
332 |
+
# B * S * 3
|
333 |
+
g0 = torch.sigmoid(gated_values[:, :, 0].unsqueeze(-1))
|
334 |
+
g1 = torch.sigmoid(gated_values[:, :, 1].unsqueeze(-1))
|
335 |
+
g2 = torch.sigmoid(gated_values[:, :, 2].unsqueeze(-1))
|
336 |
+
|
337 |
+
hiddens = g0 * bert_hiddens + g1 * pho_hiddens + g2 * res_hiddens
|
338 |
+
|
339 |
+
outputs = self.output_block(inputs_embeds=hiddens,
|
340 |
+
position_ids=torch.zeros(input_ids.size(), dtype=torch.long,
|
341 |
+
device=input_ids.device),
|
342 |
+
attention_mask=attention_mask)
|
343 |
+
|
344 |
+
sequence_output = outputs[0]
|
345 |
+
|
346 |
+
sequence_output = self.dropout(sequence_output)
|
347 |
+
logits = self.classifier(sequence_output)
|
348 |
+
|
349 |
+
outputs = MaskedLMOutput(
|
350 |
+
logits=logits,
|
351 |
+
hidden_states=outputs.last_hidden_state,
|
352 |
+
)
|
353 |
+
|
354 |
+
if labels is not None:
|
355 |
+
# Only keep active parts of the loss
|
356 |
+
labels[labels == 101] = 0
|
357 |
+
labels[labels == 102] = 0
|
358 |
+
loss = self.loss_fnt(logits.view(-1, logits.size(-1)), labels.view(-1))
|
359 |
+
outputs.loss = loss
|
360 |
+
|
361 |
+
return outputs
|
362 |
+
|
363 |
+
def set_tokenizer(self, tokenizer):
|
364 |
+
self.tokenizer = tokenizer
|
365 |
+
|
366 |
+
def predict(self, sentences):
|
367 |
+
if self.tokenizer is None:
|
368 |
+
raise RuntimeError("Please init tokenizer by `set_tokenizer(tokenizer)` before predict.")
|
369 |
+
|
370 |
+
str_flag = False
|
371 |
+
if type(sentences) == str:
|
372 |
+
sentences = [sentences]
|
373 |
+
str_flag = True
|
374 |
+
|
375 |
+
inputs = self.tokenizer(sentences, padding=True, return_tensors="pt")
|
376 |
+
outputs = self.forward(**inputs).logits
|
377 |
+
|
378 |
+
ids_list = outputs.argmax(-1)
|
379 |
+
|
380 |
+
preds = []
|
381 |
+
for i, ids in enumerate(ids_list):
|
382 |
+
ids = ids[inputs['attention_mask'][i].bool()]
|
383 |
+
pred_tokens = self.tokenizer.convert_ids_to_tokens(ids)
|
384 |
+
pred_tokens = [t if not t.startswith('##') else t[2:] for t in pred_tokens]
|
385 |
+
pred_tokens = [t if t != self.tokenizer.unk_token else '×' for t in pred_tokens]
|
386 |
+
|
387 |
+
offsets = inputs[i].offsets
|
388 |
+
src_tokens = list(sentences[i])
|
389 |
+
for (start, end), pred_token in zip(offsets, pred_tokens):
|
390 |
+
if end - start <= 0:
|
391 |
+
continue
|
392 |
+
|
393 |
+
if (end - start) != len(pred_token):
|
394 |
+
continue
|
395 |
+
|
396 |
+
if pred_token == '×':
|
397 |
+
continue
|
398 |
+
|
399 |
+
if (end - start) == 1 and not _is_chinese_char(ord(src_tokens[start])):
|
400 |
+
continue
|
401 |
+
|
402 |
+
src_tokens[start:end] = pred_token
|
403 |
+
|
404 |
+
pred = ''.join(src_tokens)
|
405 |
+
preds.append(pred)
|
406 |
+
|
407 |
+
if str_flag:
|
408 |
+
return preds[0]
|
409 |
+
|
410 |
+
return preds
|
csc_tokenizer.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union, Optional
|
2 |
+
|
3 |
+
import pypinyin
|
4 |
+
import torch
|
5 |
+
from torch import NoneType
|
6 |
+
|
7 |
+
from transformers import BertTokenizerFast
|
8 |
+
|
9 |
+
|
10 |
+
class Pinyin2(object):
|
11 |
+
def __init__(self):
|
12 |
+
super(Pinyin2, self).__init__()
|
13 |
+
pho_vocab = ['P']
|
14 |
+
pho_vocab += [chr(x) for x in range(ord('1'), ord('5') + 1)]
|
15 |
+
pho_vocab += [chr(x) for x in range(ord('a'), ord('z') + 1)]
|
16 |
+
pho_vocab += ['U']
|
17 |
+
assert len(pho_vocab) == 33
|
18 |
+
self.pho_vocab_size = len(pho_vocab)
|
19 |
+
self.pho_vocab = {c: idx for idx, c in enumerate(pho_vocab)}
|
20 |
+
|
21 |
+
def get_pho_size(self):
|
22 |
+
return self.pho_vocab_size
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def get_pinyin(c):
|
26 |
+
if len(c) > 1:
|
27 |
+
return 'U'
|
28 |
+
s = pypinyin.pinyin(
|
29 |
+
c,
|
30 |
+
style=pypinyin.Style.TONE3,
|
31 |
+
neutral_tone_with_five=True,
|
32 |
+
errors=lambda x: ['U' for _ in x],
|
33 |
+
)[0][0]
|
34 |
+
if s == 'U':
|
35 |
+
return s
|
36 |
+
assert isinstance(s, str)
|
37 |
+
assert s[-1] in '12345'
|
38 |
+
s = s[-1] + s[:-1]
|
39 |
+
return s
|
40 |
+
|
41 |
+
def convert(self, chars):
|
42 |
+
pinyins = list(map(self.get_pinyin, chars))
|
43 |
+
pinyin_ids = [list(map(self.pho_vocab.get, pinyin)) for pinyin in pinyins]
|
44 |
+
pinyin_lens = [len(pinyin) for pinyin in pinyins]
|
45 |
+
pinyin_ids = torch.nn.utils.rnn.pad_sequence(
|
46 |
+
[torch.tensor(x) for x in pinyin_ids],
|
47 |
+
batch_first=True,
|
48 |
+
padding_value=0,
|
49 |
+
)
|
50 |
+
return pinyin_ids, pinyin_lens
|
51 |
+
|
52 |
+
|
53 |
+
class ReaLiSeTokenizer(BertTokenizerFast):
|
54 |
+
|
55 |
+
def __init__(self, **kwargs):
|
56 |
+
super(ReaLiSeTokenizer, self).__init__(**kwargs)
|
57 |
+
|
58 |
+
self.pho2_convertor = Pinyin2()
|
59 |
+
|
60 |
+
def __call__(self,
|
61 |
+
text: Union[str, List[str], List[List[str]]] = None,
|
62 |
+
text_pair: Union[str, List[str], List[List[str]], NoneType] = None,
|
63 |
+
text_target: Union[str, List[str], List[List[str]]] = None,
|
64 |
+
text_pair_target: Union[str, List[str], List[List[str]], NoneType] = None,
|
65 |
+
add_special_tokens: bool = True,
|
66 |
+
padding=False,
|
67 |
+
truncation=None,
|
68 |
+
max_length: Optional[int] = None,
|
69 |
+
stride: int = 0,
|
70 |
+
is_split_into_words: bool = False,
|
71 |
+
pad_to_multiple_of: Optional[int] = None,
|
72 |
+
return_tensors=None,
|
73 |
+
return_token_type_ids: Optional[bool] = None,
|
74 |
+
return_attention_mask: Optional[bool] = None,
|
75 |
+
return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False,
|
76 |
+
return_offsets_mapping: bool = False,
|
77 |
+
return_length: bool = False,
|
78 |
+
verbose: bool = True, **kwargs):
|
79 |
+
encoding = super(ReaLiSeTokenizer, self).__call__(
|
80 |
+
text=text,
|
81 |
+
text_pair=text_pair,
|
82 |
+
text_target=text_target,
|
83 |
+
text_pair_target=text_pair_target,
|
84 |
+
add_special_tokens=add_special_tokens,
|
85 |
+
padding=padding,
|
86 |
+
truncation=truncation,
|
87 |
+
max_length=max_length,
|
88 |
+
stride=stride,
|
89 |
+
is_split_into_words=is_split_into_words,
|
90 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
91 |
+
return_tensors=return_tensors,
|
92 |
+
return_token_type_ids=return_token_type_ids,
|
93 |
+
return_attention_mask=return_attention_mask,
|
94 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
95 |
+
return_offsets_mapping=return_offsets_mapping,
|
96 |
+
return_length=return_length,
|
97 |
+
verbose=verbose,
|
98 |
+
)
|
99 |
+
|
100 |
+
input_ids = encoding['input_ids']
|
101 |
+
if type(text) == str and return_tensors is None:
|
102 |
+
input_ids = [input_ids]
|
103 |
+
|
104 |
+
pho_idx_list = []
|
105 |
+
pho_lens_list = []
|
106 |
+
for ids in input_ids:
|
107 |
+
chars = self.convert_ids_to_tokens(ids)
|
108 |
+
pho_idx, pho_lens = self.pho2_convertor.convert(chars)
|
109 |
+
if return_tensors is None:
|
110 |
+
pho_idx = pho_idx.tolist()
|
111 |
+
pho_idx_list.append(pho_idx)
|
112 |
+
pho_lens_list += pho_lens
|
113 |
+
|
114 |
+
pho_idx = pho_idx_list
|
115 |
+
pho_lens = pho_lens_list
|
116 |
+
if return_tensors == 'pt':
|
117 |
+
pho_idx = torch.vstack(pho_idx)
|
118 |
+
pho_lens = torch.LongTensor(pho_lens)
|
119 |
+
|
120 |
+
if type(text) == str and return_tensors is None:
|
121 |
+
pho_idx = pho_idx[0]
|
122 |
+
|
123 |
+
encoding['pho_idx'] = pho_idx
|
124 |
+
encoding['pho_lens'] = pho_lens
|
125 |
+
|
126 |
+
return encoding
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d2c1d619e580891dffbe99d430cc116a82389433ab2a7ed15f8b43088135cf8d
|
3 |
+
size 1140770411
|
special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoTokenizer": [
|
4 |
+
"csc_tokenizer.ReaLiSeTokenizer",
|
5 |
+
null
|
6 |
+
]
|
7 |
+
},
|
8 |
+
"clean_up_tokenization_spaces": true,
|
9 |
+
"cls_token": "[CLS]",
|
10 |
+
"do_basic_tokenize": true,
|
11 |
+
"do_lower_case": false,
|
12 |
+
"mask_token": "[MASK]",
|
13 |
+
"model_max_length": 1000000000000000019884624838656,
|
14 |
+
"never_split": null,
|
15 |
+
"pad_token": "[PAD]",
|
16 |
+
"sep_token": "[SEP]",
|
17 |
+
"strip_accents": null,
|
18 |
+
"tokenize_chinese_chars": true,
|
19 |
+
"tokenizer_class": "ReaLiSeTokenizer",
|
20 |
+
"unk_token": "[UNK]"
|
21 |
+
}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|