jinhai-2012 commited on
Commit
241b23e
·
1 Parent(s): 0a519de

Refactor trie load and construct (#4083)

Browse files

### What problem does this PR solve?

1. Fix initial build and load trie
2. Update comment

### Type of change

- [x] Refactoring

Signed-off-by: jinhai <[email protected]>

Files changed (1) hide show
  1. rag/nlp/rag_tokenizer.py +23 -10
rag/nlp/rag_tokenizer.py CHANGED
@@ -36,7 +36,7 @@ class RagTokenizer:
36
  return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
37
 
38
  def loadDict_(self, fnm):
39
- logging.info(f"[HUQIE]:Build trie {fnm}")
40
  try:
41
  of = open(fnm, "r", encoding='utf-8')
42
  while True:
@@ -50,7 +50,10 @@ class RagTokenizer:
50
  if k not in self.trie_ or self.trie_[k][0] < F:
51
  self.trie_[self.key_(line[0])] = (F, line[2])
52
  self.trie_[self.rkey_(line[0])] = 1
53
- self.trie_.save(fnm + ".trie")
 
 
 
54
  of.close()
55
  except Exception:
56
  logging.exception(f"[HUQIE]:Build trie {fnm} failed")
@@ -58,20 +61,30 @@ class RagTokenizer:
58
  def __init__(self, debug=False):
59
  self.DEBUG = debug
60
  self.DENOMINATOR = 1000000
61
- self.trie_ = datrie.Trie(string.printable)
62
  self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
63
 
64
  self.stemmer = PorterStemmer()
65
  self.lemmatizer = WordNetLemmatizer()
66
 
67
  self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
68
- try:
69
- self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
70
- return
71
- except Exception:
72
- logging.exception("[HUQIE]:Build default trie")
 
 
 
 
 
 
 
 
 
 
73
  self.trie_ = datrie.Trie(string.printable)
74
 
 
75
  self.loadDict_(self.DIR_ + ".txt")
76
 
77
  def loadUserDict(self, fnm):
@@ -86,7 +99,7 @@ class RagTokenizer:
86
  self.loadDict_(fnm)
87
 
88
  def _strQ2B(self, ustring):
89
- """把字符串全角转半角"""
90
  rstring = ""
91
  for uchar in ustring:
92
  inside_code = ord(uchar)
@@ -94,7 +107,7 @@ class RagTokenizer:
94
  inside_code = 0x0020
95
  else:
96
  inside_code -= 0xfee0
97
- if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
98
  rstring += uchar
99
  else:
100
  rstring += chr(inside_code)
 
36
  return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
37
 
38
  def loadDict_(self, fnm):
39
+ logging.info(f"[HUQIE]:Build trie from {fnm}")
40
  try:
41
  of = open(fnm, "r", encoding='utf-8')
42
  while True:
 
50
  if k not in self.trie_ or self.trie_[k][0] < F:
51
  self.trie_[self.key_(line[0])] = (F, line[2])
52
  self.trie_[self.rkey_(line[0])] = 1
53
+
54
+ dict_file_cache = fnm + ".trie"
55
+ logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}")
56
+ self.trie_.save(dict_file_cache)
57
  of.close()
58
  except Exception:
59
  logging.exception(f"[HUQIE]:Build trie {fnm} failed")
 
61
  def __init__(self, debug=False):
62
  self.DEBUG = debug
63
  self.DENOMINATOR = 1000000
 
64
  self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
65
 
66
  self.stemmer = PorterStemmer()
67
  self.lemmatizer = WordNetLemmatizer()
68
 
69
  self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
70
+
71
+ trie_file_name = self.DIR_ + ".txt.trie"
72
+ # check if trie file existence
73
+ if os.path.exists(trie_file_name):
74
+ try:
75
+ # load trie from file
76
+ self.trie_ = datrie.Trie.load(trie_file_name)
77
+ return
78
+ except Exception:
79
+ # fail to load trie from file, build default trie
80
+ logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
81
+ self.trie_ = datrie.Trie(string.printable)
82
+ else:
83
+ # file not exist, build default trie
84
+ logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
85
  self.trie_ = datrie.Trie(string.printable)
86
 
87
+ # load data from dict file and save to trie file
88
  self.loadDict_(self.DIR_ + ".txt")
89
 
90
  def loadUserDict(self, fnm):
 
99
  self.loadDict_(fnm)
100
 
101
  def _strQ2B(self, ustring):
102
+ """Convert full-width characters to half-width characters"""
103
  rstring = ""
104
  for uchar in ustring:
105
  inside_code = ord(uchar)
 
107
  inside_code = 0x0020
108
  else:
109
  inside_code -= 0xfee0
110
+ if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
111
  rstring += uchar
112
  else:
113
  rstring += chr(inside_code)