|
import json |
|
import os |
|
import shutil |
|
import time |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from transformers import BertPreTrainedModel, BertModel |
|
from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutputWithPooling |
|
from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertLMPredictionHead |
|
|
|
cache_path = Path(os.path.abspath(__file__)).parent |
|
|
|
|
|
def download_file(filename: str, path: Path): |
|
if os.path.exists(cache_path / filename): |
|
return |
|
|
|
if os.path.exists(path / filename): |
|
shutil.copyfile(path / filename, cache_path / filename) |
|
return |
|
|
|
hf_hub_download( |
|
"iioSnail/ChineseBERT-for-csc", |
|
filename, |
|
local_dir=cache_path |
|
) |
|
time.sleep(0.2) |
|
|
|
|
|
class ChineseBertForCSC(BertPreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super(ChineseBertForCSC, self).__init__(config) |
|
self.model = Dynamic_GlyceBertForMultiTask(config) |
|
self.tokenizer = None |
|
|
|
def forward(self, **kwargs): |
|
return self.model(**kwargs) |
|
|
|
def set_tokenizer(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
|
|
def _predict(self, sentence): |
|
if self.tokenizer is None: |
|
return "Please init tokenizer by `set_tokenizer(tokenizer)` before predict." |
|
|
|
inputs = self.tokenizer([sentence], return_tensors='pt') |
|
output_hidden = self.model(**inputs).logits |
|
return self.tokenizer.convert_ids_to_tokens(output_hidden.argmax(-1)[0, 1:-1]) |
|
|
|
def predict(self, sentence, window=1): |
|
_src_tokens = list(sentence) |
|
src_tokens = list(sentence) |
|
pred_tokens = self._predict(sentence) |
|
|
|
for _ in range(window): |
|
record_index = [] |
|
for i, (a, b) in enumerate(zip(src_tokens, pred_tokens)): |
|
if a != b: |
|
record_index.append(i) |
|
|
|
src_tokens = pred_tokens |
|
pred_tokens = self._predict(''.join(pred_tokens)) |
|
for i, (a, b) in enumerate(zip(src_tokens, pred_tokens)): |
|
|
|
if a != b and any([abs(i - x) <= 1 for x in record_index]): |
|
pass |
|
else: |
|
pred_tokens[i] = src_tokens[i] |
|
|
|
return ''.join(pred_tokens) |
|
|
|
|
|
|
|
class Dynamic_GlyceBertForMultiTask(BertPreTrainedModel): |
|
def __init__(self, config): |
|
super(Dynamic_GlyceBertForMultiTask, self).__init__(config) |
|
|
|
self.bert = GlyceBertModel(config) |
|
self.cls = MultiTaskHeads(config) |
|
|
|
def get_output_embeddings(self): |
|
return self.cls.predictions.decoder |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
pinyin_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
**kwargs |
|
): |
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs_x = self.bert( |
|
input_ids, |
|
pinyin_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
encoded_x = outputs_x[0] |
|
|
|
prediction_scores = self.cls(encoded_x) |
|
|
|
return MaskedLMOutput( |
|
logits=prediction_scores, |
|
hidden_states=outputs_x.hidden_states, |
|
attentions=outputs_x.attentions, |
|
) |
|
|
|
|
|
class GlyceBertModel(BertModel): |
|
r""" |
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: |
|
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` |
|
Sequence of hidden-states at the output of the last layer of the models. |
|
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` |
|
Last layer hidden-state of the first token of the sequence (classification token) |
|
further processed by a Linear layer and a Tanh activation function. The Linear |
|
layer weights are trained from the next sentence prediction (classification) |
|
objective during Bert pretraining. This output is usually *not* a good summary |
|
of the semantic content of the input, you're often better with averaging or pooling |
|
the sequence of hidden-states for the whole input sequence. |
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) |
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) |
|
of shape ``(batch_size, sequence_length, hidden_size)``: |
|
Hidden-states of the models at the output of each layer plus the initial embedding outputs. |
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``) |
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: |
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. |
|
|
|
Examples:: |
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
models = BertModel.from_pretrained('bert-base-uncased') |
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 |
|
outputs = models(input_ids) |
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple |
|
|
|
""" |
|
|
|
def __init__(self, config): |
|
super(GlyceBertModel, self).__init__(config) |
|
self.config = config |
|
|
|
self.embeddings = FusionBertEmbeddings(config) |
|
self.encoder = BertEncoder(config) |
|
self.pooler = BertPooler(config) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
pinyin_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
r""" |
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention |
|
if the models is configured as a decoder. |
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask |
|
is used in the cross-attention if the models is configured as a decoder. |
|
Mask values selected in ``[0, 1]``: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(input_shape, device=device) |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
|
|
|
|
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) |
|
|
|
|
|
|
|
if self.config.is_decoder and encoder_hidden_states is not None: |
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
|
if encoder_attention_mask is None: |
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
|
else: |
|
encoder_extended_attention_mask = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
embedding_output = self.embeddings( |
|
input_ids=input_ids, pinyin_ids=pinyin_ids, position_ids=position_ids, token_type_ids=token_type_ids, |
|
inputs_embeds=inputs_embeds |
|
) |
|
encoder_outputs = self.encoder( |
|
embedding_output, |
|
attention_mask=extended_attention_mask, |
|
head_mask=head_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_extended_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = encoder_outputs[0] |
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
|
if not return_dict: |
|
return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=sequence_output, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
def forward_with_embedding( |
|
self, |
|
input_ids=None, |
|
pinyin_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
embedding=None |
|
): |
|
r""" |
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention |
|
if the models is configured as a decoder. |
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask |
|
is used in the cross-attention if the models is configured as a decoder. |
|
Mask values selected in ``[0, 1]``: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(input_shape, device=device) |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
|
|
|
|
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) |
|
|
|
|
|
|
|
if self.config.is_decoder and encoder_hidden_states is not None: |
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
|
if encoder_attention_mask is None: |
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
|
else: |
|
encoder_extended_attention_mask = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
assert embedding is not None |
|
embedding_output = embedding |
|
encoder_outputs = self.encoder( |
|
embedding_output, |
|
attention_mask=extended_attention_mask, |
|
head_mask=head_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_extended_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = encoder_outputs[0] |
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
|
if not return_dict: |
|
return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=sequence_output, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class MultiTaskHeads(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.predictions = BertLMPredictionHead(config) |
|
|
|
def forward(self, sequence_output): |
|
prediction_scores = self.predictions(sequence_output) |
|
return prediction_scores |
|
|
|
|
|
class FusionBertEmbeddings(nn.Module): |
|
""" |
|
Construct the embeddings from word, position, glyph, pinyin and token_type embeddings. |
|
""" |
|
|
|
def __init__(self, config): |
|
super(FusionBertEmbeddings, self).__init__() |
|
|
|
self.path = Path(config._name_or_path) |
|
config_path = cache_path / 'config' |
|
if not os.path.exists(config_path): |
|
os.makedirs(config_path) |
|
|
|
font_files = [] |
|
download_file("config/STFANGSO.TTF24.npy", self.path) |
|
download_file("config/STXINGKA.TTF24.npy", self.path) |
|
download_file("config/方正古隶繁体.ttf24.npy", self.path) |
|
for file in os.listdir(config_path): |
|
if file.endswith(".npy"): |
|
font_files.append(config_path / file) |
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) |
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
self.pinyin_embeddings = PinyinEmbedding(embedding_size=128, pinyin_out_dim=config.hidden_size, config=config) |
|
self.glyph_embeddings = GlyphEmbedding(font_npy_files=font_files) |
|
|
|
|
|
|
|
self.glyph_map = nn.Linear(1728, config.hidden_size) |
|
self.map_fc = nn.Linear(config.hidden_size * 3, config.hidden_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
|
|
|
def forward(self, input_ids=None, pinyin_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): |
|
if input_ids is not None: |
|
input_shape = input_ids.size() |
|
else: |
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
seq_length = input_shape[1] |
|
|
|
if position_ids is None: |
|
position_ids = self.position_ids[:, :seq_length] |
|
|
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
|
|
|
word_embeddings = inputs_embeds |
|
pinyin_embeddings = self.pinyin_embeddings(pinyin_ids) |
|
glyph_embeddings = self.glyph_map(self.glyph_embeddings(input_ids)) |
|
|
|
concat_embeddings = torch.cat((word_embeddings, pinyin_embeddings, glyph_embeddings), 2) |
|
inputs_embeds = self.map_fc(concat_embeddings) |
|
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
embeddings = inputs_embeds + position_embeddings + token_type_embeddings |
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class PinyinEmbedding(nn.Module): |
|
|
|
def __init__(self, embedding_size: int, pinyin_out_dim: int, config): |
|
""" |
|
Pinyin Embedding Module |
|
Args: |
|
embedding_size: the size of each embedding vector |
|
pinyin_out_dim: kernel number of conv |
|
""" |
|
super(PinyinEmbedding, self).__init__() |
|
download_file("config/pinyin_map.json", Path(config._name_or_path)) |
|
with open(cache_path / 'config' / 'pinyin_map.json') as fin: |
|
pinyin_dict = json.load(fin) |
|
self.pinyin_out_dim = pinyin_out_dim |
|
self.embedding = nn.Embedding(len(pinyin_dict['idx2char']), embedding_size) |
|
self.conv = nn.Conv1d(in_channels=embedding_size, out_channels=self.pinyin_out_dim, kernel_size=2, |
|
stride=1, padding=0) |
|
|
|
def forward(self, pinyin_ids): |
|
""" |
|
Args: |
|
pinyin_ids: (bs*sentence_length*pinyin_locs) |
|
|
|
Returns: |
|
pinyin_embed: (bs,sentence_length,pinyin_out_dim) |
|
""" |
|
|
|
embed = self.embedding(pinyin_ids) |
|
bs, sentence_length, pinyin_locs, embed_size = embed.shape |
|
view_embed = embed.view(-1, pinyin_locs, embed_size) |
|
input_embed = view_embed.permute(0, 2, 1) |
|
|
|
pinyin_conv = self.conv(input_embed) |
|
pinyin_embed = F.max_pool1d(pinyin_conv, pinyin_conv.shape[-1]) |
|
return pinyin_embed.view(bs, sentence_length, self.pinyin_out_dim) |
|
|
|
|
|
class GlyphEmbedding(nn.Module): |
|
"""Glyph2Image Embedding""" |
|
|
|
def __init__(self, font_npy_files: List[str]): |
|
super(GlyphEmbedding, self).__init__() |
|
font_arrays = [ |
|
np.load(np_file).astype(np.float32) for np_file in font_npy_files |
|
] |
|
self.vocab_size = font_arrays[0].shape[0] |
|
self.font_num = len(font_arrays) |
|
self.font_size = font_arrays[0].shape[-1] |
|
|
|
font_array = np.stack(font_arrays, axis=1) |
|
self.embedding = nn.Embedding( |
|
num_embeddings=self.vocab_size, |
|
embedding_dim=self.font_size ** 2 * self.font_num, |
|
_weight=torch.from_numpy(font_array.reshape([self.vocab_size, -1])) |
|
) |
|
|
|
def forward(self, input_ids): |
|
""" |
|
get glyph images for batch inputs |
|
Args: |
|
input_ids: [batch, sentence_length] |
|
Returns: |
|
images: [batch, sentence_length, self.font_num*self.font_size*self.font_size] |
|
""" |
|
|
|
return self.embedding(input_ids) |
|
|