Upload bert_tokenizer.py
Browse files- bert_tokenizer.py +91 -2
bert_tokenizer.py
CHANGED
@@ -2,20 +2,24 @@ import json
|
|
2 |
import os
|
3 |
import time
|
4 |
from pathlib import Path
|
5 |
-
from
|
|
|
6 |
|
7 |
import tokenizers
|
8 |
import torch
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from huggingface_hub.file_download import http_user_agent
|
11 |
from pypinyin import pinyin, Style
|
|
|
|
|
|
|
12 |
|
13 |
try:
|
14 |
from tokenizers import BertWordPieceTokenizer
|
15 |
except:
|
16 |
from tokenizers.implementations import BertWordPieceTokenizer
|
17 |
|
18 |
-
from transformers import BertTokenizerFast
|
19 |
|
20 |
cache_path = Path(os.path.abspath(__file__)).parent
|
21 |
|
@@ -60,6 +64,64 @@ class ChineseBertTokenizer(BertTokenizerFast):
|
|
60 |
with open(config_path / 'pinyin2tensor.json', encoding='utf8') as fin:
|
61 |
self.pinyin2tensor = json.load(fin)
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def tokenize_sentence(self, sentence):
|
64 |
# convert sentence to ids
|
65 |
tokenizer_output = self.tokenizer.encode(sentence)
|
@@ -73,6 +135,33 @@ class ChineseBertTokenizer(BertTokenizerFast):
|
|
73 |
pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
|
74 |
return input_ids, pinyin_ids
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]:
|
77 |
# get pinyin of a sentence
|
78 |
pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
|
|
|
2 |
import os
|
3 |
import time
|
4 |
from pathlib import Path
|
5 |
+
from types import NoneType
|
6 |
+
from typing import List, Union, Optional
|
7 |
|
8 |
import tokenizers
|
9 |
import torch
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from huggingface_hub.file_download import http_user_agent
|
12 |
from pypinyin import pinyin, Style
|
13 |
+
from transformers.tokenization_utils_base import TruncationStrategy
|
14 |
+
from transformers.utils import PaddingStrategy
|
15 |
+
from transformers.utils.generic import TensorType
|
16 |
|
17 |
try:
|
18 |
from tokenizers import BertWordPieceTokenizer
|
19 |
except:
|
20 |
from tokenizers.implementations import BertWordPieceTokenizer
|
21 |
|
22 |
+
from transformers import BertTokenizerFast, BatchEncoding
|
23 |
|
24 |
cache_path = Path(os.path.abspath(__file__)).parent
|
25 |
|
|
|
64 |
with open(config_path / 'pinyin2tensor.json', encoding='utf8') as fin:
|
65 |
self.pinyin2tensor = json.load(fin)
|
66 |
|
67 |
+
def __call__(self,
|
68 |
+
text: Union[str, List[str], List[List[str]]] = None,
|
69 |
+
text_pair: Union[str, List[str], List[List[str]], NoneType] = None,
|
70 |
+
text_target: Union[str, List[str], List[List[str]]] = None,
|
71 |
+
text_pair_target: Union[str, List[str], List[List[str]], NoneType] = None,
|
72 |
+
add_special_tokens: bool = True,
|
73 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
74 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
75 |
+
max_length: Optional[int] = None,
|
76 |
+
stride: int = 0,
|
77 |
+
is_split_into_words: bool = False,
|
78 |
+
pad_to_multiple_of: Optional[int] = None,
|
79 |
+
return_tensors: Union[str, TensorType, NoneType] = None,
|
80 |
+
return_token_type_ids: Optional[bool] = None,
|
81 |
+
return_attention_mask: Optional[bool] = None,
|
82 |
+
return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False,
|
83 |
+
return_offsets_mapping: bool = False,
|
84 |
+
return_length: bool = False,
|
85 |
+
verbose: bool = True, **kwargs) -> BatchEncoding:
|
86 |
+
encoding = super.__call__(
|
87 |
+
text=text,
|
88 |
+
text_pair=text_pair,
|
89 |
+
text_target=text_target,
|
90 |
+
text_pair_target=text_pair_target,
|
91 |
+
add_special_tokens=add_special_tokens,
|
92 |
+
padding=padding,
|
93 |
+
truncation=truncation,
|
94 |
+
max_length=max_length,
|
95 |
+
stride=stride,
|
96 |
+
is_split_into_words=is_split_into_words,
|
97 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
98 |
+
return_tensors=return_tensors,
|
99 |
+
return_token_type_ids=return_token_type_ids,
|
100 |
+
return_attention_mask=return_attention_mask,
|
101 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
102 |
+
return_offsets_mapping=return_offsets_mapping,
|
103 |
+
return_length=return_length,
|
104 |
+
verbose=verbose,
|
105 |
+
)
|
106 |
+
|
107 |
+
input_ids = encoding.input_ids
|
108 |
+
|
109 |
+
pinyin_ids = None
|
110 |
+
if type(text) == str:
|
111 |
+
pinyin_ids = self.convert_ids_to_pinyin_ids(input_ids)
|
112 |
+
|
113 |
+
if type(text) == list:
|
114 |
+
pinyin_ids = []
|
115 |
+
for ids in input_ids:
|
116 |
+
pinyin_ids.append(self.convert_ids_to_pinyin_ids(ids))
|
117 |
+
|
118 |
+
if torch.is_tensor(encoding.input_ids):
|
119 |
+
pinyin_ids = torch.LongTensor(pinyin_ids)
|
120 |
+
|
121 |
+
encoding['pinyin_ids'] = pinyin_ids
|
122 |
+
|
123 |
+
return encoding
|
124 |
+
|
125 |
def tokenize_sentence(self, sentence):
|
126 |
# convert sentence to ids
|
127 |
tokenizer_output = self.tokenizer.encode(sentence)
|
|
|
135 |
pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
|
136 |
return input_ids, pinyin_ids
|
137 |
|
138 |
+
def convert_ids_to_pinyin_ids(self, ids: List[str]):
|
139 |
+
pinyin_ids = []
|
140 |
+
tokens = self.convert_tokens_to_ids(ids)
|
141 |
+
for token in tokens:
|
142 |
+
if len(token) > 1:
|
143 |
+
pinyin_ids.append([0] * 8)
|
144 |
+
continue
|
145 |
+
|
146 |
+
pinyin_string = pinyin(token, style=Style.TONE3, errors=lambda x: [['not chinese'] for _ in x])[0]
|
147 |
+
|
148 |
+
if pinyin_string == "not chinese":
|
149 |
+
pinyin_ids.append([0] * 8)
|
150 |
+
continue
|
151 |
+
|
152 |
+
if pinyin_string in self.pinyin2tensor:
|
153 |
+
pinyin_ids.append(self.pinyin2tensor[pinyin_string])
|
154 |
+
else:
|
155 |
+
ids = [0] * 8
|
156 |
+
for i, p in enumerate(pinyin_string):
|
157 |
+
if p not in self.pinyin_dict["char2idx"]:
|
158 |
+
ids = [0] * 8
|
159 |
+
break
|
160 |
+
ids[i] = self.pinyin_dict["char2idx"][p]
|
161 |
+
pinyin_ids.append(pinyin_ids)
|
162 |
+
|
163 |
+
return pinyin_ids
|
164 |
+
|
165 |
def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]:
|
166 |
# get pinyin of a sentence
|
167 |
pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
|