Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,389 Bytes
568e264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2023 Tsinghua Univ. (authors: Xingchen Song)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def tokenize_by_bpe_model(sp, txt):
return _tokenize_by_seg_dic_or_bpe_model(txt, sp=sp, upper=True)
def tokenize_by_seg_dict(seg_dict, txt):
return _tokenize_by_seg_dic_or_bpe_model(txt,
seg_dict=seg_dict,
upper=False)
def _tokenize_by_seg_dic_or_bpe_model(
txt,
sp=None,
seg_dict=None,
upper=True,
):
if sp is None:
assert seg_dict is not None
if seg_dict is None:
assert sp is not None
tokens = []
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r'([\u4e00-\u9fff])')
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(txt.upper() if upper else txt)
mix_chars = [w for w in chars if len(w.strip()) > 0]
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
tokens.append(ch_or_w)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
if sp is not None:
for p in sp.encode_as_pieces(ch_or_w):
tokens.append(p)
else:
for en_token in ch_or_w.split():
en_token = en_token.strip()
if en_token in seg_dict:
tokens.extend(seg_dict[en_token].split(' '))
else:
tokens.append(en_token)
return tokens
|