|
import json |
|
import os |
|
import time |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import tokenizers |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.file_download import http_user_agent |
|
from pypinyin import pinyin, Style |
|
|
|
try: |
|
from tokenizers import BertWordPieceTokenizer |
|
except: |
|
from tokenizers.implementations import BertWordPieceTokenizer |
|
|
|
from transformers import BertTokenizerFast |
|
|
|
cache_path = Path(os.path.abspath(__file__)).parent |
|
|
|
SOURCE_FILES_URL = { |
|
"vocab.txt": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/vocab.txt", |
|
"pinyin_map.json": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/config/pinyin_map.json", |
|
"id2pinyin.json": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/config/id2pinyin.json", |
|
"pinyin2tensor.json": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/config/id2pinyin.json", |
|
} |
|
|
|
|
|
def download_file(filename: str): |
|
if os.path.exists(cache_path / filename): |
|
return |
|
|
|
hf_hub_download( |
|
"iioSnail/chinesebert-base", |
|
filename, |
|
cache_dir=cache_path, |
|
user_agent=http_user_agent(None), |
|
) |
|
time.sleep(0.2) |
|
|
|
|
|
class ChineseBertTokenizer(BertTokenizerFast): |
|
|
|
def __init__(self, **kwargs): |
|
super(ChineseBertTokenizer, self).__init__(**kwargs) |
|
|
|
vocab_file = os.path.join(cache_path, 'vocab.txt') |
|
config_path = os.path.join(cache_path, 'config') |
|
self.max_length = 512 |
|
|
|
download_file('vocab.txt') |
|
self.tokenizer = BertWordPieceTokenizer(vocab_file) |
|
|
|
|
|
download_file('config/pinyin_map.json') |
|
with open(os.path.join(config_path, 'pinyin_map.json'), encoding='utf8') as fin: |
|
self.pinyin_dict = json.load(fin) |
|
|
|
|
|
download_file('config/id2pinyin.json') |
|
with open(os.path.join(config_path, 'id2pinyin.json'), encoding='utf8') as fin: |
|
self.id2pinyin = json.load(fin) |
|
|
|
|
|
download_file('config/pinyin2tensor.json') |
|
with open(os.path.join(config_path, 'pinyin2tensor.json'), encoding='utf8') as fin: |
|
self.pinyin2tensor = json.load(fin) |
|
|
|
def tokenize_sentence(self, sentence): |
|
|
|
tokenizer_output = self.tokenizer.encode(sentence) |
|
bert_tokens = tokenizer_output.ids |
|
pinyin_tokens = self.convert_sentence_to_pinyin_ids(sentence, tokenizer_output) |
|
|
|
assert len(bert_tokens) <= self.max_length |
|
assert len(bert_tokens) == len(pinyin_tokens) |
|
|
|
input_ids = torch.LongTensor(bert_tokens) |
|
pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1) |
|
return input_ids, pinyin_ids |
|
|
|
def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]: |
|
|
|
pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x]) |
|
pinyin_locs = {} |
|
|
|
for index, item in enumerate(pinyin_list): |
|
pinyin_string = item[0] |
|
|
|
if pinyin_string == "not chinese": |
|
continue |
|
if pinyin_string in self.pinyin2tensor: |
|
pinyin_locs[index] = self.pinyin2tensor[pinyin_string] |
|
else: |
|
ids = [0] * 8 |
|
for i, p in enumerate(pinyin_string): |
|
if p not in self.pinyin_dict["char2idx"]: |
|
ids = [0] * 8 |
|
break |
|
ids[i] = self.pinyin_dict["char2idx"][p] |
|
pinyin_locs[index] = ids |
|
|
|
|
|
pinyin_ids = [] |
|
for idx, (token, offset) in enumerate(zip(tokenizer_output.tokens, tokenizer_output.offsets)): |
|
if offset[1] - offset[0] != 1: |
|
pinyin_ids.append([0] * 8) |
|
continue |
|
if offset[0] in pinyin_locs: |
|
pinyin_ids.append(pinyin_locs[offset[0]]) |
|
else: |
|
pinyin_ids.append([0] * 8) |
|
|
|
return pinyin_ids |
|
|