iioSnail commited on
Commit
a2e10c0
·
1 Parent(s): 8d7c576

Upload 6 files

Browse files
models/classifier.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # author: xiaoya li
5
+ # first create: 2021.01.25
6
+ # files: nn_modules.py
7
+ #
8
+
9
+ import torch.nn as nn
10
+
11
+ class BertMLP(nn.Module):
12
+ def __init__(self, config,):
13
+ super().__init__()
14
+ self.dense_layer = nn.Linear(config.hidden_size, config.hidden_size)
15
+ self.dense_to_labels_layer = nn.Linear(config.hidden_size, config.num_labels)
16
+ self.activation = nn.Tanh()
17
+
18
+ def forward(self, sequence_hidden_states):
19
+ sequence_output = self.dense_layer(sequence_hidden_states)
20
+ sequence_output = self.activation(sequence_output)
21
+ sequence_output = self.dense_to_labels_layer(sequence_output)
22
+ return sequence_output
models/fusion_embedding.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @file : glyce_embedding.py
5
+ @author: zijun
6
+ @contact : [email protected]
7
+ @date : 2020/8/23 10:40
8
+ @version: 1.0
9
+ @desc : 【char embedding】+【pinyin embedding】+【glyph embedding】 = fusion embedding
10
+ """
11
+ import os
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ from models.glyph_embedding import GlyphEmbedding
17
+ from models.pinyin_embedding import PinyinEmbedding
18
+
19
+
20
+ class FusionBertEmbeddings(nn.Module):
21
+ """
22
+ Construct the embeddings from word, position, glyph, pinyin and token_type embeddings.
23
+ """
24
+
25
+ def __init__(self, config):
26
+ super(FusionBertEmbeddings, self).__init__()
27
+ config_path = os.path.join(config.name_or_path, 'config')
28
+ font_files = []
29
+ for file in os.listdir(config_path):
30
+ if file.endswith(".npy"):
31
+ font_files.append(os.path.join(config_path, file))
32
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
33
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
34
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
35
+ self.pinyin_embeddings = PinyinEmbedding(embedding_size=128, pinyin_out_dim=config.hidden_size,
36
+ config_path=config_path)
37
+ self.glyph_embeddings = GlyphEmbedding(font_npy_files=font_files)
38
+
39
+ # self.LayerNorm is not snake-cased to stick with TensorFlow models variable name and be able to load
40
+ # any TensorFlow checkpoint file
41
+ self.glyph_map = nn.Linear(1728, config.hidden_size)
42
+ self.map_fc = nn.Linear(config.hidden_size * 3, config.hidden_size)
43
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
44
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
45
+
46
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
47
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
48
+
49
+ def forward(self, input_ids=None, pinyin_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
50
+ if input_ids is not None:
51
+ input_shape = input_ids.size()
52
+ else:
53
+ input_shape = inputs_embeds.size()[:-1]
54
+
55
+ seq_length = input_shape[1]
56
+
57
+ if position_ids is None:
58
+ position_ids = self.position_ids[:, :seq_length]
59
+
60
+ if token_type_ids is None:
61
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
62
+
63
+ if inputs_embeds is None:
64
+ inputs_embeds = self.word_embeddings(input_ids)
65
+
66
+ # get char embedding, pinyin embedding and glyph embedding
67
+ word_embeddings = inputs_embeds # [bs,l,hidden_size]
68
+ pinyin_embeddings = self.pinyin_embeddings(pinyin_ids) # [bs,l,hidden_size]
69
+ glyph_embeddings = self.glyph_map(self.glyph_embeddings(input_ids)) # [bs,l,hidden_size]
70
+ # fusion layer
71
+ concat_embeddings = torch.cat((word_embeddings, pinyin_embeddings, glyph_embeddings), 2)
72
+ inputs_embeds = self.map_fc(concat_embeddings)
73
+
74
+ position_embeddings = self.position_embeddings(position_ids)
75
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
76
+
77
+ embeddings = inputs_embeds + position_embeddings + token_type_embeddings
78
+ embeddings = self.LayerNorm(embeddings)
79
+ embeddings = self.dropout(embeddings)
80
+ return embeddings
models/glyph_embedding.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding: utf-8
2
+ """
3
+ @author: Yuxian Meng
4
+ @contact: [email protected]
5
+
6
+ @version: 1.0
7
+ @file: glyph_embedding
8
+ @time: 2020/8/4 15:04
9
+
10
+ """
11
+
12
+ from typing import List
13
+
14
+ import numpy as np
15
+ import torch
16
+ from torch import nn
17
+
18
+
19
+ class GlyphEmbedding(nn.Module):
20
+ """Glyph2Image Embedding"""
21
+
22
+ def __init__(self, font_npy_files: List[str]):
23
+ super(GlyphEmbedding, self).__init__()
24
+ font_arrays = [
25
+ np.load(np_file).astype(np.float32) for np_file in font_npy_files
26
+ ]
27
+ self.vocab_size = font_arrays[0].shape[0]
28
+ self.font_num = len(font_arrays)
29
+ self.font_size = font_arrays[0].shape[-1]
30
+ # N, C, H, W
31
+ font_array = np.stack(font_arrays, axis=1)
32
+ self.embedding = nn.Embedding(
33
+ num_embeddings=self.vocab_size,
34
+ embedding_dim=self.font_size ** 2 * self.font_num,
35
+ _weight=torch.from_numpy(font_array.reshape([self.vocab_size, -1]))
36
+ )
37
+
38
+ def forward(self, input_ids):
39
+ """
40
+ get glyph images for batch inputs
41
+ Args:
42
+ input_ids: [batch, sentence_length]
43
+ Returns:
44
+ images: [batch, sentence_length, self.font_num*self.font_size*self.font_size]
45
+ """
46
+ # return self.embedding(input_ids).view([-1, self.font_num, self.font_size, self.font_size])
47
+ return self.embedding(input_ids)
models/modeling_glycebert.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @file : modeling_glycebert.py
5
+ @author: zijun
6
+ @contact : [email protected]
7
+ @date : 2020/9/6 18:50
8
+ @version: 1.0
9
+ @desc : ChineseBert Model
10
+ """
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import CrossEntropyLoss, MSELoss
16
+ try:
17
+ from transformers.modeling_bert import BertEncoder, BertPooler, BertOnlyMLMHead, BertPreTrainedModel, BertModel
18
+ except:
19
+ from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertOnlyMLMHead, BertPreTrainedModel, BertModel
20
+
21
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput, \
22
+ QuestionAnsweringModelOutput, TokenClassifierOutput
23
+
24
+ from models.fusion_embedding import FusionBertEmbeddings
25
+ from models.classifier import BertMLP
26
+
27
+ class GlyceBertModel(BertModel):
28
+ r"""
29
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
30
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
31
+ Sequence of hidden-states at the output of the last layer of the models.
32
+ **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
33
+ Last layer hidden-state of the first token of the sequence (classification token)
34
+ further processed by a Linear layer and a Tanh activation function. The Linear
35
+ layer weights are trained from the next sentence prediction (classification)
36
+ objective during Bert pretraining. This output is usually *not* a good summary
37
+ of the semantic content of the input, you're often better with averaging or pooling
38
+ the sequence of hidden-states for the whole input sequence.
39
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
40
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
41
+ of shape ``(batch_size, sequence_length, hidden_size)``:
42
+ Hidden-states of the models at the output of each layer plus the initial embedding outputs.
43
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
44
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
45
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
46
+
47
+ Examples::
48
+
49
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
50
+ models = BertModel.from_pretrained('bert-base-uncased')
51
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
52
+ outputs = models(input_ids)
53
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
54
+
55
+ """
56
+
57
+ def __init__(self, config):
58
+ super(GlyceBertModel, self).__init__(config)
59
+ self.config = config
60
+
61
+ self.embeddings = FusionBertEmbeddings(config)
62
+ self.encoder = BertEncoder(config)
63
+ self.pooler = BertPooler(config)
64
+
65
+ self.init_weights()
66
+
67
+ def forward(
68
+ self,
69
+ input_ids=None,
70
+ pinyin_ids=None,
71
+ attention_mask=None,
72
+ token_type_ids=None,
73
+ position_ids=None,
74
+ head_mask=None,
75
+ inputs_embeds=None,
76
+ encoder_hidden_states=None,
77
+ encoder_attention_mask=None,
78
+ output_attentions=None,
79
+ output_hidden_states=None,
80
+ return_dict=None,
81
+ ):
82
+ r"""
83
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
84
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
85
+ if the models is configured as a decoder.
86
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
87
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
88
+ is used in the cross-attention if the models is configured as a decoder.
89
+ Mask values selected in ``[0, 1]``:
90
+
91
+ - 1 for tokens that are **not masked**,
92
+ - 0 for tokens that are **masked**.
93
+ """
94
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
95
+ output_hidden_states = (
96
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
97
+ )
98
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
99
+
100
+ if input_ids is not None and inputs_embeds is not None:
101
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
102
+ elif input_ids is not None:
103
+ input_shape = input_ids.size()
104
+ elif inputs_embeds is not None:
105
+ input_shape = inputs_embeds.size()[:-1]
106
+ else:
107
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
108
+
109
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
110
+
111
+ if attention_mask is None:
112
+ attention_mask = torch.ones(input_shape, device=device)
113
+ if token_type_ids is None:
114
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
115
+
116
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
117
+ # ourselves in which case we just need to make it broadcastable to all heads.
118
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
119
+
120
+ # If a 2D or 3D attention mask is provided for the cross-attention
121
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
122
+ if self.config.is_decoder and encoder_hidden_states is not None:
123
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
124
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
125
+ if encoder_attention_mask is None:
126
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
127
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
128
+ else:
129
+ encoder_extended_attention_mask = None
130
+
131
+ # Prepare head mask if needed
132
+ # 1.0 in head_mask indicate we keep the head
133
+ # attention_probs has shape bsz x n_heads x N x N
134
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
135
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
136
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
137
+
138
+ embedding_output = self.embeddings(
139
+ input_ids=input_ids, pinyin_ids=pinyin_ids, position_ids=position_ids, token_type_ids=token_type_ids,
140
+ inputs_embeds=inputs_embeds
141
+ )
142
+ encoder_outputs = self.encoder(
143
+ embedding_output,
144
+ attention_mask=extended_attention_mask,
145
+ head_mask=head_mask,
146
+ encoder_hidden_states=encoder_hidden_states,
147
+ encoder_attention_mask=encoder_extended_attention_mask,
148
+ output_attentions=output_attentions,
149
+ output_hidden_states=output_hidden_states,
150
+ return_dict=return_dict,
151
+ )
152
+ sequence_output = encoder_outputs[0]
153
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
154
+
155
+ if not return_dict:
156
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
157
+
158
+ return BaseModelOutputWithPooling(
159
+ last_hidden_state=sequence_output,
160
+ pooler_output=pooled_output,
161
+ hidden_states=encoder_outputs.hidden_states,
162
+ attentions=encoder_outputs.attentions,
163
+ )
164
+
165
+
166
+ class GlyceBertForMaskedLM(BertPreTrainedModel):
167
+ def __init__(self, config):
168
+ super(GlyceBertForMaskedLM, self).__init__(config)
169
+
170
+ self.bert = GlyceBertModel(config)
171
+ self.cls = BertOnlyMLMHead(config)
172
+
173
+ self.init_weights()
174
+
175
+ def get_output_embeddings(self):
176
+ return self.cls.predictions.decoder
177
+
178
+ def forward(
179
+ self,
180
+ input_ids=None,
181
+ pinyin_ids=None,
182
+ attention_mask=None,
183
+ token_type_ids=None,
184
+ position_ids=None,
185
+ head_mask=None,
186
+ inputs_embeds=None,
187
+ encoder_hidden_states=None,
188
+ encoder_attention_mask=None,
189
+ labels=None,
190
+ output_attentions=None,
191
+ output_hidden_states=None,
192
+ return_dict=None,
193
+ **kwargs
194
+ ):
195
+ r"""
196
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
197
+ Labels for computing the masked language modeling loss.
198
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
199
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
200
+ in ``[0, ..., config.vocab_size]``
201
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
202
+ Used to hide legacy arguments that have been deprecated.
203
+ """
204
+ if "masked_lm_labels" in kwargs:
205
+ warnings.warn(
206
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
207
+ FutureWarning,
208
+ )
209
+ labels = kwargs.pop("masked_lm_labels")
210
+ assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
211
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
212
+
213
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
214
+
215
+ outputs = self.bert(
216
+ input_ids,
217
+ pinyin_ids,
218
+ attention_mask=attention_mask,
219
+ token_type_ids=token_type_ids,
220
+ position_ids=position_ids,
221
+ head_mask=head_mask,
222
+ inputs_embeds=inputs_embeds,
223
+ encoder_hidden_states=encoder_hidden_states,
224
+ encoder_attention_mask=encoder_attention_mask,
225
+ output_attentions=output_attentions,
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=return_dict,
228
+ )
229
+
230
+ sequence_output = outputs[0]
231
+ prediction_scores = self.cls(sequence_output)
232
+
233
+ masked_lm_loss = None
234
+ if labels is not None:
235
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
236
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
237
+
238
+ if not return_dict:
239
+ output = (prediction_scores,) + outputs[2:]
240
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
241
+
242
+ return MaskedLMOutput(
243
+ loss=masked_lm_loss,
244
+ logits=prediction_scores,
245
+ hidden_states=outputs.hidden_states,
246
+ attentions=outputs.attentions,
247
+ )
248
+
249
+
250
+ class GlyceBertForSequenceClassification(BertPreTrainedModel):
251
+ def __init__(self, config):
252
+ super().__init__(config)
253
+ self.num_labels = config.num_labels
254
+
255
+ self.bert = GlyceBertModel(config)
256
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
257
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
258
+
259
+ self.init_weights()
260
+
261
+ def forward(
262
+ self,
263
+ input_ids=None,
264
+ pinyin_ids=None,
265
+ attention_mask=None,
266
+ token_type_ids=None,
267
+ position_ids=None,
268
+ head_mask=None,
269
+ inputs_embeds=None,
270
+ labels=None,
271
+ output_attentions=None,
272
+ output_hidden_states=None,
273
+ return_dict=None,
274
+ ):
275
+ r"""
276
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
277
+ Labels for computing the sequence classification/regression loss.
278
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
279
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
280
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
281
+ """
282
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
283
+
284
+ outputs = self.bert(
285
+ input_ids,
286
+ pinyin_ids,
287
+ attention_mask=attention_mask,
288
+ token_type_ids=token_type_ids,
289
+ position_ids=position_ids,
290
+ head_mask=head_mask,
291
+ inputs_embeds=inputs_embeds,
292
+ output_attentions=output_attentions,
293
+ output_hidden_states=output_hidden_states,
294
+ return_dict=return_dict,
295
+ )
296
+
297
+ pooled_output = outputs[1]
298
+
299
+ pooled_output = self.dropout(pooled_output)
300
+ logits = self.classifier(pooled_output)
301
+
302
+ loss = None
303
+ if labels is not None:
304
+ if self.num_labels == 1:
305
+ # We are doing regression
306
+ loss_fct = MSELoss()
307
+ loss = loss_fct(logits.view(-1), labels.view(-1))
308
+ else:
309
+ loss_fct = CrossEntropyLoss()
310
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
311
+
312
+ if not return_dict:
313
+ output = (logits,) + outputs[2:]
314
+ return ((loss,) + output) if loss is not None else output
315
+
316
+ return SequenceClassifierOutput(
317
+ loss=loss,
318
+ logits=logits,
319
+ hidden_states=outputs.hidden_states,
320
+ attentions=outputs.attentions,
321
+ )
322
+
323
+
324
+ class GlyceBertForQuestionAnswering(BertPreTrainedModel):
325
+ """BERT model for Question Answering (span extraction).
326
+ This module is composed of the BERT model with a linear layer on top of
327
+ the sequence output that computes start_logits and end_logits
328
+
329
+ Params:
330
+ `config`: a BertConfig class instance with the configuration to build a new model.
331
+
332
+ Inputs:
333
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
334
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
335
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
336
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
337
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
338
+ a `sentence B` token (see BERT paper for more details).
339
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
340
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
341
+ input sequence length in the current batch. It's the mask that we typically use for attention when
342
+ a batch has varying length sentences.
343
+ `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
344
+ Positions are clamped to the length of the sequence and position outside of the sequence are not taken
345
+ into account for computing the loss.
346
+ `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
347
+ Positions are clamped to the length of the sequence and position outside of the sequence are not taken
348
+ into account for computing the loss.
349
+
350
+ Outputs:
351
+ if `start_positions` and `end_positions` are not `None`:
352
+ Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
353
+ if `start_positions` or `end_positions` is `None`:
354
+ Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
355
+ position tokens of shape [batch_size, sequence_length].
356
+
357
+ Example usage:
358
+ ```python
359
+ # Already been converted into WordPiece token ids
360
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
361
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
362
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
363
+
364
+ config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
365
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
366
+
367
+ model = BertForQuestionAnswering(config)
368
+ start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
369
+ ```
370
+ """
371
+
372
+ def __init__(self, config):
373
+ super().__init__(config)
374
+ self.num_labels = config.num_labels
375
+
376
+ self.bert = GlyceBertModel(config)
377
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
378
+
379
+ self.init_weights()
380
+
381
+ def forward(
382
+ self,
383
+ input_ids=None,
384
+ pinyin_ids=None,
385
+ attention_mask=None,
386
+ token_type_ids=None,
387
+ position_ids=None,
388
+ head_mask=None,
389
+ inputs_embeds=None,
390
+ start_positions=None,
391
+ end_positions=None,
392
+ output_attentions=None,
393
+ output_hidden_states=None,
394
+ return_dict=None,
395
+ ):
396
+ r"""
397
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
398
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
399
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`).
400
+ Position outside of the sequence are not taken into account for computing the loss.
401
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
402
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
403
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`).
404
+ Position outside of the sequence are not taken into account for computing the loss.
405
+ """
406
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
407
+
408
+ outputs = self.bert(
409
+ input_ids,
410
+ pinyin_ids,
411
+ attention_mask=attention_mask,
412
+ token_type_ids=token_type_ids,
413
+ position_ids=position_ids,
414
+ head_mask=head_mask,
415
+ inputs_embeds=inputs_embeds,
416
+ output_attentions=output_attentions,
417
+ output_hidden_states=output_hidden_states,
418
+ return_dict=return_dict,
419
+ )
420
+
421
+ sequence_output = outputs[0]
422
+
423
+ logits = self.qa_outputs(sequence_output)
424
+ start_logits, end_logits = logits.split(1, dim=-1)
425
+ start_logits = start_logits.squeeze(-1)
426
+ end_logits = end_logits.squeeze(-1)
427
+
428
+ total_loss = None
429
+ if start_positions is not None and end_positions is not None:
430
+ # If we are on multi-GPU, split add a dimension
431
+ if len(start_positions.size()) > 1:
432
+ start_positions = start_positions.squeeze(-1)
433
+ if len(end_positions.size()) > 1:
434
+ end_positions = end_positions.squeeze(-1)
435
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
436
+ ignored_index = start_logits.size(1)
437
+ start_positions.clamp_(0, ignored_index)
438
+ end_positions.clamp_(0, ignored_index)
439
+
440
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
441
+ start_loss = loss_fct(start_logits, start_positions)
442
+ end_loss = loss_fct(end_logits, end_positions)
443
+ total_loss = (start_loss + end_loss) / 2
444
+
445
+ if not return_dict:
446
+ output = (start_logits, end_logits) + outputs[2:]
447
+ return ((total_loss,) + output) if total_loss is not None else output
448
+
449
+ return QuestionAnsweringModelOutput(
450
+ loss=total_loss,
451
+ start_logits=start_logits,
452
+ end_logits=end_logits,
453
+ hidden_states=outputs.hidden_states,
454
+ attentions=outputs.attentions,
455
+ )
456
+
457
+ class GlyceBertForTokenClassification(BertPreTrainedModel):
458
+ def __init__(self, config, mlp=False):
459
+ super().__init__(config)
460
+ self.num_labels = config.num_labels
461
+
462
+ self.bert = GlyceBertModel(config)
463
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
464
+ if mlp:
465
+ self.classifier = BertMLP(config)
466
+ else:
467
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
468
+
469
+ self.init_weights()
470
+
471
+ def forward(self,
472
+ input_ids=None,
473
+ pinyin_ids=None,
474
+ attention_mask=None,
475
+ token_type_ids=None,
476
+ position_ids=None,
477
+ head_mask=None,
478
+ inputs_embeds=None,
479
+ labels=None,
480
+ output_attentions=None,
481
+ output_hidden_states=None,
482
+ return_dict=None,
483
+ ):
484
+ r"""
485
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
486
+ Labels for computing the token classification loss.
487
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
488
+ """
489
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
490
+
491
+ outputs = self.bert(
492
+ input_ids,
493
+ pinyin_ids,
494
+ attention_mask=attention_mask,
495
+ token_type_ids=token_type_ids,
496
+ position_ids=position_ids,
497
+ head_mask=head_mask,
498
+ inputs_embeds=inputs_embeds,
499
+ output_attentions=output_attentions,
500
+ output_hidden_states=output_hidden_states,
501
+ return_dict=return_dict,
502
+ )
503
+
504
+ sequence_output = outputs[0]
505
+
506
+ sequence_output = self.dropout(sequence_output)
507
+ logits = self.classifier(sequence_output)
508
+
509
+ loss = None
510
+ if labels is not None:
511
+ loss_fct = CrossEntropyLoss()
512
+ # Only keep the active parts of the loss
513
+ if attention_mask is not None:
514
+ active_loss = attention_mask.view(-1) == 1
515
+ active_logits = logits.view(-1, self.num_labels)
516
+ active_labels = torch.where(
517
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
518
+ )
519
+ loss = loss_fct(active_logits, active_labels)
520
+ else:
521
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
522
+
523
+ if not return_dict:
524
+ output = (logits,) + outputs[2:]
525
+ return ((loss,) + output) if loss is not None else output
526
+
527
+ return TokenClassifierOutput(
528
+ loss=loss,
529
+ logits=logits,
530
+ hidden_states=outputs.hidden_states,
531
+ attentions=outputs.attentions,
532
+ )
models/pinyin_embedding.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @file : pinyin.py
5
+ @author: zijun
6
+ @contact : [email protected]
7
+ @date : 2020/8/16 14:45
8
+ @version: 1.0
9
+ @desc : pinyin embedding
10
+ """
11
+ import json
12
+ import os
13
+
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+
17
+
18
+ class PinyinEmbedding(nn.Module):
19
+ def __init__(self, embedding_size: int, pinyin_out_dim: int, config_path):
20
+ """
21
+ Pinyin Embedding Module
22
+ Args:
23
+ embedding_size: the size of each embedding vector
24
+ pinyin_out_dim: kernel number of conv
25
+ """
26
+ super(PinyinEmbedding, self).__init__()
27
+ with open(os.path.join(config_path, 'pinyin_map.json')) as fin:
28
+ pinyin_dict = json.load(fin)
29
+ self.pinyin_out_dim = pinyin_out_dim
30
+ self.embedding = nn.Embedding(len(pinyin_dict['idx2char']), embedding_size)
31
+ self.conv = nn.Conv1d(in_channels=embedding_size, out_channels=self.pinyin_out_dim, kernel_size=2,
32
+ stride=1, padding=0)
33
+
34
+ def forward(self, pinyin_ids):
35
+ """
36
+ Args:
37
+ pinyin_ids: (bs*sentence_length*pinyin_locs)
38
+
39
+ Returns:
40
+ pinyin_embed: (bs,sentence_length,pinyin_out_dim)
41
+ """
42
+ # input pinyin ids for 1-D conv
43
+ embed = self.embedding(pinyin_ids) # [bs,sentence_length,pinyin_locs,embed_size]
44
+ bs, sentence_length, pinyin_locs, embed_size = embed.shape
45
+ view_embed = embed.view(-1, pinyin_locs, embed_size) # [(bs*sentence_length),pinyin_locs,embed_size]
46
+ input_embed = view_embed.permute(0, 2, 1) # [(bs*sentence_length), embed_size, pinyin_locs]
47
+ # conv + max_pooling
48
+ pinyin_conv = self.conv(input_embed) # [(bs*sentence_length),pinyin_out_dim,H]
49
+ pinyin_embed = F.max_pool1d(pinyin_conv, pinyin_conv.shape[-1]) # [(bs*sentence_length),pinyin_out_dim,1]
50
+ return pinyin_embed.view(bs, sentence_length, self.pinyin_out_dim) # [bs,sentence_length,pinyin_out_dim]
models/tokenizer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import List
4
+
5
+ import tokenizers
6
+ import torch
7
+ from pypinyin import pinyin, Style
8
+
9
+ try:
10
+ from tokenizers import BertWordPieceTokenizer
11
+ except:
12
+ from tokenizers.implementations import BertWordPieceTokenizer
13
+
14
+ from transformers import BertTokenizerFast
15
+
16
+
17
+ class ChineseBertTokenizer(BertTokenizerFast):
18
+
19
+ def __init__(self, **kwargs):
20
+ super(ChineseBertTokenizer, self).__init__(**kwargs)
21
+
22
+ bert_path = self.name_or_path
23
+ vocab_file = os.path.join(bert_path, 'vocab.txt')
24
+ config_path = os.path.join(bert_path, 'config')
25
+ self.max_length = 512
26
+ self.tokenizer = BertWordPieceTokenizer(vocab_file)
27
+
28
+ # load pinyin map dict
29
+ with open(os.path.join(config_path, 'pinyin_map.json'), encoding='utf8') as fin:
30
+ self.pinyin_dict = json.load(fin)
31
+ # load char id map tensor
32
+ with open(os.path.join(config_path, 'id2pinyin.json'), encoding='utf8') as fin:
33
+ self.id2pinyin = json.load(fin)
34
+ # load pinyin map tensor
35
+ with open(os.path.join(config_path, 'pinyin2tensor.json'), encoding='utf8') as fin:
36
+ self.pinyin2tensor = json.load(fin)
37
+
38
+ def tokenize_sentence(self, sentence):
39
+ # convert sentence to ids
40
+ tokenizer_output = self.tokenizer.encode(sentence)
41
+ bert_tokens = tokenizer_output.ids
42
+ pinyin_tokens = self.convert_sentence_to_pinyin_ids(sentence, tokenizer_output)
43
+ # assert,token nums should be same as pinyin token nums
44
+ assert len(bert_tokens) <= self.max_length
45
+ assert len(bert_tokens) == len(pinyin_tokens)
46
+ # convert list to tensor
47
+ input_ids = torch.LongTensor(bert_tokens)
48
+ pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
49
+ return input_ids, pinyin_ids
50
+
51
+ def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]:
52
+ # get pinyin of a sentence
53
+ pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
54
+ pinyin_locs = {}
55
+ # get pinyin of each location
56
+ for index, item in enumerate(pinyin_list):
57
+ pinyin_string = item[0]
58
+ # not a Chinese character, pass
59
+ if pinyin_string == "not chinese":
60
+ continue
61
+ if pinyin_string in self.pinyin2tensor:
62
+ pinyin_locs[index] = self.pinyin2tensor[pinyin_string]
63
+ else:
64
+ ids = [0] * 8
65
+ for i, p in enumerate(pinyin_string):
66
+ if p not in self.pinyin_dict["char2idx"]:
67
+ ids = [0] * 8
68
+ break
69
+ ids[i] = self.pinyin_dict["char2idx"][p]
70
+ pinyin_locs[index] = ids
71
+
72
+ # find chinese character location, and generate pinyin ids
73
+ pinyin_ids = []
74
+ for idx, (token, offset) in enumerate(zip(tokenizer_output.tokens, tokenizer_output.offsets)):
75
+ if offset[1] - offset[0] != 1:
76
+ pinyin_ids.append([0] * 8)
77
+ continue
78
+ if offset[0] in pinyin_locs:
79
+ pinyin_ids.append(pinyin_locs[offset[0]])
80
+ else:
81
+ pinyin_ids.append([0] * 8)
82
+
83
+ return pinyin_ids