iioSnail commited on
Commit
24dc275
·
1 Parent(s): 8135e8c

Upload bert_tokenizer.py

Browse files
Files changed (1) hide show
  1. bert_tokenizer.py +91 -2
bert_tokenizer.py CHANGED
@@ -2,20 +2,24 @@ import json
2
  import os
3
  import time
4
  from pathlib import Path
5
- from typing import List
 
6
 
7
  import tokenizers
8
  import torch
9
  from huggingface_hub import hf_hub_download
10
  from huggingface_hub.file_download import http_user_agent
11
  from pypinyin import pinyin, Style
 
 
 
12
 
13
  try:
14
  from tokenizers import BertWordPieceTokenizer
15
  except:
16
  from tokenizers.implementations import BertWordPieceTokenizer
17
 
18
- from transformers import BertTokenizerFast
19
 
20
  cache_path = Path(os.path.abspath(__file__)).parent
21
 
@@ -60,6 +64,64 @@ class ChineseBertTokenizer(BertTokenizerFast):
60
  with open(config_path / 'pinyin2tensor.json', encoding='utf8') as fin:
61
  self.pinyin2tensor = json.load(fin)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def tokenize_sentence(self, sentence):
64
  # convert sentence to ids
65
  tokenizer_output = self.tokenizer.encode(sentence)
@@ -73,6 +135,33 @@ class ChineseBertTokenizer(BertTokenizerFast):
73
  pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
74
  return input_ids, pinyin_ids
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]:
77
  # get pinyin of a sentence
78
  pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
 
2
  import os
3
  import time
4
  from pathlib import Path
5
+ from types import NoneType
6
+ from typing import List, Union, Optional
7
 
8
  import tokenizers
9
  import torch
10
  from huggingface_hub import hf_hub_download
11
  from huggingface_hub.file_download import http_user_agent
12
  from pypinyin import pinyin, Style
13
+ from transformers.tokenization_utils_base import TruncationStrategy
14
+ from transformers.utils import PaddingStrategy
15
+ from transformers.utils.generic import TensorType
16
 
17
  try:
18
  from tokenizers import BertWordPieceTokenizer
19
  except:
20
  from tokenizers.implementations import BertWordPieceTokenizer
21
 
22
+ from transformers import BertTokenizerFast, BatchEncoding
23
 
24
  cache_path = Path(os.path.abspath(__file__)).parent
25
 
 
64
  with open(config_path / 'pinyin2tensor.json', encoding='utf8') as fin:
65
  self.pinyin2tensor = json.load(fin)
66
 
67
+ def __call__(self,
68
+ text: Union[str, List[str], List[List[str]]] = None,
69
+ text_pair: Union[str, List[str], List[List[str]], NoneType] = None,
70
+ text_target: Union[str, List[str], List[List[str]]] = None,
71
+ text_pair_target: Union[str, List[str], List[List[str]], NoneType] = None,
72
+ add_special_tokens: bool = True,
73
+ padding: Union[bool, str, PaddingStrategy] = False,
74
+ truncation: Union[bool, str, TruncationStrategy] = None,
75
+ max_length: Optional[int] = None,
76
+ stride: int = 0,
77
+ is_split_into_words: bool = False,
78
+ pad_to_multiple_of: Optional[int] = None,
79
+ return_tensors: Union[str, TensorType, NoneType] = None,
80
+ return_token_type_ids: Optional[bool] = None,
81
+ return_attention_mask: Optional[bool] = None,
82
+ return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False,
83
+ return_offsets_mapping: bool = False,
84
+ return_length: bool = False,
85
+ verbose: bool = True, **kwargs) -> BatchEncoding:
86
+ encoding = super.__call__(
87
+ text=text,
88
+ text_pair=text_pair,
89
+ text_target=text_target,
90
+ text_pair_target=text_pair_target,
91
+ add_special_tokens=add_special_tokens,
92
+ padding=padding,
93
+ truncation=truncation,
94
+ max_length=max_length,
95
+ stride=stride,
96
+ is_split_into_words=is_split_into_words,
97
+ pad_to_multiple_of=pad_to_multiple_of,
98
+ return_tensors=return_tensors,
99
+ return_token_type_ids=return_token_type_ids,
100
+ return_attention_mask=return_attention_mask,
101
+ return_overflowing_tokens=return_overflowing_tokens,
102
+ return_offsets_mapping=return_offsets_mapping,
103
+ return_length=return_length,
104
+ verbose=verbose,
105
+ )
106
+
107
+ input_ids = encoding.input_ids
108
+
109
+ pinyin_ids = None
110
+ if type(text) == str:
111
+ pinyin_ids = self.convert_ids_to_pinyin_ids(input_ids)
112
+
113
+ if type(text) == list:
114
+ pinyin_ids = []
115
+ for ids in input_ids:
116
+ pinyin_ids.append(self.convert_ids_to_pinyin_ids(ids))
117
+
118
+ if torch.is_tensor(encoding.input_ids):
119
+ pinyin_ids = torch.LongTensor(pinyin_ids)
120
+
121
+ encoding['pinyin_ids'] = pinyin_ids
122
+
123
+ return encoding
124
+
125
  def tokenize_sentence(self, sentence):
126
  # convert sentence to ids
127
  tokenizer_output = self.tokenizer.encode(sentence)
 
135
  pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
136
  return input_ids, pinyin_ids
137
 
138
+ def convert_ids_to_pinyin_ids(self, ids: List[str]):
139
+ pinyin_ids = []
140
+ tokens = self.convert_tokens_to_ids(ids)
141
+ for token in tokens:
142
+ if len(token) > 1:
143
+ pinyin_ids.append([0] * 8)
144
+ continue
145
+
146
+ pinyin_string = pinyin(token, style=Style.TONE3, errors=lambda x: [['not chinese'] for _ in x])[0]
147
+
148
+ if pinyin_string == "not chinese":
149
+ pinyin_ids.append([0] * 8)
150
+ continue
151
+
152
+ if pinyin_string in self.pinyin2tensor:
153
+ pinyin_ids.append(self.pinyin2tensor[pinyin_string])
154
+ else:
155
+ ids = [0] * 8
156
+ for i, p in enumerate(pinyin_string):
157
+ if p not in self.pinyin_dict["char2idx"]:
158
+ ids = [0] * 8
159
+ break
160
+ ids[i] = self.pinyin_dict["char2idx"][p]
161
+ pinyin_ids.append(pinyin_ids)
162
+
163
+ return pinyin_ids
164
+
165
  def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]:
166
  # get pinyin of a sentence
167
  pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])