Tuchuanhuhuhu commited on
Commit
0a2de58
·
1 Parent(s): 7dbc9ca

加快了加载索引的速度

Browse files
Files changed (2) hide show
  1. modules/chat_func.py +1 -1
  2. modules/llama_func.py +10 -6
modules/chat_func.py CHANGED
@@ -272,7 +272,7 @@ def predict(
272
  if reply_language == "跟随问题语言(不稳定)":
273
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
274
  if files:
275
- msg = "构建索引中……(这可能需要比较久的时间)"
276
  logging.info(msg)
277
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
278
  index = construct_index(openai_api_key, file_src=files)
 
272
  if reply_language == "跟随问题语言(不稳定)":
273
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
274
  if files:
275
+ msg = "加载索引中……(这可能需要几分钟)"
276
  logging.info(msg)
277
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
278
  index = construct_index(openai_api_key, file_src=files)
modules/llama_func.py CHANGED
@@ -16,15 +16,19 @@ import colorama
16
  from modules.presets import *
17
  from modules.utils import *
18
 
 
 
 
 
 
 
19
 
20
  def get_documents(file_src):
21
  documents = []
22
- index_name = ""
23
  logging.debug("Loading documents...")
24
  logging.debug(f"file_src: {file_src}")
25
  for file in file_src:
26
- logging.debug(f"file: {file.name}")
27
- index_name += file.name
28
  if os.path.splitext(file.name)[1] == ".pdf":
29
  logging.debug("Loading PDF...")
30
  CJKPDFReader = download_loader("CJKPDFReader")
@@ -46,8 +50,7 @@ def get_documents(file_src):
46
  text_raw = f.read()
47
  text = add_space(text_raw)
48
  documents += [Document(text)]
49
- index_name = sha1sum(index_name)
50
- return documents, index_name
51
 
52
 
53
  def construct_index(
@@ -78,7 +81,8 @@ def construct_index(
78
  chunk_size_limit,
79
  separator=separator,
80
  )
81
- documents, index_name = get_documents(file_src)
 
82
  if os.path.exists(f"./index/{index_name}.json"):
83
  logging.info("找到了缓存的索引文件,加载中……")
84
  return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
 
16
  from modules.presets import *
17
  from modules.utils import *
18
 
19
+ def get_index_name(file_src):
20
+ index_name = ""
21
+ for file in file_src:
22
+ index_name += os.path.basename(file.name)
23
+ index_name = sha1sum(index_name)
24
+ return index_name
25
 
26
  def get_documents(file_src):
27
  documents = []
 
28
  logging.debug("Loading documents...")
29
  logging.debug(f"file_src: {file_src}")
30
  for file in file_src:
31
+ logging.info(f"loading file: {file.name}")
 
32
  if os.path.splitext(file.name)[1] == ".pdf":
33
  logging.debug("Loading PDF...")
34
  CJKPDFReader = download_loader("CJKPDFReader")
 
50
  text_raw = f.read()
51
  text = add_space(text_raw)
52
  documents += [Document(text)]
53
+ return documents
 
54
 
55
 
56
  def construct_index(
 
81
  chunk_size_limit,
82
  separator=separator,
83
  )
84
+ index_name = get_index_name(file_src)
85
+ documents = get_documents(file_src)
86
  if os.path.exists(f"./index/{index_name}.json"):
87
  logging.info("找到了缓存的索引文件,加载中……")
88
  return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")