File size: 4,355 Bytes
601e637
 
a30567c
3217624
601e637
 
 
 
a30567c
 
601e637
 
 
 
 
 
 
 
 
3217624
 
de3a7d8
f72ad04
 
 
 
de3a7d8
 
 
a30567c
 
de3a7d8
 
a30567c
 
 
 
 
 
 
de3a7d8
601e637
 
 
 
 
 
3217624
 
601e637
de3a7d8
a30567c
601e637
 
 
a30567c
601e637
 
de3a7d8
601e637
a30567c
601e637
 
de3a7d8
601e637
a30567c
601e637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import os
import time
from pathlib import Path
from typing import List

import tokenizers
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.file_download import http_user_agent
from pypinyin import pinyin, Style

try:
    from tokenizers import BertWordPieceTokenizer
except:
    from tokenizers.implementations import BertWordPieceTokenizer

from transformers import BertTokenizerFast

cache_path = Path(os.path.abspath(__file__)).parent

SOURCE_FILES_URL = {
    "vocab.txt": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/vocab.txt",
    "pinyin_map.json": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/config/pinyin_map.json",
    "id2pinyin.json": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/config/id2pinyin.json",
    "pinyin2tensor.json": "https://huggingface.co/iioSnail/chinesebert-base/resolve/main/config/id2pinyin.json",
}


def download_file(filename: str):
    if os.path.exists(cache_path / filename):
        return

    hf_hub_download(
        "iioSnail/chinesebert-base",
        filename,
        cache_dir=cache_path,
        user_agent=http_user_agent(None),
    )
    time.sleep(0.2)


class ChineseBertTokenizer(BertTokenizerFast):

    def __init__(self, **kwargs):
        super(ChineseBertTokenizer, self).__init__(**kwargs)

        vocab_file = os.path.join(cache_path, 'vocab.txt')
        config_path = os.path.join(cache_path, 'config')
        self.max_length = 512

        download_file('vocab.txt')
        self.tokenizer = BertWordPieceTokenizer(vocab_file)

        # load pinyin map dict
        download_file('config/pinyin_map.json')
        with open(os.path.join(config_path, 'pinyin_map.json'), encoding='utf8') as fin:
            self.pinyin_dict = json.load(fin)

        # load char id map tensor
        download_file('config/id2pinyin.json')
        with open(os.path.join(config_path, 'id2pinyin.json'), encoding='utf8') as fin:
            self.id2pinyin = json.load(fin)

        # load pinyin map tensor
        download_file('config/pinyin2tensor.json')
        with open(os.path.join(config_path, 'pinyin2tensor.json'), encoding='utf8') as fin:
            self.pinyin2tensor = json.load(fin)

    def tokenize_sentence(self, sentence):
        # convert sentence to ids
        tokenizer_output = self.tokenizer.encode(sentence)
        bert_tokens = tokenizer_output.ids
        pinyin_tokens = self.convert_sentence_to_pinyin_ids(sentence, tokenizer_output)
        # assert,token nums should be same as pinyin token nums
        assert len(bert_tokens) <= self.max_length
        assert len(bert_tokens) == len(pinyin_tokens)
        # convert list to tensor
        input_ids = torch.LongTensor(bert_tokens)
        pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
        return input_ids, pinyin_ids

    def convert_sentence_to_pinyin_ids(self, sentence: str, tokenizer_output: tokenizers.Encoding) -> List[List[int]]:
        # get pinyin of a sentence
        pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
        pinyin_locs = {}
        # get pinyin of each location
        for index, item in enumerate(pinyin_list):
            pinyin_string = item[0]
            # not a Chinese character, pass
            if pinyin_string == "not chinese":
                continue
            if pinyin_string in self.pinyin2tensor:
                pinyin_locs[index] = self.pinyin2tensor[pinyin_string]
            else:
                ids = [0] * 8
                for i, p in enumerate(pinyin_string):
                    if p not in self.pinyin_dict["char2idx"]:
                        ids = [0] * 8
                        break
                    ids[i] = self.pinyin_dict["char2idx"][p]
                pinyin_locs[index] = ids

        # find chinese character location, and generate pinyin ids
        pinyin_ids = []
        for idx, (token, offset) in enumerate(zip(tokenizer_output.tokens, tokenizer_output.offsets)):
            if offset[1] - offset[0] != 1:
                pinyin_ids.append([0] * 8)
                continue
            if offset[0] in pinyin_locs:
                pinyin_ids.append(pinyin_locs[offset[0]])
            else:
                pinyin_ids.append([0] * 8)

        return pinyin_ids