iioSnail commited on
Commit
de3a7d8
·
1 Parent(s): 601e637

Upload bert_tokenizer.py

Browse files
Files changed (1) hide show
  1. bert_tokenizer.py +25 -3
bert_tokenizer.py CHANGED
@@ -1,8 +1,8 @@
1
  import json
2
  import os
3
- from pathlib import Path
4
  from typing import List
5
 
 
6
  import tokenizers
7
  import torch
8
  from pypinyin import pinyin, Style
@@ -14,26 +14,48 @@ except:
14
 
15
  from transformers import BertTokenizerFast
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class ChineseBertTokenizer(BertTokenizerFast):
19
 
20
  def __init__(self, **kwargs):
21
  super(ChineseBertTokenizer, self).__init__(**kwargs)
22
 
23
- bert_path = Path(os.path.abspath(__file__)).parent
24
- print("bert_path", bert_path)
25
  vocab_file = os.path.join(bert_path, 'vocab.txt')
26
  config_path = os.path.join(bert_path, 'config')
27
  self.max_length = 512
 
 
28
  self.tokenizer = BertWordPieceTokenizer(vocab_file)
29
 
30
  # load pinyin map dict
 
31
  with open(os.path.join(config_path, 'pinyin_map.json'), encoding='utf8') as fin:
32
  self.pinyin_dict = json.load(fin)
 
33
  # load char id map tensor
 
34
  with open(os.path.join(config_path, 'id2pinyin.json'), encoding='utf8') as fin:
35
  self.id2pinyin = json.load(fin)
 
36
  # load pinyin map tensor
 
37
  with open(os.path.join(config_path, 'pinyin2tensor.json'), encoding='utf8') as fin:
38
  self.pinyin2tensor = json.load(fin)
39
 
 
1
  import json
2
  import os
 
3
  from typing import List
4
 
5
+ import requests
6
  import tokenizers
7
  import torch
8
  from pypinyin import pinyin, Style
 
14
 
15
  from transformers import BertTokenizerFast
16
 
17
+ SOURCE_FILES_URL = {
18
+ "vocab.txt": "https://huggingface.co/iioSnail/chinesebert-base/blob/main/vocab.txt",
19
+ "pinyin_map.json": "https://huggingface.co/iioSnail/chinesebert-base/blob/main/config/pinyin_map.json",
20
+ "id2pinyin.json": "https://huggingface.co/iioSnail/chinesebert-base/blob/main/config/id2pinyin.json",
21
+ "pinyin2tensor.json": "https://huggingface.co/iioSnail/chinesebert-base/blob/main/config/id2pinyin.json",
22
+ }
23
+
24
+
25
+ def download_file(url, filename):
26
+ if os.path.exists(filename):
27
+ return
28
+
29
+ res = requests.get(url)
30
+ with open(filename, 'wb') as file:
31
+ file.write(res.content)
32
+
33
 
34
  class ChineseBertTokenizer(BertTokenizerFast):
35
 
36
  def __init__(self, **kwargs):
37
  super(ChineseBertTokenizer, self).__init__(**kwargs)
38
 
39
+ bert_path = self.name_or_path
 
40
  vocab_file = os.path.join(bert_path, 'vocab.txt')
41
  config_path = os.path.join(bert_path, 'config')
42
  self.max_length = 512
43
+
44
+ download_file(SOURCE_FILES_URL["vocab.txt"], vocab_file)
45
  self.tokenizer = BertWordPieceTokenizer(vocab_file)
46
 
47
  # load pinyin map dict
48
+ download_file(SOURCE_FILES_URL["pinyin_map.json"], os.path.join(config_path, 'pinyin_map.json'))
49
  with open(os.path.join(config_path, 'pinyin_map.json'), encoding='utf8') as fin:
50
  self.pinyin_dict = json.load(fin)
51
+
52
  # load char id map tensor
53
+ download_file(SOURCE_FILES_URL["id2pinyin.json"], os.path.join(config_path, 'id2pinyin.json'))
54
  with open(os.path.join(config_path, 'id2pinyin.json'), encoding='utf8') as fin:
55
  self.id2pinyin = json.load(fin)
56
+
57
  # load pinyin map tensor
58
+ download_file(SOURCE_FILES_URL["pinyin2tensor.json"], os.path.join(config_path, 'pinyin2tensor.json'))
59
  with open(os.path.join(config_path, 'pinyin2tensor.json'), encoding='utf8') as fin:
60
  self.pinyin2tensor = json.load(fin)
61