iioSnail commited on
Commit
3217624
·
1 Parent(s): de3a7d8

Upload bert_tokenizer.py

Browse files
Files changed (1) hide show
  1. bert_tokenizer.py +17 -7
bert_tokenizer.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import os
 
3
  from typing import List
4
 
5
  import requests
@@ -14,6 +15,8 @@ except:
14
 
15
  from transformers import BertTokenizerFast
16
 
 
 
17
  SOURCE_FILES_URL = {
18
  "vocab.txt": "https://huggingface.co/iioSnail/chinesebert-base/blob/main/vocab.txt",
19
  "pinyin_map.json": "https://huggingface.co/iioSnail/chinesebert-base/blob/main/config/pinyin_map.json",
@@ -22,13 +25,21 @@ SOURCE_FILES_URL = {
22
  }
23
 
24
 
25
- def download_file(url, filename):
26
  if os.path.exists(filename):
27
  return
28
 
29
- res = requests.get(url)
30
- with open(filename, 'wb') as file:
31
- file.write(res.content)
 
 
 
 
 
 
 
 
32
 
33
 
34
  class ChineseBertTokenizer(BertTokenizerFast):
@@ -36,9 +47,8 @@ class ChineseBertTokenizer(BertTokenizerFast):
36
  def __init__(self, **kwargs):
37
  super(ChineseBertTokenizer, self).__init__(**kwargs)
38
 
39
- bert_path = self.name_or_path
40
- vocab_file = os.path.join(bert_path, 'vocab.txt')
41
- config_path = os.path.join(bert_path, 'config')
42
  self.max_length = 512
43
 
44
  download_file(SOURCE_FILES_URL["vocab.txt"], vocab_file)
 
1
  import json
2
  import os
3
+ from pathlib import Path
4
  from typing import List
5
 
6
  import requests
 
15
 
16
  from transformers import BertTokenizerFast
17
 
18
+ cache_path = Path(os.path.abspath(__file__)).parent
19
+
20
  SOURCE_FILES_URL = {
21
  "vocab.txt": "https://huggingface.co/iioSnail/chinesebert-base/blob/main/vocab.txt",
22
  "pinyin_map.json": "https://huggingface.co/iioSnail/chinesebert-base/blob/main/config/pinyin_map.json",
 
25
  }
26
 
27
 
28
+ def download_file(url, filename: str):
29
  if os.path.exists(filename):
30
  return
31
 
32
+ headers = {
33
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/95.0.4638.54 Safari/537.36"
34
+ }
35
+ try:
36
+ res = requests.get(url, headers=headers)
37
+ res.raise_for_status()
38
+ with open(filename, 'wb') as file:
39
+ file.write(res.content)
40
+ except:
41
+ raise RuntimeError("Error download the file of '" + filename +
42
+ "'. You can download the model file into the current directory and rerun it.")
43
 
44
 
45
  class ChineseBertTokenizer(BertTokenizerFast):
 
47
  def __init__(self, **kwargs):
48
  super(ChineseBertTokenizer, self).__init__(**kwargs)
49
 
50
+ vocab_file = os.path.join(cache_path, 'vocab.txt')
51
+ config_path = os.path.join(cache_path, 'config')
 
52
  self.max_length = 512
53
 
54
  download_file(SOURCE_FILES_URL["vocab.txt"], vocab_file)