Upload 6 files
Browse files- models/classifier.py +22 -0
- models/fusion_embedding.py +80 -0
- models/glyph_embedding.py +47 -0
- models/modeling_glycebert.py +532 -0
- models/pinyin_embedding.py +50 -0
- models/tokenizer.py +83 -0
models/classifier.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# author: xiaoya li
|
5 |
+
# first create: 2021.01.25
|
6 |
+
# files: nn_modules.py
|
7 |
+
#
|
8 |
+
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
class BertMLP(nn.Module):
|
12 |
+
def __init__(self, config,):
|
13 |
+
super().__init__()
|
14 |
+
self.dense_layer = nn.Linear(config.hidden_size, config.hidden_size)
|
15 |
+
self.dense_to_labels_layer = nn.Linear(config.hidden_size, config.num_labels)
|
16 |
+
self.activation = nn.Tanh()
|
17 |
+
|
18 |
+
def forward(self, sequence_hidden_states):
|
19 |
+
sequence_output = self.dense_layer(sequence_hidden_states)
|
20 |
+
sequence_output = self.activation(sequence_output)
|
21 |
+
sequence_output = self.dense_to_labels_layer(sequence_output)
|
22 |
+
return sequence_output
|
models/fusion_embedding.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@file : glyce_embedding.py
|
5 |
+
@author: zijun
|
6 |
+
@contact : [email protected]
|
7 |
+
@date : 2020/8/23 10:40
|
8 |
+
@version: 1.0
|
9 |
+
@desc : 【char embedding】+【pinyin embedding】+【glyph embedding】 = fusion embedding
|
10 |
+
"""
|
11 |
+
import os
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
from models.glyph_embedding import GlyphEmbedding
|
17 |
+
from models.pinyin_embedding import PinyinEmbedding
|
18 |
+
|
19 |
+
|
20 |
+
class FusionBertEmbeddings(nn.Module):
|
21 |
+
"""
|
22 |
+
Construct the embeddings from word, position, glyph, pinyin and token_type embeddings.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, config):
|
26 |
+
super(FusionBertEmbeddings, self).__init__()
|
27 |
+
config_path = os.path.join(config.name_or_path, 'config')
|
28 |
+
font_files = []
|
29 |
+
for file in os.listdir(config_path):
|
30 |
+
if file.endswith(".npy"):
|
31 |
+
font_files.append(os.path.join(config_path, file))
|
32 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
33 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
34 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
35 |
+
self.pinyin_embeddings = PinyinEmbedding(embedding_size=128, pinyin_out_dim=config.hidden_size,
|
36 |
+
config_path=config_path)
|
37 |
+
self.glyph_embeddings = GlyphEmbedding(font_npy_files=font_files)
|
38 |
+
|
39 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow models variable name and be able to load
|
40 |
+
# any TensorFlow checkpoint file
|
41 |
+
self.glyph_map = nn.Linear(1728, config.hidden_size)
|
42 |
+
self.map_fc = nn.Linear(config.hidden_size * 3, config.hidden_size)
|
43 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
44 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
45 |
+
|
46 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
47 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
48 |
+
|
49 |
+
def forward(self, input_ids=None, pinyin_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
50 |
+
if input_ids is not None:
|
51 |
+
input_shape = input_ids.size()
|
52 |
+
else:
|
53 |
+
input_shape = inputs_embeds.size()[:-1]
|
54 |
+
|
55 |
+
seq_length = input_shape[1]
|
56 |
+
|
57 |
+
if position_ids is None:
|
58 |
+
position_ids = self.position_ids[:, :seq_length]
|
59 |
+
|
60 |
+
if token_type_ids is None:
|
61 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
62 |
+
|
63 |
+
if inputs_embeds is None:
|
64 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
65 |
+
|
66 |
+
# get char embedding, pinyin embedding and glyph embedding
|
67 |
+
word_embeddings = inputs_embeds # [bs,l,hidden_size]
|
68 |
+
pinyin_embeddings = self.pinyin_embeddings(pinyin_ids) # [bs,l,hidden_size]
|
69 |
+
glyph_embeddings = self.glyph_map(self.glyph_embeddings(input_ids)) # [bs,l,hidden_size]
|
70 |
+
# fusion layer
|
71 |
+
concat_embeddings = torch.cat((word_embeddings, pinyin_embeddings, glyph_embeddings), 2)
|
72 |
+
inputs_embeds = self.map_fc(concat_embeddings)
|
73 |
+
|
74 |
+
position_embeddings = self.position_embeddings(position_ids)
|
75 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
76 |
+
|
77 |
+
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
78 |
+
embeddings = self.LayerNorm(embeddings)
|
79 |
+
embeddings = self.dropout(embeddings)
|
80 |
+
return embeddings
|
models/glyph_embedding.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding: utf-8
|
2 |
+
"""
|
3 |
+
@author: Yuxian Meng
|
4 |
+
@contact: [email protected]
|
5 |
+
|
6 |
+
@version: 1.0
|
7 |
+
@file: glyph_embedding
|
8 |
+
@time: 2020/8/4 15:04
|
9 |
+
|
10 |
+
"""
|
11 |
+
|
12 |
+
from typing import List
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
|
19 |
+
class GlyphEmbedding(nn.Module):
|
20 |
+
"""Glyph2Image Embedding"""
|
21 |
+
|
22 |
+
def __init__(self, font_npy_files: List[str]):
|
23 |
+
super(GlyphEmbedding, self).__init__()
|
24 |
+
font_arrays = [
|
25 |
+
np.load(np_file).astype(np.float32) for np_file in font_npy_files
|
26 |
+
]
|
27 |
+
self.vocab_size = font_arrays[0].shape[0]
|
28 |
+
self.font_num = len(font_arrays)
|
29 |
+
self.font_size = font_arrays[0].shape[-1]
|
30 |
+
# N, C, H, W
|
31 |
+
font_array = np.stack(font_arrays, axis=1)
|
32 |
+
self.embedding = nn.Embedding(
|
33 |
+
num_embeddings=self.vocab_size,
|
34 |
+
embedding_dim=self.font_size ** 2 * self.font_num,
|
35 |
+
_weight=torch.from_numpy(font_array.reshape([self.vocab_size, -1]))
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, input_ids):
|
39 |
+
"""
|
40 |
+
get glyph images for batch inputs
|
41 |
+
Args:
|
42 |
+
input_ids: [batch, sentence_length]
|
43 |
+
Returns:
|
44 |
+
images: [batch, sentence_length, self.font_num*self.font_size*self.font_size]
|
45 |
+
"""
|
46 |
+
# return self.embedding(input_ids).view([-1, self.font_num, self.font_size, self.font_size])
|
47 |
+
return self.embedding(input_ids)
|
models/modeling_glycebert.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@file : modeling_glycebert.py
|
5 |
+
@author: zijun
|
6 |
+
@contact : [email protected]
|
7 |
+
@date : 2020/9/6 18:50
|
8 |
+
@version: 1.0
|
9 |
+
@desc : ChineseBert Model
|
10 |
+
"""
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
16 |
+
try:
|
17 |
+
from transformers.modeling_bert import BertEncoder, BertPooler, BertOnlyMLMHead, BertPreTrainedModel, BertModel
|
18 |
+
except:
|
19 |
+
from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertOnlyMLMHead, BertPreTrainedModel, BertModel
|
20 |
+
|
21 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput, \
|
22 |
+
QuestionAnsweringModelOutput, TokenClassifierOutput
|
23 |
+
|
24 |
+
from models.fusion_embedding import FusionBertEmbeddings
|
25 |
+
from models.classifier import BertMLP
|
26 |
+
|
27 |
+
class GlyceBertModel(BertModel):
|
28 |
+
r"""
|
29 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
30 |
+
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
31 |
+
Sequence of hidden-states at the output of the last layer of the models.
|
32 |
+
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
33 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
34 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
35 |
+
layer weights are trained from the next sentence prediction (classification)
|
36 |
+
objective during Bert pretraining. This output is usually *not* a good summary
|
37 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
38 |
+
the sequence of hidden-states for the whole input sequence.
|
39 |
+
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
40 |
+
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
41 |
+
of shape ``(batch_size, sequence_length, hidden_size)``:
|
42 |
+
Hidden-states of the models at the output of each layer plus the initial embedding outputs.
|
43 |
+
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
44 |
+
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
45 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
46 |
+
|
47 |
+
Examples::
|
48 |
+
|
49 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
50 |
+
models = BertModel.from_pretrained('bert-base-uncased')
|
51 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
52 |
+
outputs = models(input_ids)
|
53 |
+
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
54 |
+
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, config):
|
58 |
+
super(GlyceBertModel, self).__init__(config)
|
59 |
+
self.config = config
|
60 |
+
|
61 |
+
self.embeddings = FusionBertEmbeddings(config)
|
62 |
+
self.encoder = BertEncoder(config)
|
63 |
+
self.pooler = BertPooler(config)
|
64 |
+
|
65 |
+
self.init_weights()
|
66 |
+
|
67 |
+
def forward(
|
68 |
+
self,
|
69 |
+
input_ids=None,
|
70 |
+
pinyin_ids=None,
|
71 |
+
attention_mask=None,
|
72 |
+
token_type_ids=None,
|
73 |
+
position_ids=None,
|
74 |
+
head_mask=None,
|
75 |
+
inputs_embeds=None,
|
76 |
+
encoder_hidden_states=None,
|
77 |
+
encoder_attention_mask=None,
|
78 |
+
output_attentions=None,
|
79 |
+
output_hidden_states=None,
|
80 |
+
return_dict=None,
|
81 |
+
):
|
82 |
+
r"""
|
83 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
84 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
85 |
+
if the models is configured as a decoder.
|
86 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
87 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
88 |
+
is used in the cross-attention if the models is configured as a decoder.
|
89 |
+
Mask values selected in ``[0, 1]``:
|
90 |
+
|
91 |
+
- 1 for tokens that are **not masked**,
|
92 |
+
- 0 for tokens that are **masked**.
|
93 |
+
"""
|
94 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
95 |
+
output_hidden_states = (
|
96 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
97 |
+
)
|
98 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
99 |
+
|
100 |
+
if input_ids is not None and inputs_embeds is not None:
|
101 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
102 |
+
elif input_ids is not None:
|
103 |
+
input_shape = input_ids.size()
|
104 |
+
elif inputs_embeds is not None:
|
105 |
+
input_shape = inputs_embeds.size()[:-1]
|
106 |
+
else:
|
107 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
108 |
+
|
109 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
110 |
+
|
111 |
+
if attention_mask is None:
|
112 |
+
attention_mask = torch.ones(input_shape, device=device)
|
113 |
+
if token_type_ids is None:
|
114 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
115 |
+
|
116 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
117 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
118 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
119 |
+
|
120 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
121 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
122 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
123 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
124 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
125 |
+
if encoder_attention_mask is None:
|
126 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
127 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
128 |
+
else:
|
129 |
+
encoder_extended_attention_mask = None
|
130 |
+
|
131 |
+
# Prepare head mask if needed
|
132 |
+
# 1.0 in head_mask indicate we keep the head
|
133 |
+
# attention_probs has shape bsz x n_heads x N x N
|
134 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
135 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
136 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
137 |
+
|
138 |
+
embedding_output = self.embeddings(
|
139 |
+
input_ids=input_ids, pinyin_ids=pinyin_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
140 |
+
inputs_embeds=inputs_embeds
|
141 |
+
)
|
142 |
+
encoder_outputs = self.encoder(
|
143 |
+
embedding_output,
|
144 |
+
attention_mask=extended_attention_mask,
|
145 |
+
head_mask=head_mask,
|
146 |
+
encoder_hidden_states=encoder_hidden_states,
|
147 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
148 |
+
output_attentions=output_attentions,
|
149 |
+
output_hidden_states=output_hidden_states,
|
150 |
+
return_dict=return_dict,
|
151 |
+
)
|
152 |
+
sequence_output = encoder_outputs[0]
|
153 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
154 |
+
|
155 |
+
if not return_dict:
|
156 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
157 |
+
|
158 |
+
return BaseModelOutputWithPooling(
|
159 |
+
last_hidden_state=sequence_output,
|
160 |
+
pooler_output=pooled_output,
|
161 |
+
hidden_states=encoder_outputs.hidden_states,
|
162 |
+
attentions=encoder_outputs.attentions,
|
163 |
+
)
|
164 |
+
|
165 |
+
|
166 |
+
class GlyceBertForMaskedLM(BertPreTrainedModel):
|
167 |
+
def __init__(self, config):
|
168 |
+
super(GlyceBertForMaskedLM, self).__init__(config)
|
169 |
+
|
170 |
+
self.bert = GlyceBertModel(config)
|
171 |
+
self.cls = BertOnlyMLMHead(config)
|
172 |
+
|
173 |
+
self.init_weights()
|
174 |
+
|
175 |
+
def get_output_embeddings(self):
|
176 |
+
return self.cls.predictions.decoder
|
177 |
+
|
178 |
+
def forward(
|
179 |
+
self,
|
180 |
+
input_ids=None,
|
181 |
+
pinyin_ids=None,
|
182 |
+
attention_mask=None,
|
183 |
+
token_type_ids=None,
|
184 |
+
position_ids=None,
|
185 |
+
head_mask=None,
|
186 |
+
inputs_embeds=None,
|
187 |
+
encoder_hidden_states=None,
|
188 |
+
encoder_attention_mask=None,
|
189 |
+
labels=None,
|
190 |
+
output_attentions=None,
|
191 |
+
output_hidden_states=None,
|
192 |
+
return_dict=None,
|
193 |
+
**kwargs
|
194 |
+
):
|
195 |
+
r"""
|
196 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
197 |
+
Labels for computing the masked language modeling loss.
|
198 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
199 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
200 |
+
in ``[0, ..., config.vocab_size]``
|
201 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
202 |
+
Used to hide legacy arguments that have been deprecated.
|
203 |
+
"""
|
204 |
+
if "masked_lm_labels" in kwargs:
|
205 |
+
warnings.warn(
|
206 |
+
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
207 |
+
FutureWarning,
|
208 |
+
)
|
209 |
+
labels = kwargs.pop("masked_lm_labels")
|
210 |
+
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
|
211 |
+
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
212 |
+
|
213 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
214 |
+
|
215 |
+
outputs = self.bert(
|
216 |
+
input_ids,
|
217 |
+
pinyin_ids,
|
218 |
+
attention_mask=attention_mask,
|
219 |
+
token_type_ids=token_type_ids,
|
220 |
+
position_ids=position_ids,
|
221 |
+
head_mask=head_mask,
|
222 |
+
inputs_embeds=inputs_embeds,
|
223 |
+
encoder_hidden_states=encoder_hidden_states,
|
224 |
+
encoder_attention_mask=encoder_attention_mask,
|
225 |
+
output_attentions=output_attentions,
|
226 |
+
output_hidden_states=output_hidden_states,
|
227 |
+
return_dict=return_dict,
|
228 |
+
)
|
229 |
+
|
230 |
+
sequence_output = outputs[0]
|
231 |
+
prediction_scores = self.cls(sequence_output)
|
232 |
+
|
233 |
+
masked_lm_loss = None
|
234 |
+
if labels is not None:
|
235 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
236 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
237 |
+
|
238 |
+
if not return_dict:
|
239 |
+
output = (prediction_scores,) + outputs[2:]
|
240 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
241 |
+
|
242 |
+
return MaskedLMOutput(
|
243 |
+
loss=masked_lm_loss,
|
244 |
+
logits=prediction_scores,
|
245 |
+
hidden_states=outputs.hidden_states,
|
246 |
+
attentions=outputs.attentions,
|
247 |
+
)
|
248 |
+
|
249 |
+
|
250 |
+
class GlyceBertForSequenceClassification(BertPreTrainedModel):
|
251 |
+
def __init__(self, config):
|
252 |
+
super().__init__(config)
|
253 |
+
self.num_labels = config.num_labels
|
254 |
+
|
255 |
+
self.bert = GlyceBertModel(config)
|
256 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
257 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
258 |
+
|
259 |
+
self.init_weights()
|
260 |
+
|
261 |
+
def forward(
|
262 |
+
self,
|
263 |
+
input_ids=None,
|
264 |
+
pinyin_ids=None,
|
265 |
+
attention_mask=None,
|
266 |
+
token_type_ids=None,
|
267 |
+
position_ids=None,
|
268 |
+
head_mask=None,
|
269 |
+
inputs_embeds=None,
|
270 |
+
labels=None,
|
271 |
+
output_attentions=None,
|
272 |
+
output_hidden_states=None,
|
273 |
+
return_dict=None,
|
274 |
+
):
|
275 |
+
r"""
|
276 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
277 |
+
Labels for computing the sequence classification/regression loss.
|
278 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
279 |
+
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
280 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
281 |
+
"""
|
282 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
283 |
+
|
284 |
+
outputs = self.bert(
|
285 |
+
input_ids,
|
286 |
+
pinyin_ids,
|
287 |
+
attention_mask=attention_mask,
|
288 |
+
token_type_ids=token_type_ids,
|
289 |
+
position_ids=position_ids,
|
290 |
+
head_mask=head_mask,
|
291 |
+
inputs_embeds=inputs_embeds,
|
292 |
+
output_attentions=output_attentions,
|
293 |
+
output_hidden_states=output_hidden_states,
|
294 |
+
return_dict=return_dict,
|
295 |
+
)
|
296 |
+
|
297 |
+
pooled_output = outputs[1]
|
298 |
+
|
299 |
+
pooled_output = self.dropout(pooled_output)
|
300 |
+
logits = self.classifier(pooled_output)
|
301 |
+
|
302 |
+
loss = None
|
303 |
+
if labels is not None:
|
304 |
+
if self.num_labels == 1:
|
305 |
+
# We are doing regression
|
306 |
+
loss_fct = MSELoss()
|
307 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
308 |
+
else:
|
309 |
+
loss_fct = CrossEntropyLoss()
|
310 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
311 |
+
|
312 |
+
if not return_dict:
|
313 |
+
output = (logits,) + outputs[2:]
|
314 |
+
return ((loss,) + output) if loss is not None else output
|
315 |
+
|
316 |
+
return SequenceClassifierOutput(
|
317 |
+
loss=loss,
|
318 |
+
logits=logits,
|
319 |
+
hidden_states=outputs.hidden_states,
|
320 |
+
attentions=outputs.attentions,
|
321 |
+
)
|
322 |
+
|
323 |
+
|
324 |
+
class GlyceBertForQuestionAnswering(BertPreTrainedModel):
|
325 |
+
"""BERT model for Question Answering (span extraction).
|
326 |
+
This module is composed of the BERT model with a linear layer on top of
|
327 |
+
the sequence output that computes start_logits and end_logits
|
328 |
+
|
329 |
+
Params:
|
330 |
+
`config`: a BertConfig class instance with the configuration to build a new model.
|
331 |
+
|
332 |
+
Inputs:
|
333 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
334 |
+
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
335 |
+
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
336 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
337 |
+
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
338 |
+
a `sentence B` token (see BERT paper for more details).
|
339 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
340 |
+
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
341 |
+
input sequence length in the current batch. It's the mask that we typically use for attention when
|
342 |
+
a batch has varying length sentences.
|
343 |
+
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
|
344 |
+
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
|
345 |
+
into account for computing the loss.
|
346 |
+
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
|
347 |
+
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
|
348 |
+
into account for computing the loss.
|
349 |
+
|
350 |
+
Outputs:
|
351 |
+
if `start_positions` and `end_positions` are not `None`:
|
352 |
+
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
|
353 |
+
if `start_positions` or `end_positions` is `None`:
|
354 |
+
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
|
355 |
+
position tokens of shape [batch_size, sequence_length].
|
356 |
+
|
357 |
+
Example usage:
|
358 |
+
```python
|
359 |
+
# Already been converted into WordPiece token ids
|
360 |
+
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
361 |
+
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
362 |
+
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
363 |
+
|
364 |
+
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
365 |
+
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
366 |
+
|
367 |
+
model = BertForQuestionAnswering(config)
|
368 |
+
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
369 |
+
```
|
370 |
+
"""
|
371 |
+
|
372 |
+
def __init__(self, config):
|
373 |
+
super().__init__(config)
|
374 |
+
self.num_labels = config.num_labels
|
375 |
+
|
376 |
+
self.bert = GlyceBertModel(config)
|
377 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
378 |
+
|
379 |
+
self.init_weights()
|
380 |
+
|
381 |
+
def forward(
|
382 |
+
self,
|
383 |
+
input_ids=None,
|
384 |
+
pinyin_ids=None,
|
385 |
+
attention_mask=None,
|
386 |
+
token_type_ids=None,
|
387 |
+
position_ids=None,
|
388 |
+
head_mask=None,
|
389 |
+
inputs_embeds=None,
|
390 |
+
start_positions=None,
|
391 |
+
end_positions=None,
|
392 |
+
output_attentions=None,
|
393 |
+
output_hidden_states=None,
|
394 |
+
return_dict=None,
|
395 |
+
):
|
396 |
+
r"""
|
397 |
+
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
398 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
399 |
+
Positions are clamped to the length of the sequence (:obj:`sequence_length`).
|
400 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
401 |
+
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
402 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
403 |
+
Positions are clamped to the length of the sequence (:obj:`sequence_length`).
|
404 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
405 |
+
"""
|
406 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
407 |
+
|
408 |
+
outputs = self.bert(
|
409 |
+
input_ids,
|
410 |
+
pinyin_ids,
|
411 |
+
attention_mask=attention_mask,
|
412 |
+
token_type_ids=token_type_ids,
|
413 |
+
position_ids=position_ids,
|
414 |
+
head_mask=head_mask,
|
415 |
+
inputs_embeds=inputs_embeds,
|
416 |
+
output_attentions=output_attentions,
|
417 |
+
output_hidden_states=output_hidden_states,
|
418 |
+
return_dict=return_dict,
|
419 |
+
)
|
420 |
+
|
421 |
+
sequence_output = outputs[0]
|
422 |
+
|
423 |
+
logits = self.qa_outputs(sequence_output)
|
424 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
425 |
+
start_logits = start_logits.squeeze(-1)
|
426 |
+
end_logits = end_logits.squeeze(-1)
|
427 |
+
|
428 |
+
total_loss = None
|
429 |
+
if start_positions is not None and end_positions is not None:
|
430 |
+
# If we are on multi-GPU, split add a dimension
|
431 |
+
if len(start_positions.size()) > 1:
|
432 |
+
start_positions = start_positions.squeeze(-1)
|
433 |
+
if len(end_positions.size()) > 1:
|
434 |
+
end_positions = end_positions.squeeze(-1)
|
435 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
436 |
+
ignored_index = start_logits.size(1)
|
437 |
+
start_positions.clamp_(0, ignored_index)
|
438 |
+
end_positions.clamp_(0, ignored_index)
|
439 |
+
|
440 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
441 |
+
start_loss = loss_fct(start_logits, start_positions)
|
442 |
+
end_loss = loss_fct(end_logits, end_positions)
|
443 |
+
total_loss = (start_loss + end_loss) / 2
|
444 |
+
|
445 |
+
if not return_dict:
|
446 |
+
output = (start_logits, end_logits) + outputs[2:]
|
447 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
448 |
+
|
449 |
+
return QuestionAnsweringModelOutput(
|
450 |
+
loss=total_loss,
|
451 |
+
start_logits=start_logits,
|
452 |
+
end_logits=end_logits,
|
453 |
+
hidden_states=outputs.hidden_states,
|
454 |
+
attentions=outputs.attentions,
|
455 |
+
)
|
456 |
+
|
457 |
+
class GlyceBertForTokenClassification(BertPreTrainedModel):
|
458 |
+
def __init__(self, config, mlp=False):
|
459 |
+
super().__init__(config)
|
460 |
+
self.num_labels = config.num_labels
|
461 |
+
|
462 |
+
self.bert = GlyceBertModel(config)
|
463 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
464 |
+
if mlp:
|
465 |
+
self.classifier = BertMLP(config)
|
466 |
+
else:
|
467 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
468 |
+
|
469 |
+
self.init_weights()
|
470 |
+
|
471 |
+
def forward(self,
|
472 |
+
input_ids=None,
|
473 |
+
pinyin_ids=None,
|
474 |
+
attention_mask=None,
|
475 |
+
token_type_ids=None,
|
476 |
+
position_ids=None,
|
477 |
+
head_mask=None,
|
478 |
+
inputs_embeds=None,
|
479 |
+
labels=None,
|
480 |
+
output_attentions=None,
|
481 |
+
output_hidden_states=None,
|
482 |
+
return_dict=None,
|
483 |
+
):
|
484 |
+
r"""
|
485 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
486 |
+
Labels for computing the token classification loss.
|
487 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
488 |
+
"""
|
489 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
490 |
+
|
491 |
+
outputs = self.bert(
|
492 |
+
input_ids,
|
493 |
+
pinyin_ids,
|
494 |
+
attention_mask=attention_mask,
|
495 |
+
token_type_ids=token_type_ids,
|
496 |
+
position_ids=position_ids,
|
497 |
+
head_mask=head_mask,
|
498 |
+
inputs_embeds=inputs_embeds,
|
499 |
+
output_attentions=output_attentions,
|
500 |
+
output_hidden_states=output_hidden_states,
|
501 |
+
return_dict=return_dict,
|
502 |
+
)
|
503 |
+
|
504 |
+
sequence_output = outputs[0]
|
505 |
+
|
506 |
+
sequence_output = self.dropout(sequence_output)
|
507 |
+
logits = self.classifier(sequence_output)
|
508 |
+
|
509 |
+
loss = None
|
510 |
+
if labels is not None:
|
511 |
+
loss_fct = CrossEntropyLoss()
|
512 |
+
# Only keep the active parts of the loss
|
513 |
+
if attention_mask is not None:
|
514 |
+
active_loss = attention_mask.view(-1) == 1
|
515 |
+
active_logits = logits.view(-1, self.num_labels)
|
516 |
+
active_labels = torch.where(
|
517 |
+
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
518 |
+
)
|
519 |
+
loss = loss_fct(active_logits, active_labels)
|
520 |
+
else:
|
521 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
522 |
+
|
523 |
+
if not return_dict:
|
524 |
+
output = (logits,) + outputs[2:]
|
525 |
+
return ((loss,) + output) if loss is not None else output
|
526 |
+
|
527 |
+
return TokenClassifierOutput(
|
528 |
+
loss=loss,
|
529 |
+
logits=logits,
|
530 |
+
hidden_states=outputs.hidden_states,
|
531 |
+
attentions=outputs.attentions,
|
532 |
+
)
|
models/pinyin_embedding.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@file : pinyin.py
|
5 |
+
@author: zijun
|
6 |
+
@contact : [email protected]
|
7 |
+
@date : 2020/8/16 14:45
|
8 |
+
@version: 1.0
|
9 |
+
@desc : pinyin embedding
|
10 |
+
"""
|
11 |
+
import json
|
12 |
+
import os
|
13 |
+
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
|
17 |
+
|
18 |
+
class PinyinEmbedding(nn.Module):
|
19 |
+
def __init__(self, embedding_size: int, pinyin_out_dim: int, config_path):
|
20 |
+
"""
|
21 |
+
Pinyin Embedding Module
|
22 |
+
Args:
|
23 |
+
embedding_size: the size of each embedding vector
|
24 |
+
pinyin_out_dim: kernel number of conv
|
25 |
+
"""
|
26 |
+
super(PinyinEmbedding, self).__init__()
|
27 |
+
with open(os.path.join(config_path, 'pinyin_map.json')) as fin:
|
28 |
+
pinyin_dict = json.load(fin)
|
29 |
+
self.pinyin_out_dim = pinyin_out_dim
|
30 |
+
self.embedding = nn.Embedding(len(pinyin_dict['idx2char']), embedding_size)
|
31 |
+
self.conv = nn.Conv1d(in_channels=embedding_size, out_channels=self.pinyin_out_dim, kernel_size=2,
|
32 |
+
stride=1, padding=0)
|
33 |
+
|
34 |
+
def forward(self, pinyin_ids):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
pinyin_ids: (bs*sentence_length*pinyin_locs)
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
pinyin_embed: (bs,sentence_length,pinyin_out_dim)
|
41 |
+
"""
|
42 |
+
# input pinyin ids for 1-D conv
|
43 |
+
embed = self.embedding(pinyin_ids) # [bs,sentence_length,pinyin_locs,embed_size]
|
44 |
+
bs, sentence_length, pinyin_locs, embed_size = embed.shape
|
45 |
+
view_embed = embed.view(-1, pinyin_locs, embed_size) # [(bs*sentence_length),pinyin_locs,embed_size]
|
46 |
+
input_embed = view_embed.permute(0, 2, 1) # [(bs*sentence_length), embed_size, pinyin_locs]
|
47 |
+
# conv + max_pooling
|
48 |
+
pinyin_conv = self.conv(input_embed) # [(bs*sentence_length),pinyin_out_dim,H]
|
49 |
+
pinyin_embed = F.max_pool1d(pinyin_conv, pinyin_conv.shape[-1]) # [(bs*sentence_length),pinyin_out_dim,1]
|
50 |
+
return pinyin_embed.view(bs, sentence_length, self.pinyin_out_dim) # [bs,sentence_length,pinyin_out_dim]
|
models/tokenizer.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import tokenizers
|
6 |
+
import torch
|
7 |
+
from pypinyin import pinyin, Style
|
8 |
+
|
9 |
+
try:
|
10 |
+
from tokenizers import BertWordPieceTokenizer
|
11 |
+
except:
|
12 |
+
from tokenizers.implementations import BertWordPieceTokenizer
|
13 |
+
|
14 |
+
from transformers import BertTokenizerFast
|
15 |
+
|
16 |
+
|
17 |
+
class ChineseBertTokenizer(BertTokenizerFast):
|
18 |
+
|
19 |
+
def __init__(self, **kwargs):
|
20 |
+
super(ChineseBertTokenizer, self).__init__(**kwargs)
|
21 |
+
|
22 |
+
bert_path = self.name_or_path
|
23 |
+
vocab_file = os.path.join(bert_path, 'vocab.txt')
|
24 |
+
config_path = os.path.join(bert_path, 'config')
|
25 |
+
self.max_length = 512
|
26 |
+
self.tokenizer = BertWordPieceTokenizer(vocab_file)
|
27 |
+
|
28 |
+
# load pinyin map dict
|
29 |
+
with open(os.path.join(config_path, 'pinyin_map.json'), encoding='utf8') as fin:
|
30 |
+
self.pinyin_dict = json.load(fin)
|
31 |
+
# load char id map tensor
|
32 |
+
with open(os.path.join(config_path, 'id2pinyin.json'), encoding='utf8') as fin:
|
33 |
+
self.id2pinyin = json.load(fin)
|
34 |
+
# load pinyin map tensor
|
35 |
+
with open(os.path.join(config_path, 'pinyin2tensor.json'), encoding='utf8') as fin:
|
36 |
+
self.pinyin2tensor = json.load(fin)
|
37 |
+
|
38 |
+
def tokenize_sentence(self, sentence):
|
39 |
+
# convert sentence to ids
|
40 |
+
tokenizer_output = self.tokenizer.encode(sentence)
|
41 |
+
bert_tokens = tokenizer_output.ids
|
42 |
+
pinyin_tokens = self.convert_sentence_to_pinyin_ids(sentence, tokenizer_output)
|
43 |
+
# assert,token nums should be same as pinyin token nums
|
44 |
+
assert len(bert_tokens) <= self.max_length
|
45 |
+
assert len(bert_tokens) == len(pinyin_tokens)
|
46 |
+
# convert list to tensor
|
47 |
+
input_ids = torch.LongTensor(bert_tokens)
|
48 |
+
pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
|
49 |
+
return input_ids, pinyin_ids
|
50 |
+
|
51 |
+
def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]:
|
52 |
+
# get pinyin of a sentence
|
53 |
+
pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
|
54 |
+
pinyin_locs = {}
|
55 |
+
# get pinyin of each location
|
56 |
+
for index, item in enumerate(pinyin_list):
|
57 |
+
pinyin_string = item[0]
|
58 |
+
# not a Chinese character, pass
|
59 |
+
if pinyin_string == "not chinese":
|
60 |
+
continue
|
61 |
+
if pinyin_string in self.pinyin2tensor:
|
62 |
+
pinyin_locs[index] = self.pinyin2tensor[pinyin_string]
|
63 |
+
else:
|
64 |
+
ids = [0] * 8
|
65 |
+
for i, p in enumerate(pinyin_string):
|
66 |
+
if p not in self.pinyin_dict["char2idx"]:
|
67 |
+
ids = [0] * 8
|
68 |
+
break
|
69 |
+
ids[i] = self.pinyin_dict["char2idx"][p]
|
70 |
+
pinyin_locs[index] = ids
|
71 |
+
|
72 |
+
# find chinese character location, and generate pinyin ids
|
73 |
+
pinyin_ids = []
|
74 |
+
for idx, (token, offset) in enumerate(zip(tokenizer_output.tokens, tokenizer_output.offsets)):
|
75 |
+
if offset[1] - offset[0] != 1:
|
76 |
+
pinyin_ids.append([0] * 8)
|
77 |
+
continue
|
78 |
+
if offset[0] in pinyin_locs:
|
79 |
+
pinyin_ids.append(pinyin_locs[offset[0]])
|
80 |
+
else:
|
81 |
+
pinyin_ids.append([0] * 8)
|
82 |
+
|
83 |
+
return pinyin_ids
|