|
import json |
|
import os |
|
from typing import List |
|
|
|
import tokenizers |
|
import torch |
|
from pypinyin import pinyin, Style |
|
|
|
try: |
|
from tokenizers import BertWordPieceTokenizer |
|
except: |
|
from tokenizers.implementations import BertWordPieceTokenizer |
|
|
|
from transformers import BertTokenizerFast |
|
|
|
|
|
class ChineseBertTokenizer(BertTokenizerFast): |
|
|
|
def __init__(self, **kwargs): |
|
super(ChineseBertTokenizer, self).__init__(**kwargs) |
|
|
|
bert_path = self.name_or_path |
|
vocab_file = os.path.join(bert_path, 'vocab.txt') |
|
config_path = os.path.join(bert_path, 'config') |
|
self.max_length = 512 |
|
self.tokenizer = BertWordPieceTokenizer(vocab_file) |
|
|
|
|
|
with open(os.path.join(config_path, 'pinyin_map.json'), encoding='utf8') as fin: |
|
self.pinyin_dict = json.load(fin) |
|
|
|
with open(os.path.join(config_path, 'id2pinyin.json'), encoding='utf8') as fin: |
|
self.id2pinyin = json.load(fin) |
|
|
|
with open(os.path.join(config_path, 'pinyin2tensor.json'), encoding='utf8') as fin: |
|
self.pinyin2tensor = json.load(fin) |
|
|
|
def tokenize_sentence(self, sentence): |
|
|
|
tokenizer_output = self.tokenizer.encode(sentence) |
|
bert_tokens = tokenizer_output.ids |
|
pinyin_tokens = self.convert_sentence_to_pinyin_ids(sentence, tokenizer_output) |
|
|
|
assert len(bert_tokens) <= self.max_length |
|
assert len(bert_tokens) == len(pinyin_tokens) |
|
|
|
input_ids = torch.LongTensor(bert_tokens) |
|
pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1) |
|
return input_ids, pinyin_ids |
|
|
|
def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]: |
|
|
|
pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x]) |
|
pinyin_locs = {} |
|
|
|
for index, item in enumerate(pinyin_list): |
|
pinyin_string = item[0] |
|
|
|
if pinyin_string == "not chinese": |
|
continue |
|
if pinyin_string in self.pinyin2tensor: |
|
pinyin_locs[index] = self.pinyin2tensor[pinyin_string] |
|
else: |
|
ids = [0] * 8 |
|
for i, p in enumerate(pinyin_string): |
|
if p not in self.pinyin_dict["char2idx"]: |
|
ids = [0] * 8 |
|
break |
|
ids[i] = self.pinyin_dict["char2idx"][p] |
|
pinyin_locs[index] = ids |
|
|
|
|
|
pinyin_ids = [] |
|
for idx, (token, offset) in enumerate(zip(tokenizer_output.tokens, tokenizer_output.offsets)): |
|
if offset[1] - offset[0] != 1: |
|
pinyin_ids.append([0] * 8) |
|
continue |
|
if offset[0] in pinyin_locs: |
|
pinyin_ids.append(pinyin_locs[offset[0]]) |
|
else: |
|
pinyin_ids.append([0] * 8) |
|
|
|
return pinyin_ids |
|
|