File size: 4,355 Bytes
601e637 a30567c 3217624 601e637 a30567c 601e637 3217624 de3a7d8 f72ad04 de3a7d8 a30567c de3a7d8 a30567c de3a7d8 601e637 3217624 601e637 de3a7d8 a30567c 601e637 a30567c 601e637 de3a7d8 601e637 a30567c 601e637 de3a7d8 601e637 a30567c 601e637 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import json
import os
import time
from pathlib import Path
from typing import List
import tokenizers
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.file_download import http_user_agent
from pypinyin import pinyin, Style
try:
from tokenizers import BertWordPieceTokenizer
except:
from tokenizers.implementations import BertWordPieceTokenizer
from transformers import BertTokenizerFast
cache_path = Path(os.path.abspath(__file__)).parent
SOURCE_FILES_URL = {
"vocab.txt": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/vocab.txt",
"pinyin_map.json": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/config/pinyin_map.json",
"id2pinyin.json": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/config/id2pinyin.json",
"pinyin2tensor.json": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/config/id2pinyin.json",
}
def download_file(filename: str):
if os.path.exists(cache_path / filename):
return
hf_hub_download(
"iioSnail/chinesebert-base",
filename,
cache_dir=cache_path,
user_agent=http_user_agent(None),
)
time.sleep(0.2)
class ChineseBertTokenizer(BertTokenizerFast):
def __init__(self, **kwargs):
super(ChineseBertTokenizer, self).__init__(**kwargs)
vocab_file = os.path.join(cache_path, 'vocab.txt')
config_path = os.path.join(cache_path, 'config')
self.max_length = 512
download_file('vocab.txt')
self.tokenizer = BertWordPieceTokenizer(vocab_file)
# load pinyin map dict
download_file('config/pinyin_map.json')
with open(os.path.join(config_path, 'pinyin_map.json'), encoding='utf8') as fin:
self.pinyin_dict = json.load(fin)
# load char id map tensor
download_file('config/id2pinyin.json')
with open(os.path.join(config_path, 'id2pinyin.json'), encoding='utf8') as fin:
self.id2pinyin = json.load(fin)
# load pinyin map tensor
download_file('config/pinyin2tensor.json')
with open(os.path.join(config_path, 'pinyin2tensor.json'), encoding='utf8') as fin:
self.pinyin2tensor = json.load(fin)
def tokenize_sentence(self, sentence):
# convert sentence to ids
tokenizer_output = self.tokenizer.encode(sentence)
bert_tokens = tokenizer_output.ids
pinyin_tokens = self.convert_sentence_to_pinyin_ids(sentence, tokenizer_output)
# assert,token nums should be same as pinyin token nums
assert len(bert_tokens) <= self.max_length
assert len(bert_tokens) == len(pinyin_tokens)
# convert list to tensor
input_ids = torch.LongTensor(bert_tokens)
pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
return input_ids, pinyin_ids
def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]:
# get pinyin of a sentence
pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
pinyin_locs = {}
# get pinyin of each location
for index, item in enumerate(pinyin_list):
pinyin_string = item[0]
# not a Chinese character, pass
if pinyin_string == "not chinese":
continue
if pinyin_string in self.pinyin2tensor:
pinyin_locs[index] = self.pinyin2tensor[pinyin_string]
else:
ids = [0] * 8
for i, p in enumerate(pinyin_string):
if p not in self.pinyin_dict["char2idx"]:
ids = [0] * 8
break
ids[i] = self.pinyin_dict["char2idx"][p]
pinyin_locs[index] = ids
# find chinese character location, and generate pinyin ids
pinyin_ids = []
for idx, (token, offset) in enumerate(zip(tokenizer_output.tokens, tokenizer_output.offsets)):
if offset[1] - offset[0] != 1:
pinyin_ids.append([0] * 8)
continue
if offset[0] in pinyin_locs:
pinyin_ids.append(pinyin_locs[offset[0]])
else:
pinyin_ids.append([0] * 8)
return pinyin_ids
|